diff --git a/paddle2onnx/command.py b/paddle2onnx/command.py index 3ee21879b..b101158ab 100755 --- a/paddle2onnx/command.py +++ b/paddle2onnx/command.py @@ -92,6 +92,12 @@ def arg_parser(): default=True, help="whether enable auto_update_opset, default is True", ) + parser.add_argument( + "--enable_dist_prim_all", + type=ast.literal_eval, + default=False, + help="whether enable dist_prim_all, default is False", + ) parser.add_argument( "--external_filename", type=str, @@ -160,6 +166,7 @@ def main(): save_file=args.save_file, opset_version=args.opset_version, auto_upgrade_opset=args.enable_auto_update_opset, + dist_prim_all=args.enable_dist_prim_all, verbose=True, enable_onnx_checker=args.enable_onnx_checker, enable_experimental_op=True, diff --git a/paddle2onnx/convert.py b/paddle2onnx/convert.py index ef9304a5c..1e43a88c9 100755 --- a/paddle2onnx/convert.py +++ b/paddle2onnx/convert.py @@ -18,6 +18,92 @@ import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o from paddle2onnx.utils import logging, paddle_jit_save_configs from contextlib import contextmanager +from paddle.decomposition import decomp +from paddle.base.executor import global_scope + + +def load_model(model_filename): + """Loads the pir model from json file.""" + assert os.path.exists( + model_filename + ), f"Model file {model_filename} does not exist." + if model_filename.endswith(".json"): + model_filename = model_filename[:-5] + return paddle.jit.load(model_filename) + + +def compare_programs(original_program, new_program): + """Compares two pir programs' operations.""" + original_ops = [op.name() for op in original_program.global_block().ops] + new_ops = [op.name() for op in new_program.global_block().ops] + return original_ops == new_ops + + +def save_program(program, model_file): + """Saves the decomposed program to a file.""" + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + tmp_dir = tempfile.mkdtemp() + filename = os.path.basename(model_file) + "_decompose" + filename_without_extension, _ = os.path.splitext(filename) + save_dir = os.path.join(tmp_dir, filename_without_extension) + + # Find feed and fetch operations + feed, fetch = [], [] + for op in program.global_block().ops: + if op.name() == "pd_op.feed": + feed.extend(op.results()) + if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output": + fetch.extend(op.operands_source()) + + with paddle.pir_utils.IrGuard(): + paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program) + + new_model_file = save_dir + ".json" + assert os.path.exists( + new_model_file + ), f"Pir Model file {new_model_file} does not exist." + logging.info(f"Decomposed Model file path: {new_model_file}") + return new_model_file + + +def load_parameter(program): + params = [] + opts = [] + for var in program.list_vars(): + if var.is_parameter or var.get_defining_op().name() == "builtin.parameter": + params.append(var) + elif var.persistable and var.get_defining_op().name() == "pd_op.data": + opts.append(var) + vars_list = params + opts + vars = [var for var in vars_list if var.persistable] + + if vars is None: + return + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + paddle.base.libpaddle.pir.create_loaded_parameter( + vars, global_scope(), exe._default_executor + ) + + +def decompose_program(model_filename): + """Decomposes the given pir program.""" + model = load_model(model_filename) + new_program = model.program().clone() + with decomp.prim_guard(): + decomp.decompose_dist_program(new_program) + + if compare_programs(model.program(), new_program): + return model_filename + + # logging.info(f"Origin program: {model.program()}") + # logging.info(f"Decomposed program: {new_program}") + + load_parameter(new_program) + return save_program(new_program, model_filename) def get_old_ir_guard(): @@ -40,6 +126,7 @@ def export( save_file=None, opset_version=7, auto_upgrade_opset=True, + dist_prim_all=False, verbose=True, enable_onnx_checker=True, enable_experimental_op=True, @@ -102,6 +189,9 @@ def export( assert os.path.exists( model_filename ), f"Pir Model file {model_filename} does not exist." + if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]: + if dist_prim_all and auto_upgrade_opset: + model_filename = decompose_program(model_filename) deploy_backend = deploy_backend.lower() if custom_op_info is None: diff --git a/paddle2onnx/mapper/activation/activation.cc b/paddle2onnx/mapper/activation/activation.cc index dbfe958ce..74b39da0a 100644 --- a/paddle2onnx/mapper/activation/activation.cc +++ b/paddle2onnx/mapper/activation/activation.cc @@ -15,9 +15,6 @@ #include "paddle2onnx/mapper/exporter.h" namespace paddle2onnx { - -REGISTER_MAPPER(abs, ActivationMapper) -REGISTER_PIR_MAPPER(abs, ActivationMapper) REGISTER_MAPPER(acos, ActivationMapper) REGISTER_PIR_MAPPER(acos, ActivationMapper) REGISTER_MAPPER(asin, ActivationMapper) @@ -132,10 +129,12 @@ void ActivationMapper::Opset7() { auto output = helper_->MakeNode(iter->second, {input})->output(0); helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype); - } else { - helper_->MakeNode(iter->second, {input_info[0].name}, + } else{ + helper_->MakeNode(iter->second, {input_info[0].name}, {output_info[0].name}); } + + } int32_t PReluMapper::GetMinOpsetVersion(bool verbose) { diff --git a/paddle2onnx/mapper/activation/activation.h b/paddle2onnx/mapper/activation/activation.h index 832b90d1b..ce7b3ff7a 100644 --- a/paddle2onnx/mapper/activation/activation.h +++ b/paddle2onnx/mapper/activation/activation.h @@ -38,7 +38,6 @@ class ActivationMapper : public Mapper { op_mapper_["cos"] = "Cos"; op_mapper_["sin"] = "Sin"; op_mapper_["round"] = "Round"; - op_mapper_["abs"] = "Abs"; op_mapper_["acos"] = "Acos"; op_mapper_["asin"] = "Asin"; op_mapper_["atan"] = "Atan"; @@ -64,7 +63,6 @@ class ActivationMapper : public Mapper { op_mapper_["cos"] = "Cos"; op_mapper_["sin"] = "Sin"; op_mapper_["round"] = "Round"; - op_mapper_["abs"] = "Abs"; op_mapper_["acos"] = "Acos"; op_mapper_["asin"] = "Asin"; op_mapper_["atan"] = "Atan"; @@ -329,10 +327,10 @@ class RsqrtMapper : public Mapper { int64_t op_id) : Mapper(p, helper, block_id, op_id) {} RsqrtMapper(const PaddlePirParser& p, - OnnxHelper* helper, - int64_t i, - bool c) - :Mapper(p, helper, i, c) {} + OnnxHelper* helper, + int64_t i, + bool c) + : Mapper(p, helper, i, c) {} void Opset7() override; }; diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index a7c9507a2..65f4f6ec8 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -252,6 +252,7 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser, current_opset = current_opset > 11 ? current_opset : 11; } else if (op_name == "pd_op.while") { auto while_op = op->dyn_cast(); + pir_parser.GetWhileInputValuesAndArgsMappings(&while_op); current_opset = GetCfBlockMinOpsetVersion(pir_parser, while_op.body()); current_opset = current_opset > 11 ? current_opset : 11; @@ -483,7 +484,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock( temp_outputs.push_back(std::move(MakeValueInfo(cond_info[0]))); if (value.defining_op() == nullptr) { value = - pir::Value(pir_parser.while_op_input_value_map[&(*(value.impl()))]); + pir::Value(pir_parser.while_op_values_args_map[&(*(value.impl()))]); } if (value.defining_op()->GetParent() != &block) { temp_inputs.push_back(std::move(MakeValueInfo(cond_info[0]))); diff --git a/paddle2onnx/mapper/mapper.h b/paddle2onnx/mapper/mapper.h index 25ef29792..f36bc8859 100644 --- a/paddle2onnx/mapper/mapper.h +++ b/paddle2onnx/mapper/mapper.h @@ -23,6 +23,7 @@ namespace paddle2onnx { class Mapper { public: + using ScalarData = PaddlePirParser::ScalarData; Mapper() {} Mapper(const PaddleParser& p, OnnxHelper* helper, @@ -238,6 +239,12 @@ class Mapper { std::string Name() const { return name_; } + void GetScalarAttr(const std::string& scalar_name, ScalarData* scalar_data) { + Assert(in_pir_mode, "Only support PIR mode."); + pir_parser_->GetOpScalarValue( + pir_op_idx_, if_in_cf_block, scalar_name, scalar_data); + } + bool HasInput(const std::string& name) const { if (in_pir_mode) { return pir_parser_->OpHasInput(pir_op_idx_, name, if_in_cf_block); diff --git a/paddle2onnx/mapper/tensor/abs.cc b/paddle2onnx/mapper/tensor/abs.cc new file mode 100644 index 000000000..bc58123fd --- /dev/null +++ b/paddle2onnx/mapper/tensor/abs.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/tensor/abs.h" + +namespace paddle2onnx { +REGISTER_PIR_MAPPER(abs, AbsMapper) +REGISTER_MAPPER(abs, AbsMapper) + +int32_t AbsMapper::GetMinOpsetVersion(bool verbose) { + return 13; + +} + +void AbsMapper::Opset13() { + auto input_info = GetInput("X"); + auto output_info = GetOutput("Out"); + if (input_info[0].dtype == P2ODataType::COMPLEX64){ + std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector({1})); + auto split_node = helper_->MakeNode("Split", {input_info[0].name},2); + AddAttribute(split_node,"axis",int64_t(-1)); + std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1}); + std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1}); + auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1}); + auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2}); + auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)}); + helper_->MakeNode("Sqrt", {node_add->output(0)}, + {output_info[0].name}); + }else{ + helper_->MakeNode("Abs", {input_info[0].name}, + {output_info[0].name}); + } +} +void AbsMapper::Opset18() { + auto input_info = GetInput("X"); + auto output_info = GetOutput("Out"); + if (input_info[0].dtype == P2ODataType::COMPLEX64){ + std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector({1})); + auto split_node = helper_->MakeNode("Split", {input_info[0].name},2); + AddAttribute(split_node,"axis",int64_t(-1)); + AddAttribute(split_node,"num_outputs",int64_t(2)); + std::string split_node1 = helper_->Squeeze(split_node->output(0), {-1}); + std::string split_node2 = helper_->Squeeze(split_node->output(1), {-1}); + auto real_squre = helper_->MakeNode("Mul", {split_node1,split_node1}); + auto imag_squre = helper_->MakeNode("Mul", {split_node2 ,split_node2}); + auto node_add = helper_->MakeNode("Add", {real_squre->output(0),imag_squre->output(0)}); + helper_->MakeNode("Sqrt", {node_add->output(0)}, + {output_info[0].name}); + }else{ + helper_->MakeNode("Abs", {input_info[0].name}, + {output_info[0].name}); + } + +} + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/abs.h b/paddle2onnx/mapper/tensor/abs.h new file mode 100644 index 000000000..bb196163b --- /dev/null +++ b/paddle2onnx/mapper/tensor/abs.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle2onnx/mapper/mapper.h" + +namespace paddle2onnx { + +class AbsMapper : public Mapper { + public: + AbsMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, + int64_t op_id) + : Mapper(p, helper, block_id, op_id) {} + AbsMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i, + bool c) + : Mapper(p, helper, i, c) { + in_pir_mode = true; + } + + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset13() override; + void Opset18() override; + +}; + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/elementwise.cc b/paddle2onnx/mapper/tensor/elementwise.cc index 28f07820d..929bab558 100755 --- a/paddle2onnx/mapper/tensor/elementwise.cc +++ b/paddle2onnx/mapper/tensor/elementwise.cc @@ -34,7 +34,8 @@ REGISTER_PIR_MAPPER(elementwise_mod, ElementWiseModMapper) REGISTER_PIR_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper) int32_t ElementwiseMapper::GetMinOpsetVersion(bool verbose) { - if (OpType() == "elementwise_min" || OpType() == "elementwise_max") { + if (convert_pir_op_name(OpType()) == "elementwise_min" || + convert_pir_op_name(OpType()) == "elementwise_max") { Logger(verbose, 8) << RequireOpset(8) << std::endl; return 8; } diff --git a/paddle2onnx/mapper/tensor/fft_r2c.cc b/paddle2onnx/mapper/tensor/fft_r2c.cc new file mode 100644 index 000000000..6b3177bf9 --- /dev/null +++ b/paddle2onnx/mapper/tensor/fft_r2c.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/tensor/fft_r2c.h" + +#include +#include +#include + +namespace paddle2onnx { +REGISTER_PIR_MAPPER(fft_r2c, FftR2cMapper); + +int32_t FftR2cMapper::GetMinOpsetVersion(bool verbose) { + return 17; +} + +void FftR2cMapper::Opset17() { + auto input_info =GetInput("x"); + auto output_info = GetOutput("out"); + output_info[0].dtype = P2ODataType::FP32; + std::string one_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector({-1})); + std::string zero_str = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), std::vector({0})); + auto node1 = helper_->MakeNode("Unsqueeze", {input_info[0].name, one_str}); + auto node2 = helper_->MakeNode("Unsqueeze", {node1->output(0), zero_str}); + auto dft_node = helper_->MakeNode("DFT", {node2->output(0)}); + AddAttribute(dft_node, "onesided", int64_t(onesided_)); + AddAttribute(dft_node, "inverse", int64_t(0)); + AddAttribute(dft_node, "axis", int64_t(2)); + helper_->MakeNode("Squeeze", {dft_node->output(0), zero_str}, {output_info[0].name}); +} + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/fft_r2c.h b/paddle2onnx/mapper/tensor/fft_r2c.h new file mode 100644 index 000000000..60f5830a5 --- /dev/null +++ b/paddle2onnx/mapper/tensor/fft_r2c.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle2onnx/mapper/mapper.h" + +namespace paddle2onnx { + +class FftR2cMapper : public Mapper { + public: + FftR2cMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool c) + : Mapper(p, helper, op_id, c) { + + in_pir_mode = true; + GetAttr("normalization", &normalization_); + GetAttr("onesided", &onesided_); + GetAttr("forward", &forward_); + GetAttr("axes",&axes_); + } + + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset17() override; + + private: + std::string normalization_; + bool onesided_; + bool forward_; + std::vector axes_; +}; + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/full.cc b/paddle2onnx/mapper/tensor/full.cc index d10a91170..7ea237526 100644 --- a/paddle2onnx/mapper/tensor/full.cc +++ b/paddle2onnx/mapper/tensor/full.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "paddle2onnx/mapper/tensor/full.h" - #include #include +#include #include namespace paddle2onnx { @@ -23,9 +23,32 @@ REGISTER_PIR_MAPPER(full, FullMapper) void FullMapper::Opset7() { auto output_info = GetOutput("Out"); - std::cout << "full value is : " << value_ << std::endl; - helper_->Constant(output_info[0].name, shape_, - GetOnnxDtype(output_info[0].dtype), value_); -} + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + helper_->Constant(output_info[0].name, + shape_, + GetOnnxDtype(output_info[0].dtype), + std::get(value_)); + } else if constexpr (std::is_same_v) { + helper_->Constant(output_info[0].name, + shape_, + GetOnnxDtype(output_info[0].dtype), + std::get(value_)); + } else if constexpr (std::is_same_v) { + helper_->Constant(output_info[0].name, + shape_, + GetOnnxDtype(output_info[0].dtype), + std::get(value_)); + } else if constexpr (std::is_same_v) { + helper_->Constant(output_info[0].name, + shape_, + GetOnnxDtype(output_info[0].dtype), + std::get(value_)); + } + }, + value_); } +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/full.h b/paddle2onnx/mapper/tensor/full.h index 70c819b16..47cade57c 100644 --- a/paddle2onnx/mapper/tensor/full.h +++ b/paddle2onnx/mapper/tensor/full.h @@ -29,7 +29,7 @@ class FullMapper : public Mapper { bool if_in_cf_block) : Mapper(p, helper, op_id, if_in_cf_block) { GetAttr("dtype", &dtype_); - GetAttr("value", &value_); + GetScalarAttr("value", &value_); GetAttr("shape", &shape_); } @@ -37,7 +37,7 @@ class FullMapper : public Mapper { private: int64_t dtype_; - double value_; + ScalarData value_; std::vector shape_; }; diff --git a/paddle2onnx/mapper/tensor/reduce_sum.cc b/paddle2onnx/mapper/tensor/reduce_sum.cc index 87f174c00..08380840a 100644 --- a/paddle2onnx/mapper/tensor/reduce_sum.cc +++ b/paddle2onnx/mapper/tensor/reduce_sum.cc @@ -27,7 +27,7 @@ int32_t ReduceMapperSum::GetMinOpsetVersion(bool verbose) { void ReduceMapperSum::Opset13() { auto axis_name_ = "dim"; GetAttr("keep_dim", &keep_dim_); -if (!in_pir_mode) { + if (!in_pir_mode) { GetAttr("reduce_all", &reduce_all_); GetAttr("in_dtype", &in_dtype_); GetAttr("out_dtype", &out_dtype_); @@ -47,6 +47,13 @@ if (!in_pir_mode) { } auto x_info = GetInput("X"); + auto x_name = x_info[0].name; + auto x_tpye = x_info[0].dtype; + if (x_info[0].dtype == P2ODataType::BOOL) { + x_name = helper_->AutoCast(x_name, x_tpye, P2ODataType::INT32); + x_tpye = P2ODataType::INT32; + } + std::string dims; if (IsAttrVar(axis_name_)) { auto info = GetAttrVar(axis_name_); @@ -61,7 +68,7 @@ if (!in_pir_mode) { } // Add attribute - auto reduce_node = helper_->MakeNode("ReduceSum", {x_info[0].name, dims}); + auto reduce_node = helper_->MakeNode("ReduceSum", {x_name, dims}); AddAttribute(reduce_node, "keepdims", static_cast(keep_dim_)); auto out_node_name = reduce_node->output(0); @@ -73,7 +80,6 @@ if (!in_pir_mode) { out_node_name = helper_->Reshape(out_node_name, {-1}); } auto out_info = GetOutput("Out"); - helper_->AutoCast(out_node_name, out_info[0].name, - x_info[0].dtype, out_info[0].dtype); + helper_->AutoCast(out_node_name, out_info[0].name, x_tpye, out_info[0].dtype); } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/while.cc b/paddle2onnx/mapper/while.cc index a868e2e34..281016cfb 100644 --- a/paddle2onnx/mapper/while.cc +++ b/paddle2onnx/mapper/while.cc @@ -25,32 +25,12 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser, std::vector outputs_info; auto while_op = op->dyn_cast(); auto cond_info = pir_parser.GetTensorInfo(while_op.cond()); - // mapping args and inputs in while op using while_op_input_value_map - std::vector while_op_input_value_address; - std::vector while_op_input_arg_address; - pir_parser.while_op_input_value_map - .clear(); // wangmingkai02: handle nested loop situations in future. - - // record input value address for (int index = 1; index < while_op.num_operands(); index++) { const pir::Value& value = while_op.operand_source(index); inputs_info.push_back(pir_parser.GetTensorInfo( - pir_parser.GetOpOutputName(value), value.type())); - while_op_input_value_address.push_back( - &(*(value).impl())); // get value address - } - // record args value address - std::vector args = while_op.block_args(); - for (int i = 0; i < args.size(); i++) { - const pir::Value& value = args[i]; - while_op_input_arg_address.push_back(&(*(value.impl()))); - } - - // mapping - for (int index = 0; index < while_op_input_value_address.size(); index++) { - pir_parser.while_op_input_value_map[while_op_input_arg_address[index]] = - while_op_input_value_address[index]; + pir_parser.GetSubBlockOpOutputName(value), value.type())); } + pir_parser.GetWhileInputValuesAndArgsMappings(&while_op); std::vector sub_blocks_ops_copy(pir_parser.sub_blocks_ops); pir_parser.sub_blocks_ops.clear(); @@ -120,7 +100,7 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser, input_names.push_back(inputs_info[i].name); } for (size_t i = 0; i < op->num_results(); i++) { - output_names.push_back(pir_parser.GetOpOutputName(op->result(i))); + output_names.push_back(pir_parser.GetSubBlockOpOutputName(op->result(i))); } auto loop_node = temp_helper->MakeNode("Loop", input_names, output_names); AddAttribute(loop_node, "body", graph); diff --git a/paddle2onnx/parser/pir_parser.cc b/paddle2onnx/parser/pir_parser.cc index 0bad0c688..16b2d67e9 100644 --- a/paddle2onnx/parser/pir_parser.cc +++ b/paddle2onnx/parser/pir_parser.cc @@ -33,6 +33,7 @@ #include "paddle/pir/include/core/ir_context.h" #include "paddle2onnx/proto/p2o_paddle.pb.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle2onnx/mapper/data_helper.h" phi::DataType TransToPhiDataType(pir::Type dtype) { if (dtype.isa()) { @@ -106,10 +107,10 @@ std::string PaddlePirParser::GetOpOutputName(const pir::Value& source) const { std::string PaddlePirParser::GetSubBlockOpOutputName( const pir::Value& source) const { - auto it = while_op_input_value_map.find(&(*(source.impl()))); + auto it = while_op_values_args_map.find(&(*(source.impl()))); pir::Operation* op; uint32_t output_idx; - if (it != while_op_input_value_map.end()) { + if (it != while_op_values_args_map.end()) { pir::Value value(it->second); op = value.defining_op(); output_idx = value.dyn_cast().index(); @@ -889,6 +890,39 @@ void PaddlePirParser::GetOpAttr(const pir::Operation* op, "Cannot found attribute %s in op %s", name, op->name())); } +void PaddlePirParser::GetOpScalarValue(int64_t op_id, + bool if_in_sub_block, + const std::string& scalar_attr_name, + ScalarData* scalar_data) const { + pir::Operation* op = + if_in_sub_block ? sub_blocks_ops[op_id] : global_blocks_ops[op_id]; + PADDLE_ENFORCE_EQ( + OpHasAttr(op, scalar_attr_name), + true, + common::errors::InvalidArgument( + "Cannot found attribute %s in op %s", scalar_attr_name, op->name())); + auto attr = op->attribute(scalar_attr_name); + if (attr.isa()) { + *scalar_data = + static_cast(attr.dyn_cast<::pir::DoubleAttribute>().data()); + } else if (attr.isa()) { + *scalar_data = + static_cast(attr.dyn_cast<::pir::FloatAttribute>().data()); + } else if (attr.isa()) { + *scalar_data = + static_cast(attr.dyn_cast<::pir::Int64Attribute>().data()); + } else if (attr.isa()) { + *scalar_data = + static_cast(attr.dyn_cast<::pir::Int32Attribute>().data()); + } else if (attr.isa()) { + *scalar_data = + static_cast(attr.dyn_cast<::pir::BoolAttribute>().data()); + } else { + Assert(false, + "ScalarData only support double, float, int64_t, int32_t and bool " + "now."); + } +} std::vector PaddlePirParser::GetOpInput( int64_t op_id, int64_t input_idx, bool if_in_sub_block) const { PADDLE_ENFORCE_GT(input_idx, @@ -1025,4 +1059,34 @@ P2ODataType PaddlePirParser::TransPirDataType2OldIrDataType( "PaddlePirParser::TransPirDataType2OnnxDataType."); } } +void PaddlePirParser::GetWhileInputValuesAndArgsMappings( + paddle::dialect::WhileOp* while_op) const { + // mapping args and inputs in while op using while_op_values_args_map + std::vector while_op_input_value_address; + std::vector while_op_input_arg_address; + // record input value address + for (int index = 1; index < while_op->num_operands(); index++) { + const pir::Value& value = while_op->operand_source(index); + while_op_input_value_address.push_back( + &(*(value).impl())); // get value address + } + // record args value address + std::vector args = while_op->block_args(); + for (int i = 0; i < args.size(); i++) { + const pir::Value& value = args[i]; + while_op_input_arg_address.push_back(&(*(value.impl()))); + } + + // mapping + for (int index = 0; index < while_op_input_value_address.size(); index++) { + auto arg_addr = while_op_input_arg_address[index]; + if (while_op_values_args_map.count(arg_addr)) continue; + auto value_addr = while_op_input_value_address[index]; + while (while_op_values_args_map.count(value_addr)) { + value_addr = while_op_values_args_map[value_addr]; + } + while_op_values_args_map[arg_addr] = value_addr; + } +} + } // namespace paddle2onnx diff --git a/paddle2onnx/parser/pir_parser.h b/paddle2onnx/parser/pir_parser.h index 37c8cc9ac..0c4f56934 100644 --- a/paddle2onnx/parser/pir_parser.h +++ b/paddle2onnx/parser/pir_parser.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include #include "paddle/common/errors.h" #include "paddle/phi/common/data_type.h" @@ -22,9 +23,11 @@ #include "paddle/pir/include/core/value.h" #include "paddle2onnx/parser/tensor_utils.h" #include "paddle2onnx/proto/p2o_paddle.pb.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" namespace paddle2onnx { class PaddlePirParser { public: + typedef std::variant ScalarData; bool Init(const std::string& _model, const std::string& _params = ""); std::map params; std::shared_ptr pir_program_; @@ -40,8 +43,8 @@ class PaddlePirParser { // recoring set of operators for all blocks std::set total_blocks_ops; // recording args of while op body name info - std::unordered_map - while_op_input_value_map; + mutable std::unordered_map + while_op_values_args_map; int NumOfBlocks() const; // int NumOfOps(int block_idx) const; int NumOfProgramOps() const; @@ -90,6 +93,11 @@ class PaddlePirParser { const std::string& name, std::vector* res) const; bool OpHasAttr(pir::Operation* op, const std::string& name) const; + + void GetOpScalarValue(int64_t op_id, + bool if_in_sub_block, + const std::string& scalar_attr_name, + ScalarData* scalar_data) const; std::string GetSubBlockOpOutputName(const pir::Value& source) const; std::vector GetOpInput(int64_t op_id, int64_t input_idx, @@ -265,6 +273,8 @@ class PaddlePirParser { std::string tensor_arr_name) const; std::string GetTensorArrayName(int64_t op_id, bool if_in_sub_block) const; std::string GenOpInputOutputName(const std::string& name) const; + void GetWhileInputValuesAndArgsMappings( + paddle::dialect::WhileOp *while_op) const; private: bool IsAttrVar(const pir::Operation* op, const int64_t& attr_id) const; diff --git a/tests/onnxbase.py b/tests/onnxbase.py index a82d6c06d..9632dc10a 100644 --- a/tests/onnxbase.py +++ b/tests/onnxbase.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import onnx from inspect import isfunction import logging from onnxruntime import InferenceSession @@ -20,7 +21,7 @@ import paddle import paddle2onnx import paddle.static as static -from paddle2onnx.convert import dygraph2onnx +from paddle2onnx.convert import dygraph2onnx, decompose_program import shutil from functools import wraps @@ -231,6 +232,8 @@ def __init__( self.input_spec_shape = input_spec_shape self.input_dtype = [] self.res_fict = {} + self.dist_prim_all = False + self.auto_upgrade_opset = False if isfunction(self.func): # self._func = self.BuildFunc(self.func, **self.kwargs_dict_dygraph["params_group1"]) @@ -283,6 +286,7 @@ def set_input_data(self, group_name, *args): self.input_feed[str(i)] = in_data.numpy() i += 1 + self.input_feed_backup = self.input_feed def set_device_mode(self, is_gpu=True): if paddle.device.is_compiled_with_cuda() is True and is_gpu: @@ -351,15 +355,24 @@ def _mk_onnx_res(self, ver): """ make onnx res """ + model_path = os.path.join( + self.pwd, self.name, self.name + "_" + str(ver) + ".onnx" + ) + model = onnx.load(model_path) sess = InferenceSession( - os.path.join(self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"), + model_path, providers=["CPUExecutionProvider"], ) input_names = sess.get_inputs() temp_dict = {} + self.input_feed = self.input_feed_backup for key, value in self.input_feed.items(): temp_dict[input_names[int(key)].name] = value self.input_feed = temp_dict + + input_feed = {} + if len(model.graph.input) == 0: + return sess.run(output_names=None, input_feed=input_feed) ort_outs = sess.run(output_names=None, input_feed=self.input_feed) return ort_outs @@ -477,13 +490,25 @@ def run(self): # # clip extra model_file = original_model_file + # clip extra + model_file = None + if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]: + if self.dist_prim_all and self.auto_upgrade_opset: + model_file = decompose_program(original_model_file) + else: + model_file = original_model_file + else: + model_file = os.path.join(self.name, "cliped_model.pdmodel") + self.clip_extra_program_only(original_model_file, model_file) + for v in self._version: onnx_model_str = paddle2onnx.export( model_file, # model_filename params_file, # params_filename - None, + None, # save_file v, # opset_version False, # auto_upgrade_opset + False, # dist_prim_all True, # verbose True, # enable_onnx_checker True, # enable_experimental_op diff --git a/tests/test_abs.py b/tests/test_abs.py index 9b5950017..0ac9eedc7 100644 --- a/tests/test_abs.py +++ b/tests/test_abs.py @@ -35,49 +35,15 @@ def forward(self, inputs): @_test_with_pir -def test_abs_9(): +def test_abs_13(): """ api: paddle.abs - op version: 9 - """ - op = Net() - op.eval() - # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, "abs", [9]) - obj.set_input_data( - "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), - ) - obj.run() - - -@_test_with_pir -def test_abs_10(): - """ - api: paddle.abs - op version: 10 - """ - op = Net() - op.eval() - # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, "abs", [10]) - obj.set_input_data( - "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), - ) - obj.run() - - -@_test_with_pir -def test_abs_11(): - """ - api: paddle.abs - op version: 11 + op version: 12 """ op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, "abs", [11]) + obj = APIOnnx(op, "abs", [13]) obj.set_input_data( "input_data", paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), @@ -85,16 +51,15 @@ def test_abs_11(): obj.run() -@_test_with_pir -def test_abs_12(): +def test_abs_18(): """ api: paddle.abs - op version: 12 + op version: 18 """ op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, "abs", [12]) + obj = APIOnnx(op, "abs", [18]) obj.set_input_data( "input_data", paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), diff --git a/tests/test_auto_scan_unary_ops.py b/tests/test_auto_scan_unary_ops.py index 8ebe99397..2162caac9 100755 --- a/tests/test_auto_scan_unary_ops.py +++ b/tests/test_auto_scan_unary_ops.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure import hypothesis.strategies as st -import numpy as np import unittest import paddle @@ -56,7 +54,7 @@ } opset_version_map = { - "abs": [7, 13, 15], + "abs": [13, 18], "acos": [7, 15], "asin": [7, 15], "atan": [7, 15], @@ -103,9 +101,8 @@ class TestUnaryOPConvert(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=2, max_value=20), min_size=0, max_size=4)) + st.lists(st.integers(min_value=2, max_value=20), min_size=0, max_size=4) + ) data_shapes = input_shape dtype = draw(st.sampled_from(["float32"])) config = { diff --git a/tests/test_fft_r2c.py b/tests/test_fft_r2c.py new file mode 100644 index 000000000..dbe6d41f7 --- /dev/null +++ b/tests/test_fft_r2c.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from onnxbase import APIOnnx +from onnxbase import randtool +from onnxbase import _test_only_pir + + +class Net(paddle.nn.Layer): + """ + simple Net + """ + + def __init__(self): + super(Net, self).__init__() + + def forward(self, inputs): + """ + forward + """ + x = paddle.fft.rfft(inputs, axis=1) + x = paddle.abs(x) + return x + + +@_test_only_pir +def test_fftr2c_17(): + """ + api: paddle.fft.rfft + op version: 17 + """ + op = Net() + op.eval() + # net, name, ver_list, delta=1e-6, rtol=1e-5 + obj = APIOnnx(op, "fft_r2c", [17]) + obj.set_input_data( + "input_data", + paddle.to_tensor(randtool("float", -1, 1, [3, 10, 10]).astype("float32")), + ) + obj.run()