From 2c1ea12472621d4c7dd2047affe563878d73f358 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Wed, 30 Jan 2019 21:53:53 -0800 Subject: [PATCH] dropout passthrough --- python/mxnet/gluon/nn/basic_layers.py | 5 ++++- src/operator/nn/dropout-inl.h | 2 +- src/operator/nn/dropout.cc | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index ace814275d61..f8566dd05aa5 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -262,7 +262,10 @@ def __init__(self, rate, axes=(), **kwargs): self._axes = axes def hybrid_forward(self, F, x): - return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + if self._rate > 0: + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + else: + return F.identity(x) def __repr__(self): s = '{name}(p = {_rate}, axes={_axes})' diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index e8a808350eb5..55fb03283e55 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -338,7 +338,7 @@ class DropoutOp { const TBlob &in = in_data[dropout::kData]; const TBlob &out = out_data[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - if (ctx.is_train || this->mode_ == dropout::kAlways) { + if (this->pkeep_ < 1 && (ctx.is_train || this->mode_ == dropout::kAlways)) { this->dropout_passthrough_ = false; if (this->axes_.ndim() == 0) { #if MXNET_USE_MKL_DROPOUT diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 3205fe9fb320..d6cbeb4e561d 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -128,9 +128,10 @@ Example:: .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { std::vector request; + const DropoutParam& param = nnvm::get(attrs.parsed); + if (param.p == 0) return request; if (dev_mask == kGPU) { #if MXNET_USE_CUDNN_DROPOUT - const DropoutParam& param = nnvm::get(attrs.parsed); // if cudnn is used, parallel random is not needed. if (1.0f - param.p > 0 && !(param.cudnn_off && param.cudnn_off.value())