Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
test compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Li, Hao H committed Apr 19, 2019
1 parent 8914f78 commit fc597f3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
54 changes: 54 additions & 0 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,60 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
op.Forward(ctx, in_blobs, req, out_blobs);
});
}

static void RNNStatefulGradComputeCPU(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::vector<TBlob> in_blobs;
std::vector<TBlob> out_blobs;
std::vector<NDArray> temp_ndarrays_i;
std::vector<NDArray> temp_ndarrays_o;
for (const NDArray& in : inputs) {
if (in.storage_type() == kDefaultStorage) {
temp_ndarrays_i.push_back(in.Reorder2Default());
in_blobs.emplace_back(temp_ndarrays_i.back().data());
} else {
in_blobs.emplace_back(in.data());
}
}
for (const NDArray& out : outputs) {
if (out.storage_type() == kDefaultStorage) {
temp_ndarrays_o.push_back(out.Reorder2Default());
out_blobs.emplace_back(temp_ndarrays_o.back().data());
} else {
out_blobs.emplace_back(out.data());
}
}

std::vector<TBlob> in_data(in_blobs.begin(), in_blobs.begin() + 3);
std::vector<TBlob> out_data{in_blobs[3]};
std::vector<TBlob> out_grad{in_blobs[4]};
const std::vector<TBlob> &in_grad = out_blobs;

int dtype = in_blobs[rnn_enum::kData].type_flag_;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
RNNOp<cpu, DType>& op = state_ptr.get_state<RNNOp<cpu, DType>>();
const RNNParam& param = op.param_;
int index = 5;
if (param.state_outputs) {
out_data.push_back(in_blobs[index++]);
out_grad.push_back(in_blobs[index++]);
}

if (param.mode == rnn_enum::kLstm) {
in_data.push_back(in_blobs[index++]);
if (param.state_outputs) {
out_data.push_back(in_blobs[index++]);
out_grad.push_back(in_blobs[index]);
}
}

op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
});
}

#endif
/*
index description
Expand Down
5 changes: 3 additions & 2 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,9 @@ The definition of GRU here is slightly different from paper but compatible with
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
/*
#if MXNET_USE_MKLDNN == 1
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
#endif
*/
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
Expand Down Expand Up @@ -295,6 +293,9 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeCPU)
#endif
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>);
} // namespace op
} // namespace mxnet

0 comments on commit fc597f3

Please sign in to comment.