Skip to content

Commit

Permalink
Merge pull request #1501 from risemeup1/support_round
Browse files Browse the repository at this point in the history
support round
  • Loading branch information
risemeup1 authored Feb 12, 2025
2 parents ec577c5 + 65b7851 commit 400040d
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 1 deletion.
21 changes: 20 additions & 1 deletion paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ REGISTER_MAPPER(reciprocal, ActivationMapper)
REGISTER_MAPPER(relu, ActivationMapper)
REGISTER_PIR_MAPPER(relu, ActivationMapper)
REGISTER_MAPPER(round, ActivationMapper)
REGISTER_PIR_MAPPER(round, ActivationMapper)
REGISTER_MAPPER(rsqrt, RsqrtMapper)
REGISTER_MAPPER(sel, ActivationMapper)
REGISTER_MAPPER(selu, SeluMapper)
Expand Down Expand Up @@ -113,9 +114,27 @@ void ActivationMapper::Opset7() {
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
} else {
helper_->MakeNode(iter->second, {input_info[0].name},
if (convert_pir_op_name(OpType()) == "abs" && input_info[0].dtype == P2ODataType::COMPLEX64){
input_info[0].dtype = P2ODataType::FP32;
int shape_size = input_info[0].shape.size();
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));

auto real_squre = helper_->MakeNode("Mul", {split_node->output(0),split_node->output(0)});
auto imag_squre = helper_->MakeNode("Mul", {split_node->output(1),split_node->output(1)});

auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});

helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
helper_->MakeNode(iter->second, {input_info[0].name},
{output_info[0].name});
}

}

}

int32_t PReluMapper::GetMinOpsetVersion(bool verbose) {
Expand Down
43 changes: 43 additions & 0 deletions paddle2onnx/mapper/tensor/fft_r2c.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/fft_r2c.h"

#include <cmath>
#include <string>
#include <vector>

namespace paddle2onnx {
REGISTER_PIR_MAPPER(fft_r2c, FftR2cMapper);

int32_t FftR2cMapper::GetMinOpsetVersion(bool verbose) {
return 17;
}

void FftR2cMapper::Opset17() {
auto input_info =GetInput("x");
auto output_info = GetOutput("out");
output_info[0].dtype = P2ODataType::FP32;
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({-1}));
std::string zero_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({0}));
auto node1 = helper_->MakeNode("Unsqueeze", {input_info[0].name, one_str});
auto node2 = helper_->MakeNode("Unsqueeze", {node1->output(0), zero_str});
auto dft_node = helper_->MakeNode("DFT", {node2->output(0)});
AddAttribute(dft_node, "onesided", int64_t(onesided_));
AddAttribute(dft_node, "inverse", int64_t(0));
AddAttribute(dft_node, "axis", int64_t(2));
helper_->MakeNode("Squeeze", {dft_node->output(0), zero_str}, {output_info[0].name});
}

} // namespace paddle2onnx
48 changes: 48 additions & 0 deletions paddle2onnx/mapper/tensor/fft_r2c.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class FftR2cMapper : public Mapper {
public:
FftR2cMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t op_id,
bool c)
: Mapper(p, helper, op_id, c) {

in_pir_mode = true;
GetAttr("normalization", &normalization_);
GetAttr("onesided", &onesided_);
GetAttr("forward", &forward_);
GetAttr("axes",&axes_);
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset17() override;

private:
std::string normalization_;
bool onesided_;
bool forward_;
std::vector<int64_t> axes_;
};

} // namespace paddle2onnx
52 changes: 52 additions & 0 deletions tests/test_fft_r2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from onnxbase import APIOnnx
from onnxbase import randtool
from onnxbase import _test_only_pir


class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, inputs):
"""
forward
"""
x = paddle.fft.rfft(inputs, axis=1)
x = paddle.abs(x)
return x


@_test_only_pir
def test_fftr2c_17():
"""
api: paddle.fft.rfft
op version: 17
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, "fft_r2c", [17])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 10, 10]).astype("float32")),
)
obj.run()

0 comments on commit 400040d

Please sign in to comment.