diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index a34d2992c8c6..2a643a266b2b 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -492,24 +492,6 @@ class DropoutOp { #endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp -static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, - const Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - const DropoutParam& param = nnvm::get(attrs.parsed); - OpStatePtr state; - MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - return state; - }); - LOG(FATAL) << "should never reach here"; - return OpStatePtr(); // should never reach here -} - template void DropoutCompute(const OpStatePtr& state, const OpContext& ctx, diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index bd76bd0d6e49..63da5613df84 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -26,10 +26,31 @@ #include "./dropout-inl.h" #include "../operator_common.h" +#include "mxnet/op_attr_types.h" + + namespace mxnet { namespace op { +OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const auto& param = nnvm::get(attrs.parsed); + OpStatePtr state; + MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + return state; + }); + LOG(FATAL) << "should never reach here"; + return OpStatePtr(); // should never reach here +} + struct DropoutGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr& n,