diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index c1be8daad52..5b9734bf7cc 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -79,6 +79,7 @@ USE_MIR_PASS(lite_greater_than_cast_fuse_pass); USE_MIR_PASS(assign_value_calc_offline_pass); USE_MIR_PASS(__xpu__graph_dedup_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); +USE_MIR_PASS(__xpu__spatial_transformer_fuse_pass); USE_MIR_PASS(__xpu__gn_silu_fuse_pass); USE_MIR_PASS(__xpu__multihead_cross_attn_fuse_pass); USE_MIR_PASS(__xpu__multihead_self_attn_fuse_pass); diff --git a/lite/core/optimizer/mir/fusion/__xpu__spatial_transformer_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__spatial_transformer_fuse_pass.cc new file mode 100644 index 00000000000..f84c391f4cf --- /dev/null +++ b/lite/core/optimizer/mir/fusion/__xpu__spatial_transformer_fuse_pass.cc @@ -0,0 +1,606 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/optimizer/mir/pass_registry.h" +#include "lite/core/optimizer/mir/pattern_matcher_high_api.h" +#include "lite/operators/subgraph_op.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +static std::vector Vec2DTo1D_int( + const std::vector>& vec) { + std::vector res; + for (const auto& v : vec) { + for (const auto& ele : v) { + res.emplace_back(ele); + } + } + return res; +} + +class SpatialTransformerfuser : public FuseBase { + public: + void BuildPattern() override { + auto* input = VarNode("input") + ->assert_is_op_input("group_norm", "X") + ->assert_is_op_input("__xpu__conv2d", "Branch") + ->AsInput(); + + // image to sequence + auto* gn_scale = VarNode("gn_scale") + ->assert_is_op_input("group_norm", "Scale") + ->AsInput(); + auto* gn_bias = + VarNode("gn_bias")->assert_is_op_input("group_norm", "Bias")->AsInput(); + auto* gn = OpNode("gn", "group_norm")->AsIntermediate(); + auto* gn_out = VarNode("gn_out") + ->assert_is_op_output("group_norm", "Y") + ->assert_is_op_input("__xpu__conv2d", "Input") + ->AsIntermediate(); + auto* gn_mean = VarNode("gn_mean") + ->assert_is_op_output("group_norm", "Mean") + ->AsIntermediate(); + auto* gn_var = VarNode("gn_var") + ->assert_is_op_output("group_norm", "Variance") + ->AsIntermediate(); + + auto* pre_xpu_conv2d = + OpNode("pre__xpu__conv2d", "__xpu__conv2d")->AsIntermediate(); + auto* pre_xpu_conv2d_bias = + VarNode("pre__xpu__conv2d_bias") + ->assert_is_op_input("__xpu__conv2d", "Bias") + ->AsInput(); + auto* pre_xpu_conv2d_filter = + VarNode("pre__xpu__conv2d_filter") + ->assert_is_op_input("__xpu__conv2d", "Filter") + ->AsInput(); + auto* pre_xpu_conv2d_output = + VarNode("pre__xpu__conv2d_output") + ->AsIntermediate() + ->assert_is_op_input("transpose2", "X") + ->assert_is_op_output("__xpu__conv2d", "Output"); + auto* pre_xpu_conv2d_output_max = + VarNode("pre__xpu__conv2d_output_max") + ->AsIntermediate() + ->assert_is_op_output("__xpu__conv2d", "OutputMax"); + + auto* transpose2 = OpNode("transpose2", "transpose2")->AsIntermediate(); + auto* transpose2_output = + VarNode("transpose2_output") + ->AsIntermediate() + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("flatten_contiguous_range", "X"); + auto* transpose2_output_xshape = + VarNode("transpose2_output_xshape") + ->AsIntermediate() + ->assert_is_op_output("transpose2", "XShape"); + + auto* flatten = + OpNode("flatten_contiguous_range", "flatten_contiguous_range") + ->AsIntermediate(); + auto* flatten_output = + VarNode("flatten_output") + ->AsIntermediate() + ->assert_is_op_output("flatten_contiguous_range", "Out") + ->assert_is_op_input("__xpu__multihead_self_attn", "Input") + ->assert_is_op_input("elementwise_add", "Y"); + auto* flatten_output_xshape = + VarNode("flatten_output_xshape") + ->AsIntermediate() + ->assert_is_op_output("flatten_contiguous_range", "XShape"); + + // __xpu__multihead_self_attn + auto* __xpu__multihead_self_attn = + OpNode("__xpu__multihead_self_attn", "__xpu__multihead_self_attn") + ->AsIntermediate(); + auto* __xpu__multihead_self_attn_fcbias = + VarNode("__xpu__multihead_self_attn_fcbias") + ->assert_is_op_input("__xpu__multihead_self_attn", "FCBias") + ->AsInput(); + auto* __xpu__multihead_self_attn_lnbias = + VarNode("__xpu__multihead_self_attn_lnbias") + ->assert_is_op_input("__xpu__multihead_self_attn", "LNBias") + ->AsInput(); + auto* __xpu__multihead_self_attn_lnscale = + VarNode("__xpu__multihead_self_attn_lnscale") + ->assert_is_op_input("__xpu__multihead_self_attn", "LNScale") + ->AsInput(); + auto* __xpu__multihead_self_attn_fcweight0 = + VarNode("__xpu__multihead_self_attn_fcweight0") + ->assert_is_op_nth_input( + "__xpu__multihead_self_attn", "FCWeight", 0) + ->AsInput(); + auto* __xpu__multihead_self_attn_fcweight1 = + VarNode("__xpu__multihead_self_attn_fcweight1") + ->assert_is_op_nth_input( + "__xpu__multihead_self_attn", "FCWeight", 1) + ->AsInput(); + auto* __xpu__multihead_self_attn_fcweight2 = + VarNode("__xpu__multihead_self_attn_fcweight2") + ->assert_is_op_nth_input( + "__xpu__multihead_self_attn", "FCWeight", 2) + ->AsInput(); + auto* __xpu__multihead_self_attn_fcweight3 = + VarNode("__xpu__multihead_self_attn_fcweight3") + ->assert_is_op_nth_input( + "__xpu__multihead_self_attn", "FCWeight", 3) + ->AsInput(); + auto* __xpu__multihead_self_attn_output = + VarNode("__xpu__multihead_self_attn_output") + ->AsIntermediate() + ->assert_is_op_output("__xpu__multihead_self_attn", "Output") + ->assert_is_op_input("elementwise_add", "X"); + auto* residual_add = + OpNode("elementwise_add", "elementwise_add")->AsIntermediate(); + auto* residual_add_output = + VarNode("residual_add_output") + ->AsIntermediate() + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_op_input("__xpu__multihead_cross_attn", "Input"); + + // __xpu__multihead_cross_attn + auto* __xpu__multihead_cross_attn = + OpNode("__xpu__multihead_cross_attn", "__xpu__multihead_cross_attn") + ->AsIntermediate(); + auto* __xpu__multihead_cross_attn_embedding = + VarNode("__xpu__multihead_cross_attn_embedding") + ->assert_is_op_input("__xpu__multihead_cross_attn", "Embedding") + ->AsInput(); + auto* __xpu__multihead_cross_attn_fcbias = + VarNode("__xpu__multihead_cross_attn_fcbias") + ->assert_is_op_input("__xpu__multihead_cross_attn", "FCBias") + ->AsInput(); + auto* __xpu__multihead_cross_attn_lnbias = + VarNode("__xpu__multihead_cross_attn_lnbias") + ->assert_is_op_input("__xpu__multihead_cross_attn", "LNBias") + ->AsInput(); + auto* __xpu__multihead_cross_attn_lnscale = + VarNode("__xpu__multihead_cross_attn_lnscale") + ->assert_is_op_input("__xpu__multihead_cross_attn", "LNScale") + ->AsInput(); + auto* __xpu__multihead_cross_attn_fcweight0 = + VarNode("__xpu__multihead_cross_attn_fcweight0") + ->assert_is_op_nth_input( + "__xpu__multihead_cross_attn", "FCWeight", 0) + ->AsInput(); + auto* __xpu__multihead_cross_attn_fcweight1 = + VarNode("__xpu__multihead_cross_attn_fcweight1") + ->assert_is_op_nth_input( + "__xpu__multihead_cross_attn", "FCWeight", 1) + ->AsInput(); + auto* __xpu__multihead_cross_attn_fcweight2 = + VarNode("__xpu__multihead_cross_attn_fcweight2") + ->assert_is_op_nth_input( + "__xpu__multihead_cross_attn", "FCWeight", 2) + ->AsInput(); + auto* __xpu__multihead_cross_attn_fcweight3 = + VarNode("__xpu__multihead_cross_attn_fcweight3") + ->assert_is_op_nth_input( + "__xpu__multihead_cross_attn", "FCWeight", 3) + ->AsInput(); + auto* __xpu__multihead_cross_attn_output = + VarNode("__xpu__multihead_cross_attn_output") + ->AsIntermediate() + ->assert_is_op_output("__xpu__multihead_cross_attn", "Output") + ->assert_is_op_input("elementwise_add", "X"); + auto* residual_add2 = + OpNode("elementwise_add2", "elementwise_add")->AsIntermediate(); + auto* residual_add2_output = + VarNode("residual2_add_output") + ->AsIntermediate() + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_op_input("__xpu__geglu", "Input"); + + // geglu + auto* __xpu__geglu = + OpNode("__xpu__geglu", "__xpu__geglu")->AsIntermediate(); + auto* __xpu__geglu_fcbias0 = + VarNode("__xpu__geglu_fcbias0") + ->assert_is_op_nth_input("__xpu__geglu", "FCBias", 0) + ->AsInput(); + auto* __xpu__geglu_fcbias1 = + VarNode("__xpu__geglu_fcbias1") + ->assert_is_op_nth_input("__xpu__geglu", "FCBias", 1) + ->AsInput(); + auto* __xpu__geglu_lnbias = + VarNode("__xpu__geglu_lnbias") + ->assert_is_op_input("__xpu__geglu", "LNBias") + ->AsInput(); + auto* __xpu__geglu_lnscale = + VarNode("__xpu__geglu_lnscale") + ->assert_is_op_input("__xpu__geglu", "LNScale") + ->AsInput(); + auto* __xpu__geglu_fcweight0 = + VarNode("__xpu__geglu_fcweight0") + ->assert_is_op_nth_input("__xpu__geglu", "FCWeight", 0) + ->AsInput(); + auto* __xpu__geglu_fcweight1 = + VarNode("__xpu__geglu_fcweight1") + ->assert_is_op_nth_input("__xpu__geglu", "FCWeight", 1) + ->AsInput(); + auto* __xpu__geglu_output = + VarNode("__xpu__geglu_output") + ->AsIntermediate() + ->assert_is_op_output("__xpu__geglu", "Output") + ->assert_is_op_input("elementwise_add", "X"); + auto* residual_add3 = + OpNode("elementwise_add3", "elementwise_add")->AsIntermediate(); + auto* residual_add3_output = + VarNode("residual3_add_output") + ->AsIntermediate() + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("reshape2", "X"); + + // sequence to image + auto* reshape = OpNode("reshape2", "reshape2")->AsIntermediate(); + auto* reshape_output = VarNode("reshape_output") + ->AsIntermediate() + ->assert_is_op_input("transpose2", "X") + ->assert_is_op_output("reshape2", "Out"); + auto* reshape_output_xshape = + VarNode("reshape_output_xshape") + ->AsIntermediate() + ->assert_is_op_output("reshape2", "XShape"); + auto* transpose2_2 = OpNode("transpose2_2", "transpose2")->AsIntermediate(); + auto* transpose2_2_output = + VarNode("transpose2_2_output") + ->AsIntermediate() + ->assert_is_op_input("__xpu__conv2d", "Input") + ->assert_is_op_output("transpose2", "Out"); + auto* transpose2_2_output_xshape = + VarNode("transpose2_2_output_xshape") + ->AsIntermediate() + ->assert_is_op_output("transpose2", "XShape"); + auto* post_xpu_conv2d = + OpNode("post__xpu__conv2d", "__xpu__conv2d")->AsIntermediate(); + auto* post_xpu_conv2d_bias = + VarNode("post__xpu__conv2d_bias") + ->assert_is_op_input("__xpu__conv2d", "Bias") + ->AsInput(); + auto* post_xpu_conv2d_filter = + VarNode("post__xpu__conv2d_filter") + ->assert_is_op_input("__xpu__conv2d", "Filter") + ->AsInput(); + auto* post_xpu_conv2d_output = + VarNode("post__xpu__conv2d_output") + ->AsOutput() + ->assert_is_op_output("__xpu__conv2d", "Output"); + auto* post_xpu_conv2d_outputmax = + VarNode("post__xpu__conv2d_output_max") + ->AsIntermediate() + ->assert_is_op_output("__xpu__conv2d", "OutputMax"); + + std::vector gn_input{input, gn_bias, gn_scale}; + std::vector gn_output{gn_out, gn_mean, gn_var}; + gn_input >> *gn >> gn_output; + std::vector pre_conv2d_input{ + gn_out, pre_xpu_conv2d_bias, pre_xpu_conv2d_filter}; + std::vector pre_conv2d_output{pre_xpu_conv2d_output, + pre_xpu_conv2d_output_max}; + pre_conv2d_input >> *pre_xpu_conv2d >> pre_conv2d_output; + *pre_xpu_conv2d_output >> *transpose2 >> *transpose2_output >> *flatten >> + *flatten_output; + *transpose2 >> *transpose2_output_xshape; + *flatten >> *flatten_output_xshape; + + std::vector mhsa_input{flatten_output, + __xpu__multihead_self_attn_fcbias, + __xpu__multihead_self_attn_fcweight0, + __xpu__multihead_self_attn_fcweight1, + __xpu__multihead_self_attn_fcweight2, + __xpu__multihead_self_attn_fcweight3, + __xpu__multihead_self_attn_lnbias, + __xpu__multihead_self_attn_lnscale}; + mhsa_input >> *__xpu__multihead_self_attn >> + *__xpu__multihead_self_attn_output >> *residual_add >> + *residual_add_output; + *flatten_output >> *residual_add; + + std::vector mhca_input{residual_add_output, + __xpu__multihead_cross_attn_embedding, + __xpu__multihead_cross_attn_fcbias, + __xpu__multihead_cross_attn_lnbias, + __xpu__multihead_cross_attn_lnscale, + __xpu__multihead_cross_attn_fcweight0, + __xpu__multihead_cross_attn_fcweight1, + __xpu__multihead_cross_attn_fcweight2, + __xpu__multihead_cross_attn_fcweight3}; + mhca_input >> *__xpu__multihead_cross_attn >> + *__xpu__multihead_cross_attn_output >> *residual_add2 >> + *residual_add2_output; + *residual_add_output >> *residual_add2; + + std::vector geglu_input{residual_add2_output, + __xpu__geglu_fcbias0, + __xpu__geglu_fcbias1, + __xpu__geglu_lnbias, + __xpu__geglu_lnscale, + __xpu__geglu_fcweight0, + __xpu__geglu_fcweight1}; + geglu_input >> *__xpu__geglu >> *__xpu__geglu_output >> *residual_add3 >> + *residual_add3_output; + *residual_add2_output >> *residual_add3; + + *residual_add3_output >> *reshape >> *reshape_output >> *transpose2_2 >> + *transpose2_2_output; + *reshape >> *reshape_output_xshape; + *transpose2_2 >> *transpose2_2_output_xshape; + + std::vector post_conv2d_input{transpose2_2_output, + post_xpu_conv2d_bias, + input, + post_xpu_conv2d_filter}; + std::vector post_conv2d_output{post_xpu_conv2d_output, + post_xpu_conv2d_outputmax}; + post_conv2d_input >> *post_xpu_conv2d >> post_conv2d_output; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + // OpDesc + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__spatial_transformer"); + auto* gn_op_info = matched.at("gn")->stmt()->op_info(); + auto* mhsa_op_info = + matched.at("__xpu__multihead_self_attn")->stmt()->op_info(); + auto* mhca_op_info = + matched.at("__xpu__multihead_cross_attn")->stmt()->op_info(); + auto* geglu_op_info = matched.at("__xpu__geglu")->stmt()->op_info(); + + std::vector fc_weight_names; + for (const auto& name : mhsa_op_info->Input("FCWeight")) { + fc_weight_names.push_back(name); + } + for (const auto& name : mhca_op_info->Input("FCWeight")) { + fc_weight_names.push_back(name); + } + for (const auto& name : geglu_op_info->Input("FCWeight")) { + fc_weight_names.push_back(name); + } + CHECK_EQ(fc_weight_names.size(), 10); + std::vector fc_weight_maxptr_names; + for (const auto& name : + mhsa_op_info->GetAttr>("FCWeightMax")) { + fc_weight_maxptr_names.push_back(name); + } + for (const auto& name : + mhca_op_info->GetAttr>("FCWeightMax")) { + fc_weight_maxptr_names.push_back(name); + } + for (const auto& name : + geglu_op_info->GetAttr>("FCWeightMax")) { + fc_weight_maxptr_names.push_back(name); + } + CHECK_EQ(fc_weight_maxptr_names.size(), 10); + + std::vector ln_scale_names; + for (const auto& name : mhsa_op_info->Input("LNScale")) { + ln_scale_names.push_back(name); + } + for (const auto& name : mhca_op_info->Input("LNScale")) { + ln_scale_names.push_back(name); + } + for (const auto& name : geglu_op_info->Input("LNScale")) { + ln_scale_names.push_back(name); + } + std::vector ln_bias_names; + for (const auto& name : mhsa_op_info->Input("LNBias")) { + ln_bias_names.push_back(name); + } + for (const auto& name : mhca_op_info->Input("LNBias")) { + ln_bias_names.push_back(name); + } + for (const auto& name : geglu_op_info->Input("LNBias")) { + ln_bias_names.push_back(name); + } + std::vector fc_bias_names; + for (const auto& name : mhsa_op_info->Input("FCBias")) { + fc_bias_names.push_back(name); + } + for (const auto& name : mhca_op_info->Input("FCBias")) { + fc_bias_names.push_back(name); + } + for (const auto& name : geglu_op_info->Input("FCBias")) { + fc_bias_names.push_back(name); + } + + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Embedding", mhca_op_info->Input("Embedding")); + op_desc.SetInput("FCWeight", fc_weight_names); + op_desc.SetInput("FCBias", fc_bias_names); + op_desc.SetInput("LNScale", ln_scale_names); + op_desc.SetInput("LNBias", ln_bias_names); + op_desc.SetAttr("groups", gn_op_info->GetAttr("groups")); + op_desc.SetAttr("epsilon", gn_op_info->GetAttr("epsilon")); + op_desc.SetInput("ConvBias", + {matched.at("pre__xpu__conv2d_bias")->arg()->name, + matched.at("post__xpu__conv2d_bias")->arg()->name}); + op_desc.SetInput("GNScale", {matched.at("gn_scale")->arg()->name}); + op_desc.SetInput("GNBias", {matched.at("gn_bias")->arg()->name}); + std::vector conv_filter_names = { + matched.at("pre__xpu__conv2d_filter")->arg()->name, + matched.at("post__xpu__conv2d_filter")->arg()->name}; + op_desc.SetInput("ConvWeight", conv_filter_names); + std::vector conv_filter_maxptr_names = { + matched.at("pre__xpu__conv2d_filter")->arg()->name + "_max", + matched.at("post__xpu__conv2d_filter")->arg()->name + "_max"}; + op_desc.SetAttr>("ConvFilterMax", + conv_filter_maxptr_names); + op_desc.SetOutput("Output", + {matched.at("post__xpu__conv2d_output")->arg()->name}); + op_desc.SetAttr>("FCWeightMax", + fc_weight_maxptr_names); + op_desc.SetAttr("head_num", mhsa_op_info->GetAttr("head_num")); + op_desc.SetAttr("size_per_head", + mhsa_op_info->GetAttr("size_per_head")); + op_desc.SetAttr("hidden_dim", + mhsa_op_info->GetAttr("hidden_dim")); + op_desc.SetAttr("embedding_dim", + mhca_op_info->GetAttr("embedding_dim")); + op_desc.SetAttr("gelu_dim", geglu_op_info->GetAttr("gelu_dim")); + + std::vector> strides; + std::vector> paddings; + std::vector> dilations; + std::vector> filter_dims; + std::vector groups; + std::vector conv_vec = {"pre__xpu__conv2d", + "post__xpu__conv2d"}; + for (auto pm_name : conv_vec) { + auto* conv_op_info = matched.at(pm_name)->stmt()->op_info(); + auto strides_tmp = conv_op_info->GetAttr>("strides"); + strides.emplace_back(std::move(strides_tmp)); + auto paddings_tmp = conv_op_info->GetAttr>("paddings"); + paddings.emplace_back(std::move(paddings_tmp)); + auto dilations_tmp = conv_op_info->GetAttr>("dilations"); + dilations.emplace_back(std::move(dilations_tmp)); + std::vector groups_tmp = + conv_op_info->GetAttr>("groups"); + groups.push_back(groups_tmp[0]); + auto filter_dims_tmp = + conv_op_info->GetAttr>("filter_dims"); + filter_dims.emplace_back(std::move(filter_dims_tmp)); + } + op_desc.SetAttr>("Conv_Groups", groups); + op_desc.SetAttr>("Strides", Vec2DTo1D_int(strides)); + op_desc.SetAttr>("Paddings", Vec2DTo1D_int(paddings)); + op_desc.SetAttr>("Dilations", Vec2DTo1D_int(dilations)); + op_desc.SetAttr>("FilterDims", Vec2DTo1D_int(filter_dims)); + + auto spatial_transformer_op = + LiteOpRegistry::Global().Create(op_desc.Type()); + auto* scope = matched.at("gn")->stmt()->op()->scope(); + UpdateWeight(scope, conv_filter_names, conv_filter_maxptr_names, false); + spatial_transformer_op->Attach(op_desc, scope); + spatial_transformer_op->SetValidPlaces( + matched.at("gn")->stmt()->op()->valid_places()); + auto kernels = spatial_transformer_op->CreateKernels( + spatial_transformer_op->valid_places()); + auto* new_op_node = graph->GraphCreateInstructNode( + spatial_transformer_op, spatial_transformer_op->valid_places()); + + std::vector froms = {"input", + "gn_scale", + "gn_bias", + "pre__xpu__conv2d_bias", + "pre__xpu__conv2d_filter", + "__xpu__multihead_self_attn_fcbias", + "__xpu__multihead_self_attn_lnbias", + "__xpu__multihead_self_attn_lnscale", + "__xpu__multihead_self_attn_fcweight0", + "__xpu__multihead_self_attn_fcweight1", + "__xpu__multihead_self_attn_fcweight2", + "__xpu__multihead_self_attn_fcweight3", + "__xpu__multihead_cross_attn_embedding", + "__xpu__multihead_cross_attn_fcbias", + "__xpu__multihead_cross_attn_lnbias", + "__xpu__multihead_cross_attn_lnscale", + "__xpu__multihead_cross_attn_fcweight0", + "__xpu__multihead_cross_attn_fcweight1", + "__xpu__multihead_cross_attn_fcweight2", + "__xpu__multihead_cross_attn_fcweight3", + "__xpu__geglu_fcbias0", + "__xpu__geglu_fcbias1", + "__xpu__geglu_lnbias", + "__xpu__geglu_lnscale", + "__xpu__geglu_fcweight0", + "__xpu__geglu_fcweight1", + "post__xpu__conv2d_bias", + "post__xpu__conv2d_filter"}; + + for (auto& from : froms) { + IR_NODE_LINK_TO(matched.at(from), new_op_node); + } + + IR_OP_VAR_LINK(new_op_node, matched.at("post__xpu__conv2d_output")); + } + + private: + void UpdateWeight(Scope* scope, + const std::vector& fc_weight_names, + const std::vector& fc_weight_max_names, + bool trans) { + std::vector weight_tensor_vec(fc_weight_names.size(), nullptr); + std::vector weight_dims_vec(fc_weight_names.size()); + std::vector weight_len_vec(fc_weight_names.size()); + + for (size_t i = 0; i < fc_weight_names.size(); ++i) { + weight_tensor_vec[i] = scope->FindMutableTensor(fc_weight_names[i]); + CHECK(weight_tensor_vec[i] != nullptr); + weight_dims_vec[i] = weight_tensor_vec[i]->dims(); + weight_len_vec[i] = weight_tensor_vec[i]->numel(); + if (trans && i > 0) { + CHECK_EQ(weight_dims_vec[i][0], weight_dims_vec[i - 1][0]); + } + } + for (size_t i = 0; i < fc_weight_names.size(); ++i) { + float* weight_host_ptr = weight_tensor_vec[i]->mutable_data(); + std::unique_ptr weight_host_trans(new float[weight_len_vec[i]]); + std::unique_ptr weight_host_trans_int16( + new int16_t[weight_len_vec[i]]); + if (trans) { + paddle::lite::xpu::math::Transpose(weight_host_ptr, + weight_host_trans.get(), + weight_dims_vec[i][0], + weight_dims_vec[i][1]); + } else { + memcpy(weight_host_trans.get(), + weight_host_ptr, + weight_len_vec[i] * sizeof(float)); + } + float max_f = paddle::lite::xpu::math::FindMaxAbs(weight_host_trans.get(), + weight_len_vec[i]); + paddle::lite::xpu::math::ConvertFP32ToInt16(weight_host_trans.get(), + weight_host_trans_int16.get(), + max_f, + weight_len_vec[i]); + memcpy(weight_tensor_vec[i]->mutable_data(), + weight_host_trans_int16.get(), + weight_len_vec[i] * sizeof(int16_t)); + scope->NewTensor(fc_weight_max_names[i]); + Tensor* weight_maxptr_tensor = + scope->FindMutableTensor(fc_weight_max_names[i]); + weight_maxptr_tensor->Resize({6}); + std::vector weight_maxptr_host(6, max_f); + memcpy(weight_maxptr_tensor->mutable_data(), + weight_maxptr_host.data(), + weight_maxptr_host.size() * sizeof(float)); + } + } +}; + +} // namespace fusion + +class XPUSpatialTransformerfusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + fusion::SpatialTransformerfuser fuser; + fuser(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__spatial_transformer_fuse_pass, + paddle::lite::mir::XPUSpatialTransformerfusePass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/optimizer/optimizer.cc b/lite/core/optimizer/optimizer.cc index ed3e4fdf135..c9cae2b90a9 100644 --- a/lite/core/optimizer/optimizer.cc +++ b/lite/core/optimizer/optimizer.cc @@ -206,6 +206,7 @@ std::unique_ptr RunDefaultOptimizer( "__xpu__multihead_cross_attn_fuse_pass", "__xpu__geglu_fuse_pass", "__xpu__quick_gelu_fuse_pass", + "__xpu__spatial_transformer_fuse_pass", "__xpu__gn_silu_fuse_pass", "__xpu__multi_encoder_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass", diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 975d3b0f7d1..43095e3b73c 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -140,4 +140,5 @@ add_kernel(__xpu__geglu_compute_xpu XPU extra SRCS __xpu__geglu_compute.cc) if(XPU_WITH_XFT) add_kernel(fusion_decoding_compute_xpu XPU extra SRCS fusion_decoding_compute.cc) add_kernel(fusion_unified_decoding_compute_xpu XPU extra SRCS fusion_unified_decoding_compute.cc) + add_kernel(__xpu__spatial_transformer_compute_xpu XPU extra SRCS __xpu__spatial_transformer_compute.cc) endif(XPU_WITH_XFT) diff --git a/lite/kernels/xpu/__xpu__spatial_transformer_compute.cc b/lite/kernels/xpu/__xpu__spatial_transformer_compute.cc new file mode 100644 index 00000000000..0ded0328404 --- /dev/null +++ b/lite/kernels/xpu/__xpu__spatial_transformer_compute.cc @@ -0,0 +1,312 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__spatial_transformer_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { +template + +static std::vector PrepareWeight( + const std::vector& fc_weight) { + std::vector result; + for (auto* weight : fc_weight) { + result.push_back(reinterpret_cast(weight->data())); + } + return result; +} + +template +void XPUSpatialTransformerCompute::PrepareWeightMax( + const std::vector& weight_max, + int max_ptr_len, + std::vector* max_xpu_ptrs) { + int max_value_num = 0; + for (auto max_tensor : weight_max) { + max_value_num += max_tensor->numel(); + } + VLOG(3) << "Total weight max value number: " << max_value_num; + weight_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float)); + float* weight_max_ptr = reinterpret_cast(weight_max_guard_->addr_); + + int offset = 0; + for (auto max_tensor : weight_max) { + float* cur_weight_max_ptr = weight_max_ptr + offset; + auto len = max_tensor->numel(); + VLOG(6) << "weight max value: " << max_tensor->data()[0] << " " + << max_tensor->data()[len - 1]; + std::vector cpu_max(max_ptr_len, max_tensor->data()[0]); + lite::TargetWrapperXPU::MemcpySync(cur_weight_max_ptr, + cpu_max.data(), + sizeof(float) * max_ptr_len, + IoDirection::HtoD); + max_xpu_ptrs->push_back(cur_weight_max_ptr); + offset += max_ptr_len; + } +} + +template +void XPUSpatialTransformerCompute::PrepareFilterMax( + const std::vector& filter_max, + int max_ptr_len, + std::vector* max_xpu_ptrs) { + int max_value_num = 0; + for (auto max_tensor : filter_max) { + max_value_num += max_tensor->numel(); + } + VLOG(3) << "Total weight max value number: " << max_value_num; + filter_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_value_num * sizeof(float)); + float* filter_max_ptr = reinterpret_cast(filter_max_guard_->addr_); + + int offset = 0; + for (auto max_tensor : filter_max) { + float* cur_filter_max_ptr = filter_max_ptr + offset; + auto len = max_tensor->numel(); + VLOG(6) << "weight max value: " << max_tensor->data()[0] << " " + << max_tensor->data()[len - 1]; + std::vector cpu_max(max_ptr_len, max_tensor->data()[0]); + lite::TargetWrapperXPU::MemcpySync(cur_filter_max_ptr, + cpu_max.data(), + sizeof(float) * max_ptr_len, + IoDirection::HtoD); + max_xpu_ptrs->push_back(cur_filter_max_ptr); + offset += max_ptr_len; + } +} + +template +void XPUSpatialTransformerCompute::PrepareForRun() { + auto& ctx = this->ctx_->template As(); + auto& param = this->template Param(); + xft_attn_fc_bias.emplace_back( + const_cast(param.fc_bias[0]->template data()), + xft::xftVec::dim_t{param.fc_bias[0]->dims()[0]}); + xft_attn_fc_bias.emplace_back( + const_cast(param.fc_bias[1]->template data()), + xft::xftVec::dim_t{param.fc_bias[1]->dims()[0]}); + xft_geglu_fc_bias.emplace_back( + const_cast(param.fc_bias[2]->template data()), + xft::xftVec::dim_t{param.fc_bias[2]->dims()[0]}); + xft_geglu_fc_bias.emplace_back( + const_cast(param.fc_bias[3]->template data()), + xft::xftVec::dim_t{param.fc_bias[3]->dims()[0]}); + // prepare scale + for (auto* ln_scale : param.ln_scale) { + xft_ln_weights.emplace_back( + const_cast(ln_scale->template data()), + xft::xftVec::dim_t{ln_scale->dims()[0]}); + } + // prepare ln_bias + for (auto* ln_bias : param.ln_bias) { + xft_ln_bias.emplace_back( + const_cast(ln_bias->template data()), + xft::xftVec::dim_t{ln_bias->dims()[0]}); + } + + // prepare gn_scale + for (auto* gn_scale : param.gn_scale) { + xft_gn_weights.emplace_back( + const_cast(gn_scale->template data()), + xft::xftVec::dim_t{gn_scale->dims()[0]}); + } + // prepare gn_bias + for (auto* gn_bias : param.gn_bias) { + xft_gn_bias.emplace_back( + const_cast(gn_bias->template data()), + xft::xftVec::dim_t{gn_bias->dims()[0]}); + } + // prepare conv bias + for (auto* conv_bias : param.conv_bias) { + xft_conv_bias.emplace_back( + const_cast(conv_bias->template data()), + xft::xftVec::dim_t{conv_bias->dims()[0]}); + } + + arg_fc_weight_int16_ = PrepareWeight(param.fc_weight); + arg_conv_filter_int16_ = PrepareWeight(param.conv_weight); + const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size(); + PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_); + PrepareFilterMax(param.conv_max, XPU_QUANT_SCALE_NUM, &conv_filter_max_); + + int channel = static_cast(param.input->dims()[1]); + int xh = static_cast(param.input->dims()[2]); + int xw = static_cast(param.input->dims()[3]); + int hidden_dim = xh * xw; + int embedding_dim = static_cast(param.embedding->dims()[2]); + + // xft fc weights + xft_q_weights.emplace_back( + const_cast(arg_fc_weight_int16_[0]), + const_cast(fc_weight_max_[0]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_q_weights.emplace_back( + const_cast(arg_fc_weight_int16_[4]), + const_cast(fc_weight_max_[4]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_k_weights.emplace_back( + const_cast(arg_fc_weight_int16_[1]), + const_cast(fc_weight_max_[1]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_k_weights.emplace_back( + const_cast(arg_fc_weight_int16_[5]), + const_cast(fc_weight_max_[5]), + xft::xftMat::dim_t{hidden_dim, embedding_dim}); + xft_v_weights.emplace_back( + const_cast(arg_fc_weight_int16_[2]), + const_cast(fc_weight_max_[2]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_v_weights.emplace_back( + const_cast(arg_fc_weight_int16_[6]), + const_cast(fc_weight_max_[6]), + xft::xftMat::dim_t{hidden_dim, embedding_dim}); + xft_attn_fc_weights.emplace_back( + const_cast(arg_fc_weight_int16_[3]), + const_cast(fc_weight_max_[3]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_attn_fc_weights.emplace_back( + const_cast(arg_fc_weight_int16_[7]), + const_cast(fc_weight_max_[7]), + xft::xftMat::dim_t{hidden_dim, hidden_dim}); + xft_geglu_fc_weights.emplace_back( + const_cast(arg_fc_weight_int16_[8]), + const_cast(fc_weight_max_[8]), + xft::xftMat::dim_t{param.geglu_dim * 2, hidden_dim}); + xft_geglu_fc_weights.emplace_back( + const_cast(arg_fc_weight_int16_[9]), + const_cast(fc_weight_max_[9]), + xft::xftMat::dim_t{hidden_dim, param.geglu_dim}); + for (size_t i = 0; i < arg_conv_filter_int16_.size(); i++) { + int kh = param.filter_dims[i][2]; + int kw = param.filter_dims[i][3]; + xft_conv_weights.emplace_back( + const_cast(arg_conv_filter_int16_[i]), + const_cast(conv_filter_max_[i]), + xft::xftTensor::dim_t{channel, hidden_dim, kh, kw}); + } + st_param.n_head = param.head_num; + st_param.size_per_head = param.size_per_head, + st_param.geglu_dim = param.geglu_dim; + st_param.add_res = true; + st_param.conv_groups = param.conv_groups; + st_param.kernel_dims = param.filter_dims; + st_param.dilations = param.dilations; + st_param.paddings = param.paddings; + st_param.strides = param.strides; + st_param.gn_groups.push_back(param.groups); + st_param.gn_eps.push_back(param.epsilon); +} + +template +void XPUSpatialTransformerCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + const InType* in = param.input->template data(); + const InType* embedding = param.embedding->template data(); + InType* out = param.output->template mutable_data(TARGET(kXPU)); + int batch = static_cast(param.input->dims()[0]); + int hidden_dim = static_cast(param.input->dims()[1]); + int channel = hidden_dim; + int xh = static_cast(param.input->dims()[2]); + int xw = static_cast(param.input->dims()[3]); + int embedding_seq = static_cast(param.embedding->dims()[1]); + int embedding_dim = static_cast(param.embedding->dims()[2]); + // input + xft::xftTensor in_tensor( + const_cast(in), nullptr, {batch, channel, xh, xw}); + xft::xftTensor embedding_tensor( + const_cast(embedding), + nullptr, + {batch, embedding_seq, embedding_dim}); + // output + xft::xftTensor output_tensor(out, {batch, channel, xh, xw}); + int r = xft::st_spatial_transformer_fusion( + ctx.GetRawContext(), + in_tensor, + embedding_tensor, + xft_ln_weights, + xft_ln_bias, + xft_gn_weights, + xft_gn_bias, + xft_q_weights, + xft_k_weights, + xft_v_weights, + xft_attn_fc_weights, + xft_attn_fc_bias, + xft_geglu_fc_weights, + xft_geglu_fc_bias, + xft_conv_weights, + xft_conv_bias, + &output_tensor, + st_param); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace xpu = paddle::lite::kernels::xpu; + +// using XPUSpatialTransformer_FP32 = xpu::XPUSpatialTransformerCompute; +using XPUSpatialTransformer_FP16 = + xpu::XPUSpatialTransformerCompute; + +// REGISTER_LITE_KERNEL( +// __xpu__spatial_transformer, +// kXPU, +// kFloat, +// kNCHW, +// XPUSpatialTransformer_FP32, +// def) +// .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("Embedding", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("ConvWeight", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("ConvBias", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("GNScale", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindInput("GNBias", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) +// .Finalize(); +REGISTER_LITE_KERNEL(__xpu__spatial_transformer, + kXPU, + kFP16, + kNCHW, + XPUSpatialTransformer_FP16, + def_fp16) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Embedding", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("FCWeight", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("ConvWeight", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("ConvBias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("GNScale", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("GNBias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__spatial_transformer_compute.h b/lite/kernels/xpu/__xpu__spatial_transformer_compute.h new file mode 100644 index 00000000000..bcbfa6f6e07 --- /dev/null +++ b/lite/kernels/xpu/__xpu__spatial_transformer_compute.h @@ -0,0 +1,108 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#define USE_XFT + +#ifdef USE_XFT +#include "layers/spatial_transformer.h" +#endif + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +namespace xft = baidu::xpu::xft; + +template +struct identity { + typedef T type; +}; + +template +class XPUSpatialTransformerCompute : public KernelLite { + public: + using param_t = operators::XPUSpatialTransformerParam; + + virtual void PrepareForRun(); + + virtual void Run(); + + virtual ~XPUSpatialTransformerCompute() = default; + + private: +#ifdef USE_XFT + xft::SpatialTransformerFusionParam st_param; + std::vector> xft_gn_weights; + std::vector> xft_gn_bias; + std::vector> xft_ln_weights; + std::vector> xft_ln_bias; + std::vector> xft_q_weights; + std::vector> xft_k_weights; + std::vector> xft_v_weights; + std::vector> xft_attn_fc_weights; + std::vector> xft_attn_fc_bias; + std::vector> xft_geglu_fc_weights; + std::vector> xft_geglu_fc_bias; + std::vector> xft_conv_weights; + std::vector> xft_conv_bias; +#else + std::vector arg_fc_bias_; + std::vector arg_ln_scale_; + std::vector arg_ln_bias_; + std::vector arg_gn_scale_; + std::vector arg_gn_bias_; + std::vector arg_conv_bias_; +#endif + std::vector arg_fc_weight_int16_; + std::vector arg_conv_filter_int16_; + std::vector fc_weight_max_; + std::vector conv_filter_max_; + XPUScratchPadGuard weight_max_guard_; + XPUScratchPadGuard filter_max_guard_; + + template + std::vector *GetWeight() { + LOG(FATAL) << "Invalid Weight Type"; + return nullptr; + } + + std::vector *GetWeight() { return &arg_fc_weight_int16_; } + + template + std::vector *GetFilter() { + LOG(FATAL) << "Invalid Weight Type"; + return nullptr; + } + + std::vector *GetFilter() { return &arg_conv_filter_int16_; } + + void PrepareWeightMax(const std::vector &weight_max, + int max_ptr_len, + std::vector *max_xpu_ptrs); + void PrepareFilterMax(const std::vector &filter_max, + int max_ptr_len, + std::vector *max_xpu_ptrs); +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index bb84d6cbe8d..57d8ef1ecfa 100755 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -256,6 +256,7 @@ add_operator(__xpu__gn_silu_op extra SRCS __xpu__gn_silu_op.cc) add_operator(__xpu__multihead_self_attn_op extra SRCS __xpu__multihead_self_attn_op.cc) add_operator(__xpu__multihead_cross_attn_op extra SRCS __xpu__multihead_cross_attn_op.cc) add_operator(__xpu__geglu_op extra SRCS __xpu__geglu_op.cc) +add_operator(__xpu__spatial_transformer_op XPU extra SRCS __xpu__spatial_transformer_op.cc) if(XPU_WITH_XFT) add_operator(fusion_decoding_op extra SRCS fusion_decoding_op.cc) diff --git a/lite/operators/__xpu__spatial_transformer_op.cc b/lite/operators/__xpu__spatial_transformer_op.cc new file mode 100644 index 00000000000..2c2c48f46b2 --- /dev/null +++ b/lite/operators/__xpu__spatial_transformer_op.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__spatial_transformer_op.h" +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +static std::vector> Vec1DTo2D_int(const std::vector& vec, + int dim) { + std::vector> res; + for (size_t i = 0; i < vec.size(); i += dim) { + std::vector tmp; + for (size_t j = 0; j < dim; j++) { + tmp.push_back(vec[i + j]); + } + res.emplace_back(std::move(tmp)); + } + return res; +} + +bool XPUSpatialTransformerOp::CheckShape() const { + CHECK_EQ(param_.input->dims().size(), 4UL); + return true; +} + +bool XPUSpatialTransformerOp::InferShapeImpl() const { + param_.output->Resize(param_.input->dims()); + return true; +} + +bool XPUSpatialTransformerOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.input = scope->FindTensor(op_desc.Input("Input").front()); + param_.embedding = scope->FindTensor(op_desc.Input("Embedding").front()); + param_.output = scope->FindMutableTensor(op_desc.Output("Output").front()); + + param_.fc_weight.clear(); + for (auto& name : op_desc.Input("FCWeight")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.fc_weight.push_back(t); + } + param_.fc_bias.clear(); + for (auto& name : op_desc.Input("FCBias")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.fc_bias.push_back(t); + } + param_.ln_scale.clear(); + for (auto& name : op_desc.Input("LNScale")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.ln_scale.push_back(t); + } + param_.ln_bias.clear(); + for (auto& name : op_desc.Input("LNBias")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.ln_bias.push_back(t); + } + param_.conv_weight.clear(); + for (auto& name : op_desc.Input("ConvWeight")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.conv_weight.push_back(t); + } + param_.conv_bias.clear(); + for (auto& name : op_desc.Input("ConvBias")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.conv_bias.push_back(t); + } + param_.gn_scale.clear(); + for (auto& name : op_desc.Input("GNScale")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.gn_scale.push_back(t); + } + param_.gn_bias.clear(); + for (auto& name : op_desc.Input("GNBias")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.gn_bias.push_back(t); + } + + param_.hidden_dim = op_desc.GetAttr("hidden_dim"); + param_.head_num = op_desc.GetAttr("head_num"); + param_.size_per_head = op_desc.GetAttr("size_per_head"); + param_.embedding_dim = op_desc.GetAttr("embedding_dim"); + param_.geglu_dim = op_desc.GetAttr("gelu_dim"); + param_.groups = op_desc.GetAttr("groups"); + param_.epsilon = op_desc.GetAttr("epsilon"); + + param_.weight_max.clear(); + for (const auto& weight_max_tensor : + op_desc.GetAttr>("FCWeightMax")) { + auto tensor = scope->FindMutableTensor(weight_max_tensor); + CHECK(tensor != nullptr); + param_.weight_max.push_back(tensor); + } + param_.conv_max.clear(); + for (const auto& weight_max_tensor : + op_desc.GetAttr>("ConvFilterMax")) { + auto tensor = scope->FindMutableTensor(weight_max_tensor); + CHECK(tensor != nullptr); + param_.conv_max.push_back(tensor); + } + param_.conv_groups = op_desc.GetAttr>("Conv_Groups"); + param_.strides = + Vec1DTo2D_int(op_desc.GetAttr>("Strides"), 2); + param_.paddings = + Vec1DTo2D_int(op_desc.GetAttr>("Paddings"), 4); + param_.dilations = + Vec1DTo2D_int(op_desc.GetAttr>("Dilations"), 2); + param_.filter_dims = + Vec1DTo2D_int(op_desc.GetAttr>("FilterDims"), 4); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__spatial_transformer, + paddle::lite::operators::XPUSpatialTransformerOp); diff --git a/lite/operators/__xpu__spatial_transformer_op.h b/lite/operators/__xpu__spatial_transformer_op.h new file mode 100644 index 00000000000..329b87bf250 --- /dev/null +++ b/lite/operators/__xpu__spatial_transformer_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUSpatialTransformerOp : public OpLite { + public: + XPUSpatialTransformerOp() {} + + explicit XPUSpatialTransformerOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUSpatialTransformer"; } + + private: + mutable XPUSpatialTransformerParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 238c268e592..cf2d9f80a52 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1797,6 +1797,34 @@ struct XPUMultiEncoderParam : ParamBase { bool already_qkv_fusion{false}; // qkv is already fusion in graph }; +struct XPUSpatialTransformerParam : ParamBase { + const lite::Tensor* input{}; + const lite::Tensor* embedding{}; + std::vector fc_weight; + std::vector fc_bias; + std::vector ln_scale; + std::vector ln_bias; + std::vector conv_weight; + std::vector conv_bias; + std::vector gn_scale; + std::vector gn_bias; + lite::Tensor* output{nullptr}; + std::vector weight_max{}; + std::vector conv_max{}; + std::vector conv_groups{}; + std::vector> strides{}; + std::vector> paddings{}; + std::vector> dilations{}; + std::vector> filter_dims{}; + int head_num{}; + int size_per_head{}; + int hidden_dim{}; + int embedding_dim{}; + int geglu_dim{}; + int groups{}; + int epsilon{}; +}; + struct XPUGnSiluParam : ParamBase { const lite::Tensor* input{}; std::vector gn_scale;