From 93366372a43ec3c775f9ca1dd1905e8dc0cb3f70 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 10 May 2023 10:44:16 +0800 Subject: [PATCH] Fix the index calculation in cross_entroy_kernel. (#53659) --- paddle/phi/kernels/gpu/cross_entropy_kernel.cu | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index a223cd7c738347..5e5ddec9912cc1 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/phi/kernels/cross_entropy_kernel.h" +#include "glog/logging.h" + #ifdef __NVCC__ #include "cub/cub.cuh" #endif @@ -468,8 +470,8 @@ __global__ void VectorizedSoftmaxForward(T* loss, using VecT = kps::details::VectorType; // each block deal with one batch - logits += blockIdx.x * mid_dim; - softmax += blockIdx.x * mid_dim; + logits += static_cast(blockIdx.x) * static_cast(mid_dim); + softmax += static_cast(blockIdx.x) * static_cast(mid_dim); const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T); const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T); @@ -1165,6 +1167,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, int dim, int D, const int ignore_index) { + VLOG(7) << "rank=" << rank << ", axis = " << axis << ", N = " << N + << ", dim = " << dim << ", D = " << D; auto stream = dev_ctx.stream(); constexpr int max_dim = 320; if (D == 1) { @@ -1247,11 +1251,11 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, int axis, DenseTensor* softmax, DenseTensor* loss) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType(), - AllocationType::GPU, - phi::errors::Unavailable("softmax_with_cross_entropy operator's " - "CUDA kernel only runs on GPU device.")); + VLOG(7) << "logits.shape={" << logits.dims() << "}, label.shape={" + << label.dims() << "}, soft_label=" << soft_label + << ", use_softmax=" << use_softmax + << ", numeric_stable_mode=" << numeric_stable_mode + << ", ignore_index=" << ignore_index << ", axis=" << axis; // do not with softmax op, and input is softmax if (!use_softmax) {