From 4dbc4b832158378b626079a9b2cc65945ffb20fd Mon Sep 17 00:00:00 2001 From: daming5432 Date: Wed, 21 Apr 2021 10:35:27 +0800 Subject: [PATCH] fix pool2d bug test=develop (#5924) --- .../opencl/cl_kernel/image/pool_kernel.cl | 22 +++++++++++-------- lite/kernels/opencl/pool_image_compute.cc | 13 ++++++++--- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl index 0ac0d17bcbf..234d20f8a79 100644 --- a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl @@ -24,18 +24,19 @@ __kernel void pool_max(__read_only image2d_t input, __private const int ksize_w, __private const int stride_h, __private const int stride_w, - __private const int4 pad) { + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); const int out_n = out_nh / out_height; const int out_h = out_nh % out_height; - int start_h = out_h * stride_h - (pad.x - pad.y); + int start_h = out_h * stride_h - pad_top; int end_h = min(start_h + ksize_h, in_height); start_h = max(start_h, 0); - int start_w = out_w * stride_w - (pad.w - pad.z); + int start_w = out_w * stride_w - pad_left; int end_w = min(start_w + ksize_w, in_width); start_w = max(start_w, 0); @@ -64,18 +65,19 @@ __kernel void pool_avg(__read_only image2d_t input, __private const int ksize_w, __private const int stride_h, __private const int stride_w, - __private const int4 pad) { + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); const int out_n = out_nh / out_height; const int out_h = out_nh % out_height; - int start_h = out_h * stride_h - pad.x; + int start_h = out_h * stride_h - pad_top; int end_h = min(start_h + ksize_h, in_height); start_h = max(start_h, 0); - int start_w = out_w * stride_w - pad.z; + int start_w = out_w * stride_w - pad_left; int end_w = min(start_w + ksize_w, in_width); start_w = max(start_w, 0); @@ -94,7 +96,7 @@ __kernel void pool_avg(__read_only image2d_t input, div = (CL_DTYPE)((end_h - start_h)*(end_w - start_w)); #else div = (CL_DTYPE)(ksize_w * ksize_h); -#endif +#endif CL_DTYPE4 avg = sum / div; const int pos_out_x = mad24(out_c, out_width, out_w); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), avg); @@ -110,7 +112,8 @@ __kernel void pool_avg_global(__read_only image2d_t input, __private const int ksize_w, __private const int stride_h, __private const int stride_w, - __private const int4 pad) { + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); // =1 const int out_nh = get_global_id(2); // = n*1 @@ -179,7 +182,8 @@ __kernel void pool_max_global(__read_only image2d_t input, __private const int ksize_w, __private const int stride_h, __private const int stride_w, - __private const int4 pad) { + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); // =1 const int out_nh = get_global_id(2); // = n*1 diff --git a/lite/kernels/opencl/pool_image_compute.cc b/lite/kernels/opencl/pool_image_compute.cc index c28ef2d8e45..faf6389952e 100644 --- a/lite/kernels/opencl/pool_image_compute.cc +++ b/lite/kernels/opencl/pool_image_compute.cc @@ -42,6 +42,7 @@ class PoolComputeImage2D : public KernelLite(); + kernel_func_name_ += param.pooling_type; const bool global_pooling = param.global_pooling; const bool exclusive = param.exclusive; @@ -90,8 +91,6 @@ class PoolComputeImage2D : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); @@ -155,7 +160,9 @@ class PoolComputeImage2D : public KernelLite(strides[1])); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, pad); + status = kernel.setArg(++arg_idx, static_cast(paddings[2])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(paddings[0])); CL_CHECK_FATAL(status); status = EnqueueNDRangeKernel(context,