From b70085c51b2a766285d0ed469446dd141429e56b Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 9 May 2021 20:01:26 +0200 Subject: [PATCH] shape: Use reshape_dim function in .to_shape() --- src/impl_methods.rs | 45 ++++++++++++++++++++++++++++----------------- tests/reshape.rs | 21 ++++++++++++++++++++- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index cda2cd95e..6c51b4515 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -23,11 +23,11 @@ use crate::dimension::{ offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; +use crate::dimension::reshape_dim; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::AxisDescription; -use crate::Layout; use crate::order::Order; use crate::shape_builder::ShapeArg; use crate::zip::{IntoNdProducer, Zip}; @@ -1641,27 +1641,38 @@ where A: Clone, S: Data, { - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { + let len = self.dim.size(); + if size_of_shape_checked(&shape) != Ok(len) { return Err(error::incompatible_shapes(&self.dim, &shape)); } - let layout = self.layout_impl(); - unsafe { - if layout.is(Layout::CORDER) && order == Order::RowMajor { - let strides = shape.default_strides(); - Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) - } else if layout.is(Layout::FORDER) && order == Order::ColumnMajor { - let strides = shape.fortran_strides(); - Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) - } else { - let (shape, view) = match order { - Order::RowMajor => (shape.set_f(false), self.view()), - Order::ColumnMajor => (shape.set_f(true), self.t()), - }; - Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( - shape, view.into_iter(), A::clone))) + // Create a view if the length is 0, safe because the array and new shape is empty. + if len == 0 { + unsafe { + return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr()))); } } + + // Try to reshape the array as a view into the existing data + match reshape_dim(&self.dim, &self.strides, &shape, order) { + Ok(to_strides) => unsafe { + return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides))); + } + Err(err) if err.kind() == ErrorKind::IncompatibleShape => { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + _otherwise => { } + } + + // otherwise create a new array and copy the elements + unsafe { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( + shape, view.into_iter(), A::clone))) + } } /// Transform the array into `shape`; any shape with the same number of diff --git a/tests/reshape.rs b/tests/reshape.rs index f03f4ccf1..19f5b4ae1 100644 --- a/tests/reshape.rs +++ b/tests/reshape.rs @@ -133,7 +133,7 @@ fn to_shape_add_axis() { let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); - assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned()); + assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view()); } @@ -150,6 +150,16 @@ fn to_shape_copy_stride() { assert!(lin2.is_owned()); } + +#[test] +fn to_shape_zero_len() { + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..0]); + let lin1 = vs.to_shape(0).unwrap(); + assert_eq!(lin1, array![]); + assert!(lin1.is_view()); +} + #[test] #[should_panic(expected = "IncompatibleShape")] fn to_shape_error1() { @@ -157,3 +167,12 @@ fn to_shape_error1() { let v = aview1(&data); let _u = v.to_shape((2, 5)).unwrap(); } + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error2() { + // overflow + let data = [3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, usize::MAX)).unwrap(); +}