Skip to content

Commit

Permalink
clean unittest test_model_cast_to_bf16 (PaddlePaddle#48705)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored Dec 9, 2022
1 parent c6d2a2f commit a0385cf
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,26 @@ def _graph_common(self, _amp_fun, startup_prog=None):

with self.static_graph():
t_bf16 = layers.data(
name='t_bf16', shape=[size, size], dtype=np.uint16
name='t_bf16', shape=[size, size], dtype=np.int32
)
tt_bf16 = layers.data(
name='tt_bf16', shape=[size, size], dtype=np.uint16
name='tt_bf16', shape=[size, size], dtype=np.int32
)
t = layers.data(name='t', shape=[size, size], dtype='float32')
tt = layers.data(name='tt', shape=[size, size], dtype='float32')

ret = layers.elementwise_add(t, tt)
ret = layers.elementwise_mul(ret, t)
ret = paddle.add(t, tt)
ret = paddle.multiply(ret, t)
ret = paddle.reshape(ret, [0, 0])

with amp.bf16.bf16_guard():
ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16)
ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16)
ret_bf16 = paddle.add(t_bf16, tt_bf16)
ret_bf16 = paddle.multiply(ret_bf16, t_bf16)
ret_bf16 = paddle.reshape(ret_bf16, [0, 0])

with amp.bf16.bf16_guard():
ret_fp32bf16 = layers.elementwise_add(t, tt)
ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t)
ret_fp32bf16 = paddle.add(t, tt)
ret_fp32bf16 = paddle.multiply(ret_fp32bf16, t)
ret_fp32bf16 = paddle.reshape(ret_fp32bf16, [0, 0])

(
Expand Down Expand Up @@ -147,11 +147,11 @@ def _graph_common(self, _amp_fun, startup_prog=None):
tt = layers.data(name='tt', shape=[size, size], dtype='float32')

with amp.bf16.bf16_guard():
ret = layers.elementwise_add(t, tt)
ret = paddle.add(t, tt)
ret = paddle.reshape(ret, [0, 0])
ret = paddle.nn.functional.elu(ret)
ret = layers.elementwise_mul(ret, t)
ret = layers.elementwise_add(ret, tt)
ret = paddle.multiply(ret, t)
ret = paddle.add(ret, tt)

static_ret_bf16 = self.get_static_graph_result(
feed={'t': n, 'tt': nn},
Expand Down

0 comments on commit a0385cf

Please sign in to comment.