Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Mark cuDNN Dropout as fully CUDA Graphs compatible. Reenable tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Feb 15, 2022
1 parent 4a2dae4 commit 64e8555
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/operator/nn/dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@ namespace op {

NNVM_REGISTER_OP(Dropout)
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
[](const NodeAttrs&, const bool is_train) {
// Dropout is just passthrough during inference
return !is_train;
})
[](const NodeAttrs& attrs, const bool is_train) {
// Dropout is just passthrough during inference for all impls
if (!is_train)
return true;

// cuDNN impl is compatible during training as well
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
real_t pkeep = 1.0f - param.p;
bool cudnn_off = param.cudnn_off && param.cudnn_off.value();
bool cudnn_available = pkeep > 0 && !cudnn_off;
return MXNET_USE_CUDNN_DROPOUT && cudnn_available;
})
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutCompute<gpu>);

NNVM_REGISTER_OP(_backward_Dropout)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def generate_inputs(self):
TestDesc('ConvTranspose', lambda: mx.gluon.nn.Conv2DTranspose(channels=32, kernel_size=(1,1))),
TestDesc('Dense', lambda: mx.gluon.nn.Dense(units=128)),
TestDesc('Activation', lambda: mx.gluon.nn.Activation('tanh')),
#TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)),
TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)),
TestDesc('Flatten', lambda: mx.gluon.nn.Flatten()),
TestDesc('MaxPool', lambda: mx.gluon.nn.MaxPool2D()),
TestDesc('AvgPool', lambda: mx.gluon.nn.AvgPool2D()),
Expand Down

0 comments on commit 64e8555

Please sign in to comment.