From 0936ee1055c792028733dfd0b10280e38493fb24 Mon Sep 17 00:00:00 2001 From: CaiZhi Date: Fri, 22 Sep 2023 11:58:07 +0800 Subject: [PATCH] [MTAI] feat(op): suppport Gaussian op on MUSA (#73) --- paddle/phi/kernels/CMakeLists.txt | 1 - paddle/phi/kernels/gpu/gaussian_kernel.cu | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 0a3a0f82e7c935..4437d34e68c7bb 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -66,7 +66,6 @@ if(WITH_MUSA) "gpu/fft_grad_kernel.cu" "gpu/fft_kernel.cu" "gpu/fused_softmax_mask_grad_kernel.cu" - "gpu/gaussian_kernel.cu" "gpu/gelu_grad_kernel.cu" "gpu/gelu_kernel.cu" "gpu/histogram_kernel.cu" diff --git a/paddle/phi/kernels/gpu/gaussian_kernel.cu b/paddle/phi/kernels/gpu/gaussian_kernel.cu index d0f839bd677d47..9befe4899ddd72 100644 --- a/paddle/phi/kernels/gpu/gaussian_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_kernel.cu @@ -78,6 +78,15 @@ void GaussianKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_MUSA +PD_REGISTER_KERNEL(gaussian, + GPU, + ALL_LAYOUT, + phi::GaussianKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float) {} +#else PD_REGISTER_KERNEL(gaussian, GPU, ALL_LAYOUT, @@ -86,3 +95,4 @@ PD_REGISTER_KERNEL(gaussian, phi::dtype::bfloat16, float, double) {} +#endif