Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pickv2.9][OpenCL]V2.9 opt conv3x3 7x7 mali #6393

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lite/backends/opencl/cl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ cl::NDRange CLContext::DefaultGlobalWorkSize(const CLImage &image) {
std::set<cl::NDRange, CLContext::CompareByRange>
CLContext::GenerateLocalWorkSizes(cl::NDRange gws, size_t max_ws) {
size_t tune_type = CLRuntime::Global()->auto_tune();
auto first_lws = DefaultLocalWorkSize(gws, max_ws, 2, false);
auto first_lws = DefaultLocalWorkSize(gws, max_ws, 3, false);
std::set<cl::NDRange, CompareByRange> lwss;
for (auto one_lws : first_lws) {
lwss.insert(one_lws);
Expand Down
128 changes: 19 additions & 109 deletions lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ __kernel void conv2d_3x3_opt_mali(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__global CL_DTYPE4 *filter_buf,
__read_only image2d_t bias,
__global CL_DTYPE4 *bias_buf,
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
Expand All @@ -536,43 +536,31 @@ __kernel void conv2d_3x3_opt_mali(__private const int item_ch,
__read_only image2d_t prelu_alpha) {
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
const int item_w_id = 2 * get_global_id(1);
const int item_h_id = get_global_id(2);

// out_width_id_per_blk
int out_w_base_id = mul24(item_ch_id, out_w);
int out_w_id0 = item_w_id;
int out_w_id1 = out_w_id0 + item_w;
int out_w_id2 = out_w_id1 + item_w;
int out_w_id3 = out_w_id2 + item_w;
int out_w_id4 = out_w_id3 + item_w;
int out_w_id1 = out_w_id0 + 1;

// in_width_id_per_blk and in_height_id_per_batch
int in_h_id = mad24((item_h_id % out_h), stride, (-pad));
int in_w_id0 = mad24(item_w_id, stride, (-pad));
int in_w_id1 = mad24(item_w, stride, in_w_id0);
int in_w_id2 = mad24(item_w, stride, in_w_id1);
int in_w_id3 = mad24(item_w, stride, in_w_id2);
int in_w_id4 = mad24(item_w, stride, in_w_id3);
int in_w_id1 = mad24(item_w_id + 1, stride, (-pad));

#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, SAMPLER, (int2)(item_ch_id, 0));
CL_DTYPE4 output[2];
output[0] = (bias_buf + item_ch_id)[0];
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
output[4] = output[0];
#else
CL_DTYPE4 output[5] = {0.0f};
CL_DTYPE4 output[2] = {0.0f};
#endif

CL_DTYPE4 filter[4] = {0.0f};
CL_DTYPE4 input[5] = {0.0f};
CL_DTYPE4 filter[2] = {0.0f};
CL_DTYPE4 input[2] = {0.0f};

for (int ch = 0; ch < ((in_ch + 3) >> 2); ch++) {
int ch_surplus = ((ch + 1) << 2) - in_ch > 0 ? ((ch + 1) << 2) - in_ch : 0;

const int in_w_base_id = mul24(ch, in_w);

int filter_w_val = ch << 2;
Expand All @@ -590,66 +578,33 @@ __kernel void conv2d_3x3_opt_mali(__private const int item_ch,
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
-1,
(in_w_id1 + w < 0 | in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
-1,
(in_w_id2 + w < 0 | in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
-1,
(in_w_id3 + w < 0 | in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
-1,
(in_w_id4 + w < 0 | in_w_id4 + w >= in_w));

input[0] = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, SAMPLER, (int2)(in_w_val0, in_h_val));
input[1] = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, SAMPLER, (int2)(in_w_val1, in_h_val));
input[2] = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, SAMPLER, (int2)(in_w_val2, in_h_val));
input[3] = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, SAMPLER, (int2)(in_w_val3, in_h_val));
input[4] = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, SAMPLER, (int2)(in_w_val4, in_h_val));

output[0] = mad(input[0].x, filter_ptr[0], output[0]);
output[1] = mad(input[1].x, filter_ptr[0], output[1]);
output[2] = mad(input[2].x, filter_ptr[0], output[2]);
output[3] = mad(input[3].x, filter_ptr[0], output[3]);
output[4] = mad(input[4].x, filter_ptr[0], output[4]);

if (ch_surplus < 3) {
output[0] = mad(input[0].y, filter_ptr[1], output[0]);
output[1] = mad(input[1].y, filter_ptr[1], output[1]);
output[2] = mad(input[2].y, filter_ptr[1], output[2]);
output[3] = mad(input[3].y, filter_ptr[1], output[3]);
output[4] = mad(input[4].y, filter_ptr[1], output[4]);
}
if (ch_surplus < 2) {
output[0] = mad(input[0].z, filter_ptr[2], output[0]);
output[1] = mad(input[1].z, filter_ptr[2], output[1]);
output[2] = mad(input[2].z, filter_ptr[2], output[2]);
output[3] = mad(input[3].z, filter_ptr[2], output[3]);
output[4] = mad(input[4].z, filter_ptr[2], output[4]);
}
if (ch_surplus < 1) {
output[0] = mad(input[0].w, filter_ptr[3], output[0]);
output[1] = mad(input[1].w, filter_ptr[3], output[1]);
output[2] = mad(input[2].w, filter_ptr[3], output[2]);
output[3] = mad(input[3].w, filter_ptr[3], output[3]);
output[4] = mad(input[4].w, filter_ptr[3], output[4]);
}
output[0] = mad(input[0].y, filter_ptr[1], output[0]);
output[1] = mad(input[1].y, filter_ptr[1], output[1]);

output[0] = mad(input[0].z, filter_ptr[2], output[0]);
output[1] = mad(input[1].z, filter_ptr[2], output[1]);

output[0] = mad(input[0].w, filter_ptr[3], output[0]);
output[1] = mad(input[1].w, filter_ptr[3], output[1]);

filter_ptr += ((in_ch + 3) >> 2) * 4;
}
}
}
CL_DTYPE4 alpha[5];
CL_DTYPE4 alpha[2];
#ifdef PRELU_CH //{
alpha[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(item_ch_id, 0));
alpha[1] = alpha[0];
alpha[2] = alpha[0];
alpha[3] = alpha[0];
alpha[4] = alpha[0];
//}
#elif defined(PRELU_ELE) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
Expand All @@ -662,48 +617,21 @@ __kernel void conv2d_3x3_opt_mali(__private const int item_ch,
SAMPLER,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
alpha[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
alpha[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
alpha[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
prelu_alpha,
SAMPLER,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
//}
#elif defined(PRELU_ALL) //{
alpha[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prelu_alpha, SAMPLER, (int2)(0, 0));
alpha[0].y = alpha[0].x;
alpha[0].z = alpha[0].x;
alpha[0].w = alpha[0].x;
alpha[1] = alpha[0];
alpha[2] = alpha[0];
alpha[3] = alpha[0];
alpha[4] = alpha[0];
//}
#endif
output[0] = activation_type4(output[0], alpha[0]);
output[1] = activation_type4(output[1], alpha[1]);
output[2] = activation_type4(output[2], alpha[2]);
output[3] = activation_type4(output[3], alpha[3]);
output[4] = activation_type4(output[4], alpha[4]);

#ifdef SCALE_ACTIVATION
output[0] = fuse_scale(output[0], 1.f, 0.f, 0.f);
output[1] = fuse_scale(output[1], 1.f, 0.f, 0.f);
output[2] = fuse_scale(output[2], 1.f, 0.f, 0.f);
output[3] = fuse_scale(output[3], 1.f, 0.f, 0.f);
output[4] = fuse_scale(output[4], 1.f, 0.f, 0.f);
#endif

WRITE_IMG_TYPE(CL_DTYPE_CHAR,
Expand All @@ -716,22 +644,4 @@ __kernel void conv2d_3x3_opt_mali(__private const int item_ch,
(int2)(out_w_base_id + out_w_id1, item_h_id),
output[1]);
}
if (out_w_id2 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id2, item_h_id),
output[2]);
}
if (out_w_id3 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id3, item_h_id),
output[3]);
}
if (out_w_id4 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id4, item_h_id),
output[4]);
}
}
Loading