Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance in backward #5262

Merged
merged 1 commit into from
Oct 31, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <deque>
#include <list>
#include <memory>
#include <unordered_set>

#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
Expand Down Expand Up @@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true;
}

static std::string FwdName(const std::string& grad_name) {
auto pos = grad_name.find("@GRAD");
if (pos == std::string::npos) {
return "";
} else {
return grad_name.substr(0, pos);
}
}

static void CreateGradVarInBlock(
size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map,
Expand All @@ -294,15 +304,15 @@ static void CreateGradVarInBlock(
for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) {
bool need_infer_shape = false;
std::unordered_set<std::string> new_vars;
ForEachVarName(ops[op_index]->Outputs(),
[&](const std::string& grad_var_name) {
if (block_desc->HasVar(grad_var_name)) {
return false;
}
need_infer_shape = true;
auto var = block_desc->Var(grad_var_name);
// FIXME(qiao) infer the datatype
var->SetDataType(framework::DataType::FP32);
new_vars.insert(var->Name());
auto it = param_name_map.find(grad_var_name);
if (it == param_name_map.end()) {
return false;
Expand All @@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
});
if (need_infer_shape) {
ops[op_index]->InferVarType(block_desc);
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
if (new_vars.find(arg) == new_vars.end()) {
continue;
}
auto pname = FwdName(arg);
auto* param = block_desc->FindVar(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
LOG(WARNING) << "Cannot find forward variable of " << arg
<< ". Set its gradient to FP32";
grad->SetDataType(DataType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
}
ops[op_index]->InferShape(*block_desc);
}
}
Expand Down