Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
turn off cudnn by default
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 27, 2019
1 parent 54f6e23 commit 8a7707c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
.describe("Whether to only turn on dropout during training or to also turn on for inference.");
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropout kernel.");
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>(false))
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>(true))
.describe("Whether to turn off cudnn in dropout operator.");
}
}; // struct DropoutParam
Expand Down
28 changes: 23 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5902,10 +5902,10 @@ def check_correctness(executor, input, ratio):
elif ratio == 0:
assert output_zeroes == 0

def check_dropout_ratio(ratio, shape):
def check_dropout_ratio(ratio, shape, cudnn_off=True):
# test dropout
x = mx.sym.var('data')
y = mx.sym.Dropout(x, p=ratio)
y = mx.sym.Dropout(x, p=ratio, cudnn_off=cudnn_off)
exe = y.simple_bind(ctx=default_context(), data=shape)

if ratio == 1:
Expand Down Expand Up @@ -5943,7 +5943,7 @@ def check_dropout_ratio(ratio, shape):

# test permanent dropout
x = mx.sym.var('data')
y = mx.sym.Dropout(x, p=ratio, mode='always')
y = mx.sym.Dropout(x, p=ratio, mode='always', cudnn_off=cudnn_off)
exe = y.simple_bind(ctx=default_context(), data=shape)

exe.arg_arrays[0][:] = 1
Expand All @@ -5968,13 +5968,13 @@ def get_slice(x, axis, idx):
ix += (slice(None, None, None),)
return x[ix]

def check_dropout_axes(ratio, shape, axes):
def check_dropout_axes(ratio, shape, axes, cudnn_off=False):
compactshape = list(shape)
for axis in axes:
compactshape[axis] = 1
compactx = mx.random.uniform(shape=tuple(compactshape))
broadcastx = compactx.broadcast_to(shape)
dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes)
dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes, cudnn_off=cudnn_off)
for axis in axes:
target = get_slice(dropouty, axis, 0).asnumpy()
for i in range(1, shape[axis]):
Expand All @@ -5986,6 +5986,11 @@ def check_dropout_axes(ratio, shape, axes):
check_dropout_ratio(1.0, shape)
check_dropout_ratio(0.75, shape)
check_dropout_ratio(0.25, shape)
check_dropout_ratio(0.5, shape, cudnn_off=False)
check_dropout_ratio(0.0, shape, cudnn_off=False)
check_dropout_ratio(1.0, shape, cudnn_off=False)
check_dropout_ratio(0.75, shape, cudnn_off=False)
check_dropout_ratio(0.25, shape, cudnn_off=False)

nshape = (10, 10, 10, 10)
with mx.autograd.train_mode():
Expand All @@ -6002,6 +6007,19 @@ def check_dropout_axes(ratio, shape, axes):
check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
check_dropout_axes(0.25, nshape, axes = (0,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (2,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (3,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 1), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (2, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 1, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 2, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False)


@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at /~https://github.com/apache/incubator-mxnet/issues/11290")
Expand Down

0 comments on commit 8a7707c

Please sign in to comment.