Skip to content

Commit

Permalink
shape: Use reshape_dim function in .to_shape()
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed May 10, 2021
1 parent f96977f commit b70085c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
45 changes: 28 additions & 17 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ use crate::dimension::{
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::dimension::broadcast::co_broadcast;
use crate::dimension::reshape_dim;
use crate::error::{self, ErrorKind, ShapeError, from_kind};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::AxisDescription;
use crate::Layout;
use crate::order::Order;
use crate::shape_builder::ShapeArg;
use crate::zip::{IntoNdProducer, Zip};
Expand Down Expand Up @@ -1641,27 +1641,38 @@ where
A: Clone,
S: Data,
{
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
let len = self.dim.size();
if size_of_shape_checked(&shape) != Ok(len) {
return Err(error::incompatible_shapes(&self.dim, &shape));
}
let layout = self.layout_impl();

unsafe {
if layout.is(Layout::CORDER) && order == Order::RowMajor {
let strides = shape.default_strides();
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
} else if layout.is(Layout::FORDER) && order == Order::ColumnMajor {
let strides = shape.fortran_strides();
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
} else {
let (shape, view) = match order {
Order::RowMajor => (shape.set_f(false), self.view()),
Order::ColumnMajor => (shape.set_f(true), self.t()),
};
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
shape, view.into_iter(), A::clone)))
// Create a view if the length is 0, safe because the array and new shape is empty.
if len == 0 {
unsafe {
return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr())));
}
}

// Try to reshape the array as a view into the existing data
match reshape_dim(&self.dim, &self.strides, &shape, order) {
Ok(to_strides) => unsafe {
return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides)));
}
Err(err) if err.kind() == ErrorKind::IncompatibleShape => {
return Err(error::incompatible_shapes(&self.dim, &shape));
}
_otherwise => { }
}

// otherwise create a new array and copy the elements
unsafe {
let (shape, view) = match order {
Order::RowMajor => (shape.set_f(false), self.view()),
Order::ColumnMajor => (shape.set_f(true), self.t()),
};
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
shape, view.into_iter(), A::clone)))
}
}

/// Transform the array into `shape`; any shape with the same number of
Expand Down
21 changes: 20 additions & 1 deletion tests/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn to_shape_add_axis() {
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();

assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view());
}


Expand All @@ -150,10 +150,29 @@ fn to_shape_copy_stride() {
assert!(lin2.is_owned());
}


#[test]
fn to_shape_zero_len() {
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
let vs = v.slice(s![.., ..0]);
let lin1 = vs.to_shape(0).unwrap();
assert_eq!(lin1, array![]);
assert!(lin1.is_view());
}

#[test]
#[should_panic(expected = "IncompatibleShape")]
fn to_shape_error1() {
let data = [1, 2, 3, 4, 5, 6, 7, 8];
let v = aview1(&data);
let _u = v.to_shape((2, 5)).unwrap();
}

#[test]
#[should_panic(expected = "IncompatibleShape")]
fn to_shape_error2() {
// overflow
let data = [3, 4, 5, 6, 7, 8];
let v = aview1(&data);
let _u = v.to_shape((2, usize::MAX)).unwrap();
}

0 comments on commit b70085c

Please sign in to comment.