diff --git a/src/matrix/decomposition/svd.rs b/src/matrix/decomposition/svd.rs index 78ff0bb..1ef81f6 100644 --- a/src/matrix/decomposition/svd.rs +++ b/src/matrix/decomposition/svd.rs @@ -349,7 +349,7 @@ mod tests { let mut singular_triplets = u_transposed.iter_rows().zip(b.diag().into_iter()).zip(v_transposed.iter_rows()) // chained zipping results in nested tuple. Flatten it. - .map(|((u_col, singular_value), v_col)| (Vector::new(u_col), singular_value, Vector::new(v_col))); + .map(|((u_col, singular_value), v_col)| (Vector::new(u_col.raw_slice()), singular_value, Vector::new(v_col.raw_slice()))); assert!(singular_triplets.by_ref() // For a matrix M, each singular value σ and left and right singular vectors u and v respectively diff --git a/src/matrix/deref.rs b/src/matrix/deref.rs new file mode 100644 index 0000000..11f2dbc --- /dev/null +++ b/src/matrix/deref.rs @@ -0,0 +1,48 @@ +use super::{MatrixSlice, MatrixSliceMut}; +use super::{Row, RowMut, Column, ColumnMut}; + +use std::ops::{Deref, DerefMut}; + +impl<'a, T: 'a> Deref for Row<'a, T> { + type Target = MatrixSlice<'a, T>; + + fn deref(&self) -> &MatrixSlice<'a, T> { + &self.row + } +} + +impl<'a, T: 'a> Deref for RowMut<'a, T> { + type Target = MatrixSliceMut<'a, T>; + + fn deref(&self) -> &MatrixSliceMut<'a, T> { + &self.row + } +} + +impl<'a, T: 'a> DerefMut for RowMut<'a, T> { + fn deref_mut(&mut self) -> &mut MatrixSliceMut<'a, T> { + &mut self.row + } +} + +impl<'a, T: 'a> Deref for Column<'a, T> { + type Target = MatrixSlice<'a, T>; + + fn deref(&self) -> &MatrixSlice<'a, T> { + &self.col + } +} + +impl<'a, T: 'a> Deref for ColumnMut<'a, T> { + type Target = MatrixSliceMut<'a, T>; + + fn deref(&self) -> &MatrixSliceMut<'a, T> { + &self.col + } +} + +impl<'a, T: 'a> DerefMut for ColumnMut<'a, T> { + fn deref_mut(&mut self) -> &mut MatrixSliceMut<'a, T> { + &mut self.col + } +} diff --git a/src/matrix/impl_ops.rs b/src/matrix/impl_ops.rs index 57332b1..95dfe06 100644 --- a/src/matrix/impl_ops.rs +++ b/src/matrix/impl_ops.rs @@ -1,6 +1,5 @@ -use super::Matrix; -use super::MatrixSlice; -use super::MatrixSliceMut; +use super::{Matrix, MatrixSlice, MatrixSliceMut}; +use super::{Row, RowMut, Column, ColumnMut}; use super::slice::{BaseMatrix, BaseMatrixMut}; use super::super::utils; @@ -79,6 +78,46 @@ impl IndexMut<[usize; 2]> for Matrix { } } +impl<'a, T> Index for Row<'a, T> { + type Output = T; + fn index(&self, idx: usize) -> &T { + &self.row[[0, idx]] + } +} + +impl<'a, T> Index for RowMut<'a, T> { + type Output = T; + fn index(&self, idx: usize) -> &T { + &self.row[[0, idx]] + } +} + +impl<'a, T> IndexMut for RowMut<'a, T> { + fn index_mut(&mut self, idx: usize) -> &mut T { + &mut self.row[[0, idx]] + } +} + +impl<'a, T> Index for Column<'a, T> { + type Output = T; + fn index(&self, idx: usize) -> &T { + &self.col[[idx, 0]] + } +} + +impl<'a, T> Index for ColumnMut<'a, T> { + type Output = T; + fn index(&self, idx: usize) -> &T { + &self.col[[idx, 0]] + } +} + +impl<'a, T> IndexMut for ColumnMut<'a, T> { + fn index_mut(&mut self, idx: usize) -> &mut T { + &mut self.col[[idx, 0]] + } +} + macro_rules! impl_bin_op_scalar_slice ( ($trt:ident, $op:ident, $slice:ident, $doc:expr) => ( @@ -665,8 +704,10 @@ impl<'a, T> $assign_trt> for MatrixSliceMut<'a, T> where T: Copy + $trt { fn $op_assign(&mut self, _rhs: Matrix) { - for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) }); + for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), + |x, &y| {*x = (*x).$op(y) }); } } } @@ -678,10 +719,10 @@ impl<'a, 'b, T> $assign_trt<&'b Matrix> for MatrixSliceMut<'a, T> where T: Copy + $trt { fn $op_assign(&mut self, _rhs: &Matrix) { - for (slice_row, target_row) in self.iter_rows_mut() + for (mut slice_row, target_row) in self.iter_rows_mut() .zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, - target_row, + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), |x, &y| {*x = (*x).$op(y) }); } } @@ -704,10 +745,10 @@ impl<'a, 'b, T> $assign_trt<$target_slice<'b, T>> for MatrixSliceMut<'a, T> where T: Copy + $trt { fn $op_assign(&mut self, _rhs: $target_slice) { - for (slice_row, target_row) in self.iter_rows_mut() + for (mut slice_row, target_row) in self.iter_rows_mut() .zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, - target_row, + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), |x, &y| {*x = (*x).$op(y) }); } } @@ -720,10 +761,10 @@ impl<'a, 'b, 'c, T> $assign_trt<&'c $target_slice<'b, T>> for MatrixSliceMut<'a, where T: Copy + $trt { fn $op_assign(&mut self, _rhs: &$target_slice) { - for (slice_row, target_row) in self.iter_rows_mut() + for (mut slice_row, target_row) in self.iter_rows_mut() .zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, - target_row, + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), |x, &y| {*x = (*x).$op(y) }); } } @@ -745,8 +786,10 @@ macro_rules! impl_op_assign_mat_slice ( impl<'a, T> $assign_trt<$target_mat<'a, T>> for Matrix where T: Copy + $trt { fn $op_assign(&mut self, _rhs: $target_mat) { - for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) }); + for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), + |x, &y| {*x = (*x).$op(y) }); } } } @@ -757,8 +800,10 @@ impl<'a, T> $assign_trt<$target_mat<'a, T>> for Matrix impl<'a, 'b, T> $assign_trt<&'b $target_mat<'a, T>> for Matrix where T: Copy + $trt { fn $op_assign(&mut self, _rhs: &$target_mat) { - for (slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { - utils::in_place_vec_bin_op(slice_row, target_row, |x, &y| {*x = (*x).$op(y) }); + for (mut slice_row, target_row) in self.iter_rows_mut().zip(_rhs.iter_rows()) { + utils::in_place_vec_bin_op(slice_row.raw_slice_mut(), + target_row.raw_slice(), + |x, &y| {*x = (*x).$op(y) }); } } } diff --git a/src/matrix/iter.rs b/src/matrix/iter.rs index 0273833..7dfa6a2 100644 --- a/src/matrix/iter.rs +++ b/src/matrix/iter.rs @@ -1,12 +1,11 @@ use std::iter::{ExactSizeIterator, FromIterator}; use std::mem; -use std::slice; +//use std::slice; -use super::{Matrix, MatrixSlice, MatrixSliceMut, Rows, RowsMut, Diagonal, DiagonalMut}; +use super::{Matrix, MatrixSlice, MatrixSliceMut}; +use super::{Row, RowMut, Rows, RowsMut, Diagonal, DiagonalMut}; use super::slice::{BaseMatrix, BaseMatrixMut, SliceIter, SliceIterMut}; - - macro_rules! impl_iter_diag ( ($diag:ident, $diag_base:ident, $diag_type:ty, $as_ptr:ident) => ( @@ -67,16 +66,14 @@ impl<'a, T, M: $diag_base> Iterator for $diag<'a, T, M> { } impl<'a, T, M: $diag_base> ExactSizeIterator for $diag<'a, T, M> {} - ); - ); impl_iter_diag!(Diagonal, BaseMatrix, &'a T, as_ptr); impl_iter_diag!(DiagonalMut, BaseMatrixMut, &'a mut T, as_mut_ptr); macro_rules! impl_iter_rows ( - ($rows:ident, $row_type:ty, $slice_from_parts:ident) => ( + ($rows:ident, $row_type:ty, $row_base:ident, $slice_base:ident) => ( /// Iterates over the rows in the matrix. impl<'a, T> Iterator for $rows<'a, T> { @@ -89,7 +86,9 @@ impl<'a, T> Iterator for $rows<'a, T> { unsafe { // Get pointer and create a slice from raw parts let ptr = self.slice_start.offset(self.row_pos as isize * self.row_stride); - row = slice::$slice_from_parts(ptr, self.slice_cols); + row = $row_base { + row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize) + }; } self.row_pos += 1; @@ -105,7 +104,9 @@ impl<'a, T> Iterator for $rows<'a, T> { unsafe { // Get pointer to last row and create a slice from raw parts let ptr = self.slice_start.offset((self.slice_rows - 1) as isize * self.row_stride); - Some(slice::$slice_from_parts(ptr, self.slice_cols)) + Some($row_base { + row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize) + }) } } else { None @@ -117,7 +118,9 @@ impl<'a, T> Iterator for $rows<'a, T> { let row: $row_type; unsafe { let ptr = self.slice_start.offset((self.row_pos + n) as isize * self.row_stride); - row = slice::$slice_from_parts(ptr, self.slice_cols); + row = $row_base { + row: $slice_base::from_raw_parts(ptr, 1, self.slice_cols, self.row_stride as usize) + } } self.row_pos += n + 1; @@ -138,8 +141,8 @@ impl<'a, T> Iterator for $rows<'a, T> { ); ); -impl_iter_rows!(Rows, &'a [T], from_raw_parts); -impl_iter_rows!(RowsMut, &'a mut [T], from_raw_parts_mut); +impl_iter_rows!(Rows, Row<'a, T>, Row, MatrixSlice); +impl_iter_rows!(RowsMut, RowMut<'a, T>, RowMut, MatrixSliceMut); impl<'a, T> ExactSizeIterator for Rows<'a, T> {} impl<'a, T> ExactSizeIterator for RowsMut<'a, T> {} @@ -225,6 +228,57 @@ impl<'a, T: 'a + Copy> FromIterator<&'a [T]> for Matrix { } } +macro_rules! impl_from_iter_row( + ($row_type:ty) => ( +impl<'a, T: 'a + Copy> FromIterator<$row_type> for Matrix { + fn from_iter>(iterable: I) -> Self { + let mut mat_data: Vec; + let cols: usize; + let mut rows = 0; + + let mut iterator = iterable.into_iter(); + + match iterator.next() { + None => { + return Matrix { + data: Vec::new(), + rows: 0, + cols: 0, + } + } + Some(row) => { + rows += 1; + // Here we set the capacity - get iterator size and the cols + let (lower_rows, _) = iterator.size_hint(); + cols = row.row.cols(); + + mat_data = Vec::with_capacity(lower_rows.saturating_add(1).saturating_mul(cols)); + mat_data.extend_from_slice(row.raw_slice()); + } + } + + for row in iterator { + assert!(row.row.cols() == cols, "Iterator row size must be constant."); + mat_data.extend_from_slice(row.raw_slice()); + rows += 1; + } + + mat_data.shrink_to_fit(); + + Matrix { + data: mat_data, + rows: rows, + cols: cols, + } + } +} + ); +); + +impl_from_iter_row!(Row<'a, T>); +impl_from_iter_row!(RowMut<'a, T>); + + impl<'a, T> IntoIterator for MatrixSlice<'a, T> { type Item = &'a T; type IntoIter = SliceIter<'a, T>; @@ -500,15 +554,15 @@ mod tests { let data = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]; for (i, row) in a.iter_rows().enumerate() { - assert_eq!(data[i], *row); + assert_eq!(data[i], *row.raw_slice()); } for (i, row) in a.iter_rows_mut().enumerate() { - assert_eq!(data[i], *row); + assert_eq!(data[i], *row.raw_slice()); } - for row in a.iter_rows_mut() { - for r in row { + for mut row in a.iter_rows_mut() { + for r in row.raw_slice_mut() { *r = 0; } } @@ -527,7 +581,7 @@ mod tests { let data = [[0, 1], [3, 4]]; for (i, row) in b.iter_rows().enumerate() { - assert_eq!(data[i], *row); + assert_eq!(data[i], *row.raw_slice()); } } @@ -543,15 +597,15 @@ mod tests { let data = [[0, 1], [3, 4]]; for (i, row) in b.iter_rows().enumerate() { - assert_eq!(data[i], *row); + assert_eq!(data[i], *row.raw_slice()); } for (i, row) in b.iter_rows_mut().enumerate() { - assert_eq!(data[i], *row); + assert_eq!(data[i], *row.raw_slice()); } - for row in b.iter_rows_mut() { - for r in row { + for mut row in b.iter_rows_mut() { + for r in row.raw_slice_mut() { *r = 0; } } @@ -568,10 +622,10 @@ mod tests { let mut row_iter = a.iter_rows(); - assert_eq!([0, 1, 2], *row_iter.nth(0).unwrap()); - assert_eq!([6, 7, 8], *row_iter.nth(1).unwrap()); + assert_eq!([0, 1, 2], *row_iter.nth(0).unwrap().raw_slice()); + assert_eq!([6, 7, 8], *row_iter.nth(1).unwrap().raw_slice()); - assert_eq!(None, row_iter.next()); + assert!(row_iter.next().is_none()); } #[test] @@ -582,12 +636,12 @@ mod tests { let row_iter = a.iter_rows(); - assert_eq!([6, 7, 8], *row_iter.last().unwrap()); + assert_eq!([6, 7, 8], *row_iter.last().unwrap().raw_slice()); let mut row_iter = a.iter_rows(); row_iter.next(); - assert_eq!([6, 7, 8], *row_iter.last().unwrap()); + assert_eq!([6, 7, 8], *row_iter.last().unwrap().raw_slice()); let mut row_iter = a.iter_rows(); @@ -596,7 +650,7 @@ mod tests { row_iter.next(); row_iter.next(); - assert_eq!(None, row_iter.last()); + assert!(row_iter.last().is_none()); } #[test] @@ -632,7 +686,7 @@ mod tests { assert_eq!((0, Some(0)), row_iter.size_hint()); - assert_eq!(None, row_iter.next()); + assert!(row_iter.next().is_none()); assert_eq!((0, Some(0)), row_iter.size_hint()); } diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index 59ea9d3..b30ddd0 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -6,6 +6,7 @@ //! Most of the logic for manipulating matrices is generically implemented //! via `BaseMatrix` and `BaseMatrixMut` trait. +use std; use std::any::Any; use std::fmt; use std::marker::PhantomData; @@ -20,6 +21,7 @@ mod decomposition; mod impl_ops; mod mat_mul; mod iter; +mod deref; pub mod slice; pub use self::slice::{BaseMatrix, BaseMatrixMut}; @@ -73,6 +75,91 @@ pub struct MatrixSliceMut<'a, T: 'a> { marker: PhantomData<&'a mut T>, } +/// Row of a matrix. +/// +/// This struct points to a slice making up +/// a row in a matrix. You can deref this +/// struct to retrieve a `MatrixSlice` of +/// the row. +/// +/// # Example +/// +/// ``` +/// # #[macro_use] extern crate rulinalg; fn main() { +/// use rulinalg::matrix::BaseMatrix; +/// +/// let mat = matrix![1.0, 2.0; +/// 3.0, 4.0]; +/// +/// let row = mat.row(1); +/// assert_eq!((*row + 2.0).sum(), 11.0); +/// # } +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct Row<'a, T: 'a> { + row: MatrixSlice<'a, T> +} + +/// Mutable row of a matrix. +/// +/// This struct points to a mutable slice +/// making up a row in a matrix. You can deref +/// this struct to retrieve a `MatrixSlice` +/// of the row. +/// +/// # Example +/// +/// ``` +/// # #[macro_use] extern crate rulinalg; fn main() { +/// use rulinalg::matrix::BaseMatrixMut; +/// +/// let mut mat = matrix![1.0, 2.0; +/// 3.0, 4.0]; +/// +/// { +/// let mut row = mat.row_mut(1); +/// *row += 2.0; +/// } +/// let expected = matrix![1.0, 2.0; +/// 5.0, 6.0]; +/// assert_matrix_eq!(mat, expected); +/// # } +/// ``` +#[derive(Debug)] +pub struct RowMut<'a, T: 'a> { + row: MatrixSliceMut<'a, T> +} + + +// +// MAYBE WE SHOULD MOVE SOME OF THIS STUFF OUT +// + +impl<'a, T: 'a> Row<'a, T> { + /// Returns the row as a slice. + pub fn raw_slice(&self) -> &'a [T] { + unsafe { + std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) + } + } +} + +impl<'a, T: 'a> RowMut<'a, T> { + /// Returns the row as a slice. + pub fn raw_slice(&self) -> &'a [T] { + unsafe { + std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) + } + } + + /// Returns the row as a slice. + pub fn raw_slice_mut(&mut self) -> &'a mut [T] { + unsafe { + std::slice::from_raw_parts_mut(self.row.as_mut_ptr(), self.row.cols()) + } + } +} + /// Row iterator. #[derive(Debug)] pub struct Rows<'a, T: 'a> { @@ -95,6 +182,60 @@ pub struct RowsMut<'a, T: 'a> { _marker: PhantomData<&'a mut T>, } +/// Column of a matrix. +/// +/// This struct points to a `MatrixSlice` +/// making up a column in a matrix. +/// You can deref this struct to retrieve +/// the raw column `MatrixSlice`. +/// +/// # Example +/// +/// ``` +/// # #[macro_use] extern crate rulinalg; fn main() { +/// use rulinalg::matrix::BaseMatrix; +/// +/// let mat = matrix![1.0, 2.0; +/// 3.0, 4.0]; +/// +/// let col = mat.col(1); +/// assert_eq!((*col + 2.0).sum(), 10.0); +/// # } +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct Column<'a, T: 'a> { + col: MatrixSlice<'a, T> +} + +/// Mutable column of a matrix. +/// +/// This struct points to a `MatrixSliceMut` +/// making up a column in a matrix. +/// You can deref this struct to retrieve +/// the raw column `MatrixSliceMut`. +/// +/// # Example +/// +/// ``` +/// # #[macro_use] extern crate rulinalg; fn main() { +/// use rulinalg::matrix::BaseMatrixMut; +/// +/// let mut mat = matrix![1.0, 2.0; +/// 3.0, 4.0]; +/// { +/// let mut column = mat.col_mut(1); +/// *column += 2.0; +/// } +/// let expected = matrix![1.0, 4.0; +/// 3.0, 6.0]; +/// assert_matrix_eq!(mat, expected); +/// # } +/// ``` +#[derive(Debug)] +pub struct ColumnMut<'a, T: 'a> { + col: MatrixSliceMut<'a, T> +} + /// Diagonal offset (used by Diagonal iterator). #[derive(Debug, PartialEq)] pub enum DiagOffset { @@ -633,7 +774,8 @@ impl<'a, T: Float> Metric for MatrixSlice<'a, T> { let mut s = T::zero(); for row in self.iter_rows() { - s = s + utils::dot(row, row); + let raw_slice = row.raw_slice(); + s = s + utils::dot(raw_slice, raw_slice); } s.sqrt() } @@ -658,7 +800,8 @@ impl<'a, T: Float> Metric for MatrixSliceMut<'a, T> { let mut s = T::zero(); for row in self.iter_rows() { - s = s + utils::dot(row, row); + let raw_slice = row.raw_slice(); + s = s + utils::dot(raw_slice, raw_slice); } s.sqrt() } @@ -811,7 +954,11 @@ fn parity(m: &M) -> T while !visited[next] { len += 1; visited[next] = true; - next = utils::find(&m.get_row(next).unwrap(), T::one()); + unsafe { + next = utils::find(&m.row_unchecked(next) + .raw_slice(), + T::one()); + } } if len % 2 == 0 { diff --git a/src/matrix/slice.rs b/src/matrix/slice.rs index bf481b7..115448a 100644 --- a/src/matrix/slice.rs +++ b/src/matrix/slice.rs @@ -19,7 +19,8 @@ //! let _new_mat = &mat_slice.transpose() * &a; //! ``` -use matrix::{Matrix, MatrixSlice, MatrixSliceMut, Rows, RowsMut, Axes}; +use matrix::{Matrix, MatrixSlice, MatrixSliceMut}; +use matrix::{Row, RowMut, Column, ColumnMut, Rows, RowsMut, Axes}; use matrix::{DiagOffset, Diagonal, DiagonalMut}; use matrix::{back_substitution, forward_substitution}; use vector::Vector; @@ -75,26 +76,87 @@ pub trait BaseMatrix: Sized { &*(self.as_ptr().offset((index[0] * self.row_stride() + index[1]) as isize)) } - /// Returns the row of a matrix at the given index. + /// Returns the column of a matrix at the given index. /// `None` if the index is out of bounds. /// /// # Examples /// /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrix}; /// - /// let a = Matrix::new(3,3, (0..9).collect::>()); - /// let slice = a.sub_slice([1,1], 2, 2); - /// let row = slice.get_row(1); - /// let expected = vec![7usize, 8]; - /// assert_eq!(row, Some(&*expected)); - /// assert!(slice.get_row(5).is_none()); + /// let mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let col = mat.col(1); + /// let expected = matrix![1usize; 4; 7]; + /// assert_matrix_eq!(*col, expected); + /// # } + /// ``` + /// + /// # Panics + /// + /// Will panic if the column index is out of bounds. + fn col(&self, index: usize) -> Column { + if index < self.cols() { + unsafe { self.col_unchecked(index) } + } else { + panic!("Column index out of bounds.") + } + } + + /// Returns the column of a matrix at the given + /// index without doing a bounds check. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrix}; + /// + /// let mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let col = unsafe { mat.col_unchecked(2) }; + /// let expected = matrix![2usize; 5; 8]; + /// assert_matrix_eq!(*col, expected); + /// # } /// ``` - fn get_row(&self, index: usize) -> Option<&[T]> { + unsafe fn col_unchecked(&self, index: usize) -> Column { + let ptr = self.as_ptr().offset(index as isize); + Column{ + col: MatrixSlice::from_raw_parts(ptr, + self.rows(), + 1, + self.row_stride()) + } + } + + /// Returns the row of a matrix at the given index. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrix}; + /// + /// let mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let row = mat.row(1); + /// let expected = matrix![3usize, 4, 5]; + /// assert_matrix_eq!(*row, expected); + /// # } + /// ``` + /// + /// # Panics + /// + /// Will panic if the row index is out of bounds. + fn row(&self, index: usize) -> Row { if index < self.rows() { - unsafe { Some(self.get_row_unchecked(index)) } + unsafe { self.row_unchecked(index) } } else { - None + panic!("Row index out of bounds.") } } @@ -103,17 +165,25 @@ pub trait BaseMatrix: Sized { /// # Examples /// /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrix}; /// - /// let a = Matrix::new(3,3, (0..9).collect::>()); - /// let slice = a.sub_slice([1,1], 2, 2); - /// let row = unsafe { slice.get_row_unchecked(1) }; - /// let mut expected = vec![7usize, 8]; - /// assert_eq!(row, &*expected); + /// let mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let row = unsafe { mat.row_unchecked(2) }; + /// let expected = matrix![6usize, 7, 8]; + /// assert_matrix_eq!(*row, expected); + /// # } /// ``` - unsafe fn get_row_unchecked(&self, index: usize) -> &[T] { + unsafe fn row_unchecked(&self, index: usize) -> Row { let ptr = self.as_ptr().offset((self.row_stride() * index) as isize); - ::std::slice::from_raw_parts(ptr, self.cols()) + Row { + row: MatrixSlice::from_raw_parts(ptr, + 1, + self.cols(), + self.row_stride()) + } } /// Returns an iterator over the matrix data. @@ -121,13 +191,17 @@ pub trait BaseMatrix: Sized { /// # Examples /// /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrix}; /// - /// let a = Matrix::new(3,3, (0..9).collect::>()); - /// let slice = a.sub_slice([1,1], 2, 2); + /// let mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let slice = mat.sub_slice([1,1], 2, 2); /// /// let slice_data = slice.iter().map(|v| *v).collect::>(); /// assert_eq!(slice_data, vec![4,5,7,8]); + /// # } /// ``` fn iter<'a>(&self) -> SliceIter<'a, T> where T: 'a @@ -154,7 +228,7 @@ pub trait BaseMatrix: Sized { /// /// // Prints "2" three times. /// for row in a.iter_rows() { - /// println!("{}", row.len()); + /// println!("{}", row.cols()); /// } /// ``` fn iter_rows(&self) -> Rows { @@ -243,7 +317,7 @@ pub trait BaseMatrix: Sized { where T: Copy + Zero + Add { let sum_rows = self.iter_rows().fold(vec![T::zero(); self.cols()], |row_sum, r| { - utils::vec_bin_op(&row_sum, r, |sum, val| sum + val) + utils::vec_bin_op(&row_sum, r.raw_slice(), |sum, val| sum + val) }); Vector::new(sum_rows) } @@ -269,7 +343,7 @@ pub trait BaseMatrix: Sized { where T: Copy + Zero + Add { let mut col_sum = Vec::with_capacity(self.rows()); - col_sum.extend(self.iter_rows().map(|row| utils::unrolled_sum(row))); + col_sum.extend(self.iter_rows().map(|row| utils::unrolled_sum(row.raw_slice()))); Vector::new(col_sum) } @@ -289,7 +363,7 @@ pub trait BaseMatrix: Sized { where T: Copy + Zero + Add { self.iter_rows() - .fold(T::zero(), |sum, row| sum + utils::unrolled_sum(row)) + .fold(T::zero(), |sum, row| sum + utils::unrolled_sum(row.raw_slice())) } /// Convert the matrix struct into a owned Matrix. @@ -333,10 +407,10 @@ pub trait BaseMatrix: Sized { "Row index is greater than number of rows."); } - for row in row_iter.clone() { + for row_idx in row_iter.clone() { unsafe { - let slice = self.get_row_unchecked(*row); - mat_vec.extend_from_slice(slice); + let row = self.row_unchecked(*row_idx); + mat_vec.extend_from_slice(row.raw_slice()); } } @@ -421,7 +495,9 @@ pub trait BaseMatrix: Sized { let mut data = Vec::with_capacity(self.rows() * self.cols()); for (self_r, m_r) in self.iter_rows().zip(m.iter_rows()) { - data.extend_from_slice(&utils::vec_bin_op(self_r, m_r, T::mul)); + data.extend_from_slice(&utils::vec_bin_op(self_r.raw_slice(), + m_r.raw_slice(), + T::mul)); } Matrix::new(self.rows(), self.cols(), data) } @@ -452,7 +528,9 @@ pub trait BaseMatrix: Sized { let mut data = Vec::with_capacity(self.rows() * self.cols()); for (self_r, m_r) in self.iter_rows().zip(m.iter_rows()) { - data.extend_from_slice(&utils::vec_bin_op(self_r, m_r, T::div)); + data.extend_from_slice(&utils::vec_bin_op(self_r.raw_slice(), + m_r.raw_slice(), + T::div)); } Matrix::new(self.rows(), self.cols(), data) } @@ -536,8 +614,8 @@ pub trait BaseMatrix: Sized { let mut new_data = Vec::with_capacity((self.cols() + m.cols()) * self.rows()); for (self_row, m_row) in self.iter_rows().zip(m.iter_rows()) { - new_data.extend_from_slice(self_row); - new_data.extend_from_slice(m_row); + new_data.extend_from_slice(self_row.raw_slice()); + new_data.extend_from_slice(m_row.raw_slice()); } Matrix { @@ -575,7 +653,7 @@ pub trait BaseMatrix: Sized { let mut new_data = Vec::with_capacity((self.rows() + m.rows()) * self.cols()); for row in self.iter_rows().chain(m.iter_rows()) { - new_data.extend_from_slice(row); + new_data.extend_from_slice(row.raw_slice()); } Matrix { @@ -887,48 +965,136 @@ pub trait BaseMatrixMut: BaseMatrix { } } + /// Returns a mutable reference to the column of a matrix at the given index. + /// `None` if the index is out of bounds. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] + /// # extern crate rulinalg; + /// + /// # fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; + /// + /// let mut mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let mut slice = mat.sub_slice_mut([1,1], 2, 2); + /// { + /// let col = slice.col_mut(1); + /// let mut expected = matrix![5usize; 8]; + /// assert_matrix_eq!(*col, expected); + /// } + /// # } + /// ``` + /// + /// # Panics + /// + /// Will panic if the column index is out of bounds. + fn col_mut(&mut self, index: usize) -> ColumnMut { + if index < self.cols() { + unsafe { self.col_unchecked_mut(index) } + } else { + panic!("Column index out of bounds.") + } + } + + /// Returns a mutable reference to the column of a matrix at the given index + /// without doing a bounds check. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] + /// # extern crate rulinalg; + /// + /// # fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; + /// + /// let mut mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let mut slice = mat.sub_slice_mut([1,1], 2, 2); + /// let col = unsafe { slice.col_unchecked_mut(1) }; + /// let mut expected = matrix![5usize; 8]; + /// assert_matrix_eq!(*col, expected); + /// # } + /// ``` + unsafe fn col_unchecked_mut(&mut self, index: usize) -> ColumnMut { + let ptr = self.as_mut_ptr().offset(index as isize); + ColumnMut { + col: MatrixSliceMut::from_raw_parts(ptr, + self.rows(), + 1, + self.row_stride()) + } + } + /// Returns a mutable reference to the row of a matrix at the given index. /// `None` if the index is out of bounds. /// /// # Examples /// /// ``` + /// # #[macro_use] + /// # extern crate rulinalg; + /// + /// # fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; /// - /// let mut a = Matrix::new(3,3, (0..9).collect::>()); - /// let mut slice = a.sub_slice_mut([1,1], 2, 2); + /// let mut mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let mut slice = mat.sub_slice_mut([1,1], 2, 2); /// { - /// let row = slice.get_row_mut(1); - /// let mut expected = vec![7usize, 8]; - /// assert_eq!(row, Some(&mut *expected)); + /// let row = slice.row_mut(1); + /// let mut expected = matrix![7usize, 8]; + /// assert_matrix_eq!(*row, expected); /// } - /// assert!(slice.get_row_mut(5).is_none()); + /// # } /// ``` - fn get_row_mut(&mut self, index: usize) -> Option<&mut [T]> { + /// + /// # Panics + /// + /// Will panic if the row index is out of bounds. + fn row_mut(&mut self, index: usize) -> RowMut { if index < self.rows() { - unsafe { Some(self.get_row_unchecked_mut(index)) } + unsafe { self.row_unchecked_mut(index) } } else { - None + panic!("Row index out of bounds.") } } /// Returns a mutable reference to the row of a matrix at the given index - /// without doing unbounds checking + /// without doing a bounds check. /// /// # Examples /// /// ``` + /// # #[macro_use] + /// # extern crate rulinalg; + /// + /// # fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; /// - /// let mut a = Matrix::new(3,3, (0..9).collect::>()); - /// let mut slice = a.sub_slice_mut([1,1], 2, 2); - /// let row = unsafe { slice.get_row_unchecked_mut(1) }; - /// let mut expected = vec![7usize, 8]; - /// assert_eq!(row, &mut *expected); + /// let mut mat = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; + /// let mut slice = mat.sub_slice_mut([1,1], 2, 2); + /// let row = unsafe { slice.row_unchecked_mut(1) }; + /// let mut expected = matrix![7usize, 8]; + /// assert_matrix_eq!(*row, expected); + /// # } /// ``` - unsafe fn get_row_unchecked_mut(&mut self, index: usize) -> &mut [T] { + unsafe fn row_unchecked_mut(&mut self, index: usize) -> RowMut { let ptr = self.as_mut_ptr().offset((self.row_stride() * index) as isize); - ::std::slice::from_raw_parts_mut(ptr, self.cols()) + RowMut { + row: MatrixSliceMut::from_raw_parts(ptr, + 1, + self.cols(), + self.row_stride()) + } } /// Swaps two rows in a matrix. @@ -1037,10 +1203,8 @@ pub trait BaseMatrixMut: BaseMatrix { /// /// let mut a = Matrix::new(3, 2, (0..6).collect::>()); /// - /// for row in a.iter_rows_mut() { - /// for r in row { - /// *r = *r + 1; - /// } + /// for mut row in a.iter_rows_mut() { + /// *row += 1; /// } /// /// // Now contains the range 1..7 @@ -1143,9 +1307,11 @@ pub trait BaseMatrixMut: BaseMatrix { "Target has different row count to self."); assert!(self.cols() == target.cols(), "Target has different column count to self."); - for (s, t) in self.iter_rows_mut().zip(target.iter_rows()) { + for (mut s, t) in self.iter_rows_mut().zip(target.iter_rows()) { // Vectorized assignment per row. - utils::in_place_vec_bin_op(s, t, |x, &y| *x = y); + utils::in_place_vec_bin_op(s.raw_slice_mut(), + t.raw_slice(), + |x, &y| *x = y); } } @@ -1312,7 +1478,7 @@ impl BaseMatrix for Matrix { new_data.reserve(m.rows() * m.cols()); for row in m.iter_rows() { - new_data.extend_from_slice(row); + new_data.extend_from_slice(row.raw_slice()); } Matrix { @@ -1367,6 +1533,84 @@ impl<'a, T> BaseMatrixMut for MatrixSliceMut<'a, T> { } } +impl<'a, T> BaseMatrix for Row<'a, T> { + fn rows(&self) -> usize { + 1 + } + fn cols(&self) -> usize { + self.row.cols() + } + fn row_stride(&self) -> usize { + self.row.row_stride() + } + + fn as_ptr(&self) -> *const T { + self.row.as_ptr() + } +} + +impl<'a, T> BaseMatrix for RowMut<'a, T> { + fn rows(&self) -> usize { + 1 + } + fn cols(&self) -> usize { + self.row.cols() + } + fn row_stride(&self) -> usize { + self.row.row_stride() + } + + fn as_ptr(&self) -> *const T { + self.row.as_ptr() + } +} + +impl<'a, T> BaseMatrixMut for RowMut<'a, T> { + /// Top left index of the slice. + fn as_mut_ptr(&mut self) -> *mut T { + self.row.as_mut_ptr() + } +} + +impl<'a, T> BaseMatrix for Column<'a, T> { + fn rows(&self) -> usize { + self.col.rows() + } + fn cols(&self) -> usize { + 1 + } + fn row_stride(&self) -> usize { + self.col.row_stride() + } + + fn as_ptr(&self) -> *const T { + self.col.as_ptr() + } +} + +impl<'a, T> BaseMatrix for ColumnMut<'a, T> { + fn rows(&self) -> usize { + self.col.rows() + } + fn cols(&self) -> usize { + 1 + } + fn row_stride(&self) -> usize { + self.col.row_stride() + } + + fn as_ptr(&self) -> *const T { + self.col.as_ptr() + } +} + +impl<'a, T> BaseMatrixMut for ColumnMut<'a, T> { + /// Top left index of the slice. + fn as_mut_ptr(&mut self) -> *mut T { + self.col.as_mut_ptr() + } +} + impl<'a, T> MatrixSlice<'a, T> { /// Produce a `MatrixSlice` from a `Matrix` ///