Skip to content

Commit

Permalink
support group conv_transpose for opencl test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenlin-work committed Feb 21, 2022
1 parent 7f42f3e commit 48ae3ec
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 10 deletions.
255 changes: 255 additions & 0 deletions lite/backends/opencl/cl_kernel/image/conv2d_transpose_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,258 @@ __kernel void conv2d_transpose(

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, out_pos, out0);
}

__kernel void group_conv2d_transpose(
__private const int global_size_dim0, // (out_c + 3) / 4
__private const int global_size_dim1, // out_w
__private const int global_size_dim2, // out_n * out_h
__read_only image2d_t input,
__read_only image2d_t filter,
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int2 input_shape,
__private const int2 output_shape,
__private const int2 stride_shape,
__private const int2 align_shape,
__private const int2 padding_shape,
__private const int2 kernel_shape,
__private const int2 dilation_shape,
__private const int2 kernel_prev_shape,
__private const int kernel_size,
__private const int input_c_blks,
__private const int in_channels_per_group,
__private const int out_channels_per_group) {
const int out_c_blk_idx = get_global_id(0);
const int out_w_idx = get_global_id(1);
const int out_nh_idx = get_global_id(2);

if (out_c_blk_idx >= global_size_dim0 || out_w_idx >= global_size_dim1 ||
out_nh_idx >= global_size_dim2) {
return;
}

int2 out_pos = (int2)(out_c_blk_idx * output_shape.x + out_w_idx, out_nh_idx);

#ifdef BIASE_CH
CL_DTYPE4 out0 =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, SAMPLER, (int2)(out_c_blk_idx, 0));
#else
CL_DTYPE4 out0 = 0.f;
#endif

const int out_n_idx = out_nh_idx / output_shape.y;
const int out_h_idx = out_nh_idx % output_shape.y;

int kernel_start_x = max(0, (out_w_idx + align_shape.x) / stride_shape.x);
int kernel_start_y = max(0, (out_h_idx + align_shape.y) / stride_shape.y);
int valid_kernel_width =
kernel_shape.x - mad24(kernel_start_x, stride_shape.x, padding_shape.x) +
out_w_idx - 1;
int valid_kernel_height =
kernel_shape.y - mad24(kernel_start_y, stride_shape.y, padding_shape.y) +
out_h_idx - 1;

CL_DTYPE4 in0;
CL_DTYPE4 weights0;
for (int o = 0; o < 4; ++o) {
int group_id = (out_c_blk_idx * 4 + o) / out_channels_per_group;
int remain =
(out_c_blk_idx * 4 + o - group_id * out_channels_per_group) % 4;
for (int ic = group_id * in_channels_per_group;
ic < (group_id + 1) * in_channels_per_group;
++ic) {
int in_idx = mul24(ic / 4, input_shape.x);
for (int k_y = valid_kernel_height, idx_h = kernel_start_y; k_y >= 0;
k_y -= stride_shape.y, idx_h++) {
int in_y_idx = mad24(
out_n_idx, input_shape.y, idx_h); // height idx of input image2d
int in_nh_value =
select(in_y_idx, -1, idx_h < 0 || idx_h >= input_shape.y);
int in_width0 = kernel_start_x;
for (int k_x = valid_kernel_width; k_x >= 0; k_x -= stride_shape.x) {
int in_width_value0 = in_width0;
in_width_value0 =
select(in_idx + in_width_value0,
-1,
(in_width_value0 < 0 || in_width_value0 >= input_shape.x));
in0 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input,
SAMPLER,
(int2)(in_width_value0, in_nh_value));
if (k_x % dilation_shape.x == 0 && k_y % dilation_shape.y == 0) {
int kernel_y_0 = ic * kernel_prev_shape.y + k_y / dilation_shape.y;
int kernel_x_0 =
(((out_c_blk_idx * 4 + o) % out_channels_per_group) / 4) *
kernel_prev_shape.x +
k_x / dilation_shape.x;
weights0 = READ_IMG_TYPE(
CL_DTYPE_CHAR, filter, SAMPLER, (int2)(kernel_x_0, kernel_y_0));
if (ic % 4 == 0) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.x, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.x, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.x, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.x, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.x, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.x, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.x, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.x, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.x, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.x, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.x, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.x, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.x, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.x, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.x, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.x, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 1) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.y, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.y, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.y, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.y, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.y, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.y, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.y, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.y, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.y, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.y, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.y, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.y, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.y, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.y, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.y, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.y, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 2) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.z, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.z, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.z, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.z, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.z, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.z, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.z, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.z, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.z, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.z, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.z, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.z, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.z, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.z, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.z, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.z, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 3) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.w, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.w, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.w, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.w, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.w, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.w, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.w, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.w, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.w, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.w, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.w, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.w, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.w, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.w, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.w, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.w, weights0.w, out0.w) : out0.w;
}
}
} else {
weights0 = (CL_DTYPE4)(0.0f);
}
in_width0++;
}
}
}
}
out0 = activation_type4(out0, 0.f);

#ifdef SCALE_ACTIVATION
out0 = fuse_scale(out0, 1.f, 0.f, 0.f);
#endif

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, out_pos, out0);
}
30 changes: 30 additions & 0 deletions lite/kernels/opencl/conv_transpose_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,26 @@ void ConvTransposeImageCompute::PrepareForRun() {
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
auto* filter_image_data = MUTABLE_DATA_CPU(tensor_hold_filter_image_);

converter.NCHWToImage(reinterpret_cast<float*>(filter_cpu),
filter_image_data,
filter_trans_dims);
MUTABLE_DATA_GPU(
filter_gpu_image_, filter_image_w_, filter_image_h_, filter_image_data);
} else if (groups_ > 1) {
CHECK_EQ(filter_tensor_n_ % groups_, 0);
kernel_name = "group_conv2d_transpose";
is_group_conv_ = true;
kernel_func_names_.push_back(kernel_name);

DDimLite filter_trans_dims{
{filter_dims[0], filter_dims[1], filter_dims[2], filter_dims[3]}};
CLImageConverterDefault converter;
const DDim& filter_image_dims =
converter.InitImageDimInfoWith(filter_trans_dims);
filter_image_w_ = filter_image_dims[0];
filter_image_h_ = filter_image_dims[1];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
auto* filter_image_data = MUTABLE_DATA_CPU(tensor_hold_filter_image_);
converter.NCHWToImage(reinterpret_cast<float*>(filter_cpu),
filter_image_data,
filter_trans_dims);
Expand Down Expand Up @@ -322,6 +342,16 @@ void ConvTransposeImageCompute::SetArgs() {
CL_CHECK_FATAL(status);
kernel->setArg(idx++, static_cast<int32_t>(maptofactor(input_tensor_c_, 4)));
CL_CHECK_FATAL(status);
if (is_group_conv_) {
int in_channels_per_group = input_tensor_c_ / groups_;
kernel->setArg(idx++, in_channels_per_group);
CL_CHECK_FATAL(status);
int out_channels_per_group = output_tensor_c_ / groups_;
kernel->setArg(idx++, out_channels_per_group);
CL_CHECK_FATAL(status);
VLOG(4) << "in_per_group: " << in_channels_per_group
<< ", out_per_group: " << out_channels_per_group;
}
}

void ConvTransposeImageCompute::Run() {
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/opencl/conv_transpose_image_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ConvTransposeImageCompute : public KernelLite<TARGET(kOpenCL),
int filter_image_h_{-1};
int filter_image_w_{-1};

bool is_group_conv_{false};
DDim last_input_dims_{};
bool is_first_epoch_for_run_{true};

Expand Down
13 changes: 3 additions & 10 deletions lite/tests/unittest_py/op/test_conv2d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def sample_program_configs(self, draw):
output_size = []
groups = draw(st.integers(min_value=1, max_value=input_c))
assume(filter_c % groups == 0)
assume(filter_m >= groups)
assume(filter_m >= groups and filter_m % groups == 0)
assume(groups != filter_m or groups != filter_c)
data_format = draw(st.sampled_from(['NCHW']))
padding_algorithm = draw(st.sampled_from(['VALID', 'SAME']))
dilations = draw(
Expand Down Expand Up @@ -262,15 +263,7 @@ def sample_predictor_configs(self):
return self.get_predictor_configs(), ["conv2d_transpose"], (atol, rtol)

def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
groups = program_config.ops[0].attrs["groups"]
if predictor_config.target() == TargetType.OpenCL and groups > 1:
return True

self.add_ignore_check_case(
teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT,
"Lite does not support this op in a specific case on opencl. We need to fix it as soon as possible."
)
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=100)
Expand Down

0 comments on commit 48ae3ec

Please sign in to comment.