Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Cudnn dropout #13896

Merged
merged 9 commits into from
Feb 5, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ struct ResourceRequest {
kTempSpace,
/*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
kParallelRandom
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
,
/*! \brief cudnnDropoutDescriptor_t object for GPU dropout kernel functions */
kCuDNNDropoutDesc
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
};
/*! \brief type of resources */
Type type;
Expand Down Expand Up @@ -157,6 +162,21 @@ struct Resource {
reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))),
shape, shape[ndim - 1], stream);
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
/*!
* \brief Get cudnn dropout descriptor from shared state space.
*
* \param dropout_desc reference to previously created cudnn dropout descriptor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
*/
void get_cudnn_dropout_desc(
cudnnDropoutDescriptor_t* dropout_desc,
mshadow::Stream<gpu> *stream,
const float dropout,
uint64_t seed) const;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7

/*!
* \brief Get CPU space as mshadow Tensor in specified type.
* The caller can request arbitrary size.
Expand Down
5 changes: 4 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def __init__(self, rate, axes=(), **kwargs):
self._axes = axes

def hybrid_forward(self, F, x):
return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd')
if self._rate > 0:
return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
else:
return F.identity(x)

def __repr__(self):
s = '{name}(p = {_rate}, axes={_axes})'
Expand Down
39 changes: 26 additions & 13 deletions src/executor/attach_op_resource_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,33 @@ void AttachOpResources(
: fresource[op](inode.source->attrs);
// Get the resource of temporal space.
for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
if (cached_temp.count(ctx) != 0) {
requested.push_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
requested.push_back(r);
cached_temp[ctx] = r;
switch (req.type) {
case ResourceRequest::kTempSpace: {
if (cached_temp.count(ctx) != 0) {
requested.push_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
szha marked this conversation as resolved.
Show resolved Hide resolved
requested.push_back(r);
cached_temp[ctx] = r;
}
break;
}
} else if (req.type == ResourceRequest::kRandom) {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
} else if (req.type == ResourceRequest::kParallelRandom) {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
} else {
LOG(FATAL) << "resource type not yet supported";
case ResourceRequest::kRandom: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
case ResourceRequest::kParallelRandom: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type " << req.type << " is not yet supported";
}
}
CHECK(vdispatch[nid] != DispatchMode::kUndefined);
Expand Down
6 changes: 6 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc:
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type not yet supported";
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ class CuDNNRNNOp : public Operator {
if (param_.p > 0) {
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
dropout_size_ = dropout_byte_ / sizeof(DType);
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id));
szha marked this conversation as resolved.
Show resolved Hide resolved
} else {
dropout_states_ = {};
dropout_byte_ = 0;
Expand Down Expand Up @@ -764,7 +764,7 @@ class CuDNNRNNOp : public Operator {
&reserve_space_byte_));
workspace_size_ = workspace_byte_ / sizeof(DType);
// Allocate the reserve space
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU());
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id));

// Check that number of params are correct
size_t cudnn_param_size;
Expand Down
Loading