Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-trt] fix: delete_quant_dequant_filter_op_pass, delete_quant_dequant_op_pass #35879

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 24 additions & 120 deletions paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale;
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();

auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
Expand All @@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.",
quant_dequant_op_out_name));
any_op2_desc->SetAttr("enable_int8", true);
// any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length);

// modify the any_op2's inputs
any_op2_desc->Flush();
auto dequant_type = quant_dequant_op->Op()->Type();
auto quantized_op_type = any_op2_desc->Type();

// get weight tensor
auto* weight_tensor =
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();

float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());

// Get weight scale
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
auto scales_name = quant_dequant_op->Op()->Output("OutScale");
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));

// To Do @Wangzheee: use "OutScale" to quantdequant
/*auto scales_name = quant_dequant_op->Op()->Output("OutScale");
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
platform::errors::InvalidArgument(
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d.",
scales_name.size()));
const LoDTensor& channel_scale_tensor =
scope->GetVar(scales_name[0])->Get<LoDTensor>();
scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU."));
// compute the channel wise abs max of the weight tensor
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));

PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
const float* channel_scale_data = channel_scale_tensor.data<float>();
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
weight_scale.push_back(channel_scale_data[i] );
}*/

// Implement channel_wise_quantize_dequantize_abs_max quantization
// algorithm
const int64_t channel = w_dims[quant_axis];
weight_scale.resize(channel, 0);
if (quant_axis == 0) {
Expand Down Expand Up @@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(weight_scale[i], 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = range / weight_scale[i];
weight_scale[i] = weight_scale[i] / range;
}
} else {
auto scale_name = quant_dequant_op_outscale->Name();
// compute the abs max of the weight tensor
// Implement quantize_dequantize_abs_max quantization algorithm
float abs_max_weight = 0.;
for (int j = 0; j < weight_tensor->numel(); j++) {
abs_max_weight =
Expand All @@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(abs_max_weight, 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back((range * range) / abs_max_weight / range);
weight_scale.push_back(abs_max_weight / range);
}

nodes2rm.insert(quant_dequant_op_outscale);

// perform quantize dequantize operations
// If quantized op is not channel wise, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (dequant_type == "fake_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
quantized_op_type, weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[0];
}
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j % w_dims[1]], 0,
platform::errors::InvalidArgument(
"fc op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j / inner_size], 0,
platform::errors::InvalidArgument(
"conv2d op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j / inner_size];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j / inner_size];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d_transpose") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
int inner_size = w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0,
platform::errors::InvalidArgument(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"));
quantized_weight_data[j] = quantized_weight_data[j] *
weight_scale[(j / inner_size) % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /=
weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
nodes2rm.insert(quant_dequant_op_out);

// link weight in quant_dequant_op_x to any_op2
Expand Down
93 changes: 51 additions & 42 deletions paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,76 +28,85 @@ namespace ir {

#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(any_op2);
GET_IR_NODE(quant_dequant_op_out);

void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_op_pattern";
FusePassBase::Init(pattern_name, graph);

GraphPatternDetector gpd;

std::string quantdequant_types =
"fake_quantize_dequantize_moving_average_abs_max";

auto* input_node = gpd.mutable_pattern()
->NewNode("input_node")
->assert_is_op_input(quantdequant_types, "X")
->AsInput();

patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
pattern(input_node, quantdequant_types);
auto* scope = param_scope();
int found_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE_EQ(
subgraph.count(input_node), true,
platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.",
input_node->name()));
Node* input = subgraph.at(input_node);
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);

// Get input scale from tensor
std::string input_scale_var_name =
quant_dequant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantOpPass should not be null."));
const LoDTensor& input_scale_tensor =
scope->GetVar(input_scale_var_name)->Get<LoDTensor>();

scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / 127.;
auto* any_op2_desc = any_op2->Op();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
CHECK(arg_name.size() > 0) << "can not find the input "
<< quant_dequant_op_out_name;
any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr(arg_name + "_scale", input_scale);
float input_scale = input_scale_data[0] / range;

// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != quant_dequant_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
// Set input scale in attr, and relink nodes
std::string input_name = input->Var()->Name();
std::string quant_dequant_output_name = quant_dequant_op_out->Var()->Name();
auto outlinks = quant_dequant_op_out->outputs;
for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", input_scale);
} else {
op_desc->SetAttr("Input_scale", input_scale);
}
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(quant_dequant_output_name, input_name);
op_desc->Flush();
IR_NODE_LINK_TO(input, quantized_node);
}
any_op2_desc->Flush();

// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{quant_dequant_op, quant_dequant_op_out,
quant_dequant_op_inscale, quant_dequant_op_outscale});
{quant_dequant_op_inscale, quant_dequant_op,
quant_dequant_op_outscale, quant_dequant_op_out});
found_count++;
};

gpd(graph, handler);
AddStatis(found_count);
}

} // namespace ir
Expand Down
33 changes: 11 additions & 22 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op});
}

void patterns::DeleteQuantDequantOpPattern::operator()() {
auto any_op_out =
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();

void patterns::DeleteQuantDequantOpPattern::operator()(
PDNode *input_node, const std::string &quantdequant_types) {
auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->assert_is_op_input(quantdequant_types, "InScale")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");
auto quant_dequant_op = pattern->NewNode(quant_dequant_op_repr())
->assert_is_op(quantdequant_types);

auto quant_dequant_out =
auto quant_dequant_op_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "Out")
->AsIntermediate();
->assert_is_op_output(quantdequant_types, "Out")
->AsOutput();

auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->assert_is_op_output(quantdequant_types, "OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();

quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale});
quant_dequant_op->LinksFrom({quant_dequant_op_inscale, input_node});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
quant_dequant_op_out->LinksFrom({quant_dequant_op});
}

void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {}

void operator()();
void operator()(PDNode* input_node, const std::string& quantdequant_types);

PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(quant_dequant_op_inscale);
PATTERN_DECL_NODE(quant_dequant_op);
PATTERN_DECL_NODE(quant_dequant_op_outscale);
PATTERN_DECL_NODE(quant_dequant_op_out);
PATTERN_DECL_NODE(any_op2);
};

struct DeleteQuantDequantFilterOpPattern : public PatternBase {
Expand Down