Skip to content

Commit

Permalink
Merge pull request #1106 from ethanhs/complexmatmul
Browse files Browse the repository at this point in the history
Complex dot()
  • Loading branch information
bluss authored Nov 12, 2021
2 parents 1c685ef + 84cc038 commit 0172657
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 21 deletions.
65 changes: 45 additions & 20 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,32 @@
// except according to those terms.

use crate::imp_prelude::*;
use crate::numeric_util;

#[cfg(feature = "blas")]
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::numeric_util;

use crate::{LinalgScalar, Zip};

use std::any::TypeId;
use std::mem::MaybeUninit;
use alloc::vec::Vec;

#[cfg(feature = "blas")]
use libc::c_int;
#[cfg(feature = "blas")]
use std::cmp;
#[cfg(feature = "blas")]
use std::mem::swap;
#[cfg(feature = "blas")]
use libc::c_int;

#[cfg(feature = "blas")]
use cblas_sys as blas_sys;
#[cfg(feature = "blas")]
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};

#[cfg(feature = "blas")]
use num_complex::{Complex32 as c32, Complex64 as c64};

/// len of vector before we use blas
#[cfg(feature = "blas")]
const DOT_BLAS_CUTOFF: usize = 32;
Expand Down Expand Up @@ -377,7 +381,12 @@ fn mat_mul_impl<A>(
// size cutoff for using BLAS
let cut = GEMM_BLAS_CUTOFF;
let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
if !(m > cut || n > cut || a > cut)
|| !(same_type::<A, f32>()
|| same_type::<A, f64>()
|| same_type::<A, c32>()
|| same_type::<A, c64>())
{
return mat_mul_general(alpha, lhs, rhs, beta, c);
}
{
Expand Down Expand Up @@ -407,8 +416,23 @@ fn mat_mul_impl<A>(
rhs_trans = CblasTrans;
}

macro_rules! gemm_scalar_cast {
(f32, $var:ident) => {
cast_as(&$var)
};
(f64, $var:ident) => {
cast_as(&$var)
};
(c32, $var:ident) => {
&$var as *const A as *const _
};
(c64, $var:ident) => {
&$var as *const A as *const _
};
}

macro_rules! gemm {
($ty:ty, $gemm:ident) => {
($ty:tt, $gemm:ident) => {
if blas_row_major_2d::<$ty, _>(&lhs_)
&& blas_row_major_2d::<$ty, _>(&rhs_)
&& blas_row_major_2d::<$ty, _>(&c_)
Expand All @@ -428,25 +452,25 @@ fn mat_mul_impl<A>(
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);

// gemm is C ← αA^Op B^Op + βC
// Where Op is notrans/trans/conjtrans
unsafe {
blas_sys::$gemm(
CblasRowMajor,
lhs_trans,
rhs_trans,
m as blas_index, // m, rows of Op(a)
n as blas_index, // n, cols of Op(b)
k as blas_index, // k, cols of Op(a)
cast_as(&alpha), // alpha
lhs_.ptr.as_ptr() as *const _, // a
lhs_stride, // lda
rhs_.ptr.as_ptr() as *const _, // b
rhs_stride, // ldb
cast_as(&beta), // beta
c_.ptr.as_ptr() as *mut _, // c
c_stride, // ldc
m as blas_index, // m, rows of Op(a)
n as blas_index, // n, cols of Op(b)
k as blas_index, // k, cols of Op(a)
gemm_scalar_cast!($ty, alpha), // alpha
lhs_.ptr.as_ptr() as *const _, // a
lhs_stride, // lda
rhs_.ptr.as_ptr() as *const _, // b
rhs_stride, // ldb
gemm_scalar_cast!($ty, beta), // beta
c_.ptr.as_ptr() as *mut _, // c
c_stride, // ldc
);
}
return;
Expand All @@ -455,6 +479,9 @@ fn mat_mul_impl<A>(
}
gemm!(f32, cblas_sgemm);
gemm!(f64, cblas_dgemm);

gemm!(c32, cblas_cgemm);
gemm!(c64, cblas_zgemm);
}
mat_mul_general(alpha, lhs, rhs, beta, c)
}
Expand Down Expand Up @@ -603,9 +630,7 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
S3: DataMut<Elem = A>,
A: LinalgScalar,
{
unsafe {
general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
}
unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
}

/// General matrix-vector multiplication
Expand Down
1 change: 1 addition & 0 deletions xtest-blas/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ test = false
approx = "0.4"
defmac = "0.2"
num-traits = "0.2"
num-complex = { version = "0.4", default-features = false }

[dependencies]
ndarray = { path = "../", features = ["approx", "blas"] }
Expand Down
90 changes: 89 additions & 1 deletion xtest-blas/tests/oper.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
extern crate approx;
extern crate blas_src;
extern crate defmac;
extern crate ndarray;
extern crate num_complex;
extern crate num_traits;
extern crate blas_src;

use ndarray::prelude::*;

Expand All @@ -12,6 +13,8 @@ use ndarray::{Data, Ix, LinalgScalar};

use approx::assert_relative_eq;
use defmac::defmac;
use num_complex::Complex32;
use num_complex::Complex64;

#[test]
fn mat_vec_product_1d() {
Expand Down Expand Up @@ -52,6 +55,20 @@ fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
.unwrap()
}

fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32> {
Array::linspace(0., (m * n) as f32 - 1., m * n)
.into_shape((m, n))
.unwrap()
.map(|&f| Complex32::new(f, 0.))
}

fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64> {
Array::linspace(0., (m * n) as f64 - 1., m * n)
.into_shape((m, n))
.unwrap()
.map(|&f| Complex64::new(f, 0.))
}

fn range1_mat64(m: Ix) -> Array1<f64> {
Array::linspace(0., m as f64 - 1., m)
}
Expand Down Expand Up @@ -250,6 +267,77 @@ fn gemm_64_1_f() {
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
}

#[test]
fn gemm_c64_1_f() {
let a = range_mat_complex64(64, 64).reversed_axes();
let (m, n) = a.dim();
// m x n times n x 1 == m x 1
let x = range_mat_complex64(n, 1);
let mut y = range_mat_complex64(m, 1);
let answer = reference_mat_mul(&a, &x) + &y;
general_mat_mul(
Complex64::new(1.0, 0.),
&a,
&x,
Complex64::new(1.0, 0.),
&mut y,
);
assert_relative_eq!(
y.mapv(|i| i.norm_sqr()),
answer.mapv(|i| i.norm_sqr()),
epsilon = 1e-12,
max_relative = 1e-7
);
}

#[test]
fn gemm_c32_1_f() {
let a = range_mat_complex(64, 64).reversed_axes();
let (m, n) = a.dim();
// m x n times n x 1 == m x 1
let x = range_mat_complex(n, 1);
let mut y = range_mat_complex(m, 1);
let answer = reference_mat_mul(&a, &x) + &y;
general_mat_mul(
Complex32::new(1.0, 0.),
&a,
&x,
Complex32::new(1.0, 0.),
&mut y,
);
assert_relative_eq!(
y.mapv(|i| i.norm_sqr()),
answer.mapv(|i| i.norm_sqr()),
epsilon = 1e-12,
max_relative = 1e-7
);
}

#[test]
fn gemm_c64_actually_complex() {
let mut a = range_mat_complex64(4,4);
a = a.map(|&i| if i.re > 8. { i.conj() } else { i });
let mut b = range_mat_complex64(4,6);
b = b.map(|&i| if i.re > 4. { i.conj() } else {i});
let mut y = range_mat_complex64(4,6);
let alpha = Complex64::new(0., 1.0);
let beta = Complex64::new(1.0, 1.0);
let answer = alpha * reference_mat_mul(&a, &b) + beta * &y;
general_mat_mul(
alpha.clone(),
&a,
&b,
beta.clone(),
&mut y,
);
assert_relative_eq!(
y.mapv(|i| i.norm_sqr()),
answer.mapv(|i| i.norm_sqr()),
epsilon = 1e-12,
max_relative = 1e-7
);
}

#[test]
fn gen_mat_vec_mul() {
let alpha = -2.3;
Expand Down

0 comments on commit 0172657

Please sign in to comment.