diff --git a/fluid_onnx/ops.py b/fluid_onnx/ops.py index 238ee2977..75291d7c4 100644 --- a/fluid_onnx/ops.py +++ b/fluid_onnx/ops.py @@ -155,6 +155,14 @@ def clip_op(operator, block): max=attrs['max']) +def compare_ops(op_type, operator, block): + ''' Conversion for compare ops, including 'Less', 'Equal', 'Greater' + ''' + inputs, attrs, outputs = op_io_info(operator) + return make_node( + op_type, inputs=inputs['X'] + inputs['Y'], outputs=outputs['Out']) + + def concat_op(operator, block): inputs, attrs, outputs = op_io_info(operator) return make_node( @@ -270,8 +278,10 @@ def elementwise_ops(op_type, operator, block): broadcast=1) -def elu_op(): - pass +def elu_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'Elu', inputs=inputs['X'], outputs=outputs['Out'], alpha=attrs['alpha']) def equal_op(): @@ -286,8 +296,10 @@ def gru_op(): pass -def gather_op(): - pass +def gather_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'Gather', inputs=inputs['X'] + inputs['Index'], outputs=outputs['Out']) def gemm_op(): @@ -298,10 +310,6 @@ def globallppool_op(): pass -def greater_op(): - pass - - def hardsigmoid_op(operator, block): inputs, attrs, outputs = op_io_info(operator) return make_node( @@ -337,12 +345,13 @@ def lstm_op(): pass -def leakyrelu_op(): - pass - - -def less_op(): - pass +def leaky_relu_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'LeakyRelu', + inputs=inputs['X'], + outputs=outputs['Out'], + alpha=attrs['alpha']) def binary_logical_ops(op_type, operator, block): @@ -477,22 +486,10 @@ def neg_op(): pass -def not_op(): - """ - Need to support broadcast. - """ - pass - - -def or_op(): - """ - Need to support broadcast. - """ - pass - - -def prelu_op(): - pass +def prelu_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'PRelu', inputs=inputs['X'] + inputs['Alpha'], outputs=outputs['Out']) def pad_op(): @@ -630,10 +627,6 @@ def split_op(): pass -def sqrt_op(): - pass - - def squeeze_op(): pass @@ -662,11 +655,13 @@ def unsqueeze_op(): pass -def xor_op(): - """ - Need to support broadcast. - """ - pass +def thresholded_relu_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'ThresholdedRelu', + inputs=inputs['X'], + outputs=outputs['Out'], + alpha=attrs['threshold']) # Based on the ONNX 1.0 operator list generated on March 26th, 2018. @@ -687,6 +682,7 @@ def xor_op(): 'conv2d': conv2d_op, # Need to continue the mapping below. 'conv2d_transpose': conv2d_transpose_op, + # 'cos': partial(activation_ops, 'Cos'), '': 'DepthToSpace', 'depthwise_conv2d': conv2d_op, 'dropout': dropout_op, @@ -695,23 +691,23 @@ def xor_op(): 'elementwise_mul': partial(elementwise_ops, 'Mul'), 'elementwise_pow': partial(elementwise_ops, 'Pow'), 'elementwise_sub': partial(elementwise_ops, 'Sub'), - '': 'Elu', - '': 'Equal', + 'elu': elu_op, + 'equal': partial(compare_ops, 'Equal'), 'exp': partial(activation_ops, 'Exp'), '': 'Flatten', 'floor': partial(activation_ops, 'Floor'), '': 'GRU', - '': 'Gather', + 'gather': gather_op, '': 'Gemm', '': 'GlobalLpPool', - '': 'Greater', + 'greater_than': partial(compare_ops, 'Greater'), 'hard_sigmoid': 'HardSigmoid', # Caffe2 error # 'Hardmax', NEEDS ATTENTION. # 'InstanceNormalization', NEEDS ATTENTION. + 'less_than': partial(compare_ops, 'Less'), 'lrn': lrn_op, '': 'LSTM', - '': 'LeakyRelu', - '': 'Less', + 'leaky_relu': leaky_relu_op, 'log': partial(activation_ops, 'Log'), 'logical_and': partial(binary_logical_ops, 'And'), 'logical_or': partial(binary_logical_ops, 'Or'), @@ -728,7 +724,7 @@ def xor_op(): '': 'Min', 'mul': mul_op, ',': 'Neg', - '': 'PRelu', + 'prelu': prelu_op, '': 'Pad', 'pool2d': pool2d_op, ',': 'RNN', @@ -752,6 +748,7 @@ def xor_op(): # 'Selu', NEEDS ATTENTION. '': 'Shape', 'sigmoid': partial(activation_ops, 'Sigmoid'), + # 'sin': partial(activation_ops, 'Sin'), '': 'Size', # 'Slice', NEEDS ATTENTION. 'softmax': softmax_op, @@ -783,6 +780,6 @@ def xor_op(): # 'experimental ParametricSoftplus' # 'experimental Scale' # 'experimental ScaledTanh' - # 'experimental ThresholdedRelu' + 'thresholded_relu': thresholded_relu_op, # 'experimental Upsample' } diff --git a/tests/test_activation_ops.py b/tests/test_activation_ops.py index b1310c2bd..605a7a2b1 100644 --- a/tests/test_activation_ops.py +++ b/tests/test_activation_ops.py @@ -86,5 +86,23 @@ def init_op_type(self): self.op_type = 'tanh' +class TestEluOp(TestAbsOp): + def init_op_type(self): + self.op_type = 'elu' + self.attrs = {'alpha': 2.0} + + +class TestLeakyReluOp(TestAbsOp): + def init_op_type(self): + self.op_type = 'leaky_relu' + self.attrs = {'alpha': 0.1} + + +class TestThresholdedReluOp(TestAbsOp): + def init_op_type(self): + self.op_type = 'thresholded_relu' + self.attrs = {'alpha': 0.1} + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_compare_ops.py b/tests/test_compare_ops.py new file mode 100644 index 000000000..e92873e36 --- /dev/null +++ b/tests/test_compare_ops.py @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import op_test +import unittest +import numpy + + +def create_test_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + a = numpy.random.random(size=(10, 7)).astype(typename) + b = numpy.random.random(size=(10, 7)).astype(typename) + c = callback(a, b) + self.inputs = {'X': a, 'Y': b} + self.outputs = {'Out': c} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}".format(op_type, typename) + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +for _type_name in {'float32', 'float64', 'int32', 'int64'}: + create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + #create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + #create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) + #create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + #create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_gather_op.py b/tests/test_gather_op.py new file mode 100644 index 000000000..93a07f7d6 --- /dev/null +++ b/tests/test_gather_op.py @@ -0,0 +1,32 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestGatherOp(OpTest): + def setUp(self): + self.op_type = "gather" + xnp = np.random.random((10, 20)).astype("float32") + self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")} + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prelu_op.py b/tests/test_prelu_op.py new file mode 100644 index 000000000..908a39efb --- /dev/null +++ b/tests/test_prelu_op.py @@ -0,0 +1,33 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class PReluTest(OpTest): + def setUp(self): + self.op_type = "prelu" + x = np.random.normal((10, 10)).astype("float32") + alpha = np.array([.1], dtype="float32") + self.inputs = {'X': x, 'Alpha': alpha} + self.outputs = {'Out': np.zeros((1, 1))} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()