From b3a2af4e9f22bbe7182c9401de3f9df2a3af5612 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Wed, 30 Jan 2019 21:53:53 -0800 Subject: [PATCH] dropout passthrough --- python/mxnet/gluon/nn/basic_layers.py | 5 ++++- src/operator/nn/dropout-inl.h | 2 +- src/operator/nn/dropout.cc | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index ace814275d61..f8566dd05aa5 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -262,7 +262,10 @@ def __init__(self, rate, axes=(), **kwargs): self._axes = axes def hybrid_forward(self, F, x): - return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + if self._rate > 0: + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + else: + return F.identity(x) def __repr__(self): s = '{name}(p = {_rate}, axes={_axes})' diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index e8a808350eb5..d4da11517ad0 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -338,7 +338,7 @@ class DropoutOp { const TBlob &in = in_data[dropout::kData]; const TBlob &out = out_data[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - if (ctx.is_train || this->mode_ == dropout::kAlways) { + if (this->pkeep < 1 && (ctx.is_train || this->mode_ == dropout::kAlways)) { this->dropout_passthrough_ = false; if (this->axes_.ndim() == 0) { #if MXNET_USE_MKL_DROPOUT diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 3205fe9fb320..4c098d86bd92 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -128,9 +128,10 @@ Example:: .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { std::vector request; + const DropoutParam& param = nnvm::get(attrs.parsed); + if (param.p == 0) return param; if (dev_mask == kGPU) { #if MXNET_USE_CUDNN_DROPOUT - const DropoutParam& param = nnvm::get(attrs.parsed); // if cudnn is used, parallel random is not needed. if (1.0f - param.p > 0 && !(param.cudnn_off && param.cudnn_off.value())