Skip to content

Commit

Permalink
[arm] add v9 gemm implementation (PaddlePaddle#9083)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjiaoAngel authored and wangbingnan03 committed Jun 13, 2022
1 parent 111d00f commit 3a00d24
Show file tree
Hide file tree
Showing 13 changed files with 1,724 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cmake/postproject.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ if(ANDROID)
if ((ARM_TARGET_ARCH_ABI STREQUAL "armv8"))
if (${ANDROID_NDK_MAJOR})
if(${ANDROID_NDK_MAJOR} GREATER_EQUAL "23")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+sve2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+sve2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+sve2+fp16+dotprod+f32mm+i8mm")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+sve2+fp16+dotprod+f32mm+i8mm")
else()
message(FATAL_ERROR "NDK VERSION: ${ANDROID_NDK_MAJOR}, however it must be greater equal 23 when sve2 is ON")
endif()
Expand Down
2 changes: 1 addition & 1 deletion lite/backends/arm/math/fp16/common_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ typedef __fp16 float16_t;
const dtype *inptr12, const dtype *inptr13, const dtype *inptr14, \
const dtype *inptr15, int numa, int numb

#define X_BLOCK_COMPUTE(llc_size, MBLOCK, NBLOCK, KBLOCK, beta) \
#define X_BLOCK_COMPUTE_FP16(llc_size, MBLOCK, NBLOCK, KBLOCK, beta) \
/* MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2*/ \
int x_block = \
(llc_size - (MBLOCK * K)) / (sizeof(float16_t) * (K + MBLOCK)); \
Expand Down
4 changes: 2 additions & 2 deletions lite/backends/arm/math/fp16/gemm_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,7 @@ void gemm_prepack_8x16(bool is_transB,

float16x8_t valpha = vdupq_n_f16(static_cast<float16_t>(local_alpha));
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
X_BLOCK_COMPUTE(llc_size, MBLOCK_FP16, NBLOCK_FP16, KBLOCK_FP16, beta)
X_BLOCK_COMPUTE_FP16(llc_size, MBLOCK_FP16, NBLOCK_FP16, KBLOCK_FP16, beta)
float16x8_t vbeta = vdupq_n_f16(beta);
float16x8_t vzero = vdupq_n_f16(0.f);
float16x8_t voffset = vdupq_n_f16(offset);
Expand Down Expand Up @@ -2594,7 +2594,7 @@ void gemm_prepack_8x8(bool is_transB,
}
float16_t alpha_ptr[40] = {0.f};
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
X_BLOCK_COMPUTE(llc_size, MBLOCK_FP16, NBLOCK_FP16, KBLOCK_FP16, beta)
X_BLOCK_COMPUTE_FP16(llc_size, MBLOCK_FP16, NBLOCK_FP16, KBLOCK_FP16, beta)
tail_pre = tail_pre * 8 + flag_act;
k_pre = k_pre * 32 + tail_pre;
for (int i = 0; i < 8; i++) {
Expand Down
5 changes: 5 additions & 0 deletions lite/backends/arm/math/packed_sgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ void loadb(

void loadb_trans(
float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax);
void loadb_eight(
float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax);

void loadb_trans_eight(
float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax);
void sgemm_prepack(bool is_transB,
int M,
int N,
Expand Down
312 changes: 312 additions & 0 deletions lite/backends/arm/math/sve/conv_impl_sve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/arm/math/sve/conv_impl_sve.h"
#include <arm_neon.h>
#include <algorithm>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/sve/gemm_sve.h"
#include "lite/core/context.h"
#include "lite/core/parallel_defines.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
#ifdef ENABLE_ARM_FP16
#include "lite/backends/arm/math/fp16/conv_impl_fp16.h"
#endif

namespace paddle {
namespace lite {
namespace arm {
namespace math {
namespace sve {

/**
* \brief convolution function for kernel size 1x1, stride size 1, gemm
* implementation
*/
template <typename Dtype>
void conv1x1s1_gemm_sve(const Dtype* i_data,
Dtype* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const Dtype* weights,
const Dtype* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
int channel_size_out = ow * oh;
int channel_size_in = win * ih;

const int group = param.groups;
const int m = oc / group;
const int n = oh * ow;
const int k = ic / group;

bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;

auto act_param = param.activation_param;

int hblock = get_hblock_sve(ctx, m, sizeof(Dtype));
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k;
if (n > 1 && m > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
}
//! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) {
// dC
for (int g = 0; g < group; ++g) {
Dtype* dout_group =
static_cast<Dtype*>(o_data) + (b * oc + g * m) * channel_size_out;
const Dtype* din_group = static_cast<const Dtype*>(i_data) +
(b * ic + g * k) * channel_size_in;
const Dtype* weights_group =
static_cast<const Dtype*>(weights) + g * weights_size_per_group;
const Dtype* bias_group = static_cast<const Dtype*>(bias) + g * m;

sgemm_prepack_sve<Dtype>(false,
m,
n,
k,
weights_group,
k,
din_group,
n,
0.f,
dout_group,
n,
bias_group,
flag_bias,
act_param,
ctx);
}
}
}

template void conv1x1s1_gemm_sve<float>(const float* i_data,
float* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);

#ifdef ENABLE_ARM_FP16
template void conv1x1s1_gemm_sve<float16_t>(const float16_t* i_data,
float16_t* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float16_t* weights,
const float16_t* bias,
const operators::ConvParam& param,
ARMContext* ctx);
#endif

/**
* \brief convolution function for kernel size 3x3, stride size 2, gemm
* implementation
*/
template <>
void conv_im2col_gemm_sve(const float* i_data,
float* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
const int group = param.groups;
auto filter_dims = param.filter->dims();
const int kernel_h = filter_dims[2];
const int kernel_w = filter_dims[3]; // nchw
const int m = oc / group;
const int n = oh * ow;
const int k = ic * kernel_h * kernel_w / group;
const int chin_per_group = ic / group;
int channel_size_out = ow * oh;
int channel_size_in = win * ih;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
int hblock = get_hblock_sve(ctx, m, 4);
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k;

auto act_param = param.activation_param;
if (n > 1 && m > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
}

float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);

auto paddings = *param.paddings;
auto dilations = *param.dilations;
//! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) {
// dC
for (int g = 0; g < group; ++g) {
float* dout_group = o_data + (b * oc + g * m) * channel_size_out;
const float* din_group =
i_data + (b * ic + g * chin_per_group) * channel_size_in;
const float* weights_group = weights + g * weights_size_per_group;
const float* bias_group = bias + g * m;
float* dB = tmp_work_space;
im2col<float>(din_group,
chin_per_group,
ih,
win,
kernel_h,
kernel_w,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
param.strides[0],
param.strides[1],
dilations[0],
dilations[1],
dB);
int ldb = n;
sgemm_prepack_sve<float>(false,
m,
n,
k,
weights_group,
k,
dB,
ldb,
0.f,
dout_group,
n,
bias_group,
flag_bias,
act_param,
ctx);
}
}
}

#ifdef ENABLE_ARM_FP16
template <>
void conv_im2col_gemm_sve(const float16_t* i_data,
float16_t* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float16_t* weights,
const float16_t* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
const int group = param.groups;
auto filter_dims = param.filter->dims();
const int kernel_h = filter_dims[2];
const int kernel_w = filter_dims[3]; // nchw
const int m = oc / group;
const int n = oh * ow;
const int k = ic * kernel_h * kernel_w / group;
const int chin_per_group = ic / group;
int channel_size_out = ow * oh;
int channel_size_in = win * ih;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
int hblock = get_hblock_sve(ctx, m, 2);
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k;

auto act_param = param.activation_param;
if (n > 1 && m > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
}

float16_t* tmp_work_space =
ctx->workspace_data<float16_t>() + ctx->llc_size() / sizeof(float16_t);

auto paddings = *param.paddings;
auto dilations = *param.dilations;
//! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) {
// dC
for (int g = 0; g < group; ++g) {
float16_t* dout_group = o_data + (b * oc + g * m) * channel_size_out;
const float16_t* din_group =
i_data + (b * ic + g * chin_per_group) * channel_size_in;
const float16_t* weights_group = weights + g * weights_size_per_group;
const float16_t* bias_group = bias + g * m;
float16_t* dB = tmp_work_space;
fp16::im2col_fp16(din_group,
chin_per_group,
ih,
win,
kernel_h,
kernel_w,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
dilations[0],
dilations[1],
dB,
param.strides[0],
param.strides[1]);
int ldb = n;
sgemm_prepack_sve<float16_t>(false,
m,
n,
k,
weights_group,
k,
dB,
ldb,
0.f,
dout_group,
n,
bias_group,
flag_bias,
act_param,
ctx);
}
}
}
#endif

} // namespace sve
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
Loading

0 comments on commit 3a00d24

Please sign in to comment.