Skip to content

Commit

Permalink
support printing bf16 tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Feb 7, 2022
1 parent f1f74e9 commit cfff45a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,20 @@ def test_tensor_str_linewidth2(self):
self.assertEqual(a_str, expected)
paddle.enable_static()

def test_tensor_str_bf16(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[1.5, 1.0], [0, 0]])
a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16)
paddle.set_printoptions(precision=4)
a_str = str(a)

expected = '''Tensor(shape=[2, 2], dtype=bfloat16, place=Place(cpu), stop_gradient=True,
[[1.5000, 1. ],
[0. , 0. ]])'''

self.assertEqual(a_str, expected)
paddle.enable_static()

def test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1])
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/tensor/to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,18 @@ def _format_tensor(var, summary, indent=0, max_width=0, signed=False):
def to_string(var, prefix='Tensor'):
indent = len(prefix) + 1

dtype = convert_dtype(var.dtype)
if var.dtype == core.VarDesc.VarType.BF16:
dtype = 'bfloat16'

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"

tensor = var.value().get_tensor()
if not tensor._is_initialized():
return "Tensor(Not initialized)"

if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
np_var = var.numpy()

if len(var.shape) == 0:
Expand All @@ -250,7 +256,7 @@ def to_string(var, prefix='Tensor'):
return _template.format(
prefix=prefix,
shape=var.shape,
dtype=convert_dtype(var.dtype),
dtype=dtype,
place=var._place_str,
stop_gradient=var.stop_gradient,
indent=' ' * indent,
Expand Down

1 comment on commit cfff45a

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.