Skip to content

Commit

Permalink
support op hardtanh & fix unfold empyt
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghonggeng committed Feb 26, 2025
1 parent 1e29c15 commit ebea7b9
Show file tree
Hide file tree
Showing 14 changed files with 437 additions and 282 deletions.
94 changes: 48 additions & 46 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ REGISTER_PIR_MAPPER(sqrt, ActivationMapper)
REGISTER_MAPPER(square, SquareMapper)
REGISTER_PIR_MAPPER(square, SquareMapper)
REGISTER_MAPPER(tan, ActivationMapper)
REGISTER_PIR_MAPPER(hardtanh, ActivationMapper)
// REGISTER_PIR_MAPPER(hardtanh, ActivationMapper)
REGISTER_PIR_MAPPER(tan, ActivationMapper)
REGISTER_MAPPER(tanh, ActivationMapper)
REGISTER_PIR_MAPPER(tanh, ActivationMapper)
Expand All @@ -93,7 +93,6 @@ REGISTER_PIR_MAPPER(tanh_shrink, TanhShrinkMapper)
REGISTER_MAPPER(thresholded_relu, ThresholdedReluMapper)
REGISTER_PIR_MAPPER(thresholded_relu, ThresholdedReluMapper)


int32_t ActivationMapper::GetMinOpsetVersion(bool verbose) {
if (convert_pir_op_name(OpType()) == "softplus") {
float beta = 0.0;
Expand All @@ -114,27 +113,23 @@ int32_t ActivationMapper::GetMinOpsetVersion(bool verbose) {
return 7;
}


void ActivationMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto iter = op_mapper_.find(convert_pir_op_name(OpType()));
Assert(op_mapper_.end() != iter,
"Cannot find " +
convert_pir_op_name(OpType()) +
" in activation op_mapper.");
"Cannot find " + convert_pir_op_name(OpType()) +
" in activation op_mapper.");
if (convert_pir_op_name(OpType()) == "erf") {
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto input = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
auto output = helper_->MakeNode(iter->second, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
} else{
helper_->MakeNode(iter->second, {input_info[0].name},
{output_info[0].name});
helper_->AutoCast(
output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
} else {
helper_->MakeNode(
iter->second, {input_info[0].name}, {output_info[0].name});
}


}

int32_t PReluMapper::GetMinOpsetVersion(bool verbose) {
Expand All @@ -158,8 +153,8 @@ void PReluMapper::Opset7() {

std::string slope_cast_name = slope_info[0].name;
if (slope_info[0].dtype == P2ODataType::FP64) {
slope_cast_name = helper_->AutoCast({slope_info[0].name}, P2ODataType::FP64,
P2ODataType::FP32);
slope_cast_name = helper_->AutoCast(
{slope_info[0].name}, P2ODataType::FP64, P2ODataType::FP32);
}

if (slope_info[0].Rank() != input_info[0].Rank()) {
Expand All @@ -178,11 +173,13 @@ void PReluMapper::Opset7() {
std::string x_cast_name = helper_->AutoCast(
{input_info[0].name}, P2ODataType::FP64, P2ODataType::FP32);
auto node = helper_->MakeNode("PRelu", {x_cast_name, slope_cast_name});
helper_->AutoCast(node->output(0), {output_info[0].name}, P2ODataType::FP32,
helper_->AutoCast(node->output(0),
{output_info[0].name},
P2ODataType::FP32,
P2ODataType::FP64);
} else {
helper_->MakeNode("PRelu", {input_info[0].name, slope_cast_name},
{output_info[0].name});
helper_->MakeNode(
"PRelu", {input_info[0].name, slope_cast_name}, {output_info[0].name});
}
}

Expand All @@ -198,8 +195,8 @@ void SeluMapper::Opset7() {
void LeakyReluMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto node = helper_->MakeNode("LeakyRelu", {input_info[0].name},
{output_info[0].name});
auto node = helper_->MakeNode(
"LeakyRelu", {input_info[0].name}, {output_info[0].name});
AddAttribute(node, "alpha", alpha_);
}

Expand All @@ -217,8 +214,8 @@ void GeluMapper::Opset9() {
auto const_1 =
helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, const_1_value);

auto input_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto input_name = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);

// the computation formula follows
// https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/functional/gelu_cn.html#gelu
Expand Down Expand Up @@ -250,8 +247,8 @@ void SoftMaxMapper::Opset7() {
axis_ = axis_ + output_info[0].Rank();
}
if (axis_ == output_info[0].Rank() - 1) {
auto node = helper_->MakeNode("Softmax", {input_info[0].name},
{output_info[0].name});
auto node = helper_->MakeNode(
"Softmax", {input_info[0].name}, {output_info[0].name});
AddAttribute(node, "axis", axis_);
} else {
std::vector<int64_t> perm = Arange(0, output_info[0].Rank());
Expand Down Expand Up @@ -282,21 +279,24 @@ void SoftMaxMapper::Opset13() {
AddAttribute(node, "axis", static_cast<int64_t>(0));
helper_->Squeeze(node->output(0), output_info[0].name, {0});
} else {
auto node = helper_->MakeNode("Softmax", {input_info[0].name},
{output_info[0].name});
auto node = helper_->MakeNode(
"Softmax", {input_info[0].name}, {output_info[0].name});
AddAttribute(node, "axis", axis);
}
}

void BReluMapper::Opset7() {
auto x_info = GetInput("X");
helper_->Clip(x_info[0].name, GetOutput("Out")[0].name, t_min_, t_max_,
helper_->Clip(x_info[0].name,
GetOutput("Out")[0].name,
t_min_,
t_max_,
x_info[0].dtype);
}

void EluMapper::Opset7() {
auto node = helper_->MakeNode("Elu", {GetInput("X")[0].name},
{GetOutput("Out")[0].name});
auto node = helper_->MakeNode(
"Elu", {GetInput("X")[0].name}, {GetOutput("Out")[0].name});
AddAttribute(node, "alpha", alpha_);
}

Expand All @@ -311,24 +311,25 @@ int32_t MishMapper::GetMinOpsetVersion(bool verbose) {
void MishMapper::Opset7() {
auto input_info = GetInput("X");
auto out_info = GetOutput("Out");
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto input = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
auto softplus = helper_->MakeNode("Softplus", {input})->output(0);
auto tanh = helper_->MakeNode("Tanh", {softplus})->output(0);
auto output = helper_->MakeNode("Mul", {input, tanh})->output(0);
helper_->AutoCast(output, out_info[0].name, P2ODataType::FP32,
out_info[0].dtype);
helper_->AutoCast(
output, out_info[0].name, P2ODataType::FP32, out_info[0].dtype);
}

void SquareMapper::Opset7() {
auto input_info = GetInput("X");
helper_->MakeNode("Mul", {input_info[0].name, input_info[0].name},
helper_->MakeNode("Mul",
{input_info[0].name, input_info[0].name},
{GetOutput("Out")[0].name});
}

void SoftShrinkMapper::Opset9() {
auto node = helper_->MakeNode("Shrink", {GetInput("X")[0].name},
{GetOutput("Out")[0].name});
auto node = helper_->MakeNode(
"Shrink", {GetInput("X")[0].name}, {GetOutput("Out")[0].name});
AddAttribute(node, "lambd", lambda_);
AddAttribute(node, "bias", lambda_);
}
Expand All @@ -337,8 +338,8 @@ void SizeMapper::Opset7() {
auto out_info = GetOutput("Out");
auto output =
helper_->MakeNode("Size", {GetInput("Input")[0].name})->output(0);
output = helper_->AutoCast(output, out_info[0].name, P2ODataType::INT64,
out_info[0].dtype);
output = helper_->AutoCast(
output, out_info[0].name, P2ODataType::INT64, out_info[0].dtype);
}

void RsqrtMapper::Opset7() {
Expand Down Expand Up @@ -371,8 +372,8 @@ void LogSoftmaxMapper::Opset7() {
axis += input_info[0].Rank();
}
if (axis == input_info[0].Rank() - 1) {
auto node = helper_->MakeNode("LogSoftmax", {input_info[0].name},
{GetOutput("Out")[0].name});
auto node = helper_->MakeNode(
"LogSoftmax", {input_info[0].name}, {GetOutput("Out")[0].name});
AddAttribute(node, "axis", axis);
} else {
auto perm = Arange(0, input_info[0].Rank());
Expand All @@ -394,7 +395,9 @@ void ThresholdedReluMapper::Opset10() {
input = helper_->AutoCast(input, x_info[0].dtype, P2ODataType::FP32);
auto node = helper_->MakeNode("ThresholdedRelu", {input});
AddAttribute(node, "alpha", threshold_);
helper_->AutoCast(node->output(0), out_info[0].name, P2ODataType::FP32,
helper_->AutoCast(node->output(0),
out_info[0].name,
P2ODataType::FP32,
out_info[0].dtype);
} else {
auto node =
Expand All @@ -406,9 +409,8 @@ void ThresholdedReluMapper::Opset10() {
void Log1PMapper::Opset7() {
auto x_info = GetInput("X");
auto out_info = GetOutput("Out");
auto one = helper_->Constant({},
GetOnnxDtype(x_info[0].dtype),
static_cast<float>(1.0));
auto one = helper_->Constant(
{}, GetOnnxDtype(x_info[0].dtype), static_cast<float>(1.0));
auto input = helper_->MakeNode("Add", {x_info[0].name, one})->output(0);
helper_->MakeNode("Log", {input}, {out_info[0].name});
}
Expand Down
32 changes: 32 additions & 0 deletions paddle2onnx/mapper/activation/hardtanh.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/activation/hardtanh.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(hardtanh, HardtanhMapper)

int32_t HardtanhMapper::GetMinOpsetVersion(bool verbose) {
Logger(verbose, 7) << RequireOpset(7) << std::endl;
return 7;
}

void HardtanhMapper::Opset7() {
auto input_info = GetInput("x");
auto output_info = GetOutput("out");

helper_->Clip(
input_info[0].name, output_info[0].name, min_, max_, input_info[0].dtype);
}
} // namespace paddle2onnx
46 changes: 46 additions & 0 deletions paddle2onnx/mapper/activation/hardtanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <cmath>
#include <map>
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class HardtanhMapper : public Mapper {
public:
HardtanhMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
if (HasAttr("min")) {
GetAttr("min", &min_);
}
if (HasAttr("max")) {
GetAttr("max", &max_);
}
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
float min_ = -1.0;
float max_ = 1.0;
};
} // namespace paddle2onnx
Loading

0 comments on commit ebea7b9

Please sign in to comment.