Skip to content

Commit

Permalink
fix double grad var judging
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Mar 29, 2022
1 parent ea5b2f2 commit 4d79c96
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) {
return func;
}

inline static bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos;
}

inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
Expand All @@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) {
return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
}

inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}

inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
Expand Down Expand Up @@ -493,11 +499,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}

protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
Expand All @@ -508,7 +515,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {

for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
Expand Down Expand Up @@ -540,6 +547,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};

template <>
Expand All @@ -553,12 +561,13 @@ class CustomGradOpMaker<imperative::OpBase>
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}

protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
Expand All @@ -574,7 +583,7 @@ class CustomGradOpMaker<imperative::OpBase>

for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
Expand All @@ -600,6 +609,7 @@ class CustomGradOpMaker<imperative::OpBase>
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};

//////////// Operator and Kernel Register //////////////
Expand Down Expand Up @@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');

bool is_double_grad = (i == 2);

// GradOpDescMaker
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs](
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs,
is_double_grad](
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block) {
CustomGradOpMaker<paddle::framework::OpDesc> maker(
fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name,
grad_op_inputs, grad_op_outputs);
grad_op_inputs, grad_op_outputs, is_double_grad);
return maker();
};

// GradOpBaseMaker
info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs,
grad_op_outputs](
grad_op_outputs, is_double_grad](
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
Expand All @@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
grad_op_name, grad_op_inputs, grad_op_outputs, is_double_grad);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
Expand Down

1 comment on commit 4d79c96

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.