Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xpu] optimize multi_encoder_xpu_fuse_pass performance #51346

Merged
merged 4 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 22 additions & 175 deletions paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/concat_kernel.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
Expand Down Expand Up @@ -515,175 +504,26 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(

} // namespace patterns

/*
step1: fuse single ops to single_encoder_xpu
step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu

1. step1
Origin subgraph:
------------ input_variable*
| / | \
| / | \
| v_matmul q_matmul k_matmul
| | | |
| | | |
| v_add q_add add
| | | |
| | | |
| v_reshape q_reshape k_reshape
| | | |
| | | |
| v_transpose q_transpose k_transpose
| | | |
| | \ /
| | qk_matmul
| | |
| | |
| | qk_add
| | |
| | |
| | qk_softmax
| | |
| | |
| ---------qkv_matmul_0
| |
| |
| qkv_transpose
| |
| |
| qkv_reshape
| |
| |
| qkv_matmul_1
| |
| |
| qkv_add_0
| |
| |
----------------------qkv_add_1
|
|
layer_norm_1
/ \
| |
| qkv_matmul_2
| |
| |
| qkv_add_2
| |
| |
| qkv_act
| |
| |
| qkv_matmul_3
| |
| |
| qkv_add_3
| |
\ /
qkv_add_4
|
layer_norm

Fused subgraph:
single_encoder_xpu

2. step2
Origin subgraph:
...
|
single_encoder_xpu
|
(single_encoder_xpu)
|
(single_encoder_xpu)
|
...
Fused subgraph:
multi_encoder_xpu
*/
class MultiEncoderXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
int ApplySingleEncoderXPUFuse(ir::Graph* graph,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const;

bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;

// Mask must be fp32 even if model is fp16
int CastMask(ir::Graph* graph) const;

// 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
void PrepareQKVWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_w,
Node* k_w,
Node* v_w,
Node** qkv_w,
Node** qkv_w_max) const;

// 1. Cast bias to fp32
// 2. Concat q/k/v bias
void PrepareQKVBias(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_bias,
Node* k_bias,
Node* v_bias,
Node** qkv_bias) const;

const std::string name_scope_{"multi_encoder_xpu_fuse_pass"};
};

void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> matmul_types_0{"matmul_v2", "matmul", "mul"};
std::vector<std::string> matmul_types_1{"matmul_v2", "matmul"};
std::vector<std::string> matmul_types_2{"matmul_v2", "matmul"};
std::vector<bool> norm_befores{true, false};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> with_masks{true, false};

int single_encoder_fused_counts = 0;
int multi_encoder_fused_counts = 0;
for (auto act_type : act_types) {
for (auto matmul_type_0 : matmul_types_0) {
for (auto matmul_type_1 : matmul_types_1) {
for (auto matmul_type_2 : matmul_types_2) {
for (auto norm_before : norm_befores) {
for (auto with_q_scale : with_q_scales) {
for (auto with_mask : with_masks) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_q_scale,
with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
}
}
}
}
auto pattern_params = GeneratePatternParams();
for (auto pattern_param : pattern_params) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
pattern_param.act_type,
pattern_param.matmul_type_0,
pattern_param.matmul_type_1,
pattern_param.matmul_type_2,
pattern_param.norm_before,
pattern_param.with_q_scale,
pattern_param.with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
int cast_mask_counts = CastMask(graph);
Expand Down Expand Up @@ -1372,6 +1212,13 @@ int MultiEncoderXPUFusePass::CastMask(ir::Graph* graph) const {
return cast_counts;
}

std::vector<PatternParam> MultiEncoderXPUFusePass::GeneratePatternParams()
const {
return std::vector<PatternParam>{
// Params are arranged in alphabetic order
{"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}};
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
Loading