Skip to content

Commit

Permalink
[OpenCL]support opencl expand (PaddlePaddle#8078)
Browse files Browse the repository at this point in the history
* [OpenCL]support opencl expand test=develop

* rm some notes test=develop

* rm test_expand_image_opencl  test=document_fix

* test  test=document_fix

* test  test=develop
  • Loading branch information
daming5432 authored and WeiLi233 committed Mar 29, 2022
1 parent 09d6e54 commit ebe9c53
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
16 changes: 7 additions & 9 deletions lite/backends/opencl/cl_kernel/image/expand_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,23 @@ __kernel void expend_cn(__private const int OUT_C,
__private const int output_height,

__read_only image2d_t input,
__write_only image2d_t output,
__private const int n_times,
__private const int c_times,
__private const int h_times,
__private const int w_times) {
__write_only image2d_t output) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);

if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}

const int IN_N = IN_NH / input_height;
const int OUT_N = OUT_NH / output_height;
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
const int in_c = out_c;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_w = out_w % input_width;
const int in_h = out_h % input_height;
const int in_n = out_n % IN_N;

const int in_nh = in_n * input_height + in_h;

int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
Expand Down
4 changes: 2 additions & 2 deletions lite/kernels/opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc
lite_cc_test(test_pixel_shuffle_image_opencl SRCS pixel_shuffle_image_compute_test.cc
DEPS kernels core)

lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc
DEPS kernels core)
#lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc
# DEPS kernels core)

#lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc
# DEPS kernels core)
Expand Down
25 changes: 2 additions & 23 deletions lite/kernels/opencl/expand_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,7 @@ class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
CHECK(expand_times.size() == 4)
<< "expand image now only support in_expand_timesdims size 4";
CHECK(expand_times[1] == 1) << "expand image do not support expend c now";

// do not confuse with these cases.it is use to support expend c in future
if (in_dims[1] == 1) {
kernel_func_name_ = "expend_c1";
} else if (in_dims[1] == 2) {
kernel_func_name_ = "expend_c2";
} else if (in_dims[1] == 3) {
kernel_func_name_ = "expend_c3";
} else if (in_dims[1] == 4) {
kernel_func_name_ = "expend_c4";
} else {
kernel_func_name_ = "expend_cn";
}

kernel_func_name_ = "expend_cn";
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
Expand Down Expand Up @@ -87,6 +74,7 @@ class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
first_epoch_for_reinit_) {
last_x_dims_ = x_dims;
first_epoch_for_reinit_ = false;

// compute image shape
paddle::lite::CLImageConverterDefault default_convertor;
out_img_shape_ = default_convertor.InitImageDimInfoWith(out_dims);
Expand Down Expand Up @@ -161,15 +149,6 @@ class ExpandComputeImage2D : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(11, *out_img);
CL_CHECK_FATAL(status);

status = kernel.setArg(12, expand_times_n);
CL_CHECK_FATAL(status);
status = kernel.setArg(13, expand_times_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(14, expand_times_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(15, expand_times_w);
CL_CHECK_FATAL(status);

status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
Expand Down
38 changes: 33 additions & 5 deletions lite/tests/unittest_py/op/test_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def __init__(self, *args, **kwargs):
Place(TargetType.Host, PrecisionType.FP32, DataLayoutType.NCHW)
]
self.enable_testing_on_place(thread=[1, 4], places=host_places)
opencl_places = [
Place(TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
Place(TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=opencl_places)

def is_program_valid(self,
program_config: ProgramConfig,
Expand All @@ -47,10 +61,20 @@ def is_program_valid(self,
return True

def sample_program_configs(self, draw):
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=3, max_size=4))
if self.get_target() == "OpenCL":
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=4,
max_size=4))
else:
in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=2,
max_size=4))
expand_shape = draw(
st.lists(
st.integers(
Expand Down Expand Up @@ -108,6 +132,10 @@ def gnerate_inputs(with_tensor):
min_value=1, max_value=8),
min_size=len(in_shape),
max_size=len(in_shape)))
if self.get_target() == "OpenCL":
with_tensor = False
attr_shape[1] = 1

inputs = gnerate_inputs(with_tensor)
expand_op = OpConfig(
type="expand",
Expand All @@ -129,7 +157,7 @@ def add_ignore_pass_case(self):
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=300)
self.run_and_statis(quant=False, max_examples=100)


if __name__ == "__main__":
Expand Down

0 comments on commit ebe9c53

Please sign in to comment.