From 687dea61aab8352b969739a244e6db0e3993a5a7 Mon Sep 17 00:00:00 2001 From: Patrick Marks Date: Mon, 29 Apr 2019 08:45:59 -0700 Subject: [PATCH] Pmarks/ldot (#1) * implement Dot from linalg * rustfmt * generalize types in sparse-dense multiplication routines --- src/lib.rs | 1 + src/sparse/csmat.rs | 51 ++++++++++++++++ src/sparse/prod.rs | 146 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 176 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 065a2cb4..4985ef04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ assert_eq!(a, b.to_csc()); */ +#![feature(re_rebalance_coherence)] #![deny(warnings)] #[cfg(feature = "alga")] diff --git a/src/sparse/csmat.rs b/src/sparse/csmat.rs index 56ba84e8..e74f3c31 100644 --- a/src/sparse/csmat.rs +++ b/src/sparse/csmat.rs @@ -1815,6 +1815,57 @@ where } } +impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot> + for ArrayBase +where + N: 'a + Copy + Num + Default + std::fmt::Debug, + I: 'a + SpIndex, + IpS: 'a + Deref, + IS: 'a + Deref, + DS: 'a + Deref, + DS2: 'b + ndarray::Data, +{ + type Output = Array; + + fn dot(&self, rhs: &CsMatBase) -> Array { + let rhs_t = rhs.transpose_view(); + let lhs_t = self.t(); + + let rows = rhs_t.rows(); + let cols = lhs_t.cols(); + // when the number of colums is small, it is more efficient + // to perform the product by iterating over the columns of + // the rhs, otherwise iterating by rows can take advantage of + // vectorized axpy. + let rres = match (rhs_t.storage(), cols >= 8) { + (CSR, true) => { + let mut res = Array::zeros((rows, cols)); + prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSR, false) => { + let mut res = Array::zeros((rows, cols).f()); + prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSC, true) => { + let mut res = Array::zeros((rows, cols)); + prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSC, false) => { + let mut res = Array::zeros((rows, cols).f()); + prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + }; + + assert_eq!(self.shape()[0], rres.shape()[0]); + assert_eq!(rhs.cols(), rres.shape()[1]); + rres + } +} + impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot> for CsMatBase where diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 882b7f95..048b70ce 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -246,12 +246,15 @@ where /// CSR-dense rowmaj multiplication /// /// Performs better if rhs has a decent number of colums. -pub fn csr_mulacc_dense_rowmaj<'a, N, I>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csr_mulacc_dense_rowmaj<'a, N1, N2, NOut, I>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, { if lhs.cols() != rhs.shape()[0] { @@ -284,12 +287,15 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, I>( /// CSC-dense rowmaj multiplication /// /// Performs better if rhs has a decent number of colums. -pub fn csc_mulacc_dense_rowmaj<'a, N, I>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csc_mulacc_dense_rowmaj<'a, N1, N2, NOut, I>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, { if lhs.cols() != rhs.shape()[0] { @@ -319,12 +325,15 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, I>( /// CSC-dense colmaj multiplication /// /// Performs better if rhs has few columns. -pub fn csc_mulacc_dense_colmaj<'a, N, I>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csc_mulacc_dense_colmaj<'a, N1, N2, NOut, I>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, { if lhs.cols() != rhs.shape()[0] { @@ -355,12 +364,15 @@ pub fn csc_mulacc_dense_colmaj<'a, N, I>( /// CSR-dense colmaj multiplication /// /// Performs better if rhs has few columns. -pub fn csr_mulacc_dense_colmaj<'a, N, I>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csr_mulacc_dense_colmaj<'a, N1, N2, NOut, I>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, { if lhs.cols() != rhs.shape()[0] { @@ -391,7 +403,8 @@ pub fn csr_mulacc_dense_colmaj<'a, N, I>( #[cfg(test)] mod test { use super::*; - use ndarray::{arr2, Array, ShapeBuilder}; + use ndarray::linalg::Dot; + use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder}; use sparse::csmat::CompressedStorage::{CSC, CSR}; use sparse::{CsMat, CsMatView, CsVec}; use test_data::{ @@ -555,7 +568,7 @@ mod test { #[test] fn mul_csr_dense_rowmaj() { - let a = Array::eye(3); + let a: Array2 = Array::eye(3); let e: CsMat = CsMat::eye(3); let mut res = Array::zeros((3, 3)); super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut()); @@ -647,4 +660,93 @@ mod test { let c = &a * &b; assert_eq!(c, expected_output); } + + // stolen from ndarray - not currently exported. + fn assert_close(a: ArrayView, b: ArrayView) + where + D: Dimension, + { + let diff = (&a - &b).mapv_into(f64::abs); + + let rtol = 1e-7; + let atol = 1e-12; + let crtol = b.mapv(|x| x.abs() * rtol); + let tol = crtol + atol; + let tol_m_diff = &diff - &tol; + let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y)); + println!("diff offset from tolerance level= {:.2e}", maxdiff); + if maxdiff > 0. { + println!("{:.4?}", a); + println!("{:.4?}", b); + panic!("results differ"); + } + } + + #[test] + fn test_sparse_dot_dense() { + let sparse = [ + mat1(), + mat1_csc(), + mat2(), + mat2().transpose_into(), + mat4(), + mat5(), + ]; + let dense = [ + mat_dense1(), + mat_dense1_colmaj(), + mat_dense1().reversed_axes(), + mat_dense2(), + mat_dense2().reversed_axes(), + ]; + + // test sparse.dot(dense) + for s in sparse.iter() { + for d in dense.iter() { + if d.shape()[0] < s.cols() { + continue; + } + + let d = d.slice(s![0..s.cols(), ..]); + + let truth = s.to_dense().dot(&d); + let test = s.dot(&d); + assert_close(test.view(), truth.view()); + } + } + } + + #[test] + fn test_dense_dot_sparse() { + let sparse = [ + mat1(), + mat1_csc(), + mat2(), + mat2().transpose_into(), + mat4(), + mat5(), + ]; + let dense = [ + mat_dense1(), + mat_dense1_colmaj(), + mat_dense1().reversed_axes(), + mat_dense2(), + mat_dense2().reversed_axes(), + ]; + + // test sparse.ldot(dense) + for s in sparse.iter() { + for d in dense.iter() { + if d.shape()[1] < s.rows() { + continue; + } + + let d = d.slice(s![.., 0..s.rows()]); + + let truth = d.dot(&s.to_dense()); + let test = d.dot(s); + assert_close(test.view(), truth.view()); + } + } + } }