diff --git a/lite/backends/arm/math/fp16/elementwise_fp16.cc b/lite/backends/arm/math/fp16/elementwise_fp16.cc index 1e0c1ddc117..133ba5369d8 100644 --- a/lite/backends/arm/math/fp16/elementwise_fp16.cc +++ b/lite/backends/arm/math/fp16/elementwise_fp16.cc @@ -189,8 +189,6 @@ namespace fp16 { : "cc", \ "memory", \ ASM_VAR); - - #else #define INIT_1 \ "vld1.16 {d0-d1}, [%[dinx_ptr]]! \n" \ @@ -260,10 +258,10 @@ namespace fp16 { #define SIMPLE_COMPUTE_TYPE(op) \ asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) STORE \ + : [dinx_ptr] "+r"(dinx_ptr), \ + [diny_ptr] "+r"(diny_ptr), \ + [dout_ptr] "+r"(dout_ptr) \ : \ - : [dinx_ptr] "r"(dinx_ptr), \ - [diny_ptr] "r"(diny_ptr), \ - [dout_ptr] "r"(dout_ptr) \ : "cc", \ "memory", \ ASM_VAR); @@ -281,11 +279,10 @@ namespace fp16 { #define SIMPLE_COMPUTE_TYPE_RELU(op) \ asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) RELU STORE \ - : \ - : [dinx_ptr] "r"(dinx_ptr), \ - [diny_ptr] "r"(diny_ptr), \ - [dout_ptr] "r"(dout_ptr), \ - [vzero] "w"(vzero) \ + : [dinx_ptr] "+r"(dinx_ptr), \ + [diny_ptr] "+r"(diny_ptr), \ + [dout_ptr] "+r"(dout_ptr) \ + : [vzero] "w"(vzero) \ : "cc", \ "memory", \ ASM_VAR); @@ -303,10 +300,9 @@ namespace fp16 { #define SIMPLE_COMPUTE_TYPE_BROADCAST(op) \ asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) STORE \ - : \ - : [dinx_ptr] "r"(dinx_ptr_1), \ - [dout_ptr] "r"(dout_ptr_1), \ - [val_y] "w"(val_y) \ + : [dinx_ptr] "+r"(dinx_ptr_1), \ + [dout_ptr] "+r"(dout_ptr_1) \ + : [val_y] "w"(val_y) \ : "cc", \ "memory", \ ASM_VAR); @@ -323,17 +319,16 @@ namespace fp16 { #define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU(op) \ asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) RELU STORE \ - : \ - : [dinx_ptr] "r"(dinx_ptr_1), \ - [dout_ptr] "r"(dout_ptr_1), \ - [val_y] "w"(val_y), \ + : [dinx_ptr] "+r"(dinx_ptr_1), \ + [dout_ptr] "+r"(dout_ptr_1) \ + : [val_y] "w"(val_y), \ [vzero] "w"(vzero) \ : "cc", \ "memory", \ ASM_VAR); #define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU_1(op) \ - asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU STORE_1 \ + asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU_1 STORE_1 \ : [cnt_num] "+r"(cnt_num), \ [dinx_ptr] "+r"(dinx_ptr_1), \ [dout_ptr] "+r"(dout_ptr_1) \ @@ -352,7 +347,6 @@ namespace fp16 { float16_t* dout, \ int num) { \ LOOP_CNT(num) \ - \ for (int i = 0; i < cnt; i++) { \ int stride = i << 5; \ const float16_t* dinx_ptr = dinx + stride; \ @@ -517,6 +511,110 @@ elmentwise_simple_compute(mul); elmentwise_simple_compute(sub); #ifdef __aarch64__ elmentwise_simple_compute(div); +#else +void elementwise_div(const float16_t* dinx, + const float16_t* diny, + float16_t* dout, + int num) { + LOOP_CNT(num) + for (int i = 0; i < cnt; i++) { + int stride = i << 5; + const float16_t* dinx_ptr = dinx + stride; + const float16_t* diny_ptr = diny + stride; + float16_t* dout_ptr = dout + stride; + float16x8_t vec_a1 = vld1q_f16(dinx_ptr); + float16x8_t vec_a2 = vld1q_f16(dinx_ptr + 8); + float16x8_t vec_b1 = vld1q_f16(diny_ptr); + float16x8_t vec_b2 = vld1q_f16(diny_ptr + 8); + vst1q_f16(dout_ptr, divq_ps_f16(vec_a1, vec_b1)); + vst1q_f16(dout_ptr + 8, divq_ps_f16(vec_a2, vec_b2)); + vec_a1 = vld1q_f16(dinx_ptr + 16); + vec_a2 = vld1q_f16(dinx_ptr + 24); + vec_b1 = vld1q_f16(diny_ptr + 16); + vec_b2 = vld1q_f16(diny_ptr + 24); + vst1q_f16(dout_ptr + 16, divq_ps_f16(vec_a1, vec_b1)); + vst1q_f16(dout_ptr + 24, divq_ps_f16(vec_a2, vec_b2)); + } + int stride = cnt << 5; + if (rem_cnt > 0) { + const float16_t* dinx_ptr = dinx + stride; + const float16_t* diny_ptr = diny + stride; + float16_t* dout_ptr = dout + stride; + int cnt_num = rem_cnt; + for (int loop = 0; loop < rem_cnt; loop++) { + float16x8_t vec_a1 = vld1q_f16(dinx_ptr + loop * 8); + float16x8_t vec_b1 = vld1q_f16(diny_ptr + loop * 8); + vst1q_f16(dout_ptr + loop * 8, divq_ps_f16(vec_a1, vec_b1)); + } + } + if (rem_rem > 0) { + stride += (rem_cnt << 3); + const float16_t* dinx_ptr = dinx + stride; + const float16_t* diny_ptr = diny + stride; + float16_t* dout_ptr = dout + stride; + for (int i = 0; i < rem_rem; i++) { + *dout_ptr = naive_div(*dinx_ptr, *diny_ptr); + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +void elementwise_div_broadcast(const float16_t* dinx, + const float16_t* diny, + float16_t* dout, + int batch, + int channels, + int num) { + OMP_PARA_INTERNAL_COLLASPE_2 + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const auto* dinx_ptr = dinx + offset; + const auto* diny_ptr = diny + j; + auto* dout_ptr = dout + offset; + LOOP_CNT(num) + for (int k = 0; k < cnt; k++) { + int stride = k << 5; + const float16_t* dinx_ptr_1 = dinx_ptr + stride; + float16_t* dout_ptr_1 = dout_ptr + stride; + float16x8_t val_y = vdupq_n_f16(diny_ptr[0]); + float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1); + float16x8_t vec_x2 = vld1q_f16(dinx_ptr_1 + 8); + vst1q_f16(dout_ptr_1, divq_ps_f16(vec_x1, val_y)); + vst1q_f16(dout_ptr_1 + 8, divq_ps_f16(vec_x2, val_y)); + vec_x1 = vld1q_f16(dinx_ptr_1 + 16); + vec_x2 = vld1q_f16(dinx_ptr_1 + 24); + vst1q_f16(dout_ptr_1 + 16, divq_ps_f16(vec_x1, val_y)); + vst1q_f16(dout_ptr_1 + 24, divq_ps_f16(vec_x2, val_y)); + } + int stride = cnt << 5; + if (rem_cnt > 0) { + const float16_t* dinx_ptr_1 = dinx_ptr + stride; + float16_t* dout_ptr_1 = dout_ptr + stride; + float16x8_t val_y = vdupq_n_f16(diny_ptr[0]); + int cnt_num = rem_cnt; + for (int loop = 0; loop < rem_cnt; loop++) { + float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1 + loop * 8); + vst1q_f16(dout_ptr_1 + loop * 8, divq_ps_f16(vec_x1, val_y)); + } + } + if (rem_rem > 0) { + stride += (rem_cnt << 3); + const float16_t* dinx_ptr_1 = dinx_ptr + stride; + float16_t* dout_ptr_1 = dout_ptr + stride; + float16_t val = diny_ptr[0]; + for (int i = 0; i < rem_rem; i++) { + *dout_ptr_1 = naive_div(*dinx_ptr_1, val); + dinx_ptr_1++; + dout_ptr_1++; + } + } + } + } +} + #endif } // namespace fp16 } // namespace math diff --git a/lite/backends/arm/math/fp16/elementwise_fp16.h b/lite/backends/arm/math/fp16/elementwise_fp16.h index 561b96ef9b7..7cd9f3f1acd 100644 --- a/lite/backends/arm/math/fp16/elementwise_fp16.h +++ b/lite/backends/arm/math/fp16/elementwise_fp16.h @@ -48,7 +48,21 @@ typedef __fp16 float16_t; elementwise_simple_compute_declare(add); elementwise_simple_compute_declare(mul); elementwise_simple_compute_declare(sub); +#ifdef __aarch64__ elementwise_simple_compute_declare(div); +#else +void elementwise_div(const float16_t* dinx, + const float16_t* diny, + float16_t* dout, + int num); + +void elementwise_div_broadcast(const float16_t* dinx, + const float16_t* diny, + float16_t* dout, + int batch, + int channels, + int num); +#endif } // namespace fp16 } // namespace math diff --git a/lite/backends/arm/math/fp16/type_trans_fp16.cc b/lite/backends/arm/math/fp16/type_trans_fp16.cc index 587cf5900bc..180f5af59bf 100644 --- a/lite/backends/arm/math/fp16/type_trans_fp16.cc +++ b/lite/backends/arm/math/fp16/type_trans_fp16.cc @@ -304,15 +304,6 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) { "vst1.32 {d16-d17}, [%[out]]!\n" "bne 4b\n" "2: \n" - "cmp %[remain_remain], #1\n" - "blt 3f\n" - "5: \n" - "vld1.16 d0[0], [%[in]]!\n" - "subs %[remain_remain], #1\n" - "vcvt.f16.f32 d16, q0\n" - "vst1.32 d16[0], [%[out]]!\n" - "bne 5b\n" - "3: \n" : [in] "+r"(in), [out] "+r"(out), [cnt] "+r"(cnt), @@ -333,6 +324,11 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) { "q9", "q10", "q11"); + for (int i = 0; i < remain_remain; i++) { + *out = static_cast(*in); + out++; + in++; + } #endif } } // namespace fp16 diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index e1742e04929..8fd2f6af30b 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -452,12 +452,18 @@ void ElementwiseDivCompute::Run() { OprandSwapable::NO, arm_math::NullNeonConfig>( this, - lite::arm::math::fp16::elementwise_div_broadcast, lite::arm::math::fp16::elementwise_div, paddle::lite::kernels::host::naive_div); #else - LOG(FATAL) << "it doesn't support v7 fp16 elementwise_div compute"; + elementwise_compute_template( + this, + lite::arm::math::fp16::elementwise_div_broadcast, + lite::arm::math::fp16::elementwise_div, + paddle::lite::kernels::host::naive_div); #endif }