Skip to content

Commit

Permalink
Optimize nearest_interp forward (#38528)
Browse files Browse the repository at this point in the history
* init commit

* remove comments

* remove nchw branch

* optimize code

* apply fast div mod in 1D kernel, rm 3D kernel

* move init of FastDivMode to CPU

* 3D kernel for nchw, FastDiv for 1D kernel

* debug done. process boundary

* 2^n

* optimize

* optimize

* change code & optimize code
  • Loading branch information
AshburnLee authored Jan 25, 2022
1 parent 2bafd33 commit 232bbce
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 29 deletions.
145 changes: 118 additions & 27 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,121 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"

namespace paddle {
namespace operators {

using framework::Tensor;
using platform::FastDivMod;
using DataLayout = framework::DataLayout;

static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}

inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
const platform::CUDADeviceContext& context, int num_img, int height,
int width) {
const int kThreadsPerBlock = 256;
int max_threads_per_block = context.GetMaxThreadsPerBlock(); // 1024
int max_threads = std::min(kThreadsPerBlock, max_threads_per_block);

int block_x = std::min(GetLastPow2(width), max_threads);
int block_y = std::min(GetLastPow2(height), max_threads / block_x);
int block_z = std::min(num_img, max_threads / block_x / block_y);

dim3 max_grid_dim = context.GetCUDAMaxGridDimSize();
int grid_x = std::min<int>(max_grid_dim.x, platform::DivUp(width, block_x));
int grid_y = std::min<int>(max_grid_dim.y, platform::DivUp(height, block_y));
int grid_z =
std::min<int>(max_grid_dim.z, platform::DivUp(num_img, block_z * 4));

const int capability = context.GetComputeCapability();
platform::GpuLaunchConfig config;
config.compute_capability = capability;
config.thread_per_block = dim3(block_x, block_y, block_z);
config.block_per_grid = dim3(grid_x, grid_y, grid_z);
return config;
}

struct FastDivModForInterpolate {
public:
FastDivMod channels_div;
FastDivMod output_w_div;
FastDivMod output_wc_div;

explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
const int outout_wc)
: channels_div(FastDivMod(channels)),
output_w_div(FastDivMod(output_w)),
output_wc_div(FastDivMod(outout_wc)) {}
};

template <typename T>
__global__ void KeNearestNeighborInterpNCHWFw(
const T* in, const size_t in_img_h, const size_t in_img_w, T* out,
const size_t out_img_h, const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w, const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;

// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);

int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;

int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;

// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
out[out_index] = in[in_index];
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}

template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
const bool align_corners, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;

for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];

int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];

int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
Expand All @@ -57,13 +139,8 @@ __global__ void KeNearestNeighborInterpFw(
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);

if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
} else {
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}

Expand Down Expand Up @@ -1292,11 +1369,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);

if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeNearestNeighborInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, interp_divmods);
}
} else if ("bilinear" == interp_method) {
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/platform/device/gpu/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
return config;
}

// TODO(wangchaochaohu): 3D will add later

} // namespace platform
} // namespace paddle

Expand Down

0 comments on commit 232bbce

Please sign in to comment.