diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 74a1aff5093f..60cc16deebd4 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1928,6 +1928,60 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, op.Forward(ctx, in_blobs, req, out_blobs); }); } + +static void RNNStatefulGradComputeCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_blobs; + std::vector out_blobs; + std::vector temp_ndarrays_i; + std::vector temp_ndarrays_o; + for (const NDArray& in : inputs) { + if (in.storage_type() == kDefaultStorage) { + temp_ndarrays_i.push_back(in.Reorder2Default()); + in_blobs.emplace_back(temp_ndarrays_i.back().data()); + } else { + in_blobs.emplace_back(in.data()); + } + } + for (const NDArray& out : outputs) { + if (out.storage_type() == kDefaultStorage) { + temp_ndarrays_o.push_back(out.Reorder2Default()); + out_blobs.emplace_back(temp_ndarrays_o.back().data()); + } else { + out_blobs.emplace_back(out.data()); + } + } + + std::vector in_data(in_blobs.begin(), in_blobs.begin() + 3); + std::vector out_data{in_blobs[3]}; + std::vector out_grad{in_blobs[4]}; + const std::vector &in_grad = out_blobs; + + int dtype = in_blobs[rnn_enum::kData].type_flag_; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + RNNOp& op = state_ptr.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(in_blobs[index++]); + out_grad.push_back(in_blobs[index++]); + } + + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(in_blobs[index++]); + if (param.state_outputs) { + out_data.push_back(in_blobs[index++]); + out_grad.push_back(in_blobs[index]); + } + } + + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); + }); +} + #endif /* index description diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 4d66becbeba7..da81107ae727 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -258,11 +258,9 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) -/* #if MXNET_USE_MKLDNN == 1 .set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) #endif -*/ .set_attr("FGradient", RNNGrad{"_backward_RNN"}) .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { @@ -295,6 +293,9 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FStatefulComputeEx", RNNStatefulGradComputeCPU) +#endif .set_attr("FStatefulCompute", RNNStatefulGradCompute); } // namespace op } // namespace mxnet