diff --git a/src/lib.rs b/src/lib.rs index 10189a41e..25959e011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1928,9 +1928,9 @@ pub trait Itertools: Iterator { /// /// assert_eq!(Some([1, 2]), iter.next_array()); /// ``` - fn next_array(&mut self) -> Option<[T; N]> + fn next_array(&mut self) -> Option<[Self::Item; N]> where - Self: Sized + Iterator, + Self: Sized, { next_array::next_array(self) } @@ -1952,9 +1952,9 @@ pub trait Itertools: Iterator { /// panic!("Expected two elements") /// } /// ``` - fn collect_array(mut self) -> Option<[T; N]> + fn collect_array(mut self) -> Option<[Self::Item; N]> where - Self: Sized + Iterator, + Self: Sized, { self.next_array().filter(|_| self.next().is_none()) } diff --git a/src/next_array.rs b/src/next_array.rs index e9747e52c..fa89f012a 100644 --- a/src/next_array.rs +++ b/src/next_array.rs @@ -1,5 +1,4 @@ use core::mem::{self, MaybeUninit}; -use core::ptr; /// An array of at most `N` elements. struct ArrayBuilder { @@ -17,7 +16,7 @@ struct ArrayBuilder { impl ArrayBuilder { /// Initializes a new, empty `ArrayBuilder`. pub fn new() -> Self { - // SAFETY: the validity invariant trivially hold for a zero-length array. + // SAFETY: The safety invariant of `arr` trivially holds for `len = 0`. Self { arr: [(); N].map(|_| MaybeUninit::uninit()), len: 0, @@ -28,50 +27,96 @@ impl ArrayBuilder { /// /// # Panics /// - /// This panics if `self.len() >= N`. + /// This panics if `self.len >= N` or if `self.len == usize::MAX`. pub fn push(&mut self, value: T) { - // SAFETY: we maintain the invariant here that arr[..len] is valid. - // Indexing with self.len also ensures self.len < N, and thus <= N after - // the increment. + // PANICS: This will panic if `self.len >= N`. + // SAFETY: The safety invariant of `self.arr` applies to elements at + // indices `0..self.len` — not to the element at `self.len`. Writing to + // the element at index `self.len` therefore does not violate the safety + // invariant of `self.arr`. Even if this line panics, we have not + // created any intermediate invalid state. self.arr[self.len] = MaybeUninit::new(value); - self.len += 1; + // PANICS: This will panic if `self.len == usize::MAX`. + // SAFETY: By invariant on `self.arr`, all elements at indicies + // `0..self.len` are valid. Due to the above write, the element at + // `self.len` is now also valid. Consequently, all elements at indicies + // `0..(self.len + 1)` are valid, and `self.len` can be safely + // incremented without violating `self.arr`'s invariant. It is fine if + // this increment panics, as we have not created any intermediate + // invalid state. + self.len = match self.len.checked_add(1) { + Some(sum) => sum, + None => panic!("`self.len == usize::MAX`; cannot increment `len`"), + }; } - /// Consumes the elements in the `ArrayBuilder` and returns them as an array `[T; N]`. + /// Consumes the elements in the `ArrayBuilder` and returns them as an array + /// `[T; N]`. /// /// If `self.len() < N`, this returns `None`. pub fn take(&mut self) -> Option<[T; N]> { if self.len == N { - // Take the array, resetting our length back to zero. + // SAFETY: Decreasing the value of `self.len` cannot violate the + // safety invariant on `self.arr`. self.len = 0; + + // SAFETY: Since `self.len` is 0, `self.arr` may safely contain + // uninitialized elements. let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit())); - // SAFETY: we had len == N, so all elements in arr are valid. - Some(unsafe { arr.map(|v| v.assume_init()) }) + Some(arr.map(|v| { + // SAFETY: We know that all elements of `arr` are valid because + // we checked that `len == N`. + unsafe { v.assume_init() } + })) } else { None } } } +impl AsMut<[T]> for ArrayBuilder { + fn as_mut(&mut self) -> &mut [T] { + let valid = &mut self.arr[..self.len]; + // SAFETY: By invariant on `self.arr`, the elements of `self.arr` at + // indices `0..self.len` are in a valid state. Since `valid` references + // only these elements, the safety precondition of + // `slice_assume_init_mut` is satisfied. + unsafe { slice_assume_init_mut(valid) } + } +} + impl Drop for ArrayBuilder { + // We provide a non-trivial `Drop` impl, because the trivial impl would be a + // no-op; `MaybeUninit` has no innate awareness of its own validity, and + // so it can only forget its contents. By leveraging the safety invariant of + // `self.arr`, we do know which elements of `self.arr` are valid, and can + // selectively run their destructors. fn drop(&mut self) { - unsafe { - // SAFETY: arr[..len] is valid, so must be dropped. First we create - // a pointer to this valid slice, then drop that slice in-place. - // The cast from *mut MaybeUninit to *mut T is always sound by - // the layout guarantees of MaybeUninit. - let ptr_to_first: *mut MaybeUninit = self.arr.as_mut_ptr(); - let ptr_to_slice = ptr::slice_from_raw_parts_mut(ptr_to_first.cast::(), self.len); - ptr::drop_in_place(ptr_to_slice); - } + let valid = self.as_mut(); + // SAFETY: TODO + unsafe { core::ptr::drop_in_place(valid) } } } +/// Assuming all the elements are initialized, get a mutable slice to them. +/// +/// # Safety +/// +/// The caller guarantees that the elements `T` referenced by `slice` are in a +/// valid state. +unsafe fn slice_assume_init_mut(slice: &mut [MaybeUninit]) -> &mut [T] { + // SAFETY: Casting `&mut [MaybeUninit]` to `&mut [T]` is sound, because + // `MaybeUninit` is guaranteed to have the same size, alignment and ABI + // as `T`, and because the caller has guaranteed that `slice` is in the + // valid state. + unsafe { &mut *(slice as *mut [MaybeUninit] as *mut [T]) } +} + /// Equivalent to `it.next_array()`. -pub fn next_array(it: &mut I) -> Option<[T; N]> +pub fn next_array(it: &mut I) -> Option<[I::Item; N]> where - I: Iterator, + I: Iterator, { let mut builder = ArrayBuilder::new(); for _ in 0..N { diff --git a/tests/test_core.rs b/tests/test_core.rs index f98790b71..493616085 100644 --- a/tests/test_core.rs +++ b/tests/test_core.rs @@ -380,7 +380,7 @@ fn next_array() { assert_eq!(iter.next_array(), Some([])); assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2])); assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4])); - assert_eq!(iter.next_array::<_, 2>(), None); + assert_eq!(iter.next_array::<2>(), None); } #[test] @@ -391,9 +391,9 @@ fn collect_array() { let v = [1]; let iter = v.iter().cloned(); - assert_eq!(iter.collect_array::<_, 2>(), None); + assert_eq!(iter.collect_array::<2>(), None); let v = [1, 2, 3]; let iter = v.iter().cloned(); - assert_eq!(iter.collect_array::<_, 2>(), None); + assert_eq!(iter.collect_array::<2>(), None); }