Skip to content

Commit

Permalink
Merge pull request #965 from SparrowLii/to_dim
Browse files Browse the repository at this point in the history
Improve performance of "no-broadcasting-needed" scenario in &array +&array operation
  • Loading branch information
bluss authored Apr 2, 2021
2 parents 25b1eeb + 0136cc3 commit d50f4ea
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 9 deletions.
20 changes: 19 additions & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ extern crate test;

use std::mem::MaybeUninit;

use ndarray::ShapeBuilder;
use ndarray::{ShapeBuilder, Array3, Array4};
use ndarray::{arr0, arr1, arr2, azip, s};
use ndarray::{Array, Array1, Array2, Axis, Ix, Zip};
use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn};
Expand Down Expand Up @@ -998,3 +998,21 @@ fn into_dyn_dyn(bench: &mut test::Bencher) {
let a = a.view();
bench.iter(|| a.clone().into_dyn());
}

#[bench]
fn broadcast_same_dim(bench: &mut test::Bencher) {
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s = Array4::from_shape_vec((2, 2, 3, 2), s.to_vec()).unwrap();
let a = s.slice(s![.., ..;-1, ..;2, ..]);
let b = s.slice(s![.., .., ..;2, ..]);
bench.iter(|| &a + &b);
}

#[bench]
fn broadcast_one_side(bench: &mut test::Bencher) {
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let s2 = [1 ,2 ,3 ,4 ,5 ,6];
let a = Array4::from_shape_vec((4, 1, 3, 2), s.to_vec()).unwrap();
let b = Array3::from_shape_vec((1, 3, 2), s2.to_vec()).unwrap();
bench.iter(|| &a + &b);
}
21 changes: 15 additions & 6 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1804,12 +1804,21 @@ where
E: Dimension,
{
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
if let Some(view1) = self.broadcast(shape.clone()) {
if let Some(view2) = other.broadcast(shape) {
return Ok((view1, view2));
}
}
Err(from_kind(ErrorKind::IncompatibleShape))
let view1 = if shape.slice() == self.dim.slice() {
self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap()
} else if let Some(view1) = self.broadcast(shape.clone()) {
view1
} else {
return Err(from_kind(ErrorKind::IncompatibleShape))
};
let view2 = if shape.slice() == other.dim.slice() {
other.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap()
} else if let Some(view2) = other.broadcast(shape) {
view2
} else {
return Err(from_kind(ErrorKind::IncompatibleShape))
};
Ok((view1, view2))
}

/// Swap axes `ax` and `bx`.
Expand Down
10 changes: 8 additions & 2 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,14 @@ where
{
type Output = Array<A, <D as DimMax<E>>::Output>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
(lhs, rhs)
} else {
self.broadcast_with(rhs).unwrap()
};
Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
}
}

Expand Down
6 changes: 6 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,12 @@ fn arithmetic_broadcast() {
sa2 + &sb2 + sc2.into_owned(),
arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]])
);

// Same shape
let a = s.slice(s![..;-1, ..;2, ..]);
let b = s.slice(s![.., ..;2, ..]);
assert_eq!(a.shape(), b.shape());
assert_eq!(&a + &b, arr3(&[[[3, 7], [19, 23]], [[3, 7], [19, 23]]]));
}

#[test]
Expand Down

0 comments on commit d50f4ea

Please sign in to comment.