Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize bilinear interpolation foward #39243

Merged
merged 23 commits into from
Feb 11, 2022
Merged
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
c7b68c8
Merge branch 'develop' of /~https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 26, 2021
0fd630e
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Aug 16, 2021
4bbb33b
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Sep 28, 2021
30a1a89
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Nov 22, 2021
ce3deec
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 21, 2021
925eb06
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 6, 2022
03aa00b
bilinear_fw init
AshburnLee Jan 26, 2022
f53ab6a
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 26, 2022
8830460
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 26, 2022
7c974f1
optimize code
AshburnLee Feb 9, 2022
cc3fd84
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 9, 2022
a6682bc
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 11, 2022
3353b9f
pre-compute linear_interp input index
AshburnLee Feb 11, 2022
f37dea1
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 101 additions & 42 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -416,34 +416,90 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
}
}

template <typename T>
__global__ void KeBilinearInterpNCHWFw(const T* in, const size_t in_img_h,
const size_t in_img_w, T* out,
const size_t out_img_h,
const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w,
const bool align_corners,
const int align_mode) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;

bool align_flag = (align_mode == 0 && !align_corners);
// bilinear sampling by multiple read in_addr and write to out_addr
int in_img_idx = align_flag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

h和w的lambda计算过程基本一样,是不是可以写个函数减少重复代码

Copy link
Contributor Author

@AshburnLee AshburnLee Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. 写成了最小函数,linear- bilinear- trilinear- 都可调用

? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;

int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;

int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;

int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;

// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);

in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}

template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;

bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];

int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];

int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
Expand All @@ -467,28 +523,17 @@ __global__ void KeBilinearInterpFw(
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;

if (data_layout == DataLayout::kNCHW) {
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];

// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];

// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda * in_pos[h_id * in_img_w * num_channels +
w_id * num_channels]);
}
// bilinear interpolation
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
out[tid] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda *
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda *
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
}
}

Expand Down Expand Up @@ -1395,11 +1440,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
thread_num = 512;
}
#endif

KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeBilinearInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_corners, align_mode);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
Expand Down