Skip to content

Commit

Permalink
[fp16] suppot fp16 in std (#50936)
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 authored Mar 2, 2023
1 parent b0a604c commit d1dd730
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_std_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,30 @@ def test_error(self):
self.assertRaises(TypeError, paddle.std, x)


class Testfp16Std(unittest.TestCase):
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([12, 14]).astype("float16")
x = paddle.static.data(
name="x", shape=[12, 14], dtype="float16"
)

y = paddle.std(x)

exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)


if __name__ == '__main__':
unittest.main()
12 changes: 8 additions & 4 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
Computes the variance of ``x`` along ``axis`` .
Args:
x (Tensor): The input Tensor with data type float32, float64.
x (Tensor): The input Tensor with data type float16, float32, float64.
axis (int|list|tuple, optional): The axis along which to perform variance calculations. ``axis`` should be int, list(int) or tuple(int).
- If ``axis`` is a list/tuple of dimension(s), variance is calculated along all element(s) of ``axis`` . ``axis`` or element(s) of ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
Expand All @@ -145,7 +145,9 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
# [1. 4.33333333]
"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'var')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'var'
)

u = mean(x, axis, True, name)
out = paddle.sum(paddle.pow((x - u), 2), axis, keepdim=keepdim, name=name)
Expand All @@ -168,7 +170,7 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
Computes the standard-deviation of ``x`` along ``axis`` .
Args:
x (Tensor): The input Tensor with data type float32, float64.
x (Tensor): The input Tensor with data type float16, float32, float64.
axis (int|list|tuple, optional): The axis along which to perform
standard-deviation calculations. ``axis`` should be int, list(int)
or tuple(int). If ``axis`` is a list/tuple of dimension(s),
Expand Down Expand Up @@ -211,7 +213,9 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'std')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'std'
)
out = var(**locals())
return paddle.sqrt(out)

Expand Down

0 comments on commit d1dd730

Please sign in to comment.