Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenCL][Kernel]support group conv_transpose for opencl test=develop #8494

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions lite/backends/opencl/cl_kernel/image/conv2d_transpose_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,258 @@ __kernel void conv2d_transpose(

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, out_pos, out0);
}

__kernel void group_conv2d_transpose(
__private const int global_size_dim0, // (out_c + 3) / 4
__private const int global_size_dim1, // out_w
__private const int global_size_dim2, // out_n * out_h
__read_only image2d_t input,
__read_only image2d_t filter,
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int2 input_shape,
__private const int2 output_shape,
__private const int2 stride_shape,
__private const int2 align_shape,
__private const int2 padding_shape,
__private const int2 kernel_shape,
__private const int2 dilation_shape,
__private const int2 kernel_prev_shape,
__private const int kernel_size,
__private const int input_c_blks,
__private const int in_channels_per_group,
__private const int out_channels_per_group) {
const int out_c_blk_idx = get_global_id(0);
const int out_w_idx = get_global_id(1);
const int out_nh_idx = get_global_id(2);

if (out_c_blk_idx >= global_size_dim0 || out_w_idx >= global_size_dim1 ||
out_nh_idx >= global_size_dim2) {
return;
}

int2 out_pos = (int2)(out_c_blk_idx * output_shape.x + out_w_idx, out_nh_idx);

#ifdef BIASE_CH
CL_DTYPE4 out0 =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, SAMPLER, (int2)(out_c_blk_idx, 0));
#else
CL_DTYPE4 out0 = 0.f;
#endif

const int out_n_idx = out_nh_idx / output_shape.y;
const int out_h_idx = out_nh_idx % output_shape.y;

int kernel_start_x = max(0, (out_w_idx + align_shape.x) / stride_shape.x);
int kernel_start_y = max(0, (out_h_idx + align_shape.y) / stride_shape.y);
int valid_kernel_width =
kernel_shape.x - mad24(kernel_start_x, stride_shape.x, padding_shape.x) +
out_w_idx - 1;
int valid_kernel_height =
kernel_shape.y - mad24(kernel_start_y, stride_shape.y, padding_shape.y) +
out_h_idx - 1;

CL_DTYPE4 in0;
CL_DTYPE4 weights0;
for (int o = 0; o < 4; ++o) {
int group_id = (out_c_blk_idx * 4 + o) / out_channels_per_group;
int remain =
(out_c_blk_idx * 4 + o - group_id * out_channels_per_group) % 4;
for (int ic = group_id * in_channels_per_group;
ic < (group_id + 1) * in_channels_per_group;
++ic) {
int in_idx = mul24(ic / 4, input_shape.x);
for (int k_y = valid_kernel_height, idx_h = kernel_start_y; k_y >= 0;
k_y -= stride_shape.y, idx_h++) {
int in_y_idx = mad24(
out_n_idx, input_shape.y, idx_h); // height idx of input image2d
int in_nh_value =
select(in_y_idx, -1, idx_h < 0 || idx_h >= input_shape.y);
int in_width0 = kernel_start_x;
for (int k_x = valid_kernel_width; k_x >= 0; k_x -= stride_shape.x) {
int in_width_value0 = in_width0;
in_width_value0 =
select(in_idx + in_width_value0,
-1,
(in_width_value0 < 0 || in_width_value0 >= input_shape.x));
in0 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input,
SAMPLER,
(int2)(in_width_value0, in_nh_value));
if (k_x % dilation_shape.x == 0 && k_y % dilation_shape.y == 0) {
int kernel_y_0 = ic * kernel_prev_shape.y + k_y / dilation_shape.y;
int kernel_x_0 =
(((out_c_blk_idx * 4 + o) % out_channels_per_group) / 4) *
kernel_prev_shape.x +
k_x / dilation_shape.x;
weights0 = READ_IMG_TYPE(
CL_DTYPE_CHAR, filter, SAMPLER, (int2)(kernel_x_0, kernel_y_0));
if (ic % 4 == 0) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.x, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.x, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.x, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.x, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.x, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.x, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.x, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.x, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.x, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.x, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.x, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.x, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.x, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.x, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.x, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.x, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 1) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.y, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.y, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.y, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.y, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.y, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.y, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.y, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.y, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.y, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.y, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.y, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.y, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.y, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.y, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.y, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.y, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 2) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.z, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.z, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.z, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.z, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.z, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.z, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.z, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.z, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.z, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.z, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.z, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.z, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.z, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.z, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.z, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.z, weights0.w, out0.w) : out0.w;
}
} else if (ic % 4 == 3) {
if (o == 0) {
out0.x =
(remain == 0) ? mad(in0.w, weights0.x, out0.x) : out0.x;
out0.x =
(remain == 1) ? mad(in0.w, weights0.y, out0.x) : out0.x;
out0.x =
(remain == 2) ? mad(in0.w, weights0.z, out0.x) : out0.x;
out0.x =
(remain == 3) ? mad(in0.w, weights0.w, out0.x) : out0.x;
} else if (o == 1) {
out0.y =
(remain == 0) ? mad(in0.w, weights0.x, out0.y) : out0.y;
out0.y =
(remain == 1) ? mad(in0.w, weights0.y, out0.y) : out0.y;
out0.y =
(remain == 2) ? mad(in0.w, weights0.z, out0.y) : out0.y;
out0.y =
(remain == 3) ? mad(in0.w, weights0.w, out0.y) : out0.y;
} else if (o == 2) {
out0.z =
(remain == 0) ? mad(in0.w, weights0.x, out0.z) : out0.z;
out0.z =
(remain == 1) ? mad(in0.w, weights0.y, out0.z) : out0.z;
out0.z =
(remain == 2) ? mad(in0.w, weights0.z, out0.z) : out0.z;
out0.z =
(remain == 3) ? mad(in0.w, weights0.w, out0.z) : out0.z;
} else if (o == 3) {
out0.w =
(remain == 0) ? mad(in0.w, weights0.x, out0.w) : out0.w;
out0.w =
(remain == 1) ? mad(in0.w, weights0.y, out0.w) : out0.w;
out0.w =
(remain == 2) ? mad(in0.w, weights0.z, out0.w) : out0.w;
out0.w =
(remain == 3) ? mad(in0.w, weights0.w, out0.w) : out0.w;
}
}
} else {
weights0 = (CL_DTYPE4)(0.0f);
}
in_width0++;
}
}
}
}
out0 = activation_type4(out0, 0.f);

#ifdef SCALE_ACTIVATION
out0 = fuse_scale(out0, 1.f, 0.f, 0.f);
#endif

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, out_pos, out0);
}
31 changes: 31 additions & 0 deletions lite/kernels/opencl/conv_transpose_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void ConvTransposeImageCompute::PrepareForRun() {
output_tensor_w_ = output_dims[3];

auto filter_dims = conv_param_->filter->dims();
filter_tensor_n_ = filter_dims[0];
filter_tensor_c_ = filter_dims[1];
filter_tensor_h_ = filter_dims[2];
filter_tensor_w_ = filter_dims[3];
Expand Down Expand Up @@ -104,6 +105,26 @@ void ConvTransposeImageCompute::PrepareForRun() {
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
auto* filter_image_data = MUTABLE_DATA_CPU(tensor_hold_filter_image_);

converter.NCHWToImage(reinterpret_cast<float*>(filter_cpu),
filter_image_data,
filter_trans_dims);
MUTABLE_DATA_GPU(
filter_gpu_image_, filter_image_w_, filter_image_h_, filter_image_data);
} else if (groups_ > 1) {
CHECK_EQ(filter_tensor_n_ % groups_, 0);
std::string kernel_name = "group_conv2d_transpose";
is_group_conv_ = true;
kernel_func_names_.push_back(kernel_name);

DDimLite filter_trans_dims{
{filter_dims[0], filter_dims[1], filter_dims[2], filter_dims[3]}};
CLImageConverterDefault converter;
const DDim& filter_image_dims =
converter.InitImageDimInfoWith(filter_trans_dims);
filter_image_w_ = filter_image_dims[0];
filter_image_h_ = filter_image_dims[1];
tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4});
auto* filter_image_data = MUTABLE_DATA_CPU(tensor_hold_filter_image_);
converter.NCHWToImage(reinterpret_cast<float*>(filter_cpu),
filter_image_data,
filter_trans_dims);
Expand Down Expand Up @@ -322,6 +343,16 @@ void ConvTransposeImageCompute::SetArgs() {
CL_CHECK_FATAL(status);
kernel->setArg(idx++, static_cast<int32_t>(maptofactor(input_tensor_c_, 4)));
CL_CHECK_FATAL(status);
if (is_group_conv_) {
int in_channels_per_group = input_tensor_c_ / groups_;
kernel->setArg(idx++, in_channels_per_group);
CL_CHECK_FATAL(status);
int out_channels_per_group = output_tensor_c_ / groups_;
kernel->setArg(idx++, out_channels_per_group);
CL_CHECK_FATAL(status);
VLOG(4) << "in_per_group: " << in_channels_per_group
<< ", out_per_group: " << out_channels_per_group;
}
}

void ConvTransposeImageCompute::Run() {
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/opencl/conv_transpose_image_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ConvTransposeImageCompute : public KernelLite<TARGET(kOpenCL),
int filter_image_h_{-1};
int filter_image_w_{-1};

bool is_group_conv_{false};
DDim last_input_dims_{};
bool is_first_epoch_for_run_{true};

Expand Down
15 changes: 4 additions & 11 deletions lite/tests/unittest_py/op/test_conv2d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def sample_program_configs(self, draw):
output_size = []
groups = draw(st.integers(min_value=1, max_value=input_c))
assume(filter_c % groups == 0)
assume(filter_m >= groups)
assume(filter_m >= groups and filter_m % groups == 0)
assume(groups != filter_m or groups != filter_c)
data_format = draw(st.sampled_from(['NCHW']))
padding_algorithm = draw(st.sampled_from(['VALID', 'SAME']))
dilations = draw(
Expand Down Expand Up @@ -262,18 +263,10 @@ def sample_predictor_configs(self):
return self.get_predictor_configs(), ["conv2d_transpose"], (atol, rtol)

def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
groups = program_config.ops[0].attrs["groups"]
if predictor_config.target() == TargetType.OpenCL and groups > 1:
return True

self.add_ignore_check_case(
teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT,
"Lite does not support this op in a specific case on opencl. We need to fix it as soon as possible."
)
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=100)
self.run_and_statis(quant=False, max_examples=150)


if __name__ == "__main__":
Expand Down