Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
linkk08 committed Mar 3, 2023
1 parent 2cd3a88 commit 512600b
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 70 deletions.
65 changes: 1 addition & 64 deletions lite/kernels/xpu/__xpu__spatial_transformer_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ template <typename InType, PrecisionType PType>
void XPUSpatialTransformerCompute<InType, PType>::PrepareForRun() {
auto& ctx = this->ctx_->template As<XPUContext>();
auto& param = this->template Param<param_t>();
// prepare fc bias
#ifdef USE_XFT
xft_attn_fc_bias.emplace_back(
const_cast<float*>(param.fc_bias[0]->template data<float>()),
xft::xftVec<float>::dim_t{param.fc_bias[0]->dims()[0]});
Expand All @@ -110,69 +108,44 @@ void XPUSpatialTransformerCompute<InType, PType>::PrepareForRun() {
xft_geglu_fc_bias.emplace_back(
const_cast<float*>(param.fc_bias[3]->template data<float>()),
xft::xftVec<float>::dim_t{param.fc_bias[3]->dims()[0]});
#else
for (auto* fc_bias : param.fc_bias) {
arg_fc_bias_.push_back(fc_bias->template data<float>());
}
#endif
// prepare scale
for (auto* ln_scale : param.ln_scale) {
#ifdef USE_XFT
xft_ln_weights.emplace_back(
const_cast<float*>(ln_scale->template data<float>()),
xft::xftVec<float>::dim_t{ln_scale->dims()[0]});
#else
arg_ln_scale_.push_back(ln_scale->template data<float>());
#endif
}
// prepare ln_bias
for (auto* ln_bias : param.ln_bias) {
#ifdef USE_XFT
xft_ln_bias.emplace_back(
const_cast<float*>(ln_bias->template data<float>()),
xft::xftVec<float>::dim_t{ln_bias->dims()[0]});
#else
arg_ln_bias_.push_back(ln_bias->template data<float>());
#endif
}

// prepare gn_scale
for (auto* gn_scale : param.gn_scale) {
#ifdef USE_XFT
xft_gn_weights.emplace_back(
const_cast<float*>(gn_scale->template data<float>()),
xft::xftVec<float>::dim_t{gn_scale->dims()[0]});
#else
arg_gn_scale_.push_back(gn_scale->template data<float>());
#endif
}
// prepare gn_bias
for (auto* gn_bias : param.gn_bias) {
#ifdef USE_XFT
xft_gn_bias.emplace_back(
const_cast<float*>(gn_bias->template data<float>()),
xft::xftVec<float>::dim_t{gn_bias->dims()[0]});
#else
arg_gn_bias_.push_back(gn_bias->template data<float>());
#endif
}
// prepare conv bias
for (auto* conv_bias : param.conv_bias) {
#ifdef USE_XFT
xft_conv_bias.emplace_back(
const_cast<float*>(conv_bias->template data<float>()),
xft::xftVec<float>::dim_t{conv_bias->dims()[0]});
#else
arg_conv_bias_.push_back(conv_bias->template data<float>());
#endif
}

arg_fc_weight_int16_ = PrepareWeight<int16_t>(param.fc_weight);
arg_conv_filter_int16_ = PrepareWeight<int16_t>(param.conv_weight);
const int XPU_QUANT_SCALE_NUM = ctx.GetRawContext()->max_ptr_size();
PrepareWeightMax(param.weight_max, XPU_QUANT_SCALE_NUM, &fc_weight_max_);
PrepareFilterMax(param.conv_max, XPU_QUANT_SCALE_NUM, &conv_filter_max_);
#ifdef USE_XFT

int channel = static_cast<int>(param.input->dims()[1]);
int xh = static_cast<int>(param.input->dims()[2]);
int xw = static_cast<int>(param.input->dims()[3]);
Expand Down Expand Up @@ -239,7 +212,6 @@ void XPUSpatialTransformerCompute<InType, PType>::PrepareForRun() {
st_param.strides = param.strides;
st_param.gn_groups.push_back(param.groups);
st_param.gn_eps.push_back(param.epsilon);
#endif
}

template <typename InType, PrecisionType PType>
Expand All @@ -256,7 +228,6 @@ void XPUSpatialTransformerCompute<InType, PType>::Run() {
int xw = static_cast<int>(param.input->dims()[3]);
int embedding_seq = static_cast<int>(param.embedding->dims()[1]);
int embedding_dim = static_cast<int>(param.embedding->dims()[2]);
#ifdef USE_XFT
// input
xft::xftTensor<InType, 4> in_tensor(
const_cast<InType*>(in), nullptr, {batch, channel, xh, xw});
Expand Down Expand Up @@ -286,40 +257,6 @@ void XPUSpatialTransformerCompute<InType, PType>::Run() {
&output_tensor,
st_param);
CHECK_EQ(r, 0);
#else
int r = xdnn::spatial_transformer_fusion<InType, int16_t, InType, int16_t>(
ctx.GetRawContext(),
in,
embedding,
*(XPUSpatialTransformerCompute::get_weight<int16_t>()),
*(XPUSpatialTransformerCompute::get_filter<int16_t>()),
out,
arg_fc_bias_,
arg_conv_bias_,
arg_ln_scale_,
arg_ln_bias_,
arg_gn_scale_,
arg_gn_bias_,
fc_weight_max_,
conv_filter_max_,
param.filter_dims,
param.dilations,
param.paddings,
param.strides,
param.conv_groups,
batch,
param.head_num,
param.size_per_head,
xh,
xw,
hidden_dim,
embedding_seq,
param.embedding_dim,
param.groups,
param.epsilon,
param.geglu_dim);
CHECK_EQ(r, 0);
#endif
}

} // namespace xpu
Expand Down
5 changes: 0 additions & 5 deletions lite/operators/__xpu__spatial_transformer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ bool XPUSpatialTransformerOp::CheckShape() const {
}

bool XPUSpatialTransformerOp::InferShapeImpl() const {
// auto input_shape = param_.input->dims();
// auto batch_size = input_shape[0];
// auto channel = input_shape[1];
// auto height = input_shape[2];
// auto weight = input_shape[3];
param_.output->Resize(param_.input->dims());
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion lite/operators/__xpu__spatial_transformer_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 512600b

Please sign in to comment.