Skip to content

Commit

Permalink
[Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#…
Browse files Browse the repository at this point in the history
…41920)

* [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml

* add unittest for full_like

* fix unittest
  • Loading branch information
Aurelius84 committed Apr 19, 2022
1 parent cc728bb commit 7063037
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
16 changes: 11 additions & 5 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
bool ret =
input.GetType() == AllocationType::GPUPINNED ||
(transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
// NOTE(dev): The default value of TransformFlag is True, if it is set with
// False
// somewhere such as api.yaml or backward.yaml that means we should skip data
// transform. Because "stop_transform_" has highest priority.
if (!transform_flag.need_trans_backend()) {
return false;
}
bool ret = input.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) !=
(target != Backend::GPUDNN ? target : Backend::GPU));
return ret;
}

Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_full_like_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
from op_test import OpTest
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.framework import _test_eager_guard


class TestFullOp(unittest.TestCase):
Expand Down Expand Up @@ -133,5 +134,19 @@ def init_data(self):
self.dtype = np.int64


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFullLikeOp4(unittest.TestCase):
def test_skip_data_transform(self):
paddle.disable_static()
with _test_eager_guard():
x = paddle.to_tensor(
[1., 2., 3., 4.], place=paddle.CUDAPinnedPlace())
out = paddle.full_like(x, 1.)
self.assertTrue(
(out.numpy() == np.ones([4]).astype(np.float32)).all(), True)
paddle.enable_static()


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@
- api : deformable_conv
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
output : Tensor(out)
infer_meta :
infer_meta :
func : DeformableConvInferMeta
kernel :
func : deformable_conv
Expand Down Expand Up @@ -763,6 +763,8 @@
param : [x, value, dtype]
data_type : dtype > x
backend : place > x
data_transform :
skip_transform : x

- api : gather
args : (Tensor x, Tensor index, Scalar axis=0)
Expand Down

1 comment on commit 7063037

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.