Skip to content

Commit

Permalink
Pmarks/ldot (#1)
Browse files Browse the repository at this point in the history
* implement Dot from linalg
* rustfmt
* generalize types in sparse-dense multiplication routines
  • Loading branch information
pmarks authored Apr 29, 2019
1 parent 4d65a64 commit 687dea6
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ assert_eq!(a, b.to_csc());
*/

#![feature(re_rebalance_coherence)]
#![deny(warnings)]

#[cfg(feature = "alga")]
Expand Down
51 changes: 51 additions & 0 deletions src/sparse/csmat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,57 @@ where
}
}

impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
for ArrayBase<DS2, Ix2>
where
N: 'a + Copy + Num + Default + std::fmt::Debug,
I: 'a + SpIndex,
IpS: 'a + Deref<Target = [I]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;

fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
let rhs_t = rhs.transpose_view();
let lhs_t = self.t();

let rows = rhs_t.rows();
let cols = lhs_t.cols();
// when the number of colums is small, it is more efficient
// to perform the product by iterating over the columns of
// the rhs, otherwise iterating by rows can take advantage of
// vectorized axpy.
let rres = match (rhs_t.storage(), cols >= 8) {
(CSR, true) => {
let mut res = Array::zeros((rows, cols));
prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSR, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, true) => {
let mut res = Array::zeros((rows, cols));
prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
};

assert_eq!(self.shape()[0], rres.shape()[0]);
assert_eq!(rhs.cols(), rres.shape()[1]);
rres
}
}

impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
for CsMatBase<N, I, IpS, IS, DS>
where
Expand Down
146 changes: 124 additions & 22 deletions src/sparse/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,15 @@ where
/// CSR-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csr_mulacc_dense_rowmaj<'a, N, I>(
lhs: CsMatViewI<N, I>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_rowmaj<'a, N1, N2, NOut, I>(
lhs: CsMatViewI<N1, I>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
{
if lhs.cols() != rhs.shape()[0] {
Expand Down Expand Up @@ -284,12 +287,15 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, I>(
/// CSC-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csc_mulacc_dense_rowmaj<'a, N, I>(
lhs: CsMatViewI<N, I>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_rowmaj<'a, N1, N2, NOut, I>(
lhs: CsMatViewI<N1, I>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
{
if lhs.cols() != rhs.shape()[0] {
Expand Down Expand Up @@ -319,12 +325,15 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, I>(
/// CSC-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csc_mulacc_dense_colmaj<'a, N, I>(
lhs: CsMatViewI<N, I>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_colmaj<'a, N1, N2, NOut, I>(
lhs: CsMatViewI<N1, I>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
{
if lhs.cols() != rhs.shape()[0] {
Expand Down Expand Up @@ -355,12 +364,15 @@ pub fn csc_mulacc_dense_colmaj<'a, N, I>(
/// CSR-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csr_mulacc_dense_colmaj<'a, N, I>(
lhs: CsMatViewI<N, I>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_colmaj<'a, N1, N2, NOut, I>(
lhs: CsMatViewI<N1, I>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
{
if lhs.cols() != rhs.shape()[0] {
Expand Down Expand Up @@ -391,7 +403,8 @@ pub fn csr_mulacc_dense_colmaj<'a, N, I>(
#[cfg(test)]
mod test {
use super::*;
use ndarray::{arr2, Array, ShapeBuilder};
use ndarray::linalg::Dot;
use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder};
use sparse::csmat::CompressedStorage::{CSC, CSR};
use sparse::{CsMat, CsMatView, CsVec};
use test_data::{
Expand Down Expand Up @@ -555,7 +568,7 @@ mod test {

#[test]
fn mul_csr_dense_rowmaj() {
let a = Array::eye(3);
let a: Array2<f64> = Array::eye(3);
let e: CsMat<f64> = CsMat::eye(3);
let mut res = Array::zeros((3, 3));
super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut());
Expand Down Expand Up @@ -647,4 +660,93 @@ mod test {
let c = &a * &b;
assert_eq!(c, expected_output);
}

// stolen from ndarray - not currently exported.
fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
where
D: Dimension,
{
let diff = (&a - &b).mapv_into(f64::abs);

let rtol = 1e-7;
let atol = 1e-12;
let crtol = b.mapv(|x| x.abs() * rtol);
let tol = crtol + atol;
let tol_m_diff = &diff - &tol;
let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y));
println!("diff offset from tolerance level= {:.2e}", maxdiff);
if maxdiff > 0. {
println!("{:.4?}", a);
println!("{:.4?}", b);
panic!("results differ");
}
}

#[test]
fn test_sparse_dot_dense() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.dot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[0] < s.cols() {
continue;
}

let d = d.slice(s![0..s.cols(), ..]);

let truth = s.to_dense().dot(&d);
let test = s.dot(&d);
assert_close(test.view(), truth.view());
}
}
}

#[test]
fn test_dense_dot_sparse() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.ldot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[1] < s.rows() {
continue;
}

let d = d.slice(s![.., 0..s.rows()]);

let truth = d.dot(&s.to_dense());
let test = d.dot(s);
assert_close(test.view(), truth.view());
}
}
}
}

0 comments on commit 687dea6

Please sign in to comment.