Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARMv7] add elementwise_div_fp16 && fix elementwise_fp16 bug #10050

Merged
merged 2 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 118 additions & 20 deletions lite/backends/arm/math/fp16/elementwise_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ namespace fp16 {
: "cc", \
"memory", \
ASM_VAR);


#else
#define INIT_1 \
"vld1.16 {d0-d1}, [%[dinx_ptr]]! \n" \
Expand Down Expand Up @@ -260,10 +258,10 @@ namespace fp16 {

#define SIMPLE_COMPUTE_TYPE(op) \
asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) STORE \
: [dinx_ptr] "+r"(dinx_ptr), \
[diny_ptr] "+r"(diny_ptr), \
[dout_ptr] "+r"(dout_ptr) \
: \
: [dinx_ptr] "r"(dinx_ptr), \
[diny_ptr] "r"(diny_ptr), \
[dout_ptr] "r"(dout_ptr) \
: "cc", \
"memory", \
ASM_VAR);
Expand All @@ -281,11 +279,10 @@ namespace fp16 {

#define SIMPLE_COMPUTE_TYPE_RELU(op) \
asm volatile(INIT SIMPLE_COMPUTE(v##op.f16) RELU STORE \
: \
: [dinx_ptr] "r"(dinx_ptr), \
[diny_ptr] "r"(diny_ptr), \
[dout_ptr] "r"(dout_ptr), \
[vzero] "w"(vzero) \
: [dinx_ptr] "+r"(dinx_ptr), \
[diny_ptr] "+r"(diny_ptr), \
[dout_ptr] "+r"(dout_ptr) \
: [vzero] "w"(vzero) \
: "cc", \
"memory", \
ASM_VAR);
Expand All @@ -303,10 +300,9 @@ namespace fp16 {

#define SIMPLE_COMPUTE_TYPE_BROADCAST(op) \
asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) STORE \
: \
: [dinx_ptr] "r"(dinx_ptr_1), \
[dout_ptr] "r"(dout_ptr_1), \
[val_y] "w"(val_y) \
: [dinx_ptr] "+r"(dinx_ptr_1), \
[dout_ptr] "+r"(dout_ptr_1) \
: [val_y] "w"(val_y) \
: "cc", \
"memory", \
ASM_VAR);
Expand All @@ -323,17 +319,16 @@ namespace fp16 {

#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU(op) \
asm volatile(INIT_BROADCAST SIMPLE_COMPUTE_BROADCAST(v##op.f16) RELU STORE \
: \
: [dinx_ptr] "r"(dinx_ptr_1), \
[dout_ptr] "r"(dout_ptr_1), \
[val_y] "w"(val_y), \
: [dinx_ptr] "+r"(dinx_ptr_1), \
[dout_ptr] "+r"(dout_ptr_1) \
: [val_y] "w"(val_y), \
[vzero] "w"(vzero) \
: "cc", \
"memory", \
ASM_VAR);

#define SIMPLE_COMPUTE_TYPE_BROADCAST_RELU_1(op) \
asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU STORE_1 \
asm volatile(INIT_1_BROADCAST SIMPLE_COMPUTE_1_BROADCAST(v##op.f16) RELU_1 STORE_1 \
: [cnt_num] "+r"(cnt_num), \
[dinx_ptr] "+r"(dinx_ptr_1), \
[dout_ptr] "+r"(dout_ptr_1) \
Expand All @@ -352,7 +347,6 @@ namespace fp16 {
float16_t* dout, \
int num) { \
LOOP_CNT(num) \
\
for (int i = 0; i < cnt; i++) { \
int stride = i << 5; \
const float16_t* dinx_ptr = dinx + stride; \
Expand Down Expand Up @@ -517,6 +511,110 @@ elmentwise_simple_compute(mul);
elmentwise_simple_compute(sub);
#ifdef __aarch64__
elmentwise_simple_compute(div);
#else
void elementwise_div(const float16_t* dinx,
const float16_t* diny,
float16_t* dout,
int num) {
LOOP_CNT(num)
for (int i = 0; i < cnt; i++) {
int stride = i << 5;
const float16_t* dinx_ptr = dinx + stride;
const float16_t* diny_ptr = diny + stride;
float16_t* dout_ptr = dout + stride;
float16x8_t vec_a1 = vld1q_f16(dinx_ptr);
float16x8_t vec_a2 = vld1q_f16(dinx_ptr + 8);
float16x8_t vec_b1 = vld1q_f16(diny_ptr);
float16x8_t vec_b2 = vld1q_f16(diny_ptr + 8);
vst1q_f16(dout_ptr, divq_ps_f16(vec_a1, vec_b1));
vst1q_f16(dout_ptr + 8, divq_ps_f16(vec_a2, vec_b2));
vec_a1 = vld1q_f16(dinx_ptr + 16);
vec_a2 = vld1q_f16(dinx_ptr + 24);
vec_b1 = vld1q_f16(diny_ptr + 16);
vec_b2 = vld1q_f16(diny_ptr + 24);
vst1q_f16(dout_ptr + 16, divq_ps_f16(vec_a1, vec_b1));
vst1q_f16(dout_ptr + 24, divq_ps_f16(vec_a2, vec_b2));
}
int stride = cnt << 5;
if (rem_cnt > 0) {
const float16_t* dinx_ptr = dinx + stride;
const float16_t* diny_ptr = diny + stride;
float16_t* dout_ptr = dout + stride;
int cnt_num = rem_cnt;
for (int loop = 0; loop < rem_cnt; loop++) {
float16x8_t vec_a1 = vld1q_f16(dinx_ptr + loop * 8);
float16x8_t vec_b1 = vld1q_f16(diny_ptr + loop * 8);
vst1q_f16(dout_ptr + loop * 8, divq_ps_f16(vec_a1, vec_b1));
}
}
if (rem_rem > 0) {
stride += (rem_cnt << 3);
const float16_t* dinx_ptr = dinx + stride;
const float16_t* diny_ptr = diny + stride;
float16_t* dout_ptr = dout + stride;
for (int i = 0; i < rem_rem; i++) {
*dout_ptr = naive_div(*dinx_ptr, *diny_ptr);
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}

void elementwise_div_broadcast(const float16_t* dinx,
const float16_t* diny,
float16_t* dout,
int batch,
int channels,
int num) {
OMP_PARA_INTERNAL_COLLASPE_2
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const auto* dinx_ptr = dinx + offset;
const auto* diny_ptr = diny + j;
auto* dout_ptr = dout + offset;
LOOP_CNT(num)
for (int k = 0; k < cnt; k++) {
int stride = k << 5;
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
float16_t* dout_ptr_1 = dout_ptr + stride;
float16x8_t val_y = vdupq_n_f16(diny_ptr[0]);
float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1);
float16x8_t vec_x2 = vld1q_f16(dinx_ptr_1 + 8);
vst1q_f16(dout_ptr_1, divq_ps_f16(vec_x1, val_y));
vst1q_f16(dout_ptr_1 + 8, divq_ps_f16(vec_x2, val_y));
vec_x1 = vld1q_f16(dinx_ptr_1 + 16);
vec_x2 = vld1q_f16(dinx_ptr_1 + 24);
vst1q_f16(dout_ptr_1 + 16, divq_ps_f16(vec_x1, val_y));
vst1q_f16(dout_ptr_1 + 24, divq_ps_f16(vec_x2, val_y));
}
int stride = cnt << 5;
if (rem_cnt > 0) {
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
float16_t* dout_ptr_1 = dout_ptr + stride;
float16x8_t val_y = vdupq_n_f16(diny_ptr[0]);
int cnt_num = rem_cnt;
for (int loop = 0; loop < rem_cnt; loop++) {
float16x8_t vec_x1 = vld1q_f16(dinx_ptr_1 + loop * 8);
vst1q_f16(dout_ptr_1 + loop * 8, divq_ps_f16(vec_x1, val_y));
}
}
if (rem_rem > 0) {
stride += (rem_cnt << 3);
const float16_t* dinx_ptr_1 = dinx_ptr + stride;
float16_t* dout_ptr_1 = dout_ptr + stride;
float16_t val = diny_ptr[0];
for (int i = 0; i < rem_rem; i++) {
*dout_ptr_1 = naive_div(*dinx_ptr_1, val);
dinx_ptr_1++;
dout_ptr_1++;
}
}
}
}
}

#endif
} // namespace fp16
} // namespace math
Expand Down
14 changes: 14 additions & 0 deletions lite/backends/arm/math/fp16/elementwise_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,21 @@ typedef __fp16 float16_t;
elementwise_simple_compute_declare(add);
elementwise_simple_compute_declare(mul);
elementwise_simple_compute_declare(sub);
#ifdef __aarch64__
elementwise_simple_compute_declare(div);
#else
void elementwise_div(const float16_t* dinx,
const float16_t* diny,
float16_t* dout,
int num);

void elementwise_div_broadcast(const float16_t* dinx,
const float16_t* diny,
float16_t* dout,
int batch,
int channels,
int num);
#endif

} // namespace fp16
} // namespace math
Expand Down
14 changes: 5 additions & 9 deletions lite/backends/arm/math/fp16/type_trans_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,6 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) {
"vst1.32 {d16-d17}, [%[out]]!\n"
"bne 4b\n"
"2: \n"
"cmp %[remain_remain], #1\n"
"blt 3f\n"
"5: \n"
"vld1.16 d0[0], [%[in]]!\n"
"subs %[remain_remain], #1\n"
"vcvt.f16.f32 d16, q0\n"
"vst1.32 d16[0], [%[out]]!\n"
"bne 5b\n"
"3: \n"
: [in] "+r"(in),
[out] "+r"(out),
[cnt] "+r"(cnt),
Expand All @@ -333,6 +324,11 @@ void fp32_to_fp16(const float* in, float16_t* out, int size) {
"q9",
"q10",
"q11");
for (int i = 0; i < remain_remain; i++) {
*out = static_cast<float16_t>(*in);
out++;
in++;
}
#endif
}
} // namespace fp16
Expand Down
10 changes: 8 additions & 2 deletions lite/kernels/arm/elementwise_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,18 @@ void ElementwiseDivCompute<float16_t, PRECISION(kFP16)>::Run() {
OprandSwapable::NO,
arm_math::NullNeonConfig>(
this,

lite::arm::math::fp16::elementwise_div_broadcast<float16_t>,
lite::arm::math::fp16::elementwise_div<float16_t>,
paddle::lite::kernels::host::naive_div<float16_t>);
#else
LOG(FATAL) << "it doesn't support v7 fp16 elementwise_div compute";
elementwise_compute_template<operators::ElementwiseParam,
float16_t,
OprandSwapable::NO,
arm_math::NullNeonConfig>(
this,
lite::arm::math::fp16::elementwise_div_broadcast,
lite::arm::math::fp16::elementwise_div,
paddle::lite::kernels::host::naive_div<float16_t>);
#endif
}

Expand Down