diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h index de58c05149a53d..26374ca36007a3 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h +++ b/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -49,6 +50,10 @@ namespace detail { template <> struct radix_key_codec_base : radix_key_codec_integral {}; + +template <> +struct radix_key_codec_base + : radix_key_codec_integral {}; } // namespace detail } // namespace rocprim namespace cub = hipcub; @@ -58,6 +63,12 @@ namespace cub { template <> struct NumericTraits : BaseTraits {}; + +template <> +struct NumericTraits + : BaseTraits { +}; + } // namespace cub #endif @@ -586,6 +597,24 @@ struct RadixTypeConfig { } }; +template <> +struct RadixTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType Convert(phi::dtype::bfloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ phi::dtype::bfloat16 Deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + phi::dtype::bfloat16 r; + r.x = (v ^ mask); + return r; + } +}; + /*---------------------------Helper Functions------------------*/ __device__ __forceinline__ int GetLaneId() { int lane_id; diff --git a/paddle/phi/kernels/gpu/top_k_grad_kernel.cu b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu index 638d53c010ce64..6c2e880e9a9efb 100644 --- a/paddle/phi/kernels/gpu/top_k_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/top_k_grad_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h" @@ -89,4 +89,5 @@ PD_REGISTER_KERNEL(topk_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 6811b3e31db544..e2793955ef9c17 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -15,11 +15,13 @@ #include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h" + namespace phi { #define FIXED_BLOCK_DIM_BASE(dim, ...) \ @@ -348,6 +350,7 @@ PD_REGISTER_KERNEL(topk, double, int, int64_t, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index 5c0f6ff707fb45..27fc92292f36f9 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -17,7 +17,7 @@ import gradient_checker import numpy as np from decorator_helper import prog_scope -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -202,6 +202,56 @@ def test_check_output(self): self.check_output() +# Situation 7: input x is Float16 +class TestExpandV2FP16Op(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.prim_op_type = "prim" + self.dtype = np.float16 + self.python_api = paddle.expand + self.public_python_api = paddle.expand + self.inputs = { + 'X': np.random.randint(10, size=(8, 8, 5)).astype(self.dtype) + } + self.attrs = {'shape': [8, 8, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_prim=True) + + +# Situation 8: input x is BF16 +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestExpandV2BF16Op(OpTest): + def setUp(self): + self.op_type = "expand_v2" + self.prim_op_type = "prim" + self.dtype = np.uint16 + self.python_api = paddle.expand + self.public_python_api = paddle.expand + x = np.random.randint(10, size=(8, 8, 5)).astype(np.float32) + self.inputs = {'X': convert_float_to_uint16(x)} + self.attrs = {'shape': [8, 8, 5]} + output = np.tile(x, (1, 1, 1)).astype(np.float32) + self.outputs = {'Out': convert_float_to_uint16(output)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + + class TestExpandV2Error(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -338,7 +388,7 @@ def test_grad(self): self.func(p) -# Situation 7: comp case, shape is a list(without tensor) +# Situation 9: comp case, shape is a list(without tensor) class TestExpandV2CompOpRank1(OpTest): def setUp(self): self.op_type = "expand_v2" @@ -392,7 +442,7 @@ def init_data(self): self.expand_times = (1, 1, 1, 1) -# Situation 8: comp case, input x is Integer +# Situation 10: comp case, input x is Integer class TestExpandV2CompOpInteger(OpTest): def setUp(self): self.op_type = "expand_v2" @@ -410,7 +460,7 @@ def test_check_output(self): self.check_output(check_prim=True) -# Situation 9: comp case, input x is Bool +# Situation 11: comp case, input x is Bool class TestExpandV2CompOpBoolean(OpTest): def setUp(self): self.op_type = "expand_v2" @@ -426,7 +476,7 @@ def test_check_output(self): self.check_output(check_prim=True) -# Situation 10: comp case, input x is Integer +# Situation 12: comp case, input x is Integer class TestExpandV2CompOpInt64_t(OpTest): def setUp(self): self.op_type = "expand_v2" diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index d64906560dc092..5612703968dad0 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -189,6 +189,56 @@ def setUp(self): self.outputs = {'Out': output, 'Indices': indices} +class TestTopkFP16Op(TestTopkOp): + def setUp(self): + self.op_type = "top_k_v2" + self.python_api = paddle.topk + self.public_python_api = paddle.topk + self.dtype = np.float16 + self.prim_op_type = "prim" + self.input_data = np.random.rand(10, 20).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest + ) + self.outputs = {'Out': output, 'Indices': indices} + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestTopkBF16Op(TestTopkOp): + def setUp(self): + self.op_type = "top_k_v2" + self.python_api = paddle.topk + self.public_python_api = paddle.topk + self.dtype = np.uint16 + self.prim_op_type = "prim" + self.input_data = np.random.rand(10, 20).astype(np.float32) + self.init_args() + self.inputs = {'X': convert_float_to_uint16(self.input_data)} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest + ) + self.outputs = { + 'Out': convert_float_to_uint16(output), + 'Indices': indices, + } + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=True) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, {'X'}, 'Out', check_eager=True) + + class TestTopKAPI(unittest.TestCase): def setUp(self): np.random.seed(123) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 40a83f6dbf3707..09aaff08c3ca5e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3418,7 +3418,15 @@ def expand(x, shape, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + [ + 'bool', + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], 'expand', ) check_type(shape, 'shape', (list, tuple, Variable), 'expand')