Skip to content

Commit

Permalink
[OpenCL]instance_norm support fp32 (PaddlePaddle#8021)
Browse files Browse the repository at this point in the history
* instance_norm support fp32 test=develop
  • Loading branch information
sprouteer authored and WeiLi233 committed Mar 29, 2022
1 parent 2310afa commit 09d6e54
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 180 deletions.
54 changes: 34 additions & 20 deletions lite/backends/opencl/cl_kernel/image/instance_norm_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,27 @@ __kernel void instance_norm(__private const int in_width,
const int local_total_size = local_work_size_x * local_work_size_y;

#ifdef LOCAL_MEM_128
__local float4 shared_mem[128];
__local CL_COMPUTE_DTYPE4 shared_mem[128];
#elif defined(LOCAL_MEM_64)
__local float4 shared_mem[64];
__local CL_COMPUTE_DTYPE4 shared_mem[64];
#else
__local float4 shared_mem[256];
__local CL_COMPUTE_DTYPE4 shared_mem[256];
#endif

int xOffset = c * in_width;
int yOffset = n * in_height;
float4 sum = 0.0f;

CL_COMPUTE_DTYPE4 sum = (CL_COMPUTE_DTYPE4)(0.0f);
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
sum += read_imagef(
input, SAMPLER, (int2)(xOffset + xIndex, yOffset + yIndex));
sum += READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR,
input,
SAMPLER,
(int2)(xOffset + xIndex, yOffset + yIndex));
}
}
shared_mem[local_id] = sum;

barrier(CLK_LOCAL_MEM_FENCE);

sum = 0.0f;
if (local_id < 32) {
for (int i = local_id + 32; i < local_total_size; i += 32) {
Expand All @@ -74,16 +75,18 @@ __kernel void instance_norm(__private const int in_width,

barrier(CLK_LOCAL_MEM_FENCE);

const float4 mean_val = shared_mem[0];
const CL_COMPUTE_DTYPE4 mean_val = shared_mem[0];

barrier(CLK_LOCAL_MEM_FENCE);

sum = 0.0f;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
float4 temp =
read_imagef(
input, SAMPLER, (int2)(xOffset + xIndex, yOffset + yIndex)) -
CL_COMPUTE_DTYPE4 temp =
READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR,
input,
SAMPLER,
(int2)(xOffset + xIndex, yOffset + yIndex)) -
mean_val;
sum += temp * temp;
}
Expand Down Expand Up @@ -113,22 +116,33 @@ __kernel void instance_norm(__private const int in_width,

barrier(CLK_LOCAL_MEM_FENCE);

const float4 sigma = sqrt(shared_mem[0] + (float4)(epsilon));
const CL_COMPUTE_DTYPE4 sigma =
sqrt(shared_mem[0] + (CL_COMPUTE_DTYPE4)(epsilon));

CL_COMPUTE_DTYPE4 s = 1 / sigma;

CL_COMPUTE_DTYPE4 vscale =
READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, scale, SAMPLER, (int2)(c, n));
CL_COMPUTE_DTYPE4 vbias =
READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, bias, SAMPLER, (int2)(c, n));

float4 s = 1 / sigma;
float4 vscale = read_imagef(scale, SAMPLER, (int2)(c, n * in_c_group));
float4 vbias = read_imagef(bias, SAMPLER, (int2)(c, n * in_c_group));
vscale *= s;

for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
int2 intout_pos = (int2)(xOffset + xIndex, yOffset + yIndex);
float4 in_val = read_imagef(input, SAMPLER, intout_pos);
half4 out_val = convert_half4((in_val - mean_val) * vscale + vbias);
CL_COMPUTE_DTYPE4 in_val =
READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, SAMPLER, intout_pos);
CL_COMPUTE_DTYPE4 output0 = (in_val - mean_val) * vscale + vbias;
CL_DTYPE4 out_val;
out_val.x = CONVERT_TYPE_TO(output0.x, CL_DTYPE);
out_val.y = CONVERT_TYPE_TO(output0.y, CL_DTYPE);
out_val.z = CONVERT_TYPE_TO(output0.z, CL_DTYPE);
out_val.w = CONVERT_TYPE_TO(output0.w, CL_DTYPE);
#ifdef RELU
out_val = max((half4)(0.0f, 0.0f, 0.0f, 0.0f), out_val);
out_val = max((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), out_val);
#endif
write_imageh(output, intout_pos, out_val);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, intout_pos, out_val);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions lite/kernels/opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ add_kernel(matmul_opencl_image OPENCL basic SRCS matmul_image_compute.cc)
######################
# image kernel test #
######################

lite_cc_test(test_gather_image_opencl SRCS gather_image_compute_test.cpp
DEPS kernels core)

Expand Down Expand Up @@ -127,8 +128,8 @@ lite_cc_test(test_bilinear_interp_image_opencl SRCS bilinear_interp_image_comput
#lite_cc_test(test_slice_image_opencl SRCS slice_image_compute_test.cc
# DEPS kernels core)

#lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
# DEPS kernels core)
lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
DEPS kernels core)

lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc
DEPS kernels core)
Expand Down
74 changes: 55 additions & 19 deletions lite/kernels/opencl/instance_norm_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),

// TODO(ysh329): add instance_norm + relu pass
// std::string build_options_ += "-DRELU";
const bool enable_fp16 =
CLRuntime::Global()->get_precision() == lite_api::CL_PRECISION_FP16;
if (enable_fp16) {
build_options_ += " -DCL_DTYPE_half -DCL_DTYPE_FLOAT_FORCE ";
}
if (out_h == 128) {
build_options_ += " -DLOCAL_MEM_128";
} else if (out_h == 64) {
Expand All @@ -75,29 +80,60 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
int cgroup = (channel + 3) / 4;
int cround = cgroup * 4;

std::vector<half_t> scale_img(cround * batch);
std::vector<half_t> bias_img(cround * batch);
const float* scale_data = instance_norm_param_->scale->data<float>();
const float* bias_data = instance_norm_param_->bias->data<float>();

for (int i = 0; i < channel; ++i) {
scale_img[i] = Float2Half(scale_data[i]);
bias_img[i] = Float2Half(bias_data[i]);
}
std::vector<float> scale_img(cround * batch);
std::vector<float> bias_img(cround * batch);

std::vector<half_t> scale_img_h(cround * batch);
std::vector<half_t> bias_img_h(cround * batch);

for (int i = 1; i < batch; ++i) {
memcpy(scale_img.data() + i * cround,
scale_img.data(),
cround * sizeof(half_t));
memcpy(bias_img.data() + i * cround,
bias_img.data(),
cround * sizeof(half_t));
}
DDim scale_img_size{{ cgroup, batch }};
MUTABLE_DATA_GPU(
&scale_image_, scale_img_size[0], scale_img_size[1], scale_img.data());
MUTABLE_DATA_GPU(
&bias_image_, scale_img_size[0], scale_img_size[1], bias_img.data());

if (enable_fp16) {
for (int i = 0; i < channel; ++i) {
scale_img_h[i] = Float2Half(scale_data[i]);
bias_img_h[i] = Float2Half(bias_data[i]);
}

for (int i = 1; i < batch; ++i) {
memcpy(scale_img_h.data() + i * cround,
scale_img_h.data(),
cround * sizeof(half_t));
memcpy(bias_img_h.data() + i * cround,
bias_img_h.data(),
cround * sizeof(half_t));
}
MUTABLE_DATA_GPU(&scale_image_,
scale_img_size[0],
scale_img_size[1],
scale_img_h.data());
MUTABLE_DATA_GPU(&bias_image_,
scale_img_size[0],
scale_img_size[1],
bias_img_h.data());
} else {
for (int i = 0; i < channel; ++i) {
scale_img[i] = scale_data[i];
bias_img[i] = bias_data[i];
}

for (int i = 1; i < batch; ++i) {
memcpy(scale_img.data() + i * cround,
scale_img.data(),
cround * sizeof(float));
memcpy(bias_img.data() + i * cround,
bias_img.data(),
cround * sizeof(float));
}
MUTABLE_DATA_GPU(&scale_image_,
scale_img_size[0],
scale_img_size[1],
scale_img.data());
MUTABLE_DATA_GPU(
&bias_image_, scale_img_size[0], scale_img_size[1], bias_img.data());
}
}

void ReInitWhenNeeded() override {
Expand Down Expand Up @@ -182,6 +218,7 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),

status = EnqueueNDRangeKernel(
context, kernel_, cl::NullRange, gws_, lws_, nullptr, event_);

CL_CHECK_FATAL(status);
}

Expand Down Expand Up @@ -324,7 +361,6 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
std::string time_stamp_{GetTimeStamp()};
cl::Kernel kernel_;
cl::NDRange gws_, lws_;

Tensor scale_image_;
Tensor bias_image_;
};
Expand Down
Loading

0 comments on commit 09d6e54

Please sign in to comment.