diff --git a/benches/bench1.rs b/benches/bench1.rs index a6fa86deb..c7f18e3c4 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -11,7 +11,7 @@ extern crate test; use std::mem::MaybeUninit; -use ndarray::ShapeBuilder; +use ndarray::{ShapeBuilder, Array3, Array4}; use ndarray::{arr0, arr1, arr2, azip, s}; use ndarray::{Array, Array1, Array2, Axis, Ix, Zip}; use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn}; @@ -998,3 +998,21 @@ fn into_dyn_dyn(bench: &mut test::Bencher) { let a = a.view(); bench.iter(|| a.clone().into_dyn()); } + +#[bench] +fn broadcast_same_dim(bench: &mut test::Bencher) { + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array4::from_shape_vec((2, 2, 3, 2), s.to_vec()).unwrap(); + let a = s.slice(s![.., ..;-1, ..;2, ..]); + let b = s.slice(s![.., .., ..;2, ..]); + bench.iter(|| &a + &b); +} + +#[bench] +fn broadcast_one_side(bench: &mut test::Bencher) { + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s2 = [1 ,2 ,3 ,4 ,5 ,6]; + let a = Array4::from_shape_vec((4, 1, 3, 2), s.to_vec()).unwrap(); + let b = Array3::from_shape_vec((1, 3, 2), s2.to_vec()).unwrap(); + bench.iter(|| &a + &b); +} \ No newline at end of file diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4045f2b59..efea69c4c 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1804,12 +1804,21 @@ where E: Dimension, { let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; - if let Some(view1) = self.broadcast(shape.clone()) { - if let Some(view2) = other.broadcast(shape) { - return Ok((view1, view2)); - } - } - Err(from_kind(ErrorKind::IncompatibleShape)) + let view1 = if shape.slice() == self.dim.slice() { + self.view().into_dimensionality::<>::Output>().unwrap() + } else if let Some(view1) = self.broadcast(shape.clone()) { + view1 + } else { + return Err(from_kind(ErrorKind::IncompatibleShape)) + }; + let view2 = if shape.slice() == other.dim.slice() { + other.view().into_dimensionality::<>::Output>().unwrap() + } else if let Some(view2) = other.broadcast(shape) { + view2 + } else { + return Err(from_kind(ErrorKind::IncompatibleShape)) + }; + Ok((view1, view2)) } /// Swap axes `ax` and `bx`. diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 663dc7183..4c255dfff 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -179,8 +179,14 @@ where { type Output = Array>::Output>; fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { - let (lhs, rhs) = self.broadcast_with(rhs).unwrap(); - Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) + let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let lhs = self.view().into_dimensionality::<>::Output>().unwrap(); + let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap(); + (lhs, rhs) + } else { + self.broadcast_with(rhs).unwrap() + }; + Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth)) } } diff --git a/tests/array.rs b/tests/array.rs index 8e084e49e..9e45d161f 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1627,6 +1627,12 @@ fn arithmetic_broadcast() { sa2 + &sb2 + sc2.into_owned(), arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) ); + + // Same shape + let a = s.slice(s![..;-1, ..;2, ..]); + let b = s.slice(s![.., ..;2, ..]); + assert_eq!(a.shape(), b.shape()); + assert_eq!(&a + &b, arr3(&[[[3, 7], [19, 23]], [[3, 7], [19, 23]]])); } #[test]