Skip to content

Commit

Permalink
support gpu mixed precision inference (PaddlePaddle#40531)
Browse files Browse the repository at this point in the history
  • Loading branch information
baoachun authored and liqitong-a committed Mar 17, 2022
1 parent 2ca34da commit 4815e0f
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 13 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(mixed_precision_configure_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)

Expand Down
149 changes: 149 additions & 0 deletions paddle/fluid/framework/ir/mixed_precision_configure_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mixed_precision_configure_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
namespace ir {

void MixedPrecisionConfigurePass::InsertCastOps(
Graph* graph, const StringSet& blacklist) const {
VLOG(3) << "Insert the cast op before and after the kernel that does not "
"supports fp16 precision";

auto update_cast_desc = [&](
framework::OpDesc& desc, const std::string& x_name,
const std::string& out_name, const int in_dtype, const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};

auto cast_input = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto inlinks = op_node->inputs;
for (auto* pre_node : inlinks) {
if (pre_node->IsVar()) {
const auto is_persistable = pre_node->Var()->Persistable();
const auto is_float =
pre_node->Var()->GetDataType() == proto::VarType::FP16 ||
pre_node->Var()->GetDataType() == proto::VarType::FP32 ||
pre_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* pre_node_input : pre_node->inputs) {
if (!pre_node_input->IsOp()) continue;
const auto& type = pre_node_input->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = pre_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;

framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 4, 5);
auto* new_op = graph->CreateOpNode(&new_op_desc);

VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);

op_node->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(pre_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, op_node);
}
}
}
}
}
};

auto cast_output = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto outlinks = op_node->outputs;
for (auto* next_node : outlinks) {
if (next_node->IsVar()) {
const auto is_persistable = next_node->Var()->Persistable();
const auto is_float =
next_node->Var()->GetDataType() == proto::VarType::FP16 ||
next_node->Var()->GetDataType() == proto::VarType::FP32 ||
next_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* next_node_output : next_node->outputs) {
if (!next_node_output->IsOp()) continue;

const auto& type = next_node_output->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = next_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;

framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 5, 4);
auto* new_op = graph->CreateOpNode(&new_op_desc);

VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);

next_node_output->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(next_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, next_node_output);
}
}
}
}
}
};

for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;

const auto& type = op_node->Op()->Type();
if (blacklist.count(type)) {
cast_input(graph, op_node, blacklist);
cast_output(graph, op_node, blacklist);
}
}
}

void MixedPrecisionConfigurePass::ApplyImpl(Graph* graph) const {
const auto blacklist =
Get<std::unordered_set<std::string>>("gpu_fp16_disabled_op_types");
InsertCastOps(graph, blacklist);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(mixed_precision_configure_pass,
paddle::framework::ir::MixedPrecisionConfigurePass);
39 changes: 39 additions & 0 deletions paddle/fluid/framework/ir/mixed_precision_configure_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

using StringSet = std::unordered_set<std::string>;

class MixedPrecisionConfigurePass : public FusePassBase {
public:
MixedPrecisionConfigurePass() = default;
virtual ~MixedPrecisionConfigurePass() {}

protected:
void ApplyImpl(Graph* graph) const override;

private:
void InsertCastOps(Graph* graph, const StringSet& blacklist) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_gpu_fp16, UseGPUFp16, bool);
DECL_ARGUMENT_FIELD(gpu_fp16_disabled_op_types, GpuFp16DisabledOpTypes,
std::unordered_set<std::string>);

// Usually use for trt dynamic shape.
// TRT will select the best kernel according to opt shape
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->dlnne_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
} else if (pass_name == "mixed_precision_configure_pass") {
pass->Set("gpu_fp16_disabled_op_types",
new std::unordered_set<std::string>(
argument->gpu_fp16_disabled_op_types()));
}
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -65,6 +66,26 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) {

#else

void IrParamsSyncAmongDevicesPass::GetVarNameToOpTypeMap(
const framework::ir::Graph &graph,
std::unordered_map<std::string, std::string> *var_name_op_type_map) {
std::vector<framework::ir::Node *> node_list =
framework::ir::TopologyVarientSort(
graph, static_cast<framework::ir::SortKind>(0));
for (auto *op_node : node_list) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;

for (auto *pre_node : op_node->inputs) {
if (pre_node->IsVar() && pre_node->Var()->Persistable()) {
var_name_op_type_map->insert(std::pair<std::string, std::string>(
pre_node->Var()->Name(), op_node->Op()->Type()));
}
}
}
}

void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
// The parameters are on the cpu, therefore, synchronization is not necessary.
if (!argument->use_gpu()) return;
Expand Down Expand Up @@ -102,6 +123,16 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
if (with_dynamic_shape) {
reserve_cpu_weights = true;
}

bool mixed_precision_mode =
argument->Has("use_gpu_fp16") && argument->use_gpu_fp16();
std::unordered_map<std::string, std::string> var_name_op_type_map{};
std::unordered_set<std::string> blacklist{};
if (mixed_precision_mode) {
GetVarNameToOpTypeMap(graph, &var_name_op_type_map);
blacklist = argument->gpu_fp16_disabled_op_types();
}

for (auto &var_name : all_vars) {
if (std::count(repetitive_params.begin(), repetitive_params.end(),
var_name)) {
Expand All @@ -117,18 +148,29 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
var->IsType<framework::Tensor>()) {
auto *t = var->GetMutable<framework::LoDTensor>();

platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
temp_tensor.mutable_data<float>(cpu_place);

// Copy the parameter data to a tmp tensor.
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
// Reallocation the space on GPU
t->clear();

// Copy parameter data to newly allocated GPU space.
paddle::framework::TensorCopySync(temp_tensor, place, t);
bool is_float = t->dtype() == paddle::experimental::DataType::FLOAT32 ||
t->dtype() == paddle::experimental::DataType::FLOAT64;
if (mixed_precision_mode &&
!blacklist.count(var_name_op_type_map[var_name]) && is_float) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
auto *half_data =
half_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<float>(platform::CPUPlace());
half_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(half_tensor, place, t);
} else {
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
#ifdef PADDLE_WITH_ASCEND_CL
void CopyParamsToNpu(Argument *argument);
#else
void CopyParamsToGpu(Argument *argument);

void GetVarNameToOpTypeMap(
const framework::ir::Graph& graph,
std::unordered_map<std::string, std::string>* var_name_op_type_map);

void CopyParamsToGpu(Argument* argument);
#endif
};

Expand Down
Loading

0 comments on commit 4815e0f

Please sign in to comment.