Skip to content

Commit

Permalink
Merge pull request #25 from maciejkula/parallel_predict
Browse files Browse the repository at this point in the history
Parallel model fitting and prediction
  • Loading branch information
maciejkula authored Aug 9, 2016
2 parents 130d00a + 5c36a54 commit d6a7a7e
Show file tree
Hide file tree
Showing 15 changed files with 623 additions and 149 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ build = "build.rs"
[dependencies]
rand = "0.3"
rustc-serialize = "0.3"
crossbeam = "0.2.9"

[build-dependencies]
gcc = "0.3"
Expand Down
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Changelog

## [unreleased][unreleased]
### Added
- factorization machines
- parallel fitting and prediction for one-vs-rest models

## [0.3.1][2016-03-01]
### Changed
- NonzeroIterable now takes &self
Expand Down
6 changes: 5 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ For full usage details, see the [API documentation](https://maciejkula.github.io

## Introduction

This crate is mostly an excuse for me to learn Rust. Nevertheless, it contains reasonably effective
This crate contains reasonably effective
implementations of a number of common machine learning algorithms.

At the moment, `rustlearn` uses its own basic dense and sparse array types, but I will be happy
Expand Down Expand Up @@ -43,6 +43,10 @@ should be roughly competitive with Python `sklearn` implementations, both in acc
- [accuracy](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/fn.accuracy_score.html)
- [ROC AUC score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.roc_auc_score.html)

## Parallelization

A number of models support both parallel model fitting and prediction.

### Model serialization

Model serialization is supported via `rustc_serialize`. This will probably change to `serde` once compiler plugins land in stable.
Expand Down
58 changes: 52 additions & 6 deletions src/array/dense.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@

use std::iter::Iterator;
use std::ops::Range;

use array::traits::*;

Expand Down Expand Up @@ -144,6 +145,7 @@ pub struct ArrayView<'a> {

/// Iterator over row or column views of a dense matrix.
pub struct ArrayIterator<'a> {
stop: usize,
idx: usize,
axis: ArrayIteratorAxis,
array: &'a Array,
Expand All @@ -155,12 +157,7 @@ impl<'a> Iterator for ArrayIterator<'a> {

fn next(&mut self) -> Option<ArrayView<'a>> {

let bound = match self.axis {
ArrayIteratorAxis::Row => self.array.rows,
ArrayIteratorAxis::Column => self.array.cols,
};

let result = if self.idx < bound {
let result = if self.idx < self.stop {
Some(ArrayView {
idx: self.idx,
axis: self.axis,
Expand All @@ -182,6 +179,7 @@ impl<'a> RowIterable for &'a Array {
type Output = ArrayIterator<'a>;
fn iter_rows(self) -> ArrayIterator<'a> {
ArrayIterator {
stop: self.rows(),
idx: 0,
axis: ArrayIteratorAxis::Row,
array: self,
Expand All @@ -196,6 +194,21 @@ impl<'a> RowIterable for &'a Array {
array: self,
}
}

fn iter_rows_range(self, range: Range<usize>) -> ArrayIterator<'a> {
let stop = if range.end > self.rows {
self.rows
} else {
range.end
};

ArrayIterator {
stop: stop,
idx: range.start,
axis: ArrayIteratorAxis::Row,
array: self,
}
}
}


Expand All @@ -204,6 +217,7 @@ impl<'a> ColumnIterable for &'a Array {
type Output = ArrayIterator<'a>;
fn iter_columns(self) -> ArrayIterator<'a> {
ArrayIterator {
stop: self.cols(),
idx: 0,
axis: ArrayIteratorAxis::Column,
array: self,
Expand All @@ -218,6 +232,21 @@ impl<'a> ColumnIterable for &'a Array {
array: self,
}
}

fn iter_columns_range(self, range: Range<usize>) -> ArrayIterator<'a> {
let stop = if range.end > self.cols {
self.cols
} else {
range.end
};

ArrayIterator {
stop: stop,
idx: range.start,
axis: ArrayIteratorAxis::Column,
array: self,
}
}
}


Expand Down Expand Up @@ -1018,4 +1047,21 @@ mod tests {
}
}
}

use datasets::iris;

#[test]
fn range_iteration() {
let (data, _) = iris::load_data();

let (start, stop) = (5, 10);

for (row_num, row) in data.iter_rows_range(start..stop).enumerate() {
for (col_idx, value) in row.iter_nonzero() {
assert!(value == data.get(start + row_num, col_idx));
}

assert!(row_num < (stop - start));
}
}
}
69 changes: 65 additions & 4 deletions src/array/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
//!
//! ```
use std::iter::Iterator;
use std::ops::Range;

use array::dense::*;
use array::traits::*;
Expand Down Expand Up @@ -79,8 +80,8 @@ pub struct SparseArrayViewIterator<'a> {

/// Iterator over row or column views of a sparse matrix.
pub struct SparseArrayIterator<'a> {
stop: usize,
idx: usize,
dim: usize,
indices: &'a Vec<Vec<usize>>,
data: &'a Vec<Vec<f32>>,
}
Expand Down Expand Up @@ -290,7 +291,22 @@ impl<'a> RowIterable for &'a SparseRowArray {
fn iter_rows(self) -> SparseArrayIterator<'a> {
SparseArrayIterator {
idx: 0,
dim: self.rows,
stop: self.rows,
indices: &self.indices,
data: &self.data,
}
}

fn iter_rows_range(self, range: Range<usize>) -> SparseArrayIterator<'a> {
let stop = if range.end > self.rows {
self.rows
} else {
range.end
};

SparseArrayIterator {
stop: stop,
idx: range.start,
indices: &self.indices,
data: &self.data,
}
Expand Down Expand Up @@ -388,11 +404,27 @@ impl<'a> ColumnIterable for &'a SparseColumnArray {
fn iter_columns(self) -> SparseArrayIterator<'a> {
SparseArrayIterator {
idx: 0,
dim: self.cols,
stop: self.cols,
indices: &self.indices,
data: &self.data,
}
}

fn iter_columns_range(self, range: Range<usize>) -> SparseArrayIterator<'a> {
let stop = if range.end > self.cols {
self.cols
} else {
range.end
};

SparseArrayIterator {
stop: stop,
idx: range.start,
indices: &self.indices,
data: &self.data,
}
}

fn view_column(self, idx: usize) -> SparseArrayView<'a> {
SparseArrayView {
indices: &self.indices[idx],
Expand Down Expand Up @@ -457,7 +489,7 @@ impl<'a> Iterator for SparseArrayIterator<'a> {

fn next(&mut self) -> Option<SparseArrayView<'a>> {

let result = if self.idx < self.dim {
let result = if self.idx < self.stop {
Some(SparseArrayView {
indices: &self.indices[self.idx][..],
data: &self.data[self.idx][..],
Expand Down Expand Up @@ -595,4 +627,33 @@ mod tests {
&dense_arr.get_rows(&vec![1, 0])));
assert!(allclose(&arr.get_rows(&(..)).todense(), &dense_arr.get_rows(&(..))));
}

use datasets::iris;

#[test]
fn range_iteration() {
let (data, _) = iris::load_data();

let (start, stop) = (5, 10);

let data = SparseRowArray::from(&data);

for (row_num, row) in data.iter_rows_range(start..stop).enumerate() {
for (col_idx, value) in row.iter_nonzero() {
assert!(value == data.get(start + row_num, col_idx));
}

assert!(row_num < (stop - start));
}

let (start, stop) = (1, 3);

let data = SparseColumnArray::from(&data);

for (col_num, col) in data.iter_columns_range(start..stop).enumerate() {
for (row_idx, value) in col.iter_nonzero() {
assert!(value == data.get(row_idx, start + col_num));
}
}
}
}
4 changes: 4 additions & 0 deletions src/array/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub trait RowIterable {
type Output: Iterator<Item = Self::Item>;
/// Iterate over rows of the matrix.
fn iter_rows(self) -> Self::Output;
/// Iterate over a subset of rows of the matrix.
fn iter_rows_range(self, range: Range<usize>) -> Self::Output;
/// View a row of the matrix.
fn view_row(self, idx: usize) -> Self::Item;
}
Expand All @@ -88,6 +90,8 @@ pub trait ColumnIterable {
type Output: Iterator<Item = Self::Item>;
/// Iterate over columns of a the matrix.
fn iter_columns(self) -> Self::Output;
/// Iterate over a subset of columns of the matrix.
fn iter_columns_range(self, range: Range<usize>) -> Self::Output;
/// View a column of the matrix.
fn view_column(self, idx: usize) -> Self::Item;
}
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
pub mod iris;

#[cfg(test)]
#[cfg(feature = "all_tests")]
#[cfg(any(feature = "all_tests", feature = "bench"))]
pub mod newsgroups;
48 changes: 44 additions & 4 deletions src/ensemble/random_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,14 @@ impl Hyperparameters {
}


#[derive(RustcEncodable, RustcDecodable)]
#[derive(Clone)]
#[derive(RustcEncodable, RustcDecodable, Clone)]
pub struct RandomForest {
trees: Vec<decision_tree::DecisionTree>,
rng: EncodableRng,
}


impl SupervisedModel<Array> for RandomForest {
impl<'a> SupervisedModel<&'a Array> for RandomForest {
fn fit(&mut self, X: &Array, y: &Array) -> Result<(), &'static str> {

let mut rng = self.rng.clone();
Expand Down Expand Up @@ -145,7 +144,7 @@ impl SupervisedModel<Array> for RandomForest {
}


impl SupervisedModel<SparseRowArray> for RandomForest {
impl<'a> SupervisedModel<&'a SparseRowArray> for RandomForest {
fn fit(&mut self, X: &SparseRowArray, y: &Array) -> Result<(), &'static str> {

let mut rng = self.rng.clone();
Expand Down Expand Up @@ -253,6 +252,47 @@ mod tests {
assert!(test_accuracy > 0.96);
}

#[test]
fn test_random_forest_iris_parallel() {
let (data, target) = load_data();

let mut test_accuracy = 0.0;

let no_splits = 10;

let mut cv = CrossValidation::new(data.rows(), no_splits);
cv.set_rng(StdRng::from_seed(&[100]));

for (train_idx, test_idx) in cv {

let x_train = data.get_rows(&train_idx);
let x_test = data.get_rows(&test_idx);

let y_train = target.get_rows(&train_idx);

let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4)
.rng(StdRng::from_seed(&[100]));

let mut model = Hyperparameters::new(tree_params, 10)
.rng(StdRng::from_seed(&[100]))
.one_vs_rest();

model.fit_parallel(&x_train, &y_train, 2).unwrap();

let test_prediction = model.predict_parallel(&x_test, 2).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
}

test_accuracy /= no_splits as f32;

println!("Accuracy {}", test_accuracy);

assert!(test_accuracy > 0.96);
}

#[test]
fn test_random_forest_iris_sparse() {
let (data, target) = load_data();
Expand Down
Loading

0 comments on commit d6a7a7e

Please sign in to comment.