Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reuse dropout state space
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 31, 2019
1 parent 7f76582 commit 9db5fde
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 67 deletions.
20 changes: 20 additions & 0 deletions include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ struct ResourceRequest {
kTempSpace,
/*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
kParallelRandom
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
,
/*! \brief cudnnDropoutDescriptor_t object for GPU dropout kernel functions */
kCuDNNDropoutDesc
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
};
/*! \brief type of resources */
Type type;
Expand Down Expand Up @@ -157,6 +162,21 @@ struct Resource {
reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))),
shape, shape[ndim - 1], stream);
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
/*!
* \brief Get cudnn dropout descriptor from shared state space.
*
* \param dropout_desc reference to previously created cudnn dropout descriptor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
*/
void get_cudnn_dropout_desc(
cudnnDropoutDescriptor_t* dropout_desc,
mshadow::Stream<gpu> *stream,
const float dropout,
uint64_t seed) const;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7

/*!
* \brief Get CPU space as mshadow Tensor in specified type.
* The caller can request arbitrary size.
Expand Down
39 changes: 26 additions & 13 deletions src/executor/attach_op_resource_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,33 @@ void AttachOpResources(
: fresource[op](inode.source->attrs);
// Get the resource of temporal space.
for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
if (cached_temp.count(ctx) != 0) {
requested.push_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
requested.push_back(r);
cached_temp[ctx] = r;
switch (req.type) {
case ResourceRequest::kTempSpace: {
if (cached_temp.count(ctx) != 0) {
requested.push_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
requested.push_back(r);
cached_temp[ctx] = r;
}
break;
}
} else if (req.type == ResourceRequest::kRandom) {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
} else if (req.type == ResourceRequest::kParallelRandom) {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
} else {
LOG(FATAL) << "resource type not yet supported";
case ResourceRequest::kRandom: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
case ResourceRequest::kParallelRandom: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type " << req.type << " is not yet supported";
}
}
CHECK(vdispatch[nid] != DispatchMode::kUndefined);
Expand Down
6 changes: 6 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc:
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type not yet supported";
}
Expand Down
22 changes: 4 additions & 18 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropout kernel.");
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>(true))
.describe("Whether to turn off cudnn in dropout operator.");
.describe("Whether to turn off cudnn in dropout operator. "
"This option is ignored if axes is specified.");
}
}; // struct DropoutParam

Expand Down Expand Up @@ -211,7 +212,6 @@ class DropoutOp {
this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value();
this->ctx_ = ctx;
if (ctx.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) {
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_));
Expand All @@ -230,9 +230,6 @@ class DropoutOp {
CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_));
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
if (init_cudnn_) {
Storage::Get()->Free(dropout_states_);
}
}
#endif // MXNET_USE_CUDNN_DROPOUT
}
Expand All @@ -249,16 +246,7 @@ class DropoutOp {
Stream<xpu> *s = ctx.get_stream<xpu>();

// set dropout state.
// TODO(szha): expensive call, should be cached and reused across operators.
if (!init_cudnn_) {
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_state_byte_));
dropout_states_ = Storage::Get()->Alloc(dropout_state_byte_, Context::GPU(s->dev_id));
CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
1.0f - this->pkeep_,
dropout_states_.dptr, dropout_state_byte_,
seed_));
init_cudnn_ = true;
}
ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - this->pkeep_, seed_);

// describe input/output tensor
int dim[4], stride[4];
Expand Down Expand Up @@ -493,10 +481,8 @@ class DropoutOp {
Context ctx_;
cudnnDataType_t dtype_;
cudnnDropoutDescriptor_t dropout_desc_;
bool init_cudnn_;
uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn)
size_t dropout_state_byte_, dropout_reserve_byte_;
Storage::Handle dropout_states_;
size_t dropout_reserve_byte_;
cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_;
#endif // MXNET_USE_CUDNN_DROPOUT
}; // class DropoutOp
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Example::
if (1.0f - param.p > 0
&& !(param.cudnn_off && param.cudnn_off.value())
&& param.axes.ndim() == 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
return request;
}
#endif
Expand Down
63 changes: 57 additions & 6 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <atomic>
#include "./common/lazy_alloc_array.h"
#include "./common/utils.h"
#include "./common/cuda_utils.h"

namespace mxnet {
namespace resource {
Expand Down Expand Up @@ -92,11 +93,14 @@ class ResourceManagerImpl : public ResourceManager {
gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 1);
cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_PARALLEL_RAND_COPY", 1);
gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_PARALLEL_RAND_COPY", 4);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
gpu_cudnn_dropout_state_copy_ = dmlc::GetEnv("MXNET_GPU_CUDNN_DROPOUT_STATE_COPY", 4);
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
engine_ref_ = Engine::_GetSharedRef();
storage_ref_ = Storage::_GetSharedRef();
cpu_rand_.reset(new ResourceRandom<cpu>(
Context::CPU(), global_seed_));
cpu_space_.reset(new ResourceTempSpace(
cpu_space_.reset(new ResourceTempSpace<ResourceRequest::kTempSpace>(
Context::CPU(), cpu_temp_space_copy_));
cpu_parallel_rand_.reset(new ResourceParallelRandom<cpu>(
Context::CPU(), cpu_native_rand_copy_, global_seed_));
Expand All @@ -110,6 +114,9 @@ class ResourceManagerImpl : public ResourceManager {
gpu_rand_.Clear();
gpu_space_.Clear();
gpu_parallel_rand_.Clear();
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
gpu_cudnn_dropout_state_.Clear();
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#endif
if (engine_ref_ != nullptr) {
engine_ref_ = nullptr;
Expand Down Expand Up @@ -139,14 +146,21 @@ class ResourceManagerImpl : public ResourceManager {
}
case ResourceRequest::kTempSpace: {
return gpu_space_.Get(ctx.dev_id, [ctx, this]() {
return new ResourceTempSpace(ctx, gpu_temp_space_copy_);
return new ResourceTempSpace<ResourceRequest::kTempSpace>(ctx, gpu_temp_space_copy_);
})->GetNext();
}
case ResourceRequest::kParallelRandom: {
return gpu_parallel_rand_.Get(ctx.dev_id, [ctx, this]() {
return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, global_seed_);
})->GetNext();
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc: {
return gpu_cudnn_dropout_state_.Get(ctx.dev_id, [ctx, this]() {
return new ResourceTempSpace<ResourceRequest::kCuDNNDropoutDesc>(ctx, gpu_cudnn_dropout_state_copy_);
})->GetNext();
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default: LOG(FATAL) << "Unknown supported type " << req.type;
}
#else
Expand Down Expand Up @@ -231,7 +245,8 @@ class ResourceManagerImpl : public ResourceManager {
}
};

// temporal space resource.
// temporary space resource.
template<ResourceRequest::Type req>
struct ResourceTempSpace {
/*! \brief the context of the device */
Context ctx;
Expand All @@ -248,7 +263,7 @@ class ResourceManagerImpl : public ResourceManager {
resource[i].var = Engine::Get()->NewVariable();
resource[i].id = static_cast<int32_t>(i);
resource[i].ptr_ = &space[i];
resource[i].req = ResourceRequest(ResourceRequest::kTempSpace);
resource[i].req = ResourceRequest(req);
space[i].ctx = ctx;
CHECK_EQ(space[i].handle.size, 0U);
}
Expand Down Expand Up @@ -372,16 +387,22 @@ class ResourceManagerImpl : public ResourceManager {
/*! \brief CPU random number resources */
std::unique_ptr<ResourceRandom<cpu> > cpu_rand_;
/*! \brief CPU temp space resources */
std::unique_ptr<ResourceTempSpace> cpu_space_;
std::unique_ptr<ResourceTempSpace<ResourceRequest::kTempSpace>> cpu_space_;
/*! \brief CPU parallel random number resources */
std::unique_ptr<ResourceParallelRandom<cpu> > cpu_parallel_rand_;
#if MXNET_USE_CUDA
/*! \brief random number generator for GPU */
common::LazyAllocArray<ResourceRandom<gpu> > gpu_rand_;
/*! \brief temp space for GPU */
common::LazyAllocArray<ResourceTempSpace> gpu_space_;
common::LazyAllocArray<ResourceTempSpace<ResourceRequest::kTempSpace>> gpu_space_;
/*! \brief GPU parallel (on device) random number resources */
common::LazyAllocArray<ResourceParallelRandom<gpu> > gpu_parallel_rand_;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
/*! \brief number of copies in GPU cudnn dropout descriptor resources */
int gpu_cudnn_dropout_state_copy_;
/*! \brief GPU parallel (on device) random number resources */
common::LazyAllocArray<ResourceTempSpace<ResourceRequest::kCuDNNDropoutDesc>> gpu_cudnn_dropout_state_;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#endif
};
} // namespace resource
Expand All @@ -394,6 +415,36 @@ void* Resource::get_host_space_internal(size_t size) const {
return static_cast<resource::SpaceAllocator*>(ptr_)->GetHostSpace(size);
}

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
void Resource::get_cudnn_dropout_desc(
cudnnDropoutDescriptor_t* dropout_desc,
mshadow::Stream<gpu> *stream,
const float dropout,
uint64_t seed) const {

CHECK_EQ(req.type, ResourceRequest::kCuDNNDropoutDesc);
auto state_space = static_cast<resource::SpaceAllocator*>(ptr_);
CHECK_EQ(state_space->ctx.dev_id, stream->dev_id)
<< "The device id of cudnn dropout state space doesn't match that from stream.";
if (!state_space->handle.size) {
// not initialized yet.
size_t dropout_state_size;
CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size));
CUDNN_CALL(cudnnSetDropoutDescriptor(*dropout_desc, stream->dnn_handle_,
dropout,
state_space->GetSpace(dropout_state_size),
dropout_state_size,
seed));
} else {
CUDNN_CALL(cudnnRestoreDropoutDescriptor(*dropout_desc, stream->dnn_handle_,
dropout,
state_space->handle.dptr,
state_space->handle.size,
seed));
}
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7

ResourceManager* ResourceManager::Get() {
typedef dmlc::ThreadLocalStore<resource::ResourceManagerImpl> inst;
return inst::Get();
Expand Down
38 changes: 25 additions & 13 deletions tests/cpp/include/test_core_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,32 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
if (!reqs.empty()) {
// Get the resource of temporal space.
for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
Resource r = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
requested.emplace_back(r);
} else if (req.type == ResourceRequest::kRandom) {
requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
} else if (req.type == ResourceRequest::kParallelRandom) {
Resource rm = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
if (ctx->run_ctx.ctx.dev_mask() == Context::kCPU) {
common::random::RandGenerator<cpu, DType>::AllocState(
rm.get_parallel_random<cpu, DType>());
switch (req.type) {
case ResourceRequest::kTempSpace: {
requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
break;
}
requested.emplace_back(rm);
} else {
LOG(FATAL) << "resource type not yet supported";
case ResourceRequest::kRandom: {
requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req));
break;
}
case ResourceRequest::kParallelRandom: {
Resource rm = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req);
if (ctx->run_ctx.ctx.dev_mask() == Context::kCPU) {
common::random::RandGenerator<cpu, DType>::AllocState(
rm.get_parallel_random<cpu, DType>());
}
requested.emplace_back(rm);
break;
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc: {
requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type " << req.type << " is not yet supported";
}
}
}
Expand Down
47 changes: 30 additions & 17 deletions tests/cpp/include/test_legacy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,25 +494,38 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
ctx.dev_id = 0;

for (const ResourceRequest& req : reqs) {
if (req.type == ResourceRequest::kTempSpace) {
if (cached_temp.count(ctx) != 0) {
opContext_.requested.emplace_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
opContext_.requested.emplace_back(r);
cached_temp[ctx] = r;
switch (req.type) {
case ResourceRequest::kTempSpace: {
if (cached_temp.count(ctx) != 0) {
opContext_.requested.emplace_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
opContext_.requested.emplace_back(r);
cached_temp[ctx] = r;
}
break;
}
case ResourceRequest::kRandom: {
opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
break;
}
case ResourceRequest::kParallelRandom: {
Resource rm = ResourceManager::Get()->Request(ctx, req);
if (ctx.dev_mask() == Context::kCPU) {
common::random::RandGenerator<cpu, DType>::AllocState(
rm.get_parallel_random<cpu, DType>());
}
opContext_.requested.emplace_back(rm);
break;
}
} else if (req.type == ResourceRequest::kRandom) {
opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
} else if (req.type == ResourceRequest::kParallelRandom) {
Resource rm = ResourceManager::Get()->Request(ctx, req);
if (ctx.dev_mask() == Context::kCPU) {
common::random::RandGenerator<cpu, DType>::AllocState(
rm.get_parallel_random<cpu, DType>());
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
case ResourceRequest::kCuDNNDropoutDesc: {
opContext_.requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
opContext_.requested.emplace_back(rm);
} else {
LOG(FATAL) << "resource type not yet supported";
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
default:
LOG(FATAL) << "resource type " << req.type << " is not yet supported";
}
}
}
Expand Down

0 comments on commit 9db5fde

Please sign in to comment.