Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM]Add rnn fp16 arm #9402

Merged
merged 5 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions lite/backends/arm/math/fp16/activation_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,49 @@ void act_tanh<float16_t>(const float16_t* din,
ptr_out++;
}
}

template <>
void act_sigmoid<float16_t>(const float16_t* din,
float16_t* dout,
int size,
int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim8 = nums_per_thread >> 3;
int neon_loop_remain_dim8 = nums_per_thread - (neon_loop_cnt_dim8 << 3);

float16x8_t vzero = vdupq_n_f16(0.f);
LITE_PARALLEL_BEGIN(i, tid, threads) {
float16x8_t exp_vec = vdupq_n_f16(0.0f);
float16x8_t recip = vdupq_n_f16(0.0f);
const float16_t* ptr_in_thread = din + i * nums_per_thread;
float16_t* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim8; ++k) {
exp_vec = expq_ps_f16(vnegq_f16(vld1q_f16(ptr_in_thread)));
exp_vec = vaddq_f16(exp_vec, vdupq_n_f16(1.0f));
recip = vrecpeq_f16(exp_vec);
recip = vmulq_f16(vrecpsq_f16(exp_vec, recip), recip);
recip = vmulq_f16(vrecpsq_f16(exp_vec, recip), recip);
vst1q_f16(ptr_out_thread, recip);
ptr_out_thread += 8;
ptr_in_thread += 8;
}
for (int j = 0; j < neon_loop_remain_dim8; ++j) {
ptr_out_thread[0] = 1.f / (1 + expf(-ptr_in_thread[0]));
ptr_in_thread++;
ptr_out_thread++;
}
}
LITE_PARALLEL_END();
float16_t* ptr_out = dout + threads * nums_per_thread;
const float16_t* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
ptr_out[0] = 1.f / (1 + expf(-ptr_in[0]));
ptr_in++;
ptr_out++;
}
}

} // namespace fp16
} // namespace math
} // namespace arm
Expand Down
3 changes: 3 additions & 0 deletions lite/backends/arm/math/fp16/activation_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ void act_prelu(const T* din,
template <typename T>
void act_tanh(const T* din, T* dout, int size, int threads);

template <typename T>
void act_sigmoid(const T* din, T* dout, int size, int threads);

} // namespace fp16
} // namespace math
} // namespace arm
Expand Down
159 changes: 159 additions & 0 deletions lite/backends/arm/math/gru.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#ifdef LITE_WITH_ARM
#include <arm_neon.h>
#endif
#ifdef ENABLE_ARM_FP16
#include "lite/backends/arm/math/fp16/funcs_fp16.h"
#endif

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -63,6 +66,37 @@ void rnn_activation(const T* din,
}
}

#ifdef ENABLE_ARM_FP16
template <>
void rnn_activation<float16_t>(const float16_t* din,
float16_t* dout,
int size,
lite_api::ActivationType act_type,
int threads) {
switch (act_type) {
case lite_api::ActivationType::kSigmoid:
fp16::act_sigmoid<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kSigmoid_v2:
fp16::act_sigmoid<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh:
fp16::act_tanh<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh_v2:
fp16::act_tanh<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kRelu:
fp16::act_relu<float16_t>(din, dout, size, threads);
break;
default:
LOG(FATAL) << "unsupport fp16 activation type:"
<< static_cast<int>(act_type);
break;
}
}
#endif

template <typename T>
void compute_kernel(RNNGRUValue<T> value,
int frame_size,
Expand Down Expand Up @@ -209,6 +243,94 @@ void compute_kernel<float>(RNNGRUValue<float> value,
}
}

#ifdef ENABLE_ARM_FP16
template <>
void compute_kernel<float16_t>(RNNGRUValue<float16_t> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
auto value_reset_gate = value.gate_value;
auto value_update_gate = value.gate_value + frame_size;
auto value_reset_output = value.reset_output_value;
auto value_reset_bias = value.reset_bias;
auto cell_state_value = value.gate_value + 2 * frame_size;
auto value_output = value.output_value;
auto value_prev_out = value.prev_out_value;
int i = 0;
float16x8_t vec_one = vdupq_n_f16(1.f);

for (int b = 0; b < batch_size; b++) {
rnn_activation(value_reset_gate,
value_reset_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);
rnn_activation(value_update_gate,
value_update_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);

for (i = 0; i + 7 < frame_size; i += 8) {
float16x8_t vec_out = vld1q_f16(value_reset_output + i);
float16x8_t vec_reset = vld1q_f16(value_reset_gate + i);
float16x8_t vec_bias = vld1q_f16(value_reset_bias + i);
vec_out = vmulq_f16(vaddq_f16(vec_out, vec_bias), vec_reset);
vst1q_f16(value_reset_output + i, vec_out);
vst1q_f16(cell_state_value + i,
vaddq_f16(vec_out, vld1q_f16(cell_state_value + i)));
}
for (; i < frame_size; i++) {
value_reset_output[i] =
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
cell_state_value[i] += value_reset_output[i];
}

rnn_activation(cell_state_value,
cell_state_value,
frame_size,
lite_api::ActivationType::kTanh_v2,
1);

if (value.prev_out_value) {
for (i = 0; i + 7 < frame_size; i += 8) {
float16x8_t vec_vug = vld1q_f16(value_update_gate + i);
float16x8_t vec_vpo = vld1q_f16(value_prev_out + i);
float16x8_t vec_csv = vld1q_f16(cell_state_value + i);
vec_vpo = vmulq_f16(vec_vug, vec_vpo);
float16x8_t vec_out =
vfmaq_f16(vec_vpo, vsubq_f16(vec_one, vec_vug), vec_csv);
vst1q_f16(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
value_update_gate[i] * value_prev_out[i];
}
} else {
for (i = 0; i + 7 < frame_size; i += 8) {
float16x8_t vec_vug = vld1q_f16(value_update_gate + i);
float16x8_t vec_csv = vld1q_f16(cell_state_value + i);
float16x8_t vec_out = vmulq_f16(vsubq_f16(vec_one, vec_vug), vec_csv);
vst1q_f16(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
}
}

value_reset_gate += frame_size * 3;
value_update_gate += frame_size * 3;
value_reset_output += frame_size;
cell_state_value += frame_size * 3;
value_output += frame_size;
if (value.prev_out_value) {
value_prev_out += frame_size;
}
}
}
#endif

template <typename T>
struct RnnGruUnitFunctorV2 {
static void compute(ARMContext* ctx,
Expand Down Expand Up @@ -242,6 +364,43 @@ struct RnnGruUnitFunctorV2 {
}
};

#ifdef ENABLE_ARM_FP16
template <>
struct RnnGruUnitFunctorV2<float16_t> {
static void compute(ARMContext* ctx,
RNNGRUValue<float16_t> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
if (value.prev_out_value) {
operators::ActivationParam act_param;
act_param.has_active = false;
lite::arm::math::fp16::sgemm_fp16(false,
true,
batch_size,
frame_size,
frame_size,
1.f,
value.prev_out_value,
frame_size,
value.state_weight,
frame_size,
0.f,
value.reset_output_value,
frame_size,
nullptr,
false,
act_param,
ctx);
}
compute_kernel<float16_t>(
value, frame_size, batch_size, active_node, active_gate);
}
};

#endif

} // namespace math
} // namespace arm
} // namespace lite
Expand Down
112 changes: 112 additions & 0 deletions lite/backends/arm/math/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,118 @@ void vector_dot(
}
}

float* row_offset(Tensor& input, int start) { // NOLINT
auto in_dim = input.dims();
int width = input.numel() / in_dim[0];
int offset = start < in_dim[0] ? start * width : input.numel();
return input.mutable_data<float>() + offset;
}

#ifdef ENABLE_ARM_FP16
void add_bias_rowwise_fp16(Tensor* input,
const Tensor* bias,
int start_w,
int end_w) {
auto in_dim = input->dims();
int width = input->numel() / in_dim[0];
int w_adds = width < end_w ? width : end_w;
float16_t* i_data = input->mutable_data<float16_t>();
const float16_t* b_data = bias->data<float16_t>();
for (int i = 0; i < in_dim[0]; ++i) {
for (int w = start_w; w < w_adds; ++w) {
i_data[w] += b_data[w];
}
i_data += width;
}
}

void vector_dot_fp16(float16_t* out,
const float16_t* in,
const float16_t* v1,
int size,
const float16_t* v2) {
int loop = size >> 3;
int remain = size & 7;
const float16_t* in_ptr = in;
float16_t* out_ptr = out;
const float16_t* v1_ptr = v1;
const float16_t* v2_ptr = v2;
for (int i = 0; i < loop; ++i) {
float16x8_t in = vld1q_f16(in_ptr);
float16x8_t data1 = vld1q_f16(v1_ptr);
if (!v2) {
// in_out * v1
float16x8_t out = vmulq_f16(in, data1);
vst1q_f16(out_ptr, out);
in_ptr += 8;
v1_ptr += 8;
out_ptr += 8;
} else {
// in_out + v1 * v2
float16x8_t data2 = vld1q_f16(v2_ptr);
float16x8_t out = vfmaq_f16(in, data1, data2);
vst1q_f16(out_ptr, out);
in_ptr += 8;
v1_ptr += 8;
out_ptr += 8;
v2_ptr += 8;
}
}
for (int i = 0; i < remain; ++i) {
if (!v2) {
out_ptr[i] = in_ptr[i] * v1_ptr[i];
} else {
out_ptr[i] = in_ptr[i] + v1_ptr[i] * v2_ptr[i];
}
}
}

template <>
void activation<float16_t>(const float16_t* din,
float16_t* dout,
int size,
std::string act_str,
int threads) {
if (act_str == "sigmoid") {
fp16::act_sigmoid<float16_t>(din, dout, size, threads);
} else if (act_str == "tanh") {
fp16::act_tanh<float16_t>(din, dout, size, threads);
} else if (act_str == "relu") {
fp16::act_relu<float16_t>(din, dout, size, threads);
} else {
LOG(FATAL) << "unsupport fp16 activation " << act_str;
}
}
template <>
void activation<float16_t>(const float16_t* din,
float16_t* dout,
int size,
lite_api::ActivationType act_type,
int threads) {
switch (act_type) {
case lite_api::ActivationType::kSigmoid:
fp16::act_sigmoid<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kSigmoid_v2:
fp16::act_sigmoid<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh:
fp16::act_tanh<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh_v2:
fp16::act_tanh<float16_t>(din, dout, size, threads);
break;
case lite_api::ActivationType::kRelu:
fp16::act_relu<float16_t>(din, dout, size, threads);
break;
default:
LOG(FATAL) << "unsupport fp16 activation type:"
<< static_cast<int>(act_type);
break;
}
}
#endif

} // namespace math
} // namespace arm
} // namespace lite
Expand Down
Loading