Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten] add cuda implement of cast kernel #37610

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions paddle/pten/kernels/cuda/manipulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cuda/manipulation.h"
#include "paddle/pten/kernels/cuda/utils.h"
#include "paddle/pten/kernels/functions/cuda/cast_kernel_impl.h"
#include "paddle/pten/kernels/functions/general/manipulation.h"
#include "paddle/pten/kernels/functions/math/cast_func.h"

namespace pten {

Expand Down Expand Up @@ -123,8 +123,7 @@ void Cast(const CUDAContext& dev_ctx,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
math::CastKernelImpl<CUDAContext, T, data_t>(
dev_ctx, x, out);
detail::CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
}));
}

Expand Down Expand Up @@ -158,23 +157,32 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t,
int,
int64_t) {}
// todo: Hip need support bfloat16
PT_REGISTER_KERNEL("cast",
CUDA,
ANY,
pten::Cast,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \
CUDA, \
ANY, \
pten::Cast, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}

#if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
#else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif

PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
CUDA,
Expand Down
79 changes: 79 additions & 0 deletions paddle/pten/kernels/functions/cuda/cast_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/dense_tensor.h"

#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace pten {
namespace detail {
using CUDAContext = paddle::platform::CUDADeviceContext;

template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议copy代码过来,麻烦确保只有一份代码维护,否则原实现优化后,这里没更新,又会有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后续我再重提PR修复一下

using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;

int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx * VecSize; i < N;
i += blockDim.x * gridDim.x * VecSize) {
LoadT in_val;
paddle::platform::Load<InT, VecSize>(&in[i], &in_val);

StoreT out_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_val[j] = static_cast<OutT>(in_val[j]);
}

paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
}
}

template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}

template <typename InT, typename OutT>
void CastCUDAKernelImpl(const CUDAContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();

paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
in_data, size, out_data);
} else {
CastCUDAKernel<InT, OutT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(in_data, size, out_data);
}
}

} // namespace detail

} // namespace pten