diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index d7e1543ec781..efe38019cfda 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -81,6 +81,7 @@ struct CachedOp::CachedOpState { std::vector buff; std::vector arrays; + std::vector arrays_with_in_out; std::vector array_reqs; std::vector op_states; @@ -762,7 +763,8 @@ OpStatePtr CachedOp::StaticForward( // We are going to add input and output arrays to the array list. // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. - auto arrays = state.arrays; + state.arrays_with_in_out = state.arrays; + auto& arrays = state.arrays_with_in_out; if (config_.static_shape) { for (auto i : config_.param_indices) { auto nid = idx.input_nodes()[i]; @@ -1063,7 +1065,8 @@ void CachedOp::StaticBackward( // We are going to add input and output arrays to the array list. // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. - auto arrays = state.arrays; + state.arrays_with_in_out = state.arrays; + auto& arrays = state.arrays_with_in_out; for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { auto eid = state.info.bwd_input_eid[i]; if (eid == kEidNotExist) { diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 698666f94d90..3e03b6bd2d16 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -79,7 +79,6 @@ class OperatorState { public: OperatorState(Operator *opr, const OperatorProperty *prop) { opr_ = opr; - fwd_init_ = bwd_init_ = false; in_data_fwd_.resize(prop->ListArguments().size()); in_data_bwd_.resize(prop->ListArguments().size()); @@ -110,19 +109,16 @@ class OperatorState { const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (!fwd_init_) { - CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size()); - CHECK_EQ(outputs.size(), out_data_.size()); - // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones - // referred by arg_data_ptr_ will be overriden - for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i]; - for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i]; - for (size_t i = 0; i < aux_data_.size(); ++i) { - aux_data_[i] = inputs[i + in_data_fwd_.size()]; - } - for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i]; - fwd_init_ = true; + CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size()); + CHECK_EQ(outputs.size(), out_data_.size()); + // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones + // referred by arg_data_ptr_ will be overriden + for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i]; + for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i]; + for (size_t i = 0; i < aux_data_.size(); ++i) { + aux_data_[i] = inputs[i + in_data_fwd_.size()]; } + for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i]; opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_); } @@ -130,27 +126,22 @@ class OperatorState { const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (!bwd_init_) { - CHECK(fwd_init_); - CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size()); - // override tblobs pointed by arg_data_ptr_ since they might not contain - // initialized data during forward pass. - for (size_t i = 0; i < arg_data_ptr_.size(); ++i) { - *arg_data_ptr_[i] = inputs[i]; - } - for (size_t i = 0; i < aux_data_.size(); ++i) { - aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i]; - } - CHECK_EQ(outputs.size(), in_grad_.size()); - for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i]; - bwd_init_ = true; + CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size()); + // override tblobs pointed by arg_data_ptr_ since they might not contain + // initialized data during forward pass. + for (size_t i = 0; i < arg_data_ptr_.size(); ++i) { + *arg_data_ptr_[i] = inputs[i]; + } + for (size_t i = 0; i < aux_data_.size(); ++i) { + aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i]; } + CHECK_EQ(outputs.size(), in_grad_.size()); + for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i]; opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_); } private: Operator *opr_; - bool fwd_init_, bwd_init_; // input data blobs for forward and backward // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is