Skip to content

Commit

Permalink
add ne_p gt_p primitive operators
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Sep 9, 2022
1 parent 81c8173 commit 44f6771
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 1 deletion.
2 changes: 2 additions & 0 deletions paddle/fluid/operators/prim_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ set(PRIM_OP_SRCS
log_p_op.cc
select_p_op.cc
eq_p_op.cc
gt_p_op.cc
ne_p_op.cc
pow_p_op.cc
max_p_op.cc
erf_p_op.cc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,66 @@ def init_data(self):
]


class TestGtPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')

self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}

self.all_ops = [
# prim op:
'gt_p',
# jvp op:
'fill_constant_p',
# transpose op:
]


class TestNePJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')

self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}

self.all_ops = [
# prim op:
'ne_p',
# jvp op:
'fill_constant_p',
# transpose op:
]


class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,44 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestNeOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'not_equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')

self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['not_equal', 'ne_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestGtOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'greater_than'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')

self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['greater_than', 'gt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestPowOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,44 @@ def init_data(self):
self.out_map = {self.output['Z']: 0}


class TestNePPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}

self.prim2orig_args = (X, Y)
self.all_ops = ['ne_p', 'not_equal']
self.out_map = {self.output['Z']: 0}


class TestGtPPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}

self.prim2orig_args = (X, Y)
self.all_ops = ['gt_p', 'greater_than']
self.out_map = {self.output['Z']: 0}


class TestPowPPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,11 @@ def test_illegal_param(self):
('mean_with_axis', lambda x: paddle.mean(x, axis=1),
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_keepdim', lambda x: paddle.mean(x, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32')))
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_axis_keepdim',
lambda x: paddle.mean(x, axis=0, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):

def setUp(self):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
('select', primops.select,
(randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
))
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,16 @@ def eq(x, y, out=None):
return _simple_binop(LayerHelper('eq_p', **locals()))


@REGISTER_FN('gt_p', 'X', 'Y', 'Z')
def ge(x, y, out=None):
return _simple_binop(LayerHelper('gt_p', **locals()))


@REGISTER_FN('ne_p', 'X', 'Y', 'Z')
def ne(x, y, out=None):
return _simple_binop(LayerHelper('ne_p', **locals()))


@REGISTER_FN('pow_p', 'X', 'Y', 'Z')
def pow(x, y, out=None):
return _simple_binop(LayerHelper('pow_p', **locals()))
Expand Down
42 changes: 42 additions & 0 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,20 @@ def equal_orig2prim(op, x, y):
return eq(x, y)


@REGISTER_ORIG2PRIM('not_equal')
def ne_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return ne(x, y)


@REGISTER_ORIG2PRIM('greater_than')
def gt_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return gt(x, y)


@REGISTER_ORIG2PRIM('elementwise_pow')
def elementwise_pow_orig2prim(op, x, y):
if x.shape != y.shape:
Expand Down Expand Up @@ -523,6 +537,16 @@ def eq_prim2orig(op, x, y):
return paddle.equal(x, y)


@REGISTER_PRIM2ORIG('gt_p')
def gt_prim2orig(op, x, y):
return paddle.greater_than(x, y)


@REGISTER_PRIM2ORIG('ne_p')
def ne_prim2orig(op, x, y):
return paddle.not_equal(x, y)


@REGISTER_PRIM2ORIG('pow_p')
def pow_prim2orig(op, x, y):
return paddle.pow(x, y)
Expand Down Expand Up @@ -787,6 +811,24 @@ def eq_jvp(op, x_dot, y_dot):
return z_dot


@REGISTER_JVP('gt_p')
def gt_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot


@REGISTER_JVP('ne_p')
def ne_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot


@REGISTER_JVP('pow_p')
def pow_jvp(op, x_dot, y_dot):

Expand Down

0 comments on commit 44f6771

Please sign in to comment.