Skip to content

Commit

Permalink
Move transpose to pten (#39327)
Browse files Browse the repository at this point in the history
* immigrate_transpose_to_pten cpu kernel only; test=develop

* fix bug; test=develop

* add transpose cuda api

* bug fix;

* fix bugs

* fix bugs; test=develop

* bug fix;

* move transepose to pten; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* add transpose grad fp16 support; test=develop

* fix bug; test=develop

* fix npu bug; test=develop

* fix nemul = 0 bug; test=develop

* add fp16 support; test=develop

* fix data type register bug; test=develop

* fix transpose bug; test=develop

* update transpose

* fix transpose bug; test=develop

* remove useless code; test=develop

* remove useless code; test=develop

* fix transpose alias bug; test=develop

* polish code; test=develop

* resolve confict; test=develop

* resolve confilct; test=develop

* recover prepared operator; test=develop

* fix bug; test=develop

* polish code; test=develop

* fix bug; test=develop

* fix bug; test=develop
  • Loading branch information
phlrain authored Mar 2, 2022
1 parent 2a5590a commit 7a85792
Show file tree
Hide file tree
Showing 20 changed files with 426 additions and 270 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ USE_OP(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(transpose);
USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN);

namespace paddle {
Expand Down
60 changes: 11 additions & 49 deletions paddle/fluid/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
}
};

class TransposeGradInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType(framework::GradVarName("Out"),
framework::GradVarName("X"));
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -347,59 +355,13 @@ REGISTER_OPERATOR(
transpose, ops::TransposeOp, ops::TransposeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad,
ops::TransposeGradInferVarType);

REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
ops::Transpose2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
ops::TransposeGradInferVarType,
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
139 changes: 0 additions & 139 deletions paddle/fluid/operators/transpose_op.cu

This file was deleted.

42 changes: 21 additions & 21 deletions paddle/fluid/operators/transpose_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ limitations under the License. */

#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -258,10 +259,10 @@ struct SystemElemType<16> {
};

template <typename T, int tile_long, int tile_short>
void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d,
int tile_size_i, int tile_size_j,
int total_tiles_count, const T* input,
const Dim3& input_dims, T* output) {
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims,
T* output) {
constexpr int NumThreads = tile_long;
if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
TilingSwapDim1And2<
Expand All @@ -278,7 +279,7 @@ void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d,

template <typename T, int tile_long, int tile_short, typename dummy = void>
struct NarrowDims2TransposeDispatch {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -319,7 +320,7 @@ struct NarrowDims2TransposeDispatch<
T, tile_long, tile_short,
typename std::enable_if<
CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -351,7 +352,7 @@ struct NarrowDims2TransposeDispatch<
T, tile_long, tile_short,
typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
void>::type> {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
Expand All @@ -368,7 +369,7 @@ struct NarrowDims2TransposeDispatch<
};

template <typename T, bool conjugate = false>
void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input,
void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input,
const Dim3& input_dims, T* output,
const int kMinTileSize) {
// First get available tile sizes for the data type requested as backups
Expand Down Expand Up @@ -473,9 +474,8 @@ __global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input,

// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template <typename T>
void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,
const T* input, const Dim3& input_dims,
T* output) {
void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input,
const Dim3& input_dims, T* output) {
// Suppose tile size > 16
static const int kMinTileSize = 16;
static const int kMinNarrowTileSize = 96;
Expand Down Expand Up @@ -512,7 +512,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,
} else {
// If input shape is small, such as 8X8, just do simple copy
int total_elements = input_dims[0] * input_dims[1] * input_dims[2];
auto config = GetGpuLaunchConfig1D(d, total_elements);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
TransposeSimpleKernel<T, 0, 2, 1><<<
config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_elements, input, input_dims, output);
Expand All @@ -521,7 +521,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,

template <typename T>
struct SwapDim1And2InTranspose {
typedef platform::CUDADeviceContext Device;
typedef phi::GPUContext Device;
void operator()(const Device& d, const T* in,
const std::vector<int>& combined_dims, T* out) {
Dim3 input_dims = {static_cast<int>(combined_dims[0]),
Expand All @@ -533,15 +533,15 @@ struct SwapDim1And2InTranspose {

template <typename T>
struct SwapDim0And2InTranspose {
typedef platform::CUDADeviceContext Device;
typedef phi::GPUContext Device;
void operator()(const Device& d, const T* in,
const std::vector<int>& combined_dims, T* out) {
Dim3 input_dims = {static_cast<int>(combined_dims[0]),
static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])};

size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
auto config = GetGpuLaunchConfig1D(d, total_size);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);

TransposeSimpleKernel<T, 2, 1, 0><<<
config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
Expand Down Expand Up @@ -607,7 +607,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,

template <typename T>
struct TransposeSimple {
static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in,
static bool run(const phi::GPUContext& ctx, const Tensor& in,
const std::vector<int32_t> perm, Tensor* out) {
// First reduce the dimensions of the input tensor if possible.
std::vector<int> new_perm;
Expand Down Expand Up @@ -654,12 +654,12 @@ struct TransposeSimple {
};

template <typename T>
void TransposeGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int ndims, const Tensor& in,
const std::vector<int32_t> perm, Tensor* out) {
void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims,
const Tensor& in,
const std::vector<int32_t>& perm, Tensor* out) {
auto ret = TransposeSimple<T>::run(dev_ctx, in, perm, out);
if (!ret) {
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, in, out, perm);
TransCompute<phi::GPUContext, T>(ndims, dev_ctx, in, out, perm);
}
}

Expand Down
Loading

0 comments on commit 7a85792

Please sign in to comment.