Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] add multi_encoder_xpu_slice_fuse_pass, generate_sequence_xpu_fuse_pass, generate_sequence_xpu kernel #50570

Merged
merged 2 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ if(WITH_XPU)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu)
endif()

cc_library(
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
->assert_is_op_input(mul_type_, "Y")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return true;
return node->Var()->GetShape().size() == 2;
});
auto* mul =
pattern->NewNode(mul_repr())
->assert_is_op(mul_type_)
->assert_more([](Node* node) {
return true;
auto op_type = node->Op()->Type();
if (op_type == "matmul") {
return !PADDLE_GET_CONST(bool,
Expand Down
182 changes: 182 additions & 0 deletions paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct GenerateSequenceXPUPattern : public PatternBase {
GenerateSequenceXPUPattern(PDPattern* pattern, const std::string& name_scope);

// declare operator node's name
PATTERN_DECL_NODE(fill_any_like);
PATTERN_DECL_NODE(cumsum);
PATTERN_DECL_NODE(elementwise_sub);
// declare variable node's name
PATTERN_DECL_NODE(fill_any_like_x);
PATTERN_DECL_NODE(fill_any_like_out);
PATTERN_DECL_NODE(cumsum_out);
PATTERN_DECL_NODE(elementwise_sub_out);
};

GenerateSequenceXPUPattern::GenerateSequenceXPUPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fill_any_like_x = pattern->NewNode(fill_any_like_x_repr())
->assert_is_op_input("fill_any_like", "X")
->assert_var_not_persistable()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2;
});
auto* fill_any_like =
pattern->NewNode(fill_any_like_repr())
->assert_is_op("fill_any_like")
->assert_more([](Node* node) {
float value = PADDLE_GET_CONST(float, node->Op()->GetAttr("value"));
return static_cast<int>(value) == 1;
});
auto* fill_any_like_out = pattern->NewNode(fill_any_like_out_repr())
->assert_is_op_output("fill_any_like", "Out")
->assert_is_op_input("cumsum", "X")
->assert_is_op_input("elementwise_sub", "Y")
->assert_var_not_persistable()
->assert_has_n_outputs(2);
auto* cumsum =
pattern->NewNode(cumsum_repr())
->assert_is_op("cumsum")
->assert_more([](Node* node) {
return !PADDLE_GET_CONST(bool, node->Op()->GetAttr("exclusive")) &&
!PADDLE_GET_CONST(bool, node->Op()->GetAttr("reverse")) &&
!PADDLE_GET_CONST(bool, node->Op()->GetAttr("flatten")) &&
((PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == 1) ||
(PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1));
});
auto* cumsum_out = pattern->NewNode(cumsum_out_repr())
->assert_is_op_output("cumsum", "Out")
->assert_is_op_input("elementwise_sub", "X")
->assert_var_not_persistable()
->assert_has_n_outputs(1);
auto* elementwise_sub =
pattern->NewNode(elementwise_sub_repr())
->assert_is_op("elementwise_sub")
->assert_more([](Node* node) {
return PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1;
});
auto* elementwise_sub_out =
pattern->NewNode(elementwise_sub_out_repr())
->assert_is_op_output("elementwise_sub", "Out")
->assert_var_not_persistable();
fill_any_like->LinksFrom({fill_any_like_x}).LinksTo({fill_any_like_out});
cumsum->LinksFrom({fill_any_like_out}).LinksTo({cumsum_out});
elementwise_sub->LinksFrom({cumsum_out, fill_any_like_out})
.LinksTo({elementwise_sub_out});
}

} // namespace patterns

/*
Origin subgraph:
fill_any_like
/ \
| |
| cumsum
| |
\ /
elemetwise_sub

Fused subgraph:
generate_sequence_xpu
*/
class GenerateSequenceXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
const std::string name_scope_{"generate_sequence_xpu_fuse_pass"};
};

void GenerateSequenceXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::GenerateSequenceXPUPattern pattern(gpd.mutable_pattern(),
name_scope_);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GenerateSequenceXPUFusePass fuse";
GET_IR_NODE(fill_any_like);
GET_IR_NODE(cumsum);
GET_IR_NODE(elementwise_sub);
GET_IR_NODE(fill_any_like_x);
GET_IR_NODE(fill_any_like_out);
GET_IR_NODE(cumsum_out);
GET_IR_NODE(elementwise_sub_out);

auto* block = fill_any_like->Op()->Block();
framework::OpDesc op_desc(block);
op_desc.SetType("generate_sequence_xpu");
op_desc.SetInput("x", {fill_any_like_x->Name()});
op_desc.SetOutput("out", {elementwise_sub_out->Name()});
op_desc.SetAttr(
"dtype", PADDLE_GET_CONST(int, fill_any_like->Op()->GetAttr("dtype")));
auto* generate_sequence_xpu = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(fill_any_like, generate_sequence_xpu);
IR_NODE_LINK_TO(generate_sequence_xpu, elementwise_sub_out);

// delete useless node
std::unordered_set<const Node*> delete_nodes{
fill_any_like, fill_any_like_out, cumsum, cumsum_out, elementwise_sub};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(generate_sequence_xpu_fuse_pass,
paddle::framework::ir::GenerateSequenceXPUFusePass);

REGISTER_PASS_CAPABILITY(generate_sequence_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"generate_sequence_xpu", 0));
154 changes: 154 additions & 0 deletions paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct MultiEncoderXPUSlicePattern : public PatternBase {
MultiEncoderXPUSlicePattern(PDPattern* pattern,
const std::string& name_scope);

// declare operator node's name
PATTERN_DECL_NODE(multi_encoder_xpu);
PATTERN_DECL_NODE(slice);
// declare variable node's name
PATTERN_DECL_NODE(multi_encoder_xpu_out);
PATTERN_DECL_NODE(slice_out);
};

MultiEncoderXPUSlicePattern::MultiEncoderXPUSlicePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* multi_encoder_xpu =
pattern->NewNode(multi_encoder_xpu_repr())
->assert_is_op("multi_encoder_xpu")
->assert_more([](Node* node) {
return (!PADDLE_GET_CONST(bool,
node->Op()->GetAttr("norm_before"))) &&
(PADDLE_GET_CONST(int, node->Op()->GetAttr("slice_idx")) ==
-1);
});
auto* multi_encoder_xpu_out =
pattern->NewNode(multi_encoder_xpu_out_repr())
->assert_is_op_output("multi_encoder_xpu", "out")
->assert_is_op_input("slice", "Input")
->assert_var_not_persistable()
->assert_has_n_outputs(1);
auto* slice =
pattern->NewNode(slice_repr())
->assert_is_op("slice")
->assert_more([](Node* node) {
std::vector<int> axes =
PADDLE_GET_CONST(std::vector<int>, node->Op()->GetAttr("axes"));
std::vector<int> decrease_axis = PADDLE_GET_CONST(
std::vector<int>, node->Op()->GetAttr("decrease_axis"));
std::vector<int> starts = PADDLE_GET_CONST(
std::vector<int>, node->Op()->GetAttr("starts"));
std::vector<int> ends =
PADDLE_GET_CONST(std::vector<int>, node->Op()->GetAttr("ends"));
return axes.size() == 1 && axes[0] == 1 &&
decrease_axis.size() == 1 && decrease_axis[0] == 1 &&
starts.size() == 1 && starts[0] == 0 && //
ends.size() == 1 && ends[0] == 1;
});
auto* slice_out = pattern->NewNode(slice_out_repr())
->assert_is_op_output("slice", "Out")
->assert_var_not_persistable();
multi_encoder_xpu->LinksTo({multi_encoder_xpu_out});
slice->LinksFrom({multi_encoder_xpu_out}).LinksTo({slice_out});
}

} // namespace patterns

/*
Origin subgraph:
multi_encoder_xpu
|
slice

Fused subgraph:
multi_encoder_xpu
*/
class MultiEncoderXPUSliceFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"};
};

void MultiEncoderXPUSliceFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::MultiEncoderXPUSlicePattern pattern(gpd.mutable_pattern(),
name_scope_);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle MultiEncoderXPUSliceFusePass fuse";
GET_IR_NODE(multi_encoder_xpu);
GET_IR_NODE(slice);
GET_IR_NODE(multi_encoder_xpu_out);
GET_IR_NODE(slice_out);

auto* op_desc = multi_encoder_xpu->Op();
op_desc->SetOutput("out", {slice_out->Var()->Name()});
op_desc->SetAttr("slice_idx", static_cast<int>(0));
IR_NODE_LINK_TO(multi_encoder_xpu, slice_out);

// delete useless node
std::unordered_set<const Node*> delete_nodes{multi_encoder_xpu_out, slice};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(multi_encoder_xpu_slice_fuse_pass,
paddle::framework::ir::MultiEncoderXPUSliceFusePass);

REGISTER_PASS_CAPABILITY(multi_encoder_xpu_slice_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"multi_encoder_xpu", 0));
4 changes: 2 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,11 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"generate_sequence_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
// "embedding_with_eltwise_add_xpu_fuse_pass",
"fc_xpu_fuse_pass",
// "multi_encoder_slice_link_xpu_fuse_pass",
// "generate_sequence_xpu_fuse_pass",
// "link_previous_out_max_xpu_pass",
});
use_xpu_ = true;
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
data_type : x
optional : bias

- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
infer_meta :
func : GenerateSequenceXPUInferMeta
kernel :
func : generate_sequence_xpu
data_type : dtype

- op : multi_encoder_xpu
args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx)
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
Expand Down
Loading