Skip to content

Commit

Permalink
fix pool2d bug test=develop (#5924)
Browse files Browse the repository at this point in the history
  • Loading branch information
daming5432 authored Apr 21, 2021
1 parent d2ac50e commit 4dbc4b8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
22 changes: 13 additions & 9 deletions lite/backends/opencl/cl_kernel/image/pool_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ __kernel void pool_max(__read_only image2d_t input,
__private const int ksize_w,
__private const int stride_h,
__private const int stride_w,
__private const int4 pad) {
__private const int pad_top,
__private const int pad_left) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;

int start_h = out_h * stride_h - (pad.x - pad.y);
int start_h = out_h * stride_h - pad_top;
int end_h = min(start_h + ksize_h, in_height);
start_h = max(start_h, 0);

int start_w = out_w * stride_w - (pad.w - pad.z);
int start_w = out_w * stride_w - pad_left;
int end_w = min(start_w + ksize_w, in_width);
start_w = max(start_w, 0);

Expand Down Expand Up @@ -64,18 +65,19 @@ __kernel void pool_avg(__read_only image2d_t input,
__private const int ksize_w,
__private const int stride_h,
__private const int stride_w,
__private const int4 pad) {
__private const int pad_top,
__private const int pad_left) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;

int start_h = out_h * stride_h - pad.x;
int start_h = out_h * stride_h - pad_top;
int end_h = min(start_h + ksize_h, in_height);
start_h = max(start_h, 0);

int start_w = out_w * stride_w - pad.z;
int start_w = out_w * stride_w - pad_left;
int end_w = min(start_w + ksize_w, in_width);
start_w = max(start_w, 0);

Expand All @@ -94,7 +96,7 @@ __kernel void pool_avg(__read_only image2d_t input,
div = (CL_DTYPE)((end_h - start_h)*(end_w - start_w));
#else
div = (CL_DTYPE)(ksize_w * ksize_h);
#endif
#endif
CL_DTYPE4 avg = sum / div;
const int pos_out_x = mad24(out_c, out_width, out_w);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), avg);
Expand All @@ -110,7 +112,8 @@ __kernel void pool_avg_global(__read_only image2d_t input,
__private const int ksize_w,
__private const int stride_h,
__private const int stride_w,
__private const int4 pad) {
__private const int pad_top,
__private const int pad_left) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1); // =1
const int out_nh = get_global_id(2); // = n*1
Expand Down Expand Up @@ -179,7 +182,8 @@ __kernel void pool_max_global(__read_only image2d_t input,
__private const int ksize_w,
__private const int stride_h,
__private const int stride_w,
__private const int4 pad) {
__private const int pad_top,
__private const int pad_left) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1); // =1
const int out_nh = get_global_id(2); // = n*1
Expand Down
13 changes: 10 additions & 3 deletions lite/kernels/opencl/pool_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),

void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();

kernel_func_name_ += param.pooling_type;
const bool global_pooling = param.global_pooling;
const bool exclusive = param.exclusive;
Expand Down Expand Up @@ -90,8 +91,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
}
}

cl_int4 pad = {paddings[0], paddings[1], paddings[2], paddings[3]};

#ifdef LITE_WITH_LOG
VLOG(4) << "in_dims : [" << in_dims.size() << "]" << in_dims[0] << " "
<< in_dims[1] << " " << in_dims[2] << " " << in_dims[3];
Expand All @@ -107,6 +106,12 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
<< paddings[1] << " " << paddings[2] << " " << paddings[3];
#endif

bool pads_equal = ((abs(paddings[0] - paddings[1]) < 2) &&
(abs(paddings[2] - paddings[3]) < 2));
if (!pads_equal) {
LOG(FATAL)
<< "padding requires pad_left == pad_right, pad_top == pad_bottom";
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);

Expand Down Expand Up @@ -155,7 +160,9 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, pad);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);

status = EnqueueNDRangeKernel(context,
Expand Down

0 comments on commit 4dbc4b8

Please sign in to comment.