Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] fix sve backends bug(matmul_v2&conv) #9696

Merged
merged 4 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions lite/backends/arm/math/gemm_s8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.

#include "lite/backends/arm/math/gemm_s8.h"
#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
#include "lite/backends/arm/math/sve/gemm_sve_i8mm.h"
#endif

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -112,6 +115,113 @@ template void gemm_s8<int8_t>(bool is_transA,
const operators::ActivationParam act_param,
ARMContext* ctx);

#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
template <typename Dtype>
void gemm_sve(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) {
if (N == 1) {
gemv_int8(A, B, C, is_transA, M, K, scale, is_bias, bias, act_param, ctx);
return;
}
if (M == 1) {
#ifdef TARGET_IOS
float* bias_ptr = new float[N];
float* scale_ptr = new float[N];
#else
float bias_ptr[N]; // NOLINT
float scale_ptr[N]; // NOLINT
#endif
if (is_bias) {
for (int i = 0; i < N; i++) {
bias_ptr[i] = bias[0];
}
}
for (int i = 0; i < N; i++) {
scale_ptr[i] = scale[0];
}
gemv_int8(B,
A,
C,
!is_transB,
N,
K,
scale_ptr,
is_bias,
bias_ptr,
act_param,
ctx);
#ifdef TARGET_IOS
delete[] bias_ptr;
delete[] scale_ptr;
#endif
return;
}

//! prepack
Tensor tpackedA_sve;
int hblock_sve = paddle::lite::arm::math::sve::get_hblock_int8_sve(ctx);
int round_up_a_sve = ((hblock_sve + M - 1) / hblock_sve) * hblock_sve;
int round_up_k_sve = 8 * ((K + 7) / 8);
tpackedA_sve.Resize({round_up_a_sve * round_up_k_sve});
int lda = is_transA ? M : K;
paddle::lite::arm::math::sve::prepackA_int8_sve(
tpackedA_sve.mutable_data<int8_t>(), A, lda, 0, M, 0, K, is_transA, ctx);
// sve
lite::arm::math::sve::gemm_prepack_int8_sve<Dtype>(
tpackedA_sve.data<int8_t>(),
B,
bias,
C,
M,
N,
K,
is_bias,
is_transB,
scale,
act_param,
ctx);
}

template void gemm_sve<float>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
float* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);

template void gemm_sve<int8_t>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
int8_t* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
#endif

} // namespace math
} // namespace arm
} // namespace lite
Expand Down
16 changes: 16 additions & 0 deletions lite/backends/arm/math/gemm_s8.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ void gemm_s8(bool is_transA,
const operators::ActivationParam act_param,
ARMContext* ctx);

#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
template <typename Dtype>
void gemm_sve(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
#endif
} // namespace math
} // namespace arm
} // namespace lite
Expand Down
Loading