Skip to content

Commit

Permalink
Add grid parallelism parameter to h2o.gbm
Browse files Browse the repository at this point in the history
  • Loading branch information
anqi committed Jan 13, 2015
1 parent f0f9635 commit 6e9a357
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
18 changes: 11 additions & 7 deletions R/h2o-package/R/Algorithms.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ h2o.coxph <- function(x, y, data, key = "", weights = NULL, offset = NULL,
# ----------------------- Generalized Boosting Machines (GBM) ----------------------- #
# TODO: don't support missing x; default to everything?
h2o.gbm <- function(x, y, distribution = 'multinomial', data, key = "", n.trees = 10, interaction.depth = 5, n.minobsinnode = 10, shrinkage = 0.1,
n.bins = 20, group_split = TRUE, importance = FALSE, nfolds = 0, validation, holdout.fraction = 0, balance.classes = FALSE, max.after.balance.size = 5, class.sampling.factors = NULL) {
n.bins = 20, group_split = TRUE, importance = FALSE, nfolds = 0, validation, holdout.fraction = 0, balance.classes = FALSE,
max.after.balance.size = 5, class.sampling.factors = NULL, grid.parallelism = 1) {
args <- .verify_dataxy(data, x, y)

if(!is.character(key)) stop("key must be of class character")
Expand Down Expand Up @@ -173,22 +174,24 @@ h2o.gbm <- function(x, y, distribution = 'multinomial', data, key = "", n.trees
stop("validation must be an H2O parsed dataset")
if(!is.numeric(holdout.fraction)) stop("holdout.fraction must be numeric")
if(as.numeric(holdout.fraction) > 0 && (!missing(validation) || nfolds>1) ) stop("holdout.fraction cannot be combined with validation or nfolds")

if(!is.numeric(grid.parallelism)) stop("grid.parallelism must be numeric")
if(grid.parallelism < 1 || grid.parallelism > 4) stop("grid.parallelism must be 1, 2, 3, or 4")

# NB: externally, 1 based indexing; internally, 0 based
cols = paste(args$x_i - 1, collapse=",")
group_split <- as.numeric(group_split)
if(missing(validation) && nfolds == 0) {
res = .h2o.__remoteSend(data@h2o, .h2o.__PAGE_GBM, source=data@key, holdout_fraction = as.numeric(holdout.fraction), destination_key=key, response=args$y, cols=cols, ntrees=n.trees, max_depth=interaction.depth, learn_rate=shrinkage, family=family, group_split = group_split,
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors)
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors, grid_parallelism = grid.parallelism)
} else if(missing(validation) && nfolds >= 2) {
res = .h2o.__remoteSend(data@h2o, .h2o.__PAGE_GBM, source=data@key, destination_key=key, response=args$y, cols=cols, ntrees=n.trees, max_depth=interaction.depth, learn_rate=shrinkage, family=family, group_split = group_split,
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), n_folds=nfolds, balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors)
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), n_folds=nfolds, balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors, grid_parallelism = grid.parallelism)
} else if(!missing(validation) && nfolds == 0) {
res = .h2o.__remoteSend(data@h2o, .h2o.__PAGE_GBM, source=data@key, destination_key=key, response=args$y, cols=cols, ntrees=n.trees, max_depth=interaction.depth, learn_rate=shrinkage, family=family, group_split = group_split,
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), validation=validation@key, balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors)
min_rows=n.minobsinnode, classification=classification, nbins=n.bins, importance=as.numeric(importance), validation=validation@key, balance_classes=as.numeric(balance.classes), max_after_balance_size=as.numeric(max.after.balance.size), class_sampling_factors = class.sampling.factors, grid_parallelism = grid.parallelism)
} else stop("Cannot set both validation and nfolds at the same time")
params = list(x=args$x, y=args$y, distribution=distribution, n.trees=n.trees, interaction.depth=interaction.depth, shrinkage=shrinkage, n.minobsinnode=n.minobsinnode, n.bins=n.bins, importance=importance, nfolds=nfolds, balance.classes=balance.classes, max.after.balance.size=max.after.balance.size, class.sampling.factors = class.sampling.factors,
h2o = data@h2o, group_split = group_split)
h2o = data@h2o, group_split = group_split, grid_parallelism = grid.parallelism)

if(.is_singlerun("GBM", params))
.h2o.singlerun.internal("GBM", data, res, nfolds, validation, params)
Expand All @@ -207,7 +210,7 @@ h2o.gbm <- function(x, y, distribution = 'multinomial', data, key = "", n.trees
mySum$balance.classes = res$balance_classes
mySum$max.after.balance.size = res$max_after_balance_size
mySum$class.sampling.factors = res$class_sampling_factors

# if(params$distribution == "multinomial") {
# temp = matrix(unlist(res$cm), nrow = length(res$cm))
# mySum$prediction_error = 1-sum(diag(temp))/sum(temp)
Expand All @@ -231,6 +234,7 @@ h2o.gbm <- function(x, y, distribution = 'multinomial', data, key = "", n.trees
params$balance.classes = res$balance_classes
params$max.after.balance.size = res$max_after_balance_size
params$class.sampling.factors = res$class_sampling_factors
params$grid.parallelism = res$grid_parallelism
result$params = params

if(result$params$distribution %in% c("multinomial", "bernoulli")) {
Expand Down
4 changes: 3 additions & 1 deletion R/h2o-package/man/h2o.gbm.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ H2O: Gradient Boosted Machines
h2o.gbm(x, y, distribution = "multinomial", data, key = "", n.trees = 10,
interaction.depth = 5, n.minobsinnode = 10, shrinkage = 0.1, n.bins = 20,
group_split = TRUE, importance = FALSE, nfolds = 0, validation, holdout.fraction = 0,
balance.classes = FALSE, max.after.balance.size = 5, class.sampling.factors = NULL)
balance.classes = FALSE, max.after.balance.size = 5, class.sampling.factors = NULL,
grid.parallelism = 1)
}
\arguments{
\item{x}{
Expand Down Expand Up @@ -61,6 +62,7 @@ An \code{\linkS4class{H2OParsedData}} object containing the variables in the mod
\item{balance.classes}{(Optional) Balance training data class counts via over/under-sampling (for imbalanced data)}
\item{max.after.balance.size}{Maximum relative size of the training data after balancing class counts (can be less than 1.0)}
\item{class.sampling.factors}{ Desired over/under-sampling ratios per class (lexicographic order). }
\item{grid.parallelism}{An integer between 1 and 4 (inclusive) indicating how many parallel threads to run during grid search.}
}
\value{
An object of class \code{\linkS4class{H2OGBMModel}} with slots key, data, valid (the validation dataset) and model, where the last is a list of the following components:
Expand Down

0 comments on commit 6e9a357

Please sign in to comment.