Skip to content

Commit

Permalink
Add yaml for eye OP
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 committed Apr 6, 2022
1 parent 0c968b9 commit 22a6f38
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
6 changes: 4 additions & 2 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,10 +1724,12 @@ def eye(num_rows,
else:
num_columns = num_rows

if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
_current_expected_place())
elif _in_legacy_dygraph():
out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
num_columns)

else:
helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype',
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/tests/unittests/test_eye_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def setUp(self):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
Expand All @@ -39,37 +40,39 @@ def setUp(self):
self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class TestEyeOp1(OpTest):
def setUp(self):
'''
Test eye op with default parameters
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
self.attrs = {'num_rows': 50}
self.outputs = {'Out': np.eye(50, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class TestEyeOp2(OpTest):
def setUp(self):
'''
Test eye op with specified shape
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
self.attrs = {'num_rows': 99, 'num_columns': 1}
self.outputs = {'Out': np.eye(99, 1, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)


class API_TestTensorEye(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,18 @@
func : expm1
backward : expm1_grad

- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
infer_meta :
func : EyeInferMeta
param : [num_rows, num_columns, dtype]
kernel :
func : eye
param : [num_rows, num_columns, dtype]
data_type : dtype
backend : place

- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
Expand Down

1 comment on commit 22a6f38

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.