Skip to content

Commit

Permalink
[OpenCL] fuse conv prelu pass (#5461)
Browse files Browse the repository at this point in the history
* fuse conv prelu pass test=develop

* add #ifdef PRELU to conv2d_1x1_opt_kernel.cl test=develop

* modify activation and activation_type4, fix prelu bug test=develop

* test=develop

* add alpha_image_p_ init test=develop

* add annotation for prelu test=develop

* test=develop

* rm some waste code test=develop

* test=develop

* add support for depth_conv2d_common test=develop

* fix arg bug test=develop
  • Loading branch information
daming5432 authored Mar 3, 2021
1 parent 68e358b commit d51bd58
Show file tree
Hide file tree
Showing 25 changed files with 756 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ __kernel void depthwise_conv2d(const int numel, // num of elements
v += bias_data[c];
}
#ifdef RELU
output_data[index] = activation(v);
CL_DTYPE alpha;
output_data[index] = activation(v, alpha);
#else
output_data[index] = v;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ __kernel void elementwise_add(__global const CL_DTYPE* x_data,
for (int n = 0; n < num; ++n) { // n: [0, h*w)
*dout_ptr = *din_ptr + diny_data;
#ifdef RELU
*dout_ptr = activation(*dout_ptr);
CL_DTYPE alpha;
*dout_ptr = activation(*dout_ptr, alpha);
#endif
++dout_ptr;
++din_ptr;
Expand Down
18 changes: 11 additions & 7 deletions lite/backends/opencl/cl_kernel/buffer/fc_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ void fc_gemm_naive(__global const CL_DTYPE* a,
}

#ifdef RELU
c[row * N + col] = activation(c0);
CL_DTYPE alpha;
c[row * N + col] = activation(c0, alpha);
#else
c[row * N + col] = c0;
#endif
Expand Down Expand Up @@ -91,7 +92,8 @@ void gemm_batch_naive(__global const CL_DTYPE* a,
c0 += a0 * b0;
}

cur_c[row * N + col] = activation(c0);
CL_DTYPE alpha;
cur_c[row * N + col] = activation(c0, alpha);
}


Expand Down Expand Up @@ -235,7 +237,8 @@ void fc_gemv_naive(__global const CL_DTYPE* a,
}

#ifdef RELU
c[col] = activation(c0);
CL_DTYPE alpha;
c[col] = activation(c0, alpha);
#else
c[col] = c0;
#endif
Expand All @@ -254,6 +257,7 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
const int M, const int N, const int K) {
const int col = get_global_id(0) << 2; // gws[0]: [0, N >> 2) height of B == N

half alpha;
if (col + 3 < N) {
half4 c0 = 0.0f;
if (bias) {
Expand Down Expand Up @@ -310,11 +314,11 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
} else {
switch (col % 4) {
case 3:
c[col + 2] = activation(c0.z);
c[col + 2] = activation(c0.z, alpha);
case 2:
c[col + 1] = activation(c0.y);
c[col + 1] = activation(c0.y, alpha);
case 1:
c[col] = activation(c0.x);
c[col] = activation(c0.x, alpha);
}
}
#else
Expand All @@ -341,7 +345,7 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
c0 += a0 * b0;
}
#ifdef RELU
c[col + col_offset] = activation(c0);
c[col + col_offset] = activation(c0, alpha);
#else
c[col + col_offset] = c0;
#endif
Expand Down
3 changes: 2 additions & 1 deletion lite/backends/opencl/cl_kernel/buffer/relu_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ limitations under the License. */

__kernel void relu(__global const CL_DTYPE* x_data, const int count, __global CL_DTYPE* out_data) {
const int index = get_global_id(0);
CL_DTYPE alpha;
if (index < count) {
out_data[index] = activation(x_data[index]);
out_data[index] = activation(x_data[index], alpha);
}
}
14 changes: 2 additions & 12 deletions lite/backends/opencl/cl_kernel/cl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,7 @@ __constant sampler_t SAMPLER =
/////////////////////////////////
// activation / activation_type4
/////////////////////////////////
inline CL_DTYPE activation(CL_DTYPE in
#ifdef PRELU
,
CL_DTYPE prelu_alpha
#endif
) {
inline CL_DTYPE activation(CL_DTYPE in, CL_DTYPE prelu_alpha) {
CL_DTYPE output = in;
#ifdef PRELU
output = select(prelu_alpha * in, in, in >= (CL_DTYPE)0);
Expand Down Expand Up @@ -138,12 +133,7 @@ inline CL_DTYPE activation(CL_DTYPE in
return output;
}

inline CL_DTYPE4 activation_type4(CL_DTYPE4 in
#ifdef PRELU
,
CL_DTYPE4 prelu_alpha
#endif
) {
inline CL_DTYPE4 activation_type4(CL_DTYPE4 in, CL_DTYPE4 prelu_alpha) {
CL_DTYPE4 output = in;
#ifdef PRELU
output = select(prelu_alpha * in, in, isgreaterequal(in, (CL_DTYPE4)0));
Expand Down
68 changes: 58 additions & 10 deletions lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ __kernel void conv2d_1x1_opt(
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int old_w) {
__private const int old_w,
__read_only image2d_t prelu_alpha) {

const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
Expand Down Expand Up @@ -251,10 +252,33 @@ __kernel void conv2d_1x1_opt(
READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, SAMPLER, (int2)(out_c, 0));
#endif

output0 = activation_type4(output0);
output1 = activation_type4(output1);
output2 = activation_type4(output2);
output3 = activation_type4(output3);
CL_DTYPE4 alpha0,alpha1,alpha2,alpha3;
#ifdef PRELU_CH //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(out_c, 0));
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#elif defined(PRELU_ELE) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, output_pos0);
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#elif defined(PRELU_ALL) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha0.y = alpha0.x;
alpha0.z = alpha0.x;
alpha0.w = alpha0.x;
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#endif
output0 = activation_type4(output0, alpha0);
output1 = activation_type4(output1, alpha1);
output2 = activation_type4(output2, alpha2);
output3 = activation_type4(output3, alpha3);

#ifdef SCALE_ACTIVATION
output0 = fuse_scale(output0, 1.f, 0.f, 0.f);
Expand Down Expand Up @@ -301,7 +325,8 @@ __kernel void conv2d_1x1_simple(
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int old_w) {
__private const int old_w,
__read_only image2d_t prelu_alpha) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
Expand Down Expand Up @@ -421,10 +446,33 @@ __kernel void conv2d_1x1_simple(
READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, SAMPLER, (int2)(out_c, 0));
#endif

output0 = activation_type4(output0);
output1 = activation_type4(output1);
output2 = activation_type4(output2);
output3 = activation_type4(output3);
CL_DTYPE4 alpha0,alpha1,alpha2,alpha3;
#ifdef PRELU_CH //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(out_c, 0));
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#elif defined(PRELU_ELE) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, output_pos0);
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#elif defined(PRELU_ALL) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha0.y = alpha0.x;
alpha0.z = alpha0.x;
alpha0.w = alpha0.x;
alpha1 = alpha0;
alpha2 = alpha0;
alpha3 = alpha0;
//}
#endif
output0 = activation_type4(output0, alpha0);
output1 = activation_type4(output1, alpha1);
output2 = activation_type4(output2, alpha2);
output3 = activation_type4(output3, alpha3);

#ifdef SCALE_ACTIVATION
output0 = fuse_scale(output0, 1.f, 0.f, 0.f);
Expand Down
20 changes: 18 additions & 2 deletions lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
__private const int filter_width,
__private const int filter_height,
__private const int group,
__private const int input_tensor_c) {
__private const int input_tensor_c,
__read_only image2d_t prelu_alpha) {

const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
Expand Down Expand Up @@ -251,7 +252,22 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
output.w = (i == 3) ? output.w + tmp_out : output.w;
}
}
output = activation_type4(output);

CL_DTYPE4 alpha0;
#ifdef PRELU_CH //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(out_c, 0));
//}
#elif defined(PRELU_ELE) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, output_pos);
//}
#elif defined(PRELU_ALL) //{
alpha0 = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha0.y = alpha0.x;
alpha0.z = alpha0.x;
alpha0.w = alpha0.x;
//}
#endif
output = activation_type4(output, alpha0);

#ifdef SCALE_ACTIVATION
output = fuse_scale(output, 1.f, 0.f, 0.f);
Expand Down
116 changes: 104 additions & 12 deletions lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
__private const int out_h,
__read_only image2d_t prelu_alpha) {
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
Expand Down Expand Up @@ -216,11 +217,56 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
}
}

output[0] = activation_type4(output[0]);
output[1] = activation_type4(output[1]);
output[2] = activation_type4(output[2]);
output[3] = activation_type4(output[3]);
output[4] = activation_type4(output[4]);
CL_DTYPE4 alpha[5];
#ifdef PRELU_CH //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(item_ch_id, 0));
alpha[1] = alpha[0];
alpha[2] = alpha[0];
alpha[3] = alpha[0];
alpha[4] = alpha[0];
//}
#elif defined(PRELU_ELE) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
alpha[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
alpha[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
alpha[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
alpha[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
//}
#elif defined(PRELU_ALL) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha[0].y = alpha[0].x; alpha[0].z = alpha[0].x; alpha[0].w = alpha[0].x;
alpha[1] = alpha[0]; alpha[2] = alpha[0];
alpha[3] = alpha[0]; alpha[4] = alpha[0];
//}
#endif
output[0] = activation_type4(output[0], alpha[0]);
output[1] = activation_type4(output[1], alpha[1]);
output[2] = activation_type4(output[2], alpha[2]);
output[3] = activation_type4(output[3], alpha[3]);
output[4] = activation_type4(output[4], alpha[4]);

#ifdef SCALE_ACTIVATION
output[0] = fuse_scale(output[0], 1.f, 0.f, 0.f);
Expand Down Expand Up @@ -276,7 +322,8 @@ __kernel void conv2d_3x3_multi_batch(__private const int item_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
__private const int out_h,
__read_only image2d_t prelu_alpha) {
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
Expand Down Expand Up @@ -464,11 +511,56 @@ __kernel void conv2d_3x3_multi_batch(__private const int item_ch,
}
}

output[0] = activation_type4(output[0]);
output[1] = activation_type4(output[1]);
output[2] = activation_type4(output[2]);
output[3] = activation_type4(output[3]);
output[4] = activation_type4(output[4]);
CL_DTYPE4 alpha[5];
#ifdef PRELU_CH //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(item_ch_id, 0));
alpha[1] = alpha[0];
alpha[2] = alpha[0];
alpha[3] = alpha[0];
alpha[4] = alpha[0];
//}
#elif defined(PRELU_ELE) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
alpha[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
alpha[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
alpha[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
alpha[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
//}
#elif defined(PRELU_ALL) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha[0].y = alpha[0].x; alpha[0].z = alpha[0].x; alpha[0].w = alpha[0].x;
alpha[1] = alpha[0]; alpha[2] = alpha[0];
alpha[3] = alpha[0]; alpha[4] = alpha[0];
//}
#endif
output[0] = activation_type4(output[0], alpha[0]);
output[1] = activation_type4(output[1], alpha[1]);
output[2] = activation_type4(output[2], alpha[2]);
output[3] = activation_type4(output[3], alpha[3]);
output[4] = activation_type4(output[4], alpha[4]);

#ifdef SCALE_ACTIVATION
output[0] = fuse_scale(output[0], 1.f, 0.f, 0.f);
Expand Down
Loading

0 comments on commit d51bd58

Please sign in to comment.