Override Iterator::advance(_back)_by for array::IntoIter

Because I happened to notice that `nth` is currently getting codegen'd as a loop even for `Copy` types: <https://rust.godbolt.org/z/fPqv7Gvs7>
This commit is contained in:
Scott McMurray 2021-12-03 21:36:51 -08:00
parent 2a9e0831d6
commit eb846dbaca
2 changed files with 149 additions and 1 deletions

View file

@ -1,7 +1,7 @@
//! Defines the `IntoIter` owned iterator for arrays.
use crate::{
fmt,
cmp, fmt,
iter::{self, ExactSizeIterator, FusedIterator, TrustedLen},
mem::{self, MaybeUninit},
ops::Range,
@ -150,6 +150,27 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
fn last(mut self) -> Option<Self::Item> {
self.next_back()
}
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let len = self.len();
// The number of elements to drop. Always in-bounds by construction.
let delta = cmp::min(n, len);
let range_to_drop = self.alive.start..(self.alive.start + delta);
// Moving the start marks them as conceptually "dropped", so if anything
// goes bad then our drop impl won't double-free them.
self.alive.start += delta;
// SAFETY: These elements are currently initialized, so it's fine to drop them.
unsafe {
let slice = self.data.get_unchecked_mut(range_to_drop);
ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(slice));
}
if n > len { Err(len) } else { Ok(()) }
}
}
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
@ -170,6 +191,27 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
unsafe { self.data.get_unchecked(idx).assume_init_read() }
})
}
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let len = self.len();
// The number of elements to drop. Always in-bounds by construction.
let delta = cmp::min(n, len);
let range_to_drop = (self.alive.end - delta)..self.alive.end;
// Moving the end marks them as conceptually "dropped", so if anything
// goes bad then our drop impl won't double-free them.
self.alive.end -= delta;
// SAFETY: These elements are currently initialized, so it's fine to drop them.
unsafe {
let slice = self.data.get_unchecked_mut(range_to_drop);
ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(slice));
}
if n > len { Err(len) } else { Ok(()) }
}
}
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]

View file

@ -474,3 +474,109 @@ fn array_split_array_mut_out_of_bounds() {
v.split_array_mut::<7>();
}
#[test]
fn array_intoiter_advance_by() {
use std::cell::Cell;
struct DropCounter<'a>(usize, &'a Cell<usize>);
impl Drop for DropCounter<'_> {
fn drop(&mut self) {
let x = self.1.get();
self.1.set(x + 1);
}
}
let counter = Cell::new(0);
let a: [_; 100] = std::array::from_fn(|i| DropCounter(i, &counter));
let mut it = IntoIterator::into_iter(a);
let r = it.advance_by(1);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 99);
assert_eq!(counter.get(), 1);
let r = it.advance_by(0);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 99);
assert_eq!(counter.get(), 1);
let r = it.advance_by(11);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 88);
assert_eq!(counter.get(), 12);
let x = it.next();
assert_eq!(x.as_ref().map(|x| x.0), Some(12));
assert_eq!(it.len(), 87);
assert_eq!(counter.get(), 12);
drop(x);
assert_eq!(counter.get(), 13);
let r = it.advance_by(123456);
assert_eq!(r, Err(87));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
let r = it.advance_by(0);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
let r = it.advance_by(10);
assert_eq!(r, Err(0));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
}
#[test]
fn array_intoiter_advance_back_by() {
use std::cell::Cell;
struct DropCounter<'a>(usize, &'a Cell<usize>);
impl Drop for DropCounter<'_> {
fn drop(&mut self) {
let x = self.1.get();
self.1.set(x + 1);
}
}
let counter = Cell::new(0);
let a: [_; 100] = std::array::from_fn(|i| DropCounter(i, &counter));
let mut it = IntoIterator::into_iter(a);
let r = it.advance_back_by(1);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 99);
assert_eq!(counter.get(), 1);
let r = it.advance_back_by(0);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 99);
assert_eq!(counter.get(), 1);
let r = it.advance_back_by(11);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 88);
assert_eq!(counter.get(), 12);
let x = it.next_back();
assert_eq!(x.as_ref().map(|x| x.0), Some(87));
assert_eq!(it.len(), 87);
assert_eq!(counter.get(), 12);
drop(x);
assert_eq!(counter.get(), 13);
let r = it.advance_back_by(123456);
assert_eq!(r, Err(87));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
let r = it.advance_back_by(0);
assert_eq!(r, Ok(()));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
let r = it.advance_back_by(10);
assert_eq!(r, Err(0));
assert_eq!(it.len(), 0);
assert_eq!(counter.get(), 100);
}