Skip to content

Commit

Permalink
Add yaml for eye OP (PaddlePaddle#41476)
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 authored and wu.zeng committed Apr 10, 2022
1 parent d170187 commit 2b1300a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
6 changes: 4 additions & 2 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,10 +1724,12 @@ def eye(num_rows,
else:
num_columns = num_rows

if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
_current_expected_place())
elif _in_legacy_dygraph():
out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
num_columns)

else:
helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype',
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/tests/unittests/test_eye_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def setUp(self):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
Expand All @@ -39,37 +40,39 @@ def setUp(self):
self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class TestEyeOp1(OpTest):
def setUp(self):
'''
Test eye op with default parameters
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
self.attrs = {'num_rows': 50}
self.outputs = {'Out': np.eye(50, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class TestEyeOp2(OpTest):
def setUp(self):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
self.attrs = {'num_rows': 99, 'num_columns': 1}
self.outputs = {'Out': np.eye(99, 1, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class API_TestTensorEye(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,18 @@
func : expm1
backward : expm1_grad

- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
infer_meta :
func : EyeInferMeta
param : [num_rows, num_columns, dtype]
kernel :
func : eye
param : [num_rows, num_columns, dtype]
data_type : dtype
backend : place

- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
Expand Down

0 comments on commit 2b1300a

Please sign in to comment.