Skip to content

Commit

Permalink
Merge pull request #1505 from risemeup1/fix_abs_bug
Browse files Browse the repository at this point in the history
fix abs bug
  • Loading branch information
risemeup1 authored Feb 17, 2025
2 parents 5735b0a + 5f1e448 commit 0b04b06
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 70 deletions.
23 changes: 2 additions & 21 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
#include "paddle2onnx/mapper/exporter.h"

namespace paddle2onnx {

REGISTER_MAPPER(abs, ActivationMapper)
REGISTER_PIR_MAPPER(abs, ActivationMapper)
REGISTER_MAPPER(acos, ActivationMapper)
REGISTER_MAPPER(asin, ActivationMapper)
REGISTER_MAPPER(atan, ActivationMapper)
Expand Down Expand Up @@ -114,28 +111,12 @@ void ActivationMapper::Opset7() {
auto output = helper_->MakeNode(iter->second, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
} else {
if (convert_pir_op_name(OpType()) == "abs" && input_info[0].dtype == P2ODataType::COMPLEX64){
input_info[0].dtype = P2ODataType::FP32;
int shape_size = input_info[0].shape.size();
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));

auto real_squre = helper_->MakeNode("Mul", {split_node->output(0),split_node->output(0)});
auto imag_squre = helper_->MakeNode("Mul", {split_node->output(1),split_node->output(1)});

auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});

helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
} else{
helper_->MakeNode(iter->second, {input_info[0].name},
{output_info[0].name});
}

}


}

int32_t PReluMapper::GetMinOpsetVersion(bool verbose) {
Expand Down
2 changes: 0 additions & 2 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class ActivationMapper : public Mapper {
op_mapper_["cos"] = "Cos";
op_mapper_["sin"] = "Sin";
op_mapper_["round"] = "Round";
op_mapper_["abs"] = "Abs";
op_mapper_["acos"] = "Acos";
op_mapper_["asin"] = "Asin";
op_mapper_["atan"] = "Atan";
Expand All @@ -64,7 +63,6 @@ class ActivationMapper : public Mapper {
op_mapper_["cos"] = "Cos";
op_mapper_["sin"] = "Sin";
op_mapper_["round"] = "Round";
op_mapper_["abs"] = "Abs";
op_mapper_["acos"] = "Acos";
op_mapper_["asin"] = "Asin";
op_mapper_["atan"] = "Atan";
Expand Down
67 changes: 67 additions & 0 deletions paddle2onnx/mapper/tensor/abs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/abs.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(abs, AbsMapper)
REGISTER_MAPPER(abs, AbsMapper)

int32_t AbsMapper::GetMinOpsetVersion(bool verbose) {
return 13;

}

void AbsMapper::Opset13() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
if (input_info[0].dtype == P2ODataType::COMPLEX64){
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));
std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1});
std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1});
auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1});
auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2});
auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});
helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
helper_->MakeNode("Abs", {input_info[0].name},
{output_info[0].name});
}
}
void AbsMapper::Opset18() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
if (input_info[0].dtype == P2ODataType::COMPLEX64){
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));
AddAttribute(split_node,"num_outputs",int64_t(2));
std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1});
std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1});
auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1});
auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2});
auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});
helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
helper_->MakeNode("Abs", {input_info[0].name},
{output_info[0].name});
}

}

} // namespace paddle2onnx
40 changes: 40 additions & 0 deletions paddle2onnx/mapper/tensor/abs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class AbsMapper : public Mapper {
public:
AbsMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
AbsMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset13() override;
void Opset18() override;

};

} // namespace paddle2onnx
47 changes: 6 additions & 41 deletions tests/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,31 @@ def forward(self, inputs):


@_test_with_pir
def test_abs_9():
def test_abs_13():
"""
api: paddle.abs
op version: 9
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, "abs", [9])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")),
)
obj.run()


@_test_with_pir
def test_abs_10():
"""
api: paddle.abs
op version: 10
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, "abs", [10])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")),
)
obj.run()


@_test_with_pir
def test_abs_11():
"""
api: paddle.abs
op version: 11
op version: 12
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, "abs", [11])
obj = APIOnnx(op, "abs", [13])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")),
)
obj.run()


@_test_with_pir
def test_abs_12():
def test_abs_18():
"""
api: paddle.abs
op version: 12
op version: 18
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, "abs", [12])
obj = APIOnnx(op, "abs", [18])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")),
Expand Down
9 changes: 3 additions & 6 deletions tests/test_auto_scan_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import OPConvertAutoScanTest, BaseNet
from hypothesis import reproduce_failure
import hypothesis.strategies as st
import numpy as np
import unittest
import paddle

Expand Down Expand Up @@ -56,7 +54,7 @@
}

opset_version_map = {
"abs": [7, 13, 15],
"abs": [13, 18],
"acos": [7, 15],
"asin": [7, 15],
"atan": [7, 15],
Expand Down Expand Up @@ -103,9 +101,8 @@ class TestUnaryOPConvert(OPConvertAutoScanTest):

def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=2, max_value=20), min_size=0, max_size=4))
st.lists(st.integers(min_value=2, max_value=20), min_size=0, max_size=4)
)
data_shapes = input_shape
dtype = draw(st.sampled_from(["float32"]))
config = {
Expand Down

0 comments on commit 0b04b06

Please sign in to comment.