Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
risemeup1 committed Feb 18, 2025
2 parents bf08d32 + 890ed82 commit 45eef56
Show file tree
Hide file tree
Showing 21 changed files with 524 additions and 101 deletions.
7 changes: 7 additions & 0 deletions paddle2onnx/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def arg_parser():
default=True,
help="whether enable auto_update_opset, default is True",
)
parser.add_argument(
"--enable_dist_prim_all",
type=ast.literal_eval,
default=False,
help="whether enable dist_prim_all, default is False",
)
parser.add_argument(
"--external_filename",
type=str,
Expand Down Expand Up @@ -160,6 +166,7 @@ def main():
save_file=args.save_file,
opset_version=args.opset_version,
auto_upgrade_opset=args.enable_auto_update_opset,
dist_prim_all=args.enable_dist_prim_all,
verbose=True,
enable_onnx_checker=args.enable_onnx_checker,
enable_experimental_op=True,
Expand Down
90 changes: 90 additions & 0 deletions paddle2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,92 @@
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
from paddle2onnx.utils import logging, paddle_jit_save_configs
from contextlib import contextmanager
from paddle.decomposition import decomp
from paddle.base.executor import global_scope


def load_model(model_filename):
"""Loads the pir model from json file."""
assert os.path.exists(
model_filename
), f"Model file {model_filename} does not exist."
if model_filename.endswith(".json"):
model_filename = model_filename[:-5]
return paddle.jit.load(model_filename)


def compare_programs(original_program, new_program):
"""Compares two pir programs' operations."""
original_ops = [op.name() for op in original_program.global_block().ops]
new_ops = [op.name() for op in new_program.global_block().ops]
return original_ops == new_ops


def save_program(program, model_file):
"""Saves the decomposed program to a file."""
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

tmp_dir = tempfile.mkdtemp()
filename = os.path.basename(model_file) + "_decompose"
filename_without_extension, _ = os.path.splitext(filename)
save_dir = os.path.join(tmp_dir, filename_without_extension)

# Find feed and fetch operations
feed, fetch = [], []
for op in program.global_block().ops:
if op.name() == "pd_op.feed":
feed.extend(op.results())
if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output":
fetch.extend(op.operands_source())

with paddle.pir_utils.IrGuard():
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)

new_model_file = save_dir + ".json"
assert os.path.exists(
new_model_file
), f"Pir Model file {new_model_file} does not exist."
logging.info(f"Decomposed Model file path: {new_model_file}")
return new_model_file


def load_parameter(program):
params = []
opts = []
for var in program.list_vars():
if var.is_parameter or var.get_defining_op().name() == "builtin.parameter":
params.append(var)
elif var.persistable and var.get_defining_op().name() == "pd_op.data":
opts.append(var)
vars_list = params + opts
vars = [var for var in vars_list if var.persistable]

if vars is None:
return

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddle.base.libpaddle.pir.create_loaded_parameter(
vars, global_scope(), exe._default_executor
)


def decompose_program(model_filename):
"""Decomposes the given pir program."""
model = load_model(model_filename)
new_program = model.program().clone()
with decomp.prim_guard():
decomp.decompose_dist_program(new_program)

if compare_programs(model.program(), new_program):
return model_filename

# logging.info(f"Origin program: {model.program()}")
# logging.info(f"Decomposed program: {new_program}")

load_parameter(new_program)
return save_program(new_program, model_filename)


def get_old_ir_guard():
Expand All @@ -40,6 +126,7 @@ def export(
save_file=None,
opset_version=7,
auto_upgrade_opset=True,
dist_prim_all=False,
verbose=True,
enable_onnx_checker=True,
enable_experimental_op=True,
Expand Down Expand Up @@ -102,6 +189,9 @@ def export(
assert os.path.exists(
model_filename
), f"Pir Model file {model_filename} does not exist."
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
if dist_prim_all and auto_upgrade_opset:
model_filename = decompose_program(model_filename)

deploy_backend = deploy_backend.lower()
if custom_op_info is None:
Expand Down
9 changes: 4 additions & 5 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
#include "paddle2onnx/mapper/exporter.h"

namespace paddle2onnx {

REGISTER_MAPPER(abs, ActivationMapper)
REGISTER_PIR_MAPPER(abs, ActivationMapper)
REGISTER_MAPPER(acos, ActivationMapper)
REGISTER_PIR_MAPPER(acos, ActivationMapper)
REGISTER_MAPPER(asin, ActivationMapper)
Expand Down Expand Up @@ -132,10 +129,12 @@ void ActivationMapper::Opset7() {
auto output = helper_->MakeNode(iter->second, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
} else {
helper_->MakeNode(iter->second, {input_info[0].name},
} else{
helper_->MakeNode(iter->second, {input_info[0].name},
{output_info[0].name});
}


}

int32_t PReluMapper::GetMinOpsetVersion(bool verbose) {
Expand Down
10 changes: 4 additions & 6 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class ActivationMapper : public Mapper {
op_mapper_["cos"] = "Cos";
op_mapper_["sin"] = "Sin";
op_mapper_["round"] = "Round";
op_mapper_["abs"] = "Abs";
op_mapper_["acos"] = "Acos";
op_mapper_["asin"] = "Asin";
op_mapper_["atan"] = "Atan";
Expand All @@ -64,7 +63,6 @@ class ActivationMapper : public Mapper {
op_mapper_["cos"] = "Cos";
op_mapper_["sin"] = "Sin";
op_mapper_["round"] = "Round";
op_mapper_["abs"] = "Abs";
op_mapper_["acos"] = "Acos";
op_mapper_["asin"] = "Asin";
op_mapper_["atan"] = "Atan";
Expand Down Expand Up @@ -329,10 +327,10 @@ class RsqrtMapper : public Mapper {
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
RsqrtMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
:Mapper(p, helper, i, c) {}
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {}
void Opset7() override;
};

Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser,
current_opset = current_opset > 11 ? current_opset : 11;
} else if (op_name == "pd_op.while") {
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
pir_parser.GetWhileInputValuesAndArgsMappings(&while_op);
current_opset = GetCfBlockMinOpsetVersion(pir_parser, while_op.body());
current_opset = current_opset > 11 ? current_opset : 11;

Expand Down Expand Up @@ -483,7 +484,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
temp_outputs.push_back(std::move(MakeValueInfo(cond_info[0])));
if (value.defining_op() == nullptr) {
value =
pir::Value(pir_parser.while_op_input_value_map[&(*(value.impl()))]);
pir::Value(pir_parser.while_op_values_args_map[&(*(value.impl()))]);
}
if (value.defining_op()->GetParent() != &block) {
temp_inputs.push_back(std::move(MakeValueInfo(cond_info[0])));
Expand Down
7 changes: 7 additions & 0 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace paddle2onnx {
class Mapper {
public:
using ScalarData = PaddlePirParser::ScalarData;
Mapper() {}
Mapper(const PaddleParser& p,
OnnxHelper* helper,
Expand Down Expand Up @@ -238,6 +239,12 @@ class Mapper {

std::string Name() const { return name_; }

void GetScalarAttr(const std::string& scalar_name, ScalarData* scalar_data) {
Assert(in_pir_mode, "Only support PIR mode.");
pir_parser_->GetOpScalarValue(
pir_op_idx_, if_in_cf_block, scalar_name, scalar_data);
}

bool HasInput(const std::string& name) const {
if (in_pir_mode) {
return pir_parser_->OpHasInput(pir_op_idx_, name, if_in_cf_block);
Expand Down
67 changes: 67 additions & 0 deletions paddle2onnx/mapper/tensor/abs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/abs.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(abs, AbsMapper)
REGISTER_MAPPER(abs, AbsMapper)

int32_t AbsMapper::GetMinOpsetVersion(bool verbose) {
return 13;

}

void AbsMapper::Opset13() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
if (input_info[0].dtype == P2ODataType::COMPLEX64){
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));
std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1});
std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1});
auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1});
auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2});
auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});
helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
helper_->MakeNode("Abs", {input_info[0].name},
{output_info[0].name});
}
}
void AbsMapper::Opset18() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
if (input_info[0].dtype == P2ODataType::COMPLEX64){
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({1}));
auto split_node = helper_->MakeNode("Split", {input_info[0].name},2);
AddAttribute(split_node,"axis",int64_t(-1));
AddAttribute(split_node,"num_outputs",int64_t(2));
std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1});
std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1});
auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1});
auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2});
auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)});
helper_->MakeNode("Sqrt", {node_add->output(0)},
{output_info[0].name});
}else{
helper_->MakeNode("Abs", {input_info[0].name},
{output_info[0].name});
}

}

} // namespace paddle2onnx
40 changes: 40 additions & 0 deletions paddle2onnx/mapper/tensor/abs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class AbsMapper : public Mapper {
public:
AbsMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
AbsMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset13() override;
void Opset18() override;

};

} // namespace paddle2onnx
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/tensor/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ REGISTER_PIR_MAPPER(elementwise_mod, ElementWiseModMapper)
REGISTER_PIR_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper)

int32_t ElementwiseMapper::GetMinOpsetVersion(bool verbose) {
if (OpType() == "elementwise_min" || OpType() == "elementwise_max") {
if (convert_pir_op_name(OpType()) == "elementwise_min" ||
convert_pir_op_name(OpType()) == "elementwise_max") {
Logger(verbose, 8) << RequireOpset(8) << std::endl;
return 8;
}
Expand Down
43 changes: 43 additions & 0 deletions paddle2onnx/mapper/tensor/fft_r2c.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/fft_r2c.h"

#include <cmath>
#include <string>
#include <vector>

namespace paddle2onnx {
REGISTER_PIR_MAPPER(fft_r2c, FftR2cMapper);

int32_t FftR2cMapper::GetMinOpsetVersion(bool verbose) {
return 17;
}

void FftR2cMapper::Opset17() {
auto input_info =GetInput("x");
auto output_info = GetOutput("out");
output_info[0].dtype = P2ODataType::FP32;
std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({-1}));
std::string zero_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector<int64_t>({0}));
auto node1 = helper_->MakeNode("Unsqueeze", {input_info[0].name, one_str});
auto node2 = helper_->MakeNode("Unsqueeze", {node1->output(0), zero_str});
auto dft_node = helper_->MakeNode("DFT", {node2->output(0)});
AddAttribute(dft_node, "onesided", int64_t(onesided_));
AddAttribute(dft_node, "inverse", int64_t(0));
AddAttribute(dft_node, "axis", int64_t(2));
helper_->MakeNode("Squeeze", {dft_node->output(0), zero_str}, {output_info[0].name});
}

} // namespace paddle2onnx
Loading

0 comments on commit 45eef56

Please sign in to comment.