Skip to content

Commit

Permalink
[bf16] add bf16 kernel: gaussian_random fill_constant fill_any_like (#…
Browse files Browse the repository at this point in the history
…40027)

* add gaussian random

* add full

* refine reduce

* refine code

* refine gaussian_random unittest

* add unittest for fill_any_like fill_constant
  • Loading branch information
zhangbo9674 authored Mar 7, 2022
1 parent fd36ede commit 6a0d60d
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 14 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ struct GaussianGenerator {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_));
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/kernels/funcs/distribution_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/hostdevice.h"
Expand Down Expand Up @@ -255,11 +256,13 @@ __global__ void DistributionKernel(size_t size,
using SType = hiprandStatePhilox4_32_10_t;
#endif
size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
T args[kCount];
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT args[kCount];
T result[kCount];
for (size_t i = idx; i < size; i += total_thread * kCount) {
kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(
kps::ElementwiseRandom<SType, MT, kCount, 1, DistOp>(
&args[0], dist, &state);
kps::ElementwiseUnary<MT, T, kCount, 1, 1, TransformOp>(
&result[0], &args[0], trans);
kps::WriteData<T, T, kCount, 1, 1, true>(
out_data + i, &result[0], size - i, 1, stride, 1);
Expand Down
10 changes: 7 additions & 3 deletions paddle/phi/kernels/gpu/full_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ void FullLikeKernel(const Context& dev_ctx,
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
float,
T>::type>::type;
typename std::conditional<
std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value,
float,
T>::type>::type;

auto common_type_value = static_cast<CommonType>(value);

Expand Down Expand Up @@ -110,6 +112,7 @@ PD_REGISTER_KERNEL(full,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

Expand All @@ -123,6 +126,7 @@ PD_REGISTER_KERNEL(full_like,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
13 changes: 8 additions & 5 deletions paddle/phi/kernels/gpu/gaussian_random_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand All @@ -46,8 +46,9 @@ struct GaussianGenerator {
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_));
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
Expand Down Expand Up @@ -83,9 +84,10 @@ void GaussianRandomKernel(const Context& dev_ctx,

if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::normal_transform<MT> trans(static_cast<MT>(mean),
static_cast<MT>(std));
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
Expand All @@ -110,5 +112,6 @@ PD_REGISTER_KERNEL(gaussian_random,
ALL_LAYOUT,
phi::GaussianRandomKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#endif

#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
// #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"

namespace phi {
Expand Down
21 changes: 20 additions & 1 deletion python/paddle/fluid/tests/unittests/test_fill_any_like_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle.compat as cpt
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16


class TestFillAnyLikeOp(OpTest):
Expand All @@ -47,6 +47,25 @@ def init(self):
self.value = 0.0


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFillAnyLikeOpBfloat16(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.dtype = np.uint16
self.value = 0.0
self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)}
self.attrs = {'value': self.value, 'dtype': core.VarDesc.VarType.BF16}
self.outputs = {
'Out':
convert_float_to_uint16(self.value * np.ones_like(self.inputs["X"]))
}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)


class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self):
self.value = 1.0
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fill_constant_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ def test_check_output(self):
self.check_output()


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFillConstantBF16Op(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
'''
self.op_type = "fill_constant"
self.dtype = np.uint16
self.inputs = {}
self.attrs = {
'shape': [123, 92],
'value': 3.8,
'dtype': core.VarDesc.VarType.BF16
}
self.outputs = {'Out': convert_float_to_uint16(np.full((123, 92), 3.8))}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)


class TestFillConstantOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
Expand Down
46 changes: 45 additions & 1 deletion python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float
import paddle


Expand Down Expand Up @@ -65,6 +65,50 @@ def verify_output(self, outs):
"hist: " + str(hist) + " hist2: " + str(hist2))


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestGaussianRandomBF16Op(OpTest):
def setUp(self):
self.op_type = "gaussian_random"
self.set_attrs()
self.inputs = {}
self.use_mkldnn = False
self.attrs = {
"shape": [123, 92],
"mean": self.mean,
"std": self.std,
"seed": 10,
"dtype": paddle.fluid.core.VarDesc.VarType.BF16,
"use_mkldnn": self.use_mkldnn
}
paddle.seed(10)

self.outputs = {'Out': np.zeros((123, 92), dtype='float32')}

def set_attrs(self):
self.mean = 1.0
self.std = 2.

def test_check_output(self):
self.check_output_with_place_customized(
self.verify_output, place=core.CUDAPlace(0))

def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (123, 92))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float32")
hist /= float(outs[0].size)
data = np.random.normal(size=(123, 92), loc=1, scale=2)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float32")
hist2 /= float(outs[0].size)
self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.05),
"hist: " + str(hist) + " hist2: " + str(hist2))


class TestMeanStdAreInt(TestGaussianRandomOp):
def set_attrs(self):
self.mean = 1
Expand Down

0 comments on commit 6a0d60d

Please sign in to comment.