Auto merge of #86988 - thomcc:chunky-splitz-says-no-checking, r=the8472

Carefully remove bounds checks from some chunk iterator functions

So, I was writing code that requires the equivalent of `rchunks(N).rev()` (which isn't the same as forward `chunks(N)` — in particular, if the buffer length is not a multiple of `N`, I must handle the "remainder" first).

I happened to look at the codegen output of the function (I was actually interested in whether or not a nested loop was being unrolled — it was), and noticed that in the outer `rchunks(n).rev()` loop, LLVM seemed to be unable to remove the bounds checks from the iteration: https://rust.godbolt.org/z/Tnz4MYY8f (this panic was from the split_at in `RChunks::next_back`).

After doing some experimentation, it seems all of the `next_back` in the non-exact chunk iterators have the issue: (`Chunks::next_back`, `RChunks::next_back`, `ChunksMut::next_back`, and `RChunksMut::next_back`)...

Even worse, the forward `rchunks` iterators sometimes have the issue as well (... but only sometimes). For example https://rust.godbolt.org/z/oGhbqv53r has bounds checks, but if I uncomment the loop body, it manages to remove the check (which is bizarre, since I'd expect the opposite...). I suspect it's highly dependent on the surrounding code, so I decided to remove the bounds checks from them anyway. Overall, this change includes:
- All `next_back` functions on the non-`Exact` iterators (e.g. `R?Chunks(Mut)?`).
- All `next` functions on the non-exact rchunks iterators (e.g. `RChunks(Mut)?`).

I wasn't able to catch any of the other chunk iterators failing to remove the bounds checks (I checked iterations over `r?chunks(_exact)?(_mut)?` with constant chunk sizes under `-O3`, `-Os`, and `-Oz`), which makes sense, since these were the cases where it was harder to prove the bounds check correct to remove...

In fact, it took quite a bit of thinking to convince myself that using unchecked_ here was valid — so I'm not really surprised that LLVM had trouble (although compilers are slightly better at this sort of reasoning than humans). A consequence of that is the fact that the `// SAFETY` comment for these are... kinda long...

---

I didn't do this for, or even think about it for, any of the other iteration methods; just `next` and `next_back` (where it mattered). If this PR is accepted, I'll file a follow up for someone (possibly me) to look at the others later (in particular, `nth`/`nth_back` looked like they had similar logic), but I wanted to do this now, as IMO `next`/`next_back` are the most important here, since they're what gets used by the iteration protocol.

---

Note: While I don't expect this to impact performance directly, the panic is a side effect, which would otherwise not exist in these loops. That is, this could prevent the compiler from being able to move/remove/otherwise rework a loop over these iterators (as an example, it could not delete the code for a loop whose body computes a value which doesn't get used).

Also, some like to be able to have confidence this code has no panicking branches in the optimized code, and "no bounds checks" is kinda part of the selling point of Rust's iterators anyway.
This commit is contained in:
bors 2022-02-01 10:11:59 +00:00
commit 547f2ba06b
2 changed files with 137 additions and 7 deletions

View file

@ -1476,7 +1476,21 @@ impl<'a, T> DoubleEndedIterator for Chunks<'a, T> {
} else {
let remainder = self.v.len() % self.chunk_size;
let chunksz = if remainder != 0 { remainder } else { self.chunk_size };
let (fst, snd) = self.v.split_at(self.v.len() - chunksz);
// SAFETY: split_at_unchecked requires the argument be less than or
// equal to the length. This is guaranteed, but subtle: `chunksz`
// will always either be `self.v.len() % self.chunk_size`, which
// will always evaluate to strictly less than `self.v.len()` (or
// panic, in the case that `self.chunk_size` is zero), or it can be
// `self.chunk_size`, in the case that the length is exactly
// divisible by the chunk size.
//
// While it seems like using `self.chunk_size` in this case could
// lead to a value greater than `self.v.len()`, it cannot: if
// `self.chunk_size` were greater than `self.v.len()`, then
// `self.v.len() % self.chunk_size` would return nonzero (note that
// in this branch of the `if`, we already know that `self.v` is
// non-empty).
let (fst, snd) = unsafe { self.v.split_at_unchecked(self.v.len() - chunksz) };
self.v = fst;
Some(snd)
}
@ -1641,7 +1655,8 @@ impl<'a, T> DoubleEndedIterator for ChunksMut<'a, T> {
let sz = if remainder != 0 { remainder } else { self.chunk_size };
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let (head, tail) = tmp.split_at_mut(tmp_len - sz);
// SAFETY: Similar to `Chunks::next_back`
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(tmp_len - sz) };
self.v = head;
Some(tail)
}
@ -2410,8 +2425,14 @@ impl<'a, T> Iterator for RChunks<'a, T> {
if self.v.is_empty() {
None
} else {
let chunksz = cmp::min(self.v.len(), self.chunk_size);
let (fst, snd) = self.v.split_at(self.v.len() - chunksz);
let len = self.v.len();
let chunksz = cmp::min(len, self.chunk_size);
// SAFETY: split_at_unchecked just requires the argument be less
// than the length. This could only happen if the expression `len -
// chunksz` overflows. This could only happen if `chunksz > len`,
// which is impossible as we initialize it as the `min` of `len` and
// `self.chunk_size`.
let (fst, snd) = unsafe { self.v.split_at_unchecked(len - chunksz) };
self.v = fst;
Some(snd)
}
@ -2485,7 +2506,8 @@ impl<'a, T> DoubleEndedIterator for RChunks<'a, T> {
} else {
let remainder = self.v.len() % self.chunk_size;
let chunksz = if remainder != 0 { remainder } else { self.chunk_size };
let (fst, snd) = self.v.split_at(chunksz);
// SAFETY: similar to Chunks::next_back
let (fst, snd) = unsafe { self.v.split_at_unchecked(chunksz) };
self.v = snd;
Some(fst)
}
@ -2571,7 +2593,12 @@ impl<'a, T> Iterator for RChunksMut<'a, T> {
let sz = cmp::min(self.v.len(), self.chunk_size);
let tmp = mem::replace(&mut self.v, &mut []);
let tmp_len = tmp.len();
let (head, tail) = tmp.split_at_mut(tmp_len - sz);
// SAFETY: split_at_mut_unchecked just requires the argument be less
// than the length. This could only happen if the expression
// `tmp_len - sz` overflows. This could only happen if `sz >
// tmp_len`, which is impossible as we initialize it as the `min` of
// `self.v.len()` (e.g. `tmp_len`) and `self.chunk_size`.
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(tmp_len - sz) };
self.v = head;
Some(tail)
}
@ -2649,7 +2676,8 @@ impl<'a, T> DoubleEndedIterator for RChunksMut<'a, T> {
let remainder = self.v.len() % self.chunk_size;
let sz = if remainder != 0 { remainder } else { self.chunk_size };
let tmp = mem::replace(&mut self.v, &mut []);
let (head, tail) = tmp.split_at_mut(sz);
// SAFETY: Similar to `Chunks::next_back`
let (head, tail) = unsafe { tmp.split_at_mut_unchecked(sz) };
self.v = tail;
Some(head)
}

View file

@ -251,6 +251,40 @@ fn test_chunks_nth() {
assert_eq!(c2.next(), None);
}
#[test]
fn test_chunks_next() {
let v = [0, 1, 2, 3, 4, 5];
let mut c = v.chunks(2);
assert_eq!(c.next().unwrap(), &[0, 1]);
assert_eq!(c.next().unwrap(), &[2, 3]);
assert_eq!(c.next().unwrap(), &[4, 5]);
assert_eq!(c.next(), None);
let v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.chunks(3);
assert_eq!(c.next().unwrap(), &[0, 1, 2]);
assert_eq!(c.next().unwrap(), &[3, 4, 5]);
assert_eq!(c.next().unwrap(), &[6, 7]);
assert_eq!(c.next(), None);
}
#[test]
fn test_chunks_next_back() {
let v = [0, 1, 2, 3, 4, 5];
let mut c = v.chunks(2);
assert_eq!(c.next_back().unwrap(), &[4, 5]);
assert_eq!(c.next_back().unwrap(), &[2, 3]);
assert_eq!(c.next_back().unwrap(), &[0, 1]);
assert_eq!(c.next_back(), None);
let v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.chunks(3);
assert_eq!(c.next_back().unwrap(), &[6, 7]);
assert_eq!(c.next_back().unwrap(), &[3, 4, 5]);
assert_eq!(c.next_back().unwrap(), &[0, 1, 2]);
assert_eq!(c.next_back(), None);
}
#[test]
fn test_chunks_nth_back() {
let v: &[i32] = &[0, 1, 2, 3, 4, 5];
@ -809,6 +843,40 @@ fn test_rchunks_nth_back() {
assert_eq!(c2.next_back(), None);
}
#[test]
fn test_rchunks_next() {
let v = [0, 1, 2, 3, 4, 5];
let mut c = v.rchunks(2);
assert_eq!(c.next().unwrap(), &[4, 5]);
assert_eq!(c.next().unwrap(), &[2, 3]);
assert_eq!(c.next().unwrap(), &[0, 1]);
assert_eq!(c.next(), None);
let v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.rchunks(3);
assert_eq!(c.next().unwrap(), &[5, 6, 7]);
assert_eq!(c.next().unwrap(), &[2, 3, 4]);
assert_eq!(c.next().unwrap(), &[0, 1]);
assert_eq!(c.next(), None);
}
#[test]
fn test_rchunks_next_back() {
let v = [0, 1, 2, 3, 4, 5];
let mut c = v.rchunks(2);
assert_eq!(c.next_back().unwrap(), &[0, 1]);
assert_eq!(c.next_back().unwrap(), &[2, 3]);
assert_eq!(c.next_back().unwrap(), &[4, 5]);
assert_eq!(c.next_back(), None);
let v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.rchunks(3);
assert_eq!(c.next_back().unwrap(), &[0, 1]);
assert_eq!(c.next_back().unwrap(), &[2, 3, 4]);
assert_eq!(c.next_back().unwrap(), &[5, 6, 7]);
assert_eq!(c.next_back(), None);
}
#[test]
fn test_rchunks_last() {
let v: &[i32] = &[0, 1, 2, 3, 4, 5];
@ -874,6 +942,40 @@ fn test_rchunks_mut_nth_back() {
assert_eq!(c2.next_back(), None);
}
#[test]
fn test_rchunks_mut_next() {
let mut v = [0, 1, 2, 3, 4, 5];
let mut c = v.rchunks_mut(2);
assert_eq!(c.next().unwrap(), &mut [4, 5]);
assert_eq!(c.next().unwrap(), &mut [2, 3]);
assert_eq!(c.next().unwrap(), &mut [0, 1]);
assert_eq!(c.next(), None);
let mut v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.rchunks_mut(3);
assert_eq!(c.next().unwrap(), &mut [5, 6, 7]);
assert_eq!(c.next().unwrap(), &mut [2, 3, 4]);
assert_eq!(c.next().unwrap(), &mut [0, 1]);
assert_eq!(c.next(), None);
}
#[test]
fn test_rchunks_mut_next_back() {
let mut v = [0, 1, 2, 3, 4, 5];
let mut c = v.rchunks_mut(2);
assert_eq!(c.next_back().unwrap(), &mut [0, 1]);
assert_eq!(c.next_back().unwrap(), &mut [2, 3]);
assert_eq!(c.next_back().unwrap(), &mut [4, 5]);
assert_eq!(c.next_back(), None);
let mut v = [0, 1, 2, 3, 4, 5, 6, 7];
let mut c = v.rchunks_mut(3);
assert_eq!(c.next_back().unwrap(), &mut [0, 1]);
assert_eq!(c.next_back().unwrap(), &mut [2, 3, 4]);
assert_eq!(c.next_back().unwrap(), &mut [5, 6, 7]);
assert_eq!(c.next_back(), None);
}
#[test]
fn test_rchunks_mut_last() {
let v: &mut [i32] = &mut [0, 1, 2, 3, 4, 5];