Skip to content

Commit

Permalink
[arm] add fc+relu6 pass support (#8755)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjiaoAngel authored Apr 12, 2022
1 parent d8dd81e commit ae52d13
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 300 deletions.
3 changes: 2 additions & 1 deletion docs/quick_start/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
答:更换当前 Paddle Lite 预测库为带 `with_extra = ON` 标签的预编译库。

3、ARM CPU 端多线程支持情况,某些case下,多线程没有效果?
答:gcc 编译模式下,V7/V8 多线程均支持;clang 编译模式下,V8 支持多线程,V7 只能跑单线程

答:gcc 编译模式下,V7/V8 多线程均支持;clang 编译模式下,V8 支持多线程,V7 只支持单线程

### 模型转换

Expand Down
184 changes: 104 additions & 80 deletions lite/backends/arm/math/fp16/funcs_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,106 +21,122 @@ namespace arm {
namespace math {
namespace fp16 {

#define LOADA_DATA_32 \
float16x8_t vin1 = vld1q_f16(ptr_out); \
float16x8_t vb1 = vld1q_f16(ptr_bias); \
float16x8_t vin2 = vld1q_f16(ptr_out + 8); \
float16x8_t vb2 = vld1q_f16(ptr_bias + 8); \
float16x8_t vin3 = vld1q_f16(ptr_out + 16); \
float16x8_t vb3 = vld1q_f16(ptr_bias + 16); \
float16x8_t vin4 = vld1q_f16(ptr_out + 24); \
float16x8_t vb4 = vld1q_f16(ptr_bias + 24); \
float16x8_t vout1 = vaddq_f16(vin1, vb1); \
float16x8_t vout2 = vaddq_f16(vin2, vb2); \
float16x8_t vout3 = vaddq_f16(vin3, vb3); \
float16x8_t vout4 = vaddq_f16(vin4, vb4);
#define RELU_32 \
vout1 = vmaxq_f16(vout1, vzero); \
vout2 = vmaxq_f16(vout2, vzero); \
vout3 = vmaxq_f16(vout3, vzero); \
vout4 = vmaxq_f16(vout4, vzero);
#define RELU6_32 \
vout1 = vminq_f16(vout1, valpha); \
vout2 = vminq_f16(vout2, valpha); \
vout3 = vminq_f16(vout3, valpha); \
vout4 = vminq_f16(vout4, valpha);
#define STORE_32 \
vst1q_f16(ptr_out, vout1); \
vst1q_f16(ptr_out + 8, vout2); \
vst1q_f16(ptr_out + 16, vout3); \
vst1q_f16(ptr_out + 24, vout4); \
ptr_out += 32; \
ptr_bias += 32;
#define LOADA_DATA_8 \
float16x8_t vin1 = vld1q_f16(ptr_out); \
float16x8_t vb1 = vld1q_f16(ptr_bias); \
float16x8_t vout1 = vaddq_f16(vin1, vb1);
#define RELU_8 vout1 = vmaxq_f16(vout1, vzero);
#define RELU6_8 vout1 = vminq_f16(vout1, valpha);
#define STORE_8 \
vst1q_f16(ptr_out, vout1); \
ptr_out += 8; \
ptr_bias += 8;

template <>
void fill_bias_fc<float16_t>(float16_t *out,
const float16_t *bias,
int num,
int channel,
bool flag_relu) {
const operators::ActivationParam *act_param) {
int cnt = channel >> 5;
int remain = channel & 31;
int cnt_num = remain >> 3;
int cnt_rem = remain & 7;
if (flag_relu) {
if (act_param != nullptr && act_param->has_active) {
float16x8_t vzero = vdupq_n_f16(0.f);
for (int j = 0; j < num; ++j) {
const float16_t *ptr_bias = bias;
float16_t *ptr_out = out + j * channel;

for (int i = 0; i < cnt; ++i) {
float16x8_t vin1 = vld1q_f16(ptr_out);
float16x8_t vb1 = vld1q_f16(ptr_bias);

float16x8_t vin2 = vld1q_f16(ptr_out + 8);
float16x8_t vb2 = vld1q_f16(ptr_bias + 8);

float16x8_t vin3 = vld1q_f16(ptr_out + 16);
float16x8_t vb3 = vld1q_f16(ptr_bias + 16);

float16x8_t vin4 = vld1q_f16(ptr_out + 24);
float16x8_t vb4 = vld1q_f16(ptr_bias + 24);

float16x8_t vout1 = vaddq_f16(vin1, vb1);
float16x8_t vout2 = vaddq_f16(vin2, vb2);
float16x8_t vout3 = vaddq_f16(vin3, vb3);
float16x8_t vout4 = vaddq_f16(vin4, vb4);

vout1 = vmaxq_f16(vout1, vzero);
vout2 = vmaxq_f16(vout2, vzero);
vout3 = vmaxq_f16(vout3, vzero);
vout4 = vmaxq_f16(vout4, vzero);

vst1q_f16(ptr_out, vout1);
vst1q_f16(ptr_out + 8, vout2);
vst1q_f16(ptr_out + 16, vout3);
vst1q_f16(ptr_out + 24, vout4);

ptr_out += 32;
ptr_bias += 32;
if (act_param->active_type == lite_api::ActivationType::kRelu) {
for (int j = 0; j < num; ++j) {
const float16_t *ptr_bias = bias;
float16_t *ptr_out = out + j * channel;

for (int i = 0; i < cnt; ++i) {
LOADA_DATA_32
RELU_32
STORE_32
}
for (int i = 0; i < cnt_num; i++) {
LOADA_DATA_8
RELU_8
STORE_8
}
for (int i = 0; i < cnt_rem; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f;
ptr_out++;
}
}
for (int i = 0; i < cnt_num; i++) {
float16x8_t vin1 = vld1q_f16(ptr_out);
float16x8_t vb1 = vld1q_f16(ptr_bias);
float16x8_t vout1 = vaddq_f16(vin1, vb1);
vout1 = vmaxq_f16(vout1, vzero);
vst1q_f16(ptr_out, vout1);
ptr_out += 8;
ptr_bias += 8;
}
for (int i = 0; i < cnt_rem; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f;
ptr_out++;
} else if (act_param->active_type == lite_api::ActivationType::kRelu6) {
float16_t alpha = static_cast<float16_t>(act_param->Relu_clipped_coef);
float16x8_t valpha = vdupq_n_f16(alpha);
for (int j = 0; j < num; ++j) {
const float16_t *ptr_bias = bias;
float16_t *ptr_out = out + j * channel;

for (int i = 0; i < cnt; ++i) {
LOADA_DATA_32
RELU_32
RELU6_32
STORE_32
}
for (int i = 0; i < cnt_num; i++) {
LOADA_DATA_8
RELU_8
RELU6_8
STORE_8
}
for (int i = 0; i < cnt_rem; ++i) {
*ptr_out += *(ptr_bias++);
*ptr_out =
*ptr_out > 0.f ? ((*ptr_out < alpha) ? *ptr_out : alpha) : 0.f;
ptr_out++;
}
}
} else {
LOG(FATAL) << "This act_type: "
<< static_cast<int>(act_param->active_type)
<< " doesn't support";
}
} else {
for (int j = 0; j < num; ++j) {
const float16_t *ptr_bias = bias;
float16_t *ptr_out = out + j * channel;

for (int i = 0; i < cnt; ++i) {
float16x8_t vin1 = vld1q_f16(ptr_out);
float16x8_t vb1 = vld1q_f16(ptr_bias);

float16x8_t vin2 = vld1q_f16(ptr_out + 8);
float16x8_t vb2 = vld1q_f16(ptr_bias + 8);

float16x8_t vin3 = vld1q_f16(ptr_out + 16);
float16x8_t vb3 = vld1q_f16(ptr_bias + 16);

float16x8_t vin4 = vld1q_f16(ptr_out + 24);
float16x8_t vb4 = vld1q_f16(ptr_bias + 24);

float16x8_t vout1 = vaddq_f16(vin1, vb1);
float16x8_t vout2 = vaddq_f16(vin2, vb2);
float16x8_t vout3 = vaddq_f16(vin3, vb3);
float16x8_t vout4 = vaddq_f16(vin4, vb4);

vst1q_f16(ptr_out, vout1);
vst1q_f16(ptr_out + 8, vout2);
vst1q_f16(ptr_out + 16, vout3);
vst1q_f16(ptr_out + 24, vout4);

ptr_out += 32;
ptr_bias += 32;
LOADA_DATA_32
STORE_32
}
for (int i = 0; i < cnt_num; i++) {
float16x8_t vin1 = vld1q_f16(ptr_out);
float16x8_t vb1 = vld1q_f16(ptr_bias);
float16x8_t vout1 = vaddq_f16(vin1, vb1);
vst1q_f16(ptr_out, vout1);
ptr_out += 8;
ptr_bias += 8;
LOADA_DATA_8
STORE_8
}
for (int i = 0; i < cnt_rem; ++i) {
*ptr_out += *(ptr_bias++);
Expand All @@ -129,6 +145,14 @@ void fill_bias_fc<float16_t>(float16_t *out,
}
}
}
#undef LOADA_DATA_32
#undef RELU_32
#undef RELU6_32
#undef STORE_32
#undef LOADA_DATA_8
#undef RELU_8
#undef RELU6_8
#undef STORE_8

} // namespace fp16
} // namespace math
Expand Down
7 changes: 5 additions & 2 deletions lite/backends/arm/math/fp16/funcs_fp16.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ namespace math {
namespace fp16 {

template <typename T>
void fill_bias_fc(
T* tensor, const T* bias, int num, int channel, bool flag_relu);
void fill_bias_fc(T* tensor,
const T* bias,
int num,
int channel,
const operators::ActivationParam* act_param);

// exp() computed for 8 float at once
inline float16x8_t expq_ps_f16(float16x8_t x) {
Expand Down
Loading

0 comments on commit ae52d13

Please sign in to comment.