From 397757178030bb0b64528e0f9f8eab80632c4fd2 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Thu, 31 Mar 2022 11:34:10 +0800 Subject: [PATCH 1/7] add fc+relu6 pass --- lite/backends/arm/math/fp16/funcs_fp16.cc | 186 ++++++------ lite/backends/arm/math/fp16/funcs_fp16.h | 7 +- lite/backends/arm/math/funcs.cc | 269 ++++++++---------- lite/backends/arm/math/funcs.h | 7 +- .../core/optimizer/mir/fusion/fc_fuse_pass.cc | 32 ++- lite/core/optimizer/mir/fusion/fc_fuser.cc | 41 ++- lite/core/optimizer/mir/fusion/fc_fuser.h | 6 +- lite/kernels/arm/fc_compute.cc | 49 ++-- lite/kernels/arm/rnn_compute.cc | 21 +- lite/operators/fc_op.cc | 3 + lite/operators/op_params.h | 1 + 11 files changed, 330 insertions(+), 292 deletions(-) mode change 100755 => 100644 lite/backends/arm/math/fp16/funcs_fp16.h diff --git a/lite/backends/arm/math/fp16/funcs_fp16.cc b/lite/backends/arm/math/fp16/funcs_fp16.cc index b4eeb9f758b..4c9c9280b90 100644 --- a/lite/backends/arm/math/fp16/funcs_fp16.cc +++ b/lite/backends/arm/math/fp16/funcs_fp16.cc @@ -21,106 +21,122 @@ namespace arm { namespace math { namespace fp16 { +#define LOADA_DATA_32 \ + float16x8_t vin1 = vld1q_f16(ptr_out); \ + float16x8_t vb1 = vld1q_f16(ptr_bias); \ + float16x8_t vin2 = vld1q_f16(ptr_out + 8); \ + float16x8_t vb2 = vld1q_f16(ptr_bias + 8); \ + float16x8_t vin3 = vld1q_f16(ptr_out + 16); \ + float16x8_t vb3 = vld1q_f16(ptr_bias + 16); \ + float16x8_t vin4 = vld1q_f16(ptr_out + 24); \ + float16x8_t vb4 = vld1q_f16(ptr_bias + 24); \ + float16x8_t vout1 = vaddq_f16(vin1, vb1); \ + float16x8_t vout2 = vaddq_f16(vin2, vb2); \ + float16x8_t vout3 = vaddq_f16(vin3, vb3); \ + float16x8_t vout4 = vaddq_f16(vin4, vb4); +#define RELU_32 \ + vout1 = vmaxq_f16(vout1, vzero); \ + vout2 = vmaxq_f16(vout2, vzero); \ + vout3 = vmaxq_f16(vout3, vzero); \ + vout4 = vmaxq_f16(vout4, vzero); +#define RELU6_32 \ + vout1 = vminq_f16(vout1, valpha); \ + vout2 = vminq_f16(vout2, valpha); \ + vout3 = vminq_f16(vout3, valpha); \ + vout4 = vminq_f16(vout4, valpha); +#define STORE_32 \ + vst1q_f16(ptr_out, vout1); \ + vst1q_f16(ptr_out + 8, vout2); \ + vst1q_f16(ptr_out + 16, vout3); \ + vst1q_f16(ptr_out + 24, vout4); \ + ptr_out += 32; \ + ptr_bias += 32; +#define LOADA_DATA_8 \ + float16x8_t vin1 = vld1q_f16(ptr_out); \ + float16x8_t vb1 = vld1q_f16(ptr_bias); \ + float16x8_t vout1 = vaddq_f16(vin1, vb1); +#define RELU_8 vout1 = vmaxq_f16(vout1, vzero); +#define RELU6_8 vout1 = vminq_f16(vout1, valpha); +#define STORE_8 \ + vst1q_f16(ptr_out, vout1); \ + ptr_out += 8; \ + ptr_bias += 8; + template <> void fill_bias_fc(float16_t *out, const float16_t *bias, int num, int channel, - bool flag_relu) { + const operators::ActivationParam *act_param) { int cnt = channel >> 5; int remain = channel & 31; int cnt_num = remain >> 3; int cnt_rem = remain & 7; - if (flag_relu) { - float16x8_t vzero = vdupq_n_f16(0.f); - for (int j = 0; j < num; ++j) { - const float16_t *ptr_bias = bias; - float16_t *ptr_out = out + j * channel; - - for (int i = 0; i < cnt; ++i) { - float16x8_t vin1 = vld1q_f16(ptr_out); - float16x8_t vb1 = vld1q_f16(ptr_bias); - - float16x8_t vin2 = vld1q_f16(ptr_out + 8); - float16x8_t vb2 = vld1q_f16(ptr_bias + 8); - - float16x8_t vin3 = vld1q_f16(ptr_out + 16); - float16x8_t vb3 = vld1q_f16(ptr_bias + 16); - - float16x8_t vin4 = vld1q_f16(ptr_out + 24); - float16x8_t vb4 = vld1q_f16(ptr_bias + 24); - - float16x8_t vout1 = vaddq_f16(vin1, vb1); - float16x8_t vout2 = vaddq_f16(vin2, vb2); - float16x8_t vout3 = vaddq_f16(vin3, vb3); - float16x8_t vout4 = vaddq_f16(vin4, vb4); - - vout1 = vmaxq_f16(vout1, vzero); - vout2 = vmaxq_f16(vout2, vzero); - vout3 = vmaxq_f16(vout3, vzero); - vout4 = vmaxq_f16(vout4, vzero); - - vst1q_f16(ptr_out, vout1); - vst1q_f16(ptr_out + 8, vout2); - vst1q_f16(ptr_out + 16, vout3); - vst1q_f16(ptr_out + 24, vout4); - - ptr_out += 32; - ptr_bias += 32; + if (act_param != nullptr && act_param->has_active) { + float32x4_t vzero = vdupq_n_f32(0.f); + if (act_param->active_type == lite_api::ActivationType::kRelu) { + for (int j = 0; j < num; ++j) { + const float16_t *ptr_bias = bias; + float16_t *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + LOADA_DATA_32 + RELU_32 + STORE_32 + } + for (int i = 0; i < cnt_num; i++) { + LOADA_DATA_8 + RELU_8 + STORE_8 + } + for (int i = 0; i < cnt_rem; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; + ptr_out++; + } } - for (int i = 0; i < cnt_num; i++) { - float16x8_t vin1 = vld1q_f16(ptr_out); - float16x8_t vb1 = vld1q_f16(ptr_bias); - float16x8_t vout1 = vaddq_f16(vin1, vb1); - vout1 = vmaxq_f16(vout1, vzero); - vst1q_f16(ptr_out, vout1); - ptr_out += 8; - ptr_bias += 8; - } - for (int i = 0; i < cnt_rem; ++i) { - *ptr_out += *(ptr_bias++); - *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; - ptr_out++; + } else if (act_param->active_type == lite_api::ActivationType::kRelu6) { + float alpha = act_param->Relu_clipped_coef; + float32x4_t valpha = vdupq_n_f32(act_param->Relu_clipped_coef); + for (int j = 0; j < num; ++j) { + const float16_t *ptr_bias = bias; + float16_t *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + LOADA_DATA_32 + RELU_32 + RELU6_32 + STORE_32 + } + for (int i = 0; i < cnt_num; i++) { + LOADA_DATA_8 + RELU_8 + RELU6_8 + STORE_8 + } + for (int i = 0; i < cnt_rem; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = + *ptr_out > 0.f ? ((*ptr_out < alpha) ? *ptr_out : alpha) : 0.f; + ptr_out++; + } } + } else { + LOG(FATAL) << "This act_type: " + << static_cast(act_param->active_type) + << " doesn't support"; } } else { for (int j = 0; j < num; ++j) { const float16_t *ptr_bias = bias; float16_t *ptr_out = out + j * channel; - for (int i = 0; i < cnt; ++i) { - float16x8_t vin1 = vld1q_f16(ptr_out); - float16x8_t vb1 = vld1q_f16(ptr_bias); - - float16x8_t vin2 = vld1q_f16(ptr_out + 8); - float16x8_t vb2 = vld1q_f16(ptr_bias + 8); - - float16x8_t vin3 = vld1q_f16(ptr_out + 16); - float16x8_t vb3 = vld1q_f16(ptr_bias + 16); - - float16x8_t vin4 = vld1q_f16(ptr_out + 24); - float16x8_t vb4 = vld1q_f16(ptr_bias + 24); - - float16x8_t vout1 = vaddq_f16(vin1, vb1); - float16x8_t vout2 = vaddq_f16(vin2, vb2); - float16x8_t vout3 = vaddq_f16(vin3, vb3); - float16x8_t vout4 = vaddq_f16(vin4, vb4); - - vst1q_f16(ptr_out, vout1); - vst1q_f16(ptr_out + 8, vout2); - vst1q_f16(ptr_out + 16, vout3); - vst1q_f16(ptr_out + 24, vout4); - - ptr_out += 32; - ptr_bias += 32; + LOADA_DATA_32 + STORE_32 } for (int i = 0; i < cnt_num; i++) { - float16x8_t vin1 = vld1q_f16(ptr_out); - float16x8_t vb1 = vld1q_f16(ptr_bias); - float16x8_t vout1 = vaddq_f16(vin1, vb1); - vst1q_f16(ptr_out, vout1); - ptr_out += 8; - ptr_bias += 8; + LOADA_DATA_8 + STORE_8 } for (int i = 0; i < cnt_rem; ++i) { *ptr_out += *(ptr_bias++); @@ -129,6 +145,14 @@ void fill_bias_fc(float16_t *out, } } } +#undef LOADA_DATA_32 +#undef RELU_32 +#undef RELU6_32 +#undef STORE_32 +#undef LOADA_DATA_8 +#undef RELU_8 +#undef RELU6_8 +#undef STORE_8 } // namespace fp16 } // namespace math diff --git a/lite/backends/arm/math/fp16/funcs_fp16.h b/lite/backends/arm/math/fp16/funcs_fp16.h old mode 100755 new mode 100644 index cd0c8af1aa2..2db932b1ac1 --- a/lite/backends/arm/math/fp16/funcs_fp16.h +++ b/lite/backends/arm/math/fp16/funcs_fp16.h @@ -44,8 +44,11 @@ namespace math { namespace fp16 { template -void fill_bias_fc( - T* tensor, const T* bias, int num, int channel, bool flag_relu); +void fill_bias_fc(T* tensor, + const T* bias, + int num, + int channel, + const operators::ActivationParam* act_param); // exp() computed for 8 float at once inline float16x8_t expq_ps_f16(float16x8_t x) { diff --git a/lite/backends/arm/math/funcs.cc b/lite/backends/arm/math/funcs.cc index 8d20e5242e5..6870e070f4a 100644 --- a/lite/backends/arm/math/funcs.cc +++ b/lite/backends/arm/math/funcs.cc @@ -20,53 +20,109 @@ namespace lite { namespace arm { namespace math { +#define LOADA_DATA_16 \ + float32x4_t vin1 = vld1q_f32(ptr_out); \ + float32x4_t vb1 = vld1q_f32(ptr_bias); \ + float32x4_t vin2 = vld1q_f32(ptr_out + 4); \ + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); \ + float32x4_t vin3 = vld1q_f32(ptr_out + 8); \ + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); \ + float32x4_t vin4 = vld1q_f32(ptr_out + 12); \ + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); \ + float32x4_t vout1 = vaddq_f32(vin1, vb1); \ + float32x4_t vout2 = vaddq_f32(vin2, vb2); \ + float32x4_t vout3 = vaddq_f32(vin3, vb3); \ + float32x4_t vout4 = vaddq_f32(vin4, vb4); +#define RELU_16 \ + vout1 = vmaxq_f32(vout1, vzero); \ + vout2 = vmaxq_f32(vout2, vzero); \ + vout3 = vmaxq_f32(vout3, vzero); \ + vout4 = vmaxq_f32(vout4, vzero); +#define RELU6_16 \ + vout1 = vminq_f32(vout1, valpha); \ + vout2 = vminq_f32(vout2, valpha); \ + vout3 = vminq_f32(vout3, valpha); \ + vout4 = vminq_f32(vout4, valpha); +#define STORE_16 \ + vst1q_f32(ptr_out, vout1); \ + vst1q_f32(ptr_out + 4, vout2); \ + vst1q_f32(ptr_out + 8, vout3); \ + vst1q_f32(ptr_out + 12, vout4); \ + ptr_out += 16; \ + ptr_bias += 16; +#define LOADA_DATA_4 \ + float32x4_t vin1 = vld1q_f32(ptr_out); \ + float32x4_t vb1 = vld1q_f32(ptr_bias); \ + float32x4_t vout1 = vaddq_f32(vin1, vb1); +#define RELU_4 vout1 = vmaxq_f32(vout1, vzero); +#define RELU6_4 vout1 = vminq_f32(vout1, valpha); +#define STORE_4 \ + vst1q_f32(ptr_out, vout1); \ + ptr_out += 4; \ + ptr_bias += 4; + template <> -void fill_bias_fc( - float *out, const float *bias, int num, int channel, bool flag_relu) { +void fill_bias_fc(float *out, + const float *bias, + int num, + int channel, + const operators::ActivationParam *act_param) { int cnt = channel >> 4; int remain = channel & 15; - if (flag_relu) { - float32x4_t vzero = vdupq_n_f32(0.f); - for (int j = 0; j < num; ++j) { - const float *ptr_bias = bias; - float *ptr_out = out + j * channel; - - for (int i = 0; i < cnt; ++i) { - float32x4_t vin1 = vld1q_f32(ptr_out); - float32x4_t vb1 = vld1q_f32(ptr_bias); + int cnt_num = remain >> 2; + int cnt_rem = remain & 3; - float32x4_t vin2 = vld1q_f32(ptr_out + 4); - float32x4_t vb2 = vld1q_f32(ptr_bias + 4); - - float32x4_t vin3 = vld1q_f32(ptr_out + 8); - float32x4_t vb3 = vld1q_f32(ptr_bias + 8); - - float32x4_t vin4 = vld1q_f32(ptr_out + 12); - float32x4_t vb4 = vld1q_f32(ptr_bias + 12); - - float32x4_t vout1 = vaddq_f32(vin1, vb1); - float32x4_t vout2 = vaddq_f32(vin2, vb2); - float32x4_t vout3 = vaddq_f32(vin3, vb3); - float32x4_t vout4 = vaddq_f32(vin4, vb4); - - vout1 = vmaxq_f32(vout1, vzero); - vout2 = vmaxq_f32(vout2, vzero); - vout3 = vmaxq_f32(vout3, vzero); - vout4 = vmaxq_f32(vout4, vzero); - - vst1q_f32(ptr_out, vout1); - vst1q_f32(ptr_out + 4, vout2); - vst1q_f32(ptr_out + 8, vout3); - vst1q_f32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; + if (act_param != nullptr && act_param->has_active) { + float32x4_t vzero = vdupq_n_f32(0.f); + if (act_param->active_type == lite_api::ActivationType::kRelu) { + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + for (int i = 0; i < cnt; ++i) { + LOADA_DATA_16 + RELU_16 + STORE_16 + } + for (int i = 0; i < cnt_num; ++i) { + LOADA_DATA_4 + RELU_4 + STORE_4 + } + for (int i = 0; i < cnt_rem; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; + ptr_out++; + } } - for (int i = 0; i < remain; ++i) { - *ptr_out += *(ptr_bias++); - *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; - ptr_out++; + } else if (act_param->active_type == lite_api::ActivationType::kRelu6) { + float alpha = act_param->Relu_clipped_coef; + float32x4_t valpha = vdupq_n_f32(act_param->Relu_clipped_coef); + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + for (int i = 0; i < cnt; ++i) { + LOADA_DATA_16 + RELU_16 + RELU6_16 + STORE_16 + } + for (int i = 0; i < cnt_num; ++i) { + LOADA_DATA_4 + RELU_4 + RELU6_4 + STORE_4 + } + for (int i = 0; i < cnt_rem; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = + *ptr_out > 0.f ? ((*ptr_out < alpha) ? *ptr_out : alpha) : 0.f; + ptr_out++; + } } + } else { + LOG(FATAL) << "This act_type: " + << static_cast(act_param->active_type) + << " doesn't support"; } } else { for (int j = 0; j < num; ++j) { @@ -74,130 +130,27 @@ void fill_bias_fc( float *ptr_out = out + j * channel; for (int i = 0; i < cnt; ++i) { - float32x4_t vin1 = vld1q_f32(ptr_out); - float32x4_t vb1 = vld1q_f32(ptr_bias); - - float32x4_t vin2 = vld1q_f32(ptr_out + 4); - float32x4_t vb2 = vld1q_f32(ptr_bias + 4); - - float32x4_t vin3 = vld1q_f32(ptr_out + 8); - float32x4_t vb3 = vld1q_f32(ptr_bias + 8); - - float32x4_t vin4 = vld1q_f32(ptr_out + 12); - float32x4_t vb4 = vld1q_f32(ptr_bias + 12); - - float32x4_t vout1 = vaddq_f32(vin1, vb1); - float32x4_t vout2 = vaddq_f32(vin2, vb2); - float32x4_t vout3 = vaddq_f32(vin3, vb3); - float32x4_t vout4 = vaddq_f32(vin4, vb4); - - vst1q_f32(ptr_out, vout1); - vst1q_f32(ptr_out + 4, vout2); - vst1q_f32(ptr_out + 8, vout3); - vst1q_f32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; + LOADA_DATA_16 + STORE_16 } - for (int i = 0; i < remain; ++i) { - *(ptr_out++) += *(ptr_bias++); + for (int i = 0; i < cnt_num; ++i) { + LOADA_DATA_4 + STORE_4 } - } - } -} - -template <> -void fill_bias_fc( - int *out, const int *bias, int num, int channel, bool flag_relu) { - int cnt = channel >> 4; - int remain = channel & 15; - if (flag_relu) { - for (int j = 0; j < num; ++j) { - const int *ptr_bias = bias; - int *ptr_out = out + j * channel; - - int32x4_t vzero = vdupq_n_s32(0); - - for (int i = 0; i < cnt; ++i) { - int32x4_t vin1 = vld1q_s32(ptr_out); - int32x4_t vb1 = vld1q_s32(ptr_bias); - - int32x4_t vin2 = vld1q_s32(ptr_out + 4); - int32x4_t vb2 = vld1q_s32(ptr_bias + 4); - - int32x4_t vin3 = vld1q_s32(ptr_out + 8); - int32x4_t vb3 = vld1q_s32(ptr_bias + 8); - - int32x4_t vin4 = vld1q_s32(ptr_out + 12); - int32x4_t vb4 = vld1q_s32(ptr_bias + 12); - - int32x4_t vout1 = vaddq_s32(vin1, vb1); - int32x4_t vout2 = vaddq_s32(vin2, vb2); - int32x4_t vout3 = vaddq_s32(vin3, vb3); - int32x4_t vout4 = vaddq_s32(vin4, vb4); - - vout1 = vmaxq_s32(vout1, vzero); - vout2 = vmaxq_s32(vout2, vzero); - vout3 = vmaxq_s32(vout3, vzero); - vout4 = vmaxq_s32(vout4, vzero); - - vst1q_s32(ptr_out, vout1); - vst1q_s32(ptr_out + 4, vout2); - vst1q_s32(ptr_out + 8, vout3); - vst1q_s32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; - } - for (int i = 0; i < remain; ++i) { - *ptr_out += *(ptr_bias++); - *ptr_out = *ptr_out > 0 ? *ptr_out : 0; - ptr_out++; - } - } - } else { - for (int j = 0; j < num; ++j) { - const int *ptr_bias = bias; - int *ptr_out = out + j * channel; - - int32x4_t vout1; - int32x4_t vout2; - int32x4_t vout3; - int32x4_t vout4; - - for (int i = 0; i < cnt; ++i) { - int32x4_t vin1 = vld1q_s32(ptr_out); - int32x4_t vb1 = vld1q_s32(ptr_bias); - - int32x4_t vin2 = vld1q_s32(ptr_out + 4); - int32x4_t vb2 = vld1q_s32(ptr_bias + 4); - - int32x4_t vin3 = vld1q_s32(ptr_out + 8); - int32x4_t vb3 = vld1q_s32(ptr_bias + 8); - - int32x4_t vin4 = vld1q_s32(ptr_out + 12); - int32x4_t vb4 = vld1q_s32(ptr_bias + 12); - - vout1 = vaddq_s32(vin1, vb1); - vout2 = vaddq_s32(vin2, vb2); - vout3 = vaddq_s32(vin3, vb3); - vout4 = vaddq_s32(vin4, vb4); - - vst1q_s32(ptr_out, vout1); - vst1q_s32(ptr_out + 4, vout2); - vst1q_s32(ptr_out + 8, vout3); - vst1q_s32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; - } - for (int i = 0; i < remain; ++i) { + for (int i = 0; i < cnt_rem; ++i) { *(ptr_out++) += *(ptr_bias++); } } } } - +#undef LOADA_DATA_16 +#undef RELU_16 +#undef RELU6_16 +#undef STORE_16 +#undef LOADA_DATA_4 +#undef RELU_4 +#undef RELU6_4 +#undef STORE_4 } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 2733d5984e7..6e4cea16b2b 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -391,8 +391,11 @@ inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { } template -void fill_bias_fc( - T* tensor, const T* bias, int num, int channel, bool flag_relu); +void fill_bias_fc(T* tensor, + const T* bias, + int num, + int channel, + const operators::ActivationParam* act_param); template inline float32x4_t vactive_f32(const float32x4_t& x) { diff --git a/lite/core/optimizer/mir/fusion/fc_fuse_pass.cc b/lite/core/optimizer/mir/fusion/fc_fuse_pass.cc index f96fe6c0277..ac6364284fe 100644 --- a/lite/core/optimizer/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/fc_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/core/optimizer/mir/fusion/fc_fuse_pass.h" +#include #include #include #include "lite/core/optimizer/mir/fusion/fc_fuser.h" @@ -24,17 +25,38 @@ namespace mir { void FcFusePass::Apply(const std::unique_ptr& graph) { std::vector mul_types{"mul"}; - std::vector act_types; + std::vector act_types; + bool has_int8 = false; + bool has_arm = false; + bool has_weight_quant = false; for (auto& place : graph->valid_places()) { if (place.target != TARGET(kMLU)) { - act_types.push_back(true); + act_types.push_back("relu"); } if (place.target == TARGET(kARM)) { - mul_types.push_back("matmul"); - mul_types.push_back("matmul_v2"); + has_arm = true; + act_types.push_back("relu6"); + if (place.precision == PRECISION(kInt8)) { + has_int8 = true; + } } } - act_types.push_back(false); + act_types.push_back(""); + const std::list& nodes = graph->nodes(); + for (auto& node : nodes) { + if (node.IsStmt()) { + auto* op_info = (node.stmt())->op_info(); + if (op_info->HasAttr("quantization_type")) { + has_weight_quant = true; + break; + } + } + } + if (!(has_int8 && has_weight_quant) && has_arm) { + // only support FP32/FP16 + mul_types.push_back("matmul"); + mul_types.push_back("matmul_v2"); + } for (auto op_type : mul_types) { for (auto act_type : act_types) { fusion::FcFuser fuser(op_type, act_type); diff --git a/lite/core/optimizer/mir/fusion/fc_fuser.cc b/lite/core/optimizer/mir/fusion/fc_fuser.cc index 071e28501d3..3410e7278f8 100644 --- a/lite/core/optimizer/mir/fusion/fc_fuser.cc +++ b/lite/core/optimizer/mir/fusion/fc_fuser.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/core/optimizer/mir/fusion/fc_fuser.h" +#include #include #include @@ -23,14 +24,13 @@ namespace fusion { void FcFuser::BuildPattern() { auto inputs_teller0 = [](const Node* node) -> bool { - return true; auto op_desc = *const_cast(node)->stmt()->op_info(); auto input_w_name = op_desc.Input("Y").front(); auto* scope = const_cast(node)->AsStmt().op()->scope(); auto w_shape = scope->FindVar(input_w_name)->Get().dims(); size_t w_rank = w_shape.size(); - - return w_rank == 2; + bool res = w_rank == 2; + return res; }; auto inputs_teller1 = [](const Node* node) -> bool { @@ -39,21 +39,31 @@ void FcFuser::BuildPattern() { auto* scope = const_cast(node)->AsStmt().op()->scope(); auto b_shape = scope->FindVar(input_b_name)->Get().dims(); size_t b_rank = b_shape.size(); - return b_rank == 2 || b_rank == 1; + auto res = (b_rank == 2 || b_rank == 1); + return res; }; auto input_attr_teller = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); bool trans_x = op_desc.GetAttr("transpose_X"); bool trans_y = op_desc.GetAttr("transpose_Y"); - return trans_x == false && trans_y == false; + // assert alpha = 1.0f + auto alpha = op_desc.GetAttr("alpha"); + bool has_alpha = (fabsf(alpha - 1.f) > 1e-8f); + auto res = (trans_x == false && trans_y == false && !has_alpha); + return res; }; auto input_attr_teller_v2 = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); bool trans_x = op_desc.GetAttr("trans_x"); bool trans_y = op_desc.GetAttr("trans_y"); - return trans_x == false && trans_y == false; + bool has_alpha = false; + if (op_desc.HasAttr("alpha")) { + auto alpha = op_desc.GetAttr("alpha"); + has_alpha = (fabsf(alpha - 1.f) > 1e-8f); + } + bool res = (trans_x == false && trans_y == false && !has_alpha); + return res; }; - // create nodes. auto* x = VarNode("x")->assert_is_op_input(op_type_, "X"); auto* W = VarNode("W")->assert_is_op_input(op_type_, "Y"); @@ -79,7 +89,7 @@ void FcFuser::BuildPattern() { mul->AsIntermediate(); add->AsIntermediate(); - if (with_relu_) { + if (act_type_ == "relu") { auto* add_out = VarNode("add_out"); auto* relu = OpNode("relu", "relu"); std::vector relu_inputs{add_out}; @@ -87,6 +97,14 @@ void FcFuser::BuildPattern() { relu_inputs >> *relu >> *Out; add_out->AsIntermediate(); relu->AsIntermediate(); + } else if (act_type_ == "relu6") { + auto* add_out = VarNode("add_out"); + auto* relu6 = OpNode("relu6", "relu6"); + std::vector relu6_inputs{add_out}; + add_inputs >> *add >> *add_out; + relu6_inputs >> *relu6 >> *Out; + add_out->AsIntermediate(); + relu6->AsIntermediate(); } else { add_inputs >> *add >> *Out; } @@ -155,8 +173,13 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { } op_desc.SetAttr("op_type", op_type_); - if (with_relu_) { + if (act_type_ == "relu") { op_desc.SetAttr("activation_type", std::string{"relu"}); + } else if (act_type_ == "relu6") { + op_desc.SetAttr("activation_type", std::string{"relu6"}); + auto relu6_desc = *matched.at("relu6")->stmt()->op_info(); + auto alpha = relu6_desc.GetAttr("threshold"); + op_desc.SetAttr("alpha", alpha); } // Set the input scale into fc diff --git a/lite/core/optimizer/mir/fusion/fc_fuser.h b/lite/core/optimizer/mir/fusion/fc_fuser.h index b952f6fd01b..3533366e735 100644 --- a/lite/core/optimizer/mir/fusion/fc_fuser.h +++ b/lite/core/optimizer/mir/fusion/fc_fuser.h @@ -25,15 +25,15 @@ namespace fusion { class FcFuser : public FuseBase { public: - explicit FcFuser(std::string op_type, bool with_relu) - : op_type_(op_type), with_relu_(with_relu) {} + explicit FcFuser(std::string op_type, std::string act_type) + : op_type_(op_type), act_type_(act_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; std::string op_type_; - bool with_relu_; + std::string act_type_; }; } // namespace fusion diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index bfbd6043e15..0224f14c45b 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -106,6 +106,9 @@ void FcCompute::ReInitWhenNeeded() { m_ = x_dims.Slice(0, in_num_col_dims).production(); k_ = x_dims.Slice(in_num_col_dims, x_dims.size()).production(); + // LOG(INFO) << "in_num_col_dims: " << param.in_num_col_dims << ", x_dims: " + // << x_dims; + // LOG(INFO) << "w_dims: " << w_dims; CHECK_EQ(k_, w_dims[0]); n_ = w_dims[1]; CHECK_EQ(k_, static_cast(w_dims[0])); @@ -205,17 +208,25 @@ void FcCompute::Run() { act_param, &ctx); if (param.bias) { - bool flag_act = false; if (param.activation_type == "relu") { - flag_act = true; + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; } CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_act); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, &act_param); } } else { if (param.activation_type == "relu") { + act_param.has_active = true; act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; } for (int i = 0; i < m_; ++i) { auto* i_data_batch = i_data + i * k_; @@ -248,13 +259,16 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } - bool flag_relu = false; operators::ActivationParam act_param; - lite_api::ActivationType act; + // lite_api::ActivationType act; act_param.has_active = false; if (param.activation_type == "relu") { - act = lite_api::ActivationType::kRelu; - flag_relu = true; + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; } if (flag_gemm_) { lite::arm::math::gemm_s8(false, @@ -272,7 +286,7 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, &act_param); } } else { for (int i = 0; i < m_; ++i) { @@ -306,16 +320,15 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } - // b_data = param.bias->data(); - bool flag_relu = false; operators::ActivationParam act_param; act_param.has_active = false; - lite_api::ActivationType act; if (param.activation_type == "relu") { - flag_relu = true; act_param.has_active = true; act_param.active_type = lite_api::ActivationType::kRelu; - act = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; } if (flag_gemm_) { CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " @@ -383,12 +396,14 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } - bool flag_act = false; operators::ActivationParam act_param; if (param.activation_type == "relu") { + act_param.has_active = true; act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { act_param.has_active = true; - flag_act = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; } if (flag_gemm_) { act_param.has_active = false; @@ -411,7 +426,7 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fp16::fill_bias_fc(o_data, b_data, m_, n_, flag_act); + lite::arm::math::fp16::fill_bias_fc(o_data, b_data, m_, n_, &act_param); } } else { for (int i = 0; i < m_; ++i) { @@ -426,7 +441,7 @@ void FcCompute::Run() { 0.f, param.bias != nullptr, b_data, - flag_act, + act_param.has_active, act_param, &ctx); } diff --git a/lite/kernels/arm/rnn_compute.cc b/lite/kernels/arm/rnn_compute.cc index e17ca410299..7d22d14fe39 100644 --- a/lite/kernels/arm/rnn_compute.cc +++ b/lite/kernels/arm/rnn_compute.cc @@ -114,7 +114,6 @@ static void preprocess(ARMContext* ctx, auto* i_data = input->data(); auto* w_data = weight.data(); auto* o_data = cache_input->mutable_data(); - bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; auto input_dims = input->dims(); @@ -140,7 +139,7 @@ static void preprocess(ARMContext* ctx, false, act_param, ctx); - lite::arm::math::fill_bias_fc(o_data, bias_ih.data(), m, n, flag_act); + lite::arm::math::fill_bias_fc(o_data, bias_ih.data(), m, n, nullptr); if ("GRU" == mode) { Tensor bias_tmp_hh; @@ -152,10 +151,9 @@ static void preprocess(ARMContext* ctx, std::memset( bias_ptr + bias_offt, 0, (bias_hh.numel() - bias_offt) * sizeof(float)); lite::arm::math::fill_bias_fc( - o_data, bias_tmp_hh.data(), m, n, flag_act); + o_data, bias_tmp_hh.data(), m, n, nullptr); } else { - lite::arm::math::fill_bias_fc( - o_data, bias_hh.data(), m, n, flag_act); + lite::arm::math::fill_bias_fc(o_data, bias_hh.data(), m, n, nullptr); } } @@ -308,7 +306,6 @@ static void lstm_cell(ARMContext* ctx, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh) { - bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; auto h_dims = init_h->dims(); @@ -395,7 +392,6 @@ static void gru_cell(ARMContext* ctx, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) { - bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; auto h_dims = init_h->dims(); @@ -743,10 +739,10 @@ void RnnCompute::Run() { last_h_unbind[i].Resize(dims); init_h_unbind_t.push_back(&init_h_unbind[i]); last_h_unbind_t.push_back(&last_h_unbind[i]); - last_h_unbind[i].mutable_data(); } lite::host::math::split( pre_state[0]->data(), init_h_unbind_t, 0, stride1); + lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride1); if ("LSTM" == mode) { for (int i = 0; i < pre_state[1]->dims()[0]; i++) { @@ -758,10 +754,11 @@ void RnnCompute::Run() { last_c_unbind[i].Resize(dims); init_c_unbind_t.push_back(&init_c_unbind[i]); last_c_unbind_t.push_back(&last_c_unbind[i]); - last_c_unbind[i].mutable_data(); } lite::host::math::split( pre_state[1]->data(), init_c_unbind_t, 0, stride2); + lite::host::math::split( + state[1]->data(), last_c_unbind_t, 0, stride2); } std::vector output_vec(2); @@ -800,12 +797,6 @@ void RnnCompute::Run() { RUN_RNN_LAYER(i, output_holder, false, 0); } } - - lite::arm::math::concat_func(last_h_unbind_t, 0, state[0]); - if ("LSTM" == mode) { - lite::arm::math::concat_func(last_c_unbind_t, 0, state[1]); - } - // output_holder != output if (num_layers % 2 == 0) { output->CopyDataFrom(*output_holder); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 01d88217548..6dd08f41027 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -126,6 +126,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { param_.Prelu_alpha = const_cast(&(prelu_alpha_var->Get())); } + if (param_.activation_type == "relu6") { + param_.alpha = op_desc.GetAttr("alpha"); + } // For Int8 const OpInfo* op_info = static_cast(&op_desc); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 4be3c9ee09e..510e90c9b25 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -106,6 +106,7 @@ struct FcParam : ParamBase { std::string Prelu_mode{ "channel"}; // prelu param, can be "all", "channel" or "element" std::string op_type{"mul"}; + float alpha{6.f}; // for int8 WITH_INT8_CONFIG }; From cb63e73cb52ee76eb727a1907c98175a72be31d2 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Fri, 1 Apr 2022 16:59:45 +0800 Subject: [PATCH 2/7] fix fc+relu6 test --- lite/backends/arm/math/fp16/funcs_fp16.cc | 6 +++--- lite/kernels/arm/fc_compute.cc | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lite/backends/arm/math/fp16/funcs_fp16.cc b/lite/backends/arm/math/fp16/funcs_fp16.cc index 4c9c9280b90..d641be01997 100644 --- a/lite/backends/arm/math/fp16/funcs_fp16.cc +++ b/lite/backends/arm/math/fp16/funcs_fp16.cc @@ -73,7 +73,7 @@ void fill_bias_fc(float16_t *out, int cnt_num = remain >> 3; int cnt_rem = remain & 7; if (act_param != nullptr && act_param->has_active) { - float32x4_t vzero = vdupq_n_f32(0.f); + float16x8_t vzero = vdupq_n_f16(0.f); if (act_param->active_type == lite_api::ActivationType::kRelu) { for (int j = 0; j < num; ++j) { const float16_t *ptr_bias = bias; @@ -96,8 +96,8 @@ void fill_bias_fc(float16_t *out, } } } else if (act_param->active_type == lite_api::ActivationType::kRelu6) { - float alpha = act_param->Relu_clipped_coef; - float32x4_t valpha = vdupq_n_f32(act_param->Relu_clipped_coef); + float16_t alpha = static_cast(act_param->Relu_clipped_coef); + float16x8_t valpha = vdupq_n_f16(alpha); for (int j = 0; j < num; ++j) { const float16_t *ptr_bias = bias; float16_t *ptr_out = out + j * channel; diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 0224f14c45b..3ddd8cccc05 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -397,14 +397,6 @@ void FcCompute::Run() { b_data = bias_.data(); } operators::ActivationParam act_param; - if (param.activation_type == "relu") { - act_param.has_active = true; - act_param.active_type = lite_api::ActivationType::kRelu; - } else if (param.activation_type == "relu6") { - act_param.has_active = true; - act_param.active_type = lite_api::ActivationType::kRelu6; - act_param.Relu_clipped_coef = param.alpha; - } if (flag_gemm_) { act_param.has_active = false; lite::arm::math::fp16::sgemm_fp16(false, @@ -426,6 +418,14 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); + if (param.activation_type == "relu") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; + } lite::arm::math::fp16::fill_bias_fc(o_data, b_data, m_, n_, &act_param); } } else { From fe496868d397c3314e33da071a1b375593f2caee Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Fri, 1 Apr 2022 18:13:17 +0800 Subject: [PATCH 3/7] fix fc+relu6 error --- lite/kernels/arm/fc_compute.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 3ddd8cccc05..c2d756f2b24 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -429,6 +429,14 @@ void FcCompute::Run() { lite::arm::math::fp16::fill_bias_fc(o_data, b_data, m_, n_, &act_param); } } else { + if (param.activation_type == "relu") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; + } else if (param.activation_type == "relu6") { + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu6; + act_param.Relu_clipped_coef = param.alpha; + } for (int i = 0; i < m_; ++i) { auto* i_data_batch = i_data + i * k_; auto* o_data_batch = o_data + i * n_; From cfd68f0bd0e120f2e2fb084375ba2d999d417bd1 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Fri, 1 Apr 2022 18:51:51 +0800 Subject: [PATCH 4/7] add fc+relu6 ut --- lite/tests/unittest_py/op/test_fc_op.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/lite/tests/unittest_py/op/test_fc_op.py b/lite/tests/unittest_py/op/test_fc_op.py index 051f37e9f0b..04f3e0028f2 100644 --- a/lite/tests/unittest_py/op/test_fc_op.py +++ b/lite/tests/unittest_py/op/test_fc_op.py @@ -104,7 +104,7 @@ def generate_bias(*args, **kwargs): act_type = "" if (with_bias and random.random() > 0.5): - act_type = "relu" + act_type = draw(st.sampled_from(["relu", "relu6"])) op_inputs = {} program_inputs = {} @@ -137,6 +137,7 @@ def generate_bias(*args, **kwargs): "in_num_col_dims": in_num_col_dims, "activation_type": act_type, "use_mkldnn": False, + "alpha": 6.0, "padding_weights": padding_weights, "use_quantizer": False, "Scale_in": float(1), @@ -154,7 +155,17 @@ def sample_predictor_configs(self): return self.get_predictor_configs(), ["fc"], (1e-5, 1e-5) def add_ignore_pass_case(self): - pass + def _teller1(program_config, predictor_config): + target_type = predictor_config.target() + act_type = program_config.ops[0].attrs["activation_type"] + if act_type == "relu6": + if target_type == TargetType.Metal or target_type == TargetType.OpenCL or target_type == TargetType.X86: + return True + + self.add_ignore_check_case( + _teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT, + "Opencl/Metal/X86 doesn't support fc+relu6 compute, we will fix as soon as possible." + ) def test(self, *args, **kwargs): self.run_and_statis(quant=False, max_examples=300) From f9c2059461b6da7c4d8f68dad275c023a7d14636 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Fri, 1 Apr 2022 19:10:56 +0800 Subject: [PATCH 5/7] fix doc --- docs/quick_start/faq.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/quick_start/faq.md b/docs/quick_start/faq.md index a1f5b08f870..dd3a74d6e91 100644 --- a/docs/quick_start/faq.md +++ b/docs/quick_start/faq.md @@ -11,7 +11,8 @@ 答:更换当前 Paddle Lite 预测库为带 `with_extra = ON` 标签的预编译库。 3、ARM CPU 端多线程支持情况,某些case下,多线程没有效果? -答:gcc 编译模式下,V7/V8 多线程均支持;clang 编译模式下,V8 支持多线程,V7 只能跑单线程 + +答:gcc 编译模式下,V7/V8 多线程均支持;clang 编译模式下,V8 支持多线程,V7 只支持单线程 ### 模型转换 From 8fffdf9a6e88980308b2c33a468ae2e394bcc070 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Sat, 2 Apr 2022 18:04:31 +0800 Subject: [PATCH 6/7] fix relu6 ut --- lite/tests/unittest_py/op/test_fc_op.py | 15 ++------------- lite/tests/unittest_py/pass/test_fc_fuse_pass.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/lite/tests/unittest_py/op/test_fc_op.py b/lite/tests/unittest_py/op/test_fc_op.py index 04f3e0028f2..051f37e9f0b 100644 --- a/lite/tests/unittest_py/op/test_fc_op.py +++ b/lite/tests/unittest_py/op/test_fc_op.py @@ -104,7 +104,7 @@ def generate_bias(*args, **kwargs): act_type = "" if (with_bias and random.random() > 0.5): - act_type = draw(st.sampled_from(["relu", "relu6"])) + act_type = "relu" op_inputs = {} program_inputs = {} @@ -137,7 +137,6 @@ def generate_bias(*args, **kwargs): "in_num_col_dims": in_num_col_dims, "activation_type": act_type, "use_mkldnn": False, - "alpha": 6.0, "padding_weights": padding_weights, "use_quantizer": False, "Scale_in": float(1), @@ -155,17 +154,7 @@ def sample_predictor_configs(self): return self.get_predictor_configs(), ["fc"], (1e-5, 1e-5) def add_ignore_pass_case(self): - def _teller1(program_config, predictor_config): - target_type = predictor_config.target() - act_type = program_config.ops[0].attrs["activation_type"] - if act_type == "relu6": - if target_type == TargetType.Metal or target_type == TargetType.OpenCL or target_type == TargetType.X86: - return True - - self.add_ignore_check_case( - _teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT, - "Opencl/Metal/X86 doesn't support fc+relu6 compute, we will fix as soon as possible." - ) + pass def test(self, *args, **kwargs): self.run_and_statis(quant=False, max_examples=300) diff --git a/lite/tests/unittest_py/pass/test_fc_fuse_pass.py b/lite/tests/unittest_py/pass/test_fc_fuse_pass.py index ac78fbfad96..be9094b7690 100644 --- a/lite/tests/unittest_py/pass/test_fc_fuse_pass.py +++ b/lite/tests/unittest_py/pass/test_fc_fuse_pass.py @@ -61,7 +61,7 @@ def is_program_valid(self, return True def sample_program_configs(self, draw): - has_relu = draw(st.sampled_from([True, False])) + act_type = draw(st.sampled_from(["", "relu", "relu6"])) op_type = draw(st.sampled_from(["mul", "matmul", "matmul_v2"])) mul_x_in_shape = draw( st.lists( @@ -152,15 +152,18 @@ def sample_program_configs(self, draw): outputs={"Out": ["elementwise_add_output_data"]}, attrs={"axis": axis}) + act_attrs = {} + if act_type == "relu6": + act_attrs = {"threshold": 6.0, } active_op = OpConfig( - type="relu", + type=act_type, inputs={"X": ["elementwise_add_output_data"]}, outputs={"Out": ["output_data"]}, - attrs={}) + attrs=act_attrs) ops = [mul_op, elementwise_add_op] output_data = "elementwise_add_output_data" - if has_relu: + if act_type == "relu" or act_type == "relu6": ops.append(active_op) output_data = "output_data" program_config = ProgramConfig( @@ -178,9 +181,12 @@ def add_ignore_pass_case(self): def _teller1(program_config, predictor_config): target_type = predictor_config.target() op_type = program_config.ops[0].type + act_type = program_config.ops[2].type if target_type == TargetType.X86: if op_type == "matmul" or op_type == "matmul_v2": return True + if act_type == "relu6": + return True self.add_ignore_check_case( _teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT, From 6bbf36de578142549801592100e026a42071a832 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Wed, 6 Apr 2022 14:06:06 +0800 Subject: [PATCH 7/7] Update test_fc_fuse_pass.py --- lite/tests/unittest_py/pass/test_fc_fuse_pass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lite/tests/unittest_py/pass/test_fc_fuse_pass.py b/lite/tests/unittest_py/pass/test_fc_fuse_pass.py index be9094b7690..11559df9eb8 100644 --- a/lite/tests/unittest_py/pass/test_fc_fuse_pass.py +++ b/lite/tests/unittest_py/pass/test_fc_fuse_pass.py @@ -181,12 +181,13 @@ def add_ignore_pass_case(self): def _teller1(program_config, predictor_config): target_type = predictor_config.target() op_type = program_config.ops[0].type - act_type = program_config.ops[2].type if target_type == TargetType.X86: if op_type == "matmul" or op_type == "matmul_v2": return True - if act_type == "relu6": - return True + if len(program_config.ops) > 2: + act_type = program_config.ops[2].type + if act_type == "relu6": + return True self.add_ignore_check_case( _teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT,