Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
dropout passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 31, 2019
1 parent 1ad98ff commit b3a2af4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def __init__(self, rate, axes=(), **kwargs):
self._axes = axes

def hybrid_forward(self, F, x):
return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
if self._rate > 0:
return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
else:
return F.identity(x)

def __repr__(self):
s = '{name}(p = {_rate}, axes={_axes})'
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class DropoutOp {
const TBlob &in = in_data[dropout::kData];
const TBlob &out = out_data[dropout::kOut];
const TBlob &mask = out_data[dropout::kMask];
if (ctx.is_train || this->mode_ == dropout::kAlways) {
if (this->pkeep < 1 && (ctx.is_train || this->mode_ == dropout::kAlways)) {
this->dropout_passthrough_ = false;
if (this->axes_.ndim() == 0) {
#if MXNET_USE_MKL_DROPOUT
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ Example::
.set_attr<FResourceRequestEx>("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
std::vector<ResourceRequest> request;
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
if (param.p == 0) return param;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN_DROPOUT
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
// if cudnn is used, parallel random is not needed.
if (1.0f - param.p > 0
&& !(param.cudnn_off && param.cudnn_off.value())
Expand Down

0 comments on commit b3a2af4

Please sign in to comment.