Skip to content

Commit

Permalink
Merge pull request #588 from jturner314/generalize-op-types
Browse files Browse the repository at this point in the history
Allow ops on arrays with elems of different types
  • Loading branch information
jturner314 authored May 5, 2019
2 parents 638ac16 + 9b51170 commit 7733b7c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 26 deletions.
11 changes: 6 additions & 5 deletions src/arraytraits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ impl<S, D, I> IndexMut<I> for ArrayBase<S, D>

/// Return `true` if the array shapes and all elements of `self` and
/// `rhs` are equal. Return `false` otherwise.
impl<S, S2, D> PartialEq<ArrayBase<S2, D>> for ArrayBase<S, D>
where D: Dimension,
S: Data,
S2: Data<Elem = S::Elem>,
S::Elem: PartialEq
impl<A, B, S, S2, D> PartialEq<ArrayBase<S2, D>> for ArrayBase<S, D>
where
A: PartialEq<B>,
S: Data<Elem = A>,
S2: Data<Elem = B>,
D: Dimension,
{
fn eq(&self, rhs: &ArrayBase<S2, D>) -> bool {
if self.shape() != rhs.shape() {
Expand Down
42 changes: 24 additions & 18 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ macro_rules! impl_binary_op(
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where A: Clone + $trt<A, Output=A>,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=A>,
D: Dimension,
E: Dimension,
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
Expand All @@ -82,12 +84,14 @@ impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where A: Clone + $trt<A, Output=A>,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=A>,
D: Dimension,
E: Dimension,
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
Expand All @@ -107,12 +111,14 @@ impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where A: Clone + $trt<A, Output=A>,
S: Data<Elem=A>,
S2: Data<Elem=A>,
D: Dimension,
E: Dimension,
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Array<A, D> {
Expand Down
5 changes: 3 additions & 2 deletions src/numeric_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ pub fn unrolled_dot<A>(xs: &[A], ys: &[A]) -> A
/// Compute pairwise equality
///
/// `xs` and `ys` must be the same length
pub fn unrolled_eq<A>(xs: &[A], ys: &[A]) -> bool
where A: PartialEq
pub fn unrolled_eq<A, B>(xs: &[A], ys: &[B]) -> bool
where
A: PartialEq<B>,
{
debug_assert_eq!(xs.len(), ys.len());
// eightfold unrolled for performance (this is not done by llvm automatically)
Expand Down
2 changes: 1 addition & 1 deletion tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn inner_iter() {

#[test]
fn inner_iter_corner_cases() {
let a0 = ArcArray::zeros(());
let a0 = ArcArray::<i32, _>::zeros(());
assert_equal(a0.genrows(), vec![aview1(&[0])]);

let a2 = ArcArray::<i32, _>::zeros((0, 3));
Expand Down

0 comments on commit 7733b7c

Please sign in to comment.