Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating get_row functions to return MatrixSlice #99

Merged
merged 13 commits into from
Dec 11, 2016
4 changes: 2 additions & 2 deletions src/matrix/decomposition/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ mod tests {
assert!(!row.iter().take(idx).any(|&x| x > 1e-10));
assert!(!row.iter().skip(idx + 1).any(|&x| x > 1e-10));
// Assert non-negativity of diagonal elements
assert!(row[idx] >= 0.0);
assert!(row.raw_slice()[idx] >= 0.0);
}

let recovered = u * b * v.transpose();
Expand All @@ -349,7 +349,7 @@ mod tests {

let mut singular_triplets = u_transposed.iter_rows().zip(b.diag().into_iter()).zip(v_transposed.iter_rows())
// chained zipping results in nested tuple. Flatten it.
.map(|((u_col, singular_value), v_col)| (Vector::new(u_col), singular_value, Vector::new(v_col)));
.map(|((u_col, singular_value), v_col)| (Vector::new(u_col.raw_slice()), singular_value, Vector::new(v_col.raw_slice())));

assert!(singular_triplets.by_ref()
// For a matrix M, each singular value σ and left and right singular vectors u and v respectively
Expand Down
48 changes: 48 additions & 0 deletions src/matrix/deref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use super::{MatrixSlice, MatrixSliceMut};
use super::{Row, RowMut, Column, ColumnMut};

use std::ops::{Deref, DerefMut};

impl<'a, T: 'a> Deref for Row<'a, T> {
type Target = MatrixSlice<'a, T>;

fn deref(&self) -> &MatrixSlice<'a, T> {
&self.row
}
}

impl<'a, T: 'a> Deref for RowMut<'a, T> {
type Target = MatrixSliceMut<'a, T>;

fn deref(&self) -> &MatrixSliceMut<'a, T> {
&self.row
}
}

impl<'a, T: 'a> DerefMut for RowMut<'a, T> {
fn deref_mut(&mut self) -> &mut MatrixSliceMut<'a, T> {
&mut self.row
}
}

impl<'a, T: 'a> Deref for Column<'a, T> {
type Target = MatrixSlice<'a, T>;

fn deref(&self) -> &MatrixSlice<'a, T> {
&self.col
}
}

impl<'a, T: 'a> Deref for ColumnMut<'a, T> {
type Target = MatrixSliceMut<'a, T>;

fn deref(&self) -> &MatrixSliceMut<'a, T> {
&self.col
}
}

impl<'a, T: 'a> DerefMut for ColumnMut<'a, T> {
fn deref_mut(&mut self) -> &mut MatrixSliceMut<'a, T> {
&mut self.col
}
}
36 changes: 21 additions & 15 deletions src/matrix/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,10 @@ impl<'a, T> $assign_trt<Matrix<T>> for MatrixSliceMut<'a, T>
where T: Copy + $trt<T, Output=T>
{
fn $op_assign(&mut self, _rhs: Matrix<T>) {
for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) });
for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
}
Expand All @@ -678,10 +680,10 @@ impl<'a, 'b, T> $assign_trt<&'b Matrix<T>> for MatrixSliceMut<'a, T>
where T: Copy + $trt<T, Output=T>
{
fn $op_assign(&mut self, _rhs: &Matrix<T>) {
for (slice_row, target_row) in self.iter_rows_mut()
for (mut slice_row, target_row) in self.iter_rows_mut()
.zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row,
target_row,
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
Expand All @@ -704,10 +706,10 @@ impl<'a, 'b, T> $assign_trt<$target_slice<'b, T>> for MatrixSliceMut<'a, T>
where T: Copy + $trt<T, Output=T>
{
fn $op_assign(&mut self, _rhs: $target_slice<T>) {
for (slice_row, target_row) in self.iter_rows_mut()
for (mut slice_row, target_row) in self.iter_rows_mut()
.zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row,
target_row,
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
Expand All @@ -720,10 +722,10 @@ impl<'a, 'b, 'c, T> $assign_trt<&'c $target_slice<'b, T>> for MatrixSliceMut<'a,
where T: Copy + $trt<T, Output=T>
{
fn $op_assign(&mut self, _rhs: &$target_slice<T>) {
for (slice_row, target_row) in self.iter_rows_mut()
for (mut slice_row, target_row) in self.iter_rows_mut()
.zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row,
target_row,
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
Expand All @@ -745,8 +747,10 @@ macro_rules! impl_op_assign_mat_slice (
impl<'a, T> $assign_trt<$target_mat<'a, T>> for Matrix<T>
where T: Copy + $trt<T, Output=T> {
fn $op_assign(&mut self, _rhs: $target_mat<T>) {
for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) });
for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
}
Expand All @@ -757,8 +761,10 @@ impl<'a, T> $assign_trt<$target_mat<'a, T>> for Matrix<T>
impl<'a, 'b, T> $assign_trt<&'b $target_mat<'a, T>> for Matrix<T>
where T: Copy + $trt<T, Output=T> {
fn $op_assign(&mut self, _rhs: &$target_mat<T>) {
for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) });
for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) {
utils::in_place_vec_bin_op(slice_row.raw_slice_mut(),
target_row.raw_slice(),
|x, &y| {*x = (*x).$op(y) });
}
}
}
Expand Down
117 changes: 88 additions & 29 deletions src/matrix/iter.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::iter::{ExactSizeIterator, FromIterator};
use std::mem;
use std::slice;
//use std::slice;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before final merge: Remove this?


use super::{Matrix, MatrixSlice, MatrixSliceMut, Rows, RowsMut, Diagonal, DiagonalMut};
use super::{Matrix, MatrixSlice, MatrixSliceMut};
use super::{Row, RowMut, Rows, RowsMut, Diagonal, DiagonalMut};
use super::slice::{BaseMatrix, BaseMatrixMut, SliceIter, SliceIterMut};



macro_rules! impl_iter_diag (
($diag:ident, $diag_base:ident, $diag_type:ty, $as_ptr:ident) => (

Expand Down Expand Up @@ -67,16 +66,14 @@ impl<'a, T, M: $diag_base<T>> Iterator for $diag<'a, T, M> {
}

impl<'a, T, M: $diag_base<T>> ExactSizeIterator for $diag<'a, T, M> {}

);

);

impl_iter_diag!(Diagonal, BaseMatrix, &'a T, as_ptr);
impl_iter_diag!(DiagonalMut, BaseMatrixMut, &'a mut T, as_mut_ptr);

macro_rules! impl_iter_rows (
($rows:ident, $row_type:ty, $slice_from_parts:ident) => (
($rows:ident, $row_type:ty, $row_base:ident, $slice_base:ident) => (

/// Iterates over the rows in the matrix.
impl<'a, T> Iterator for $rows<'a, T> {
Expand All @@ -89,7 +86,10 @@ impl<'a, T> Iterator for $rows<'a, T> {
unsafe {
// Get pointer and create a slice from raw parts
let ptr = self.slice_start.offset(self.row_pos as isize * self.row_stride);
row = slice::$slice_from_parts(ptr, self.slice_cols);
//row = slice::$slice_from_parts(ptr, self.slice_cols);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before final merge: remove this?

row = $row_base {
row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize)
};
}

self.row_pos += 1;
Expand All @@ -105,7 +105,10 @@ impl<'a, T> Iterator for $rows<'a, T> {
unsafe {
// Get pointer to last row and create a slice from raw parts
let ptr = self.slice_start.offset((self.slice_rows - 1) as isize * self.row_stride);
Some(slice::$slice_from_parts(ptr, self.slice_cols))
//Some(slice::$slice_from_parts(ptr, self.slice_cols))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before final merge: Remove this?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting that I kept this here to make it a little easier to benchmark later. Will remove later.

Some($row_base {
row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize)
})
}
} else {
None
Expand All @@ -117,7 +120,10 @@ impl<'a, T> Iterator for $rows<'a, T> {
let row: $row_type;
unsafe {
let ptr = self.slice_start.offset((self.row_pos + n) as isize * self.row_stride);
row = slice::$slice_from_parts(ptr, self.slice_cols);
//row = slice::$slice_from_parts(ptr, self.slice_cols);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before final merge: remove this?

Copy link
Owner Author

@AtheMathmo AtheMathmo Dec 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left these here for easier benchmarking later. Just so I can quickly check if something is funky with this iterator.

Edit: I mean to say that I will indeed remove them, thank you for flagging.

row = $row_base {
row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize)
}
}

self.row_pos += n + 1;
Expand All @@ -138,8 +144,10 @@ impl<'a, T> Iterator for $rows<'a, T> {
);
);

impl_iter_rows!(Rows, &'a [T], from_raw_parts);
impl_iter_rows!(RowsMut, &'a mut [T], from_raw_parts_mut);
// impl_iter_rows!(Rows, &'a [T], from_raw_parts);
// impl_iter_rows!(RowsMut, &'a mut [T], from_raw_parts_mut);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before final merge: remove this?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will remove this soon!

impl_iter_rows!(Rows, Row<'a, T>, Row, MatrixSlice);
impl_iter_rows!(RowsMut, RowMut<'a, T>, RowMut, MatrixSliceMut);

impl<'a, T> ExactSizeIterator for Rows<'a, T> {}
impl<'a, T> ExactSizeIterator for RowsMut<'a, T> {}
Expand Down Expand Up @@ -176,7 +184,7 @@ impl<'a, T> ExactSizeIterator for RowsMut<'a, T> {}
/// // where the first entry is less than 6.
/// let b = a.iter_rows()
/// .skip(1)
/// .filter(|x| x[0] < 6)
/// .filter(|x| x.raw_slice()[0] < 6)
/// .collect::<Matrix<usize>>();
///
/// // We take the middle rows
Expand Down Expand Up @@ -225,6 +233,57 @@ impl<'a, T: 'a + Copy> FromIterator<&'a [T]> for Matrix<T> {
}
}

macro_rules! impl_from_iter_row(
($row_type:ty) => (
impl<'a, T: 'a + Copy> FromIterator<$row_type> for Matrix<T> {
fn from_iter<I: IntoIterator<Item = $row_type>>(iterable: I) -> Self {
let mut mat_data: Vec<T>;
let cols: usize;
let mut rows = 0;

let mut iterator = iterable.into_iter();

match iterator.next() {
None => {
return Matrix {
data: Vec::new(),
rows: 0,
cols: 0,
}
}
Some(row) => {
rows += 1;
// Here we set the capacity - get iterator size and the cols
let (lower_rows, _) = iterator.size_hint();
cols = row.row.cols();

mat_data = Vec::with_capacity(lower_rows.saturating_add(1).saturating_mul(cols));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why saturating arithmetic here? I'm a bit confused about this part.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question. I'm not sure what my initial motivation was but there is no real reason.

mat_data.extend_from_slice(row.raw_slice());
}
}

for row in iterator {
assert!(row.row.cols() == cols, "Iterator slice length must be constant.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The message could perhaps be a bit clearer for the user. For example, "Iterator of rows must have rows of compatible dimensions." or so?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right. I overlooked this when I copied the existing FromIterator implementation. Also - do you think we should remove the FromIterator<&'a [T]> implementation? I think we probably should.

mat_data.extend_from_slice(row.raw_slice());
rows += 1;
}

mat_data.shrink_to_fit();

Matrix {
data: mat_data,
rows: rows,
cols: cols,
}
}
}
);
);

impl_from_iter_row!(Row<'a, T>);
impl_from_iter_row!(RowMut<'a, T>);


impl<'a, T> IntoIterator for MatrixSlice<'a, T> {
type Item = &'a T;
type IntoIter = SliceIter<'a, T>;
Expand Down Expand Up @@ -500,15 +559,15 @@ mod tests {
let data = [[0, 1, 2], [3, 4, 5], [6, 7, 8]];

for (i, row) in a.iter_rows().enumerate() {
assert_eq!(data[i], *row);
assert_eq!(data[i], *row.raw_slice());
}

for (i, row) in a.iter_rows_mut().enumerate() {
assert_eq!(data[i], *row);
assert_eq!(data[i], *row.raw_slice());
}

for row in a.iter_rows_mut() {
for r in row {
for mut row in a.iter_rows_mut() {
for r in row.raw_slice_mut() {
*r = 0;
}
}
Expand All @@ -527,7 +586,7 @@ mod tests {
let data = [[0, 1], [3, 4]];

for (i, row) in b.iter_rows().enumerate() {
assert_eq!(data[i], *row);
assert_eq!(data[i], *row.raw_slice());
}
}

Expand All @@ -543,15 +602,15 @@ mod tests {
let data = [[0, 1], [3, 4]];

for (i, row) in b.iter_rows().enumerate() {
assert_eq!(data[i], *row);
assert_eq!(data[i], *row.raw_slice());
}

for (i, row) in b.iter_rows_mut().enumerate() {
assert_eq!(data[i], *row);
assert_eq!(data[i], *row.raw_slice());
}

for row in b.iter_rows_mut() {
for r in row {
for mut row in b.iter_rows_mut() {
for r in row.raw_slice_mut() {
*r = 0;
}
}
Expand All @@ -568,10 +627,10 @@ mod tests {

let mut row_iter = a.iter_rows();

assert_eq!([0, 1, 2], *row_iter.nth(0).unwrap());
assert_eq!([6, 7, 8], *row_iter.nth(1).unwrap());
assert_eq!([0, 1, 2], *row_iter.nth(0).unwrap().raw_slice());
assert_eq!([6, 7, 8], *row_iter.nth(1).unwrap().raw_slice());

assert_eq!(None, row_iter.next());
assert!(row_iter.next().is_none());
}

#[test]
Expand All @@ -582,12 +641,12 @@ mod tests {

let row_iter = a.iter_rows();

assert_eq!([6, 7, 8], *row_iter.last().unwrap());
assert_eq!([6, 7, 8], *row_iter.last().unwrap().raw_slice());

let mut row_iter = a.iter_rows();

row_iter.next();
assert_eq!([6, 7, 8], *row_iter.last().unwrap());
assert_eq!([6, 7, 8], *row_iter.last().unwrap().raw_slice());

let mut row_iter = a.iter_rows();

Expand All @@ -596,7 +655,7 @@ mod tests {
row_iter.next();
row_iter.next();

assert_eq!(None, row_iter.last());
assert!(row_iter.last().is_none());
}

#[test]
Expand Down Expand Up @@ -632,7 +691,7 @@ mod tests {

assert_eq!((0, Some(0)), row_iter.size_hint());

assert_eq!(None, row_iter.next());
assert!(row_iter.next().is_none());
assert_eq!((0, Some(0)), row_iter.size_hint());
}

Expand Down
Loading