diff --git a/paddle2onnx/command.py b/paddle2onnx/command.py index 3ee21879b..b101158ab 100755 --- a/paddle2onnx/command.py +++ b/paddle2onnx/command.py @@ -92,6 +92,12 @@ def arg_parser(): default=True, help="whether enable auto_update_opset, default is True", ) + parser.add_argument( + "--enable_dist_prim_all", + type=ast.literal_eval, + default=False, + help="whether enable dist_prim_all, default is False", + ) parser.add_argument( "--external_filename", type=str, @@ -160,6 +166,7 @@ def main(): save_file=args.save_file, opset_version=args.opset_version, auto_upgrade_opset=args.enable_auto_update_opset, + dist_prim_all=args.enable_dist_prim_all, verbose=True, enable_onnx_checker=args.enable_onnx_checker, enable_experimental_op=True, diff --git a/paddle2onnx/convert.py b/paddle2onnx/convert.py index ff75f9b58..a35b6cc24 100755 --- a/paddle2onnx/convert.py +++ b/paddle2onnx/convert.py @@ -14,9 +14,96 @@ import os import paddle +import tempfile import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o from paddle2onnx.utils import logging, paddle_jit_save_configs from contextlib import contextmanager +from paddle.decomposition import decomp +from paddle.base.executor import global_scope + + +def load_model(model_filename): + """Loads the pir model from json file.""" + assert os.path.exists( + model_filename + ), f"Model file {model_filename} does not exist." + if model_filename.endswith(".json"): + model_filename = model_filename[:-5] + return paddle.jit.load(model_filename) + + +def compare_programs(original_program, new_program): + """Compares two pir programs' operations.""" + original_ops = [op.name() for op in original_program.global_block().ops] + new_ops = [op.name() for op in new_program.global_block().ops] + return original_ops == new_ops + + +def save_program(program, model_file): + """Saves the decomposed program to a file.""" + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + tmp_dir = tempfile.mkdtemp() + filename = os.path.basename(model_file) + "_decompose" + filename_without_extension, _ = os.path.splitext(filename) + save_dir = os.path.join(tmp_dir, filename_without_extension) + + # Find feed and fetch operations + feed, fetch = [], [] + for op in program.global_block().ops: + if op.name() == "pd_op.feed": + feed.extend(op.results()) + if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output": + fetch.extend(op.operands_source()) + + with paddle.pir_utils.IrGuard(): + paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program) + + new_model_file = save_dir + ".json" + assert os.path.exists( + new_model_file + ), f"Pir Model file {new_model_file} does not exist." + logging.info(f"Decomposed Model file path: {new_model_file}") + return new_model_file + + +def load_parameter(program): + params = [] + opts = [] + for var in program.list_vars(): + if var.is_parameter or var.get_defining_op().name() == "builtin.parameter": + params.append(var) + elif var.persistable and var.get_defining_op().name() == "pd_op.data": + opts.append(var) + vars_list = params + opts + vars = [var for var in vars_list if var.persistable] + + if vars is None: + return + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + paddle.base.libpaddle.pir.create_loaded_parameter( + vars, global_scope(), exe._default_executor + ) + + +def decompose_program(model_filename): + """Decomposes the given pir program.""" + model = load_model(model_filename) + new_program = model.program().clone() + with decomp.prim_guard(): + decomp.decompose_dist_program(new_program) + + if compare_programs(model.program(), new_program): + return model_filename + + # logging.info(f"Origin program: {model.program()}") + # logging.info(f"Decomposed program: {new_program}") + + load_parameter(new_program) + return save_program(new_program, model_filename) def get_old_ir_guard(): @@ -39,6 +126,7 @@ def export( save_file=None, opset_version=7, auto_upgrade_opset=True, + dist_prim_all=False, verbose=True, enable_onnx_checker=True, enable_experimental_op=True, @@ -49,6 +137,10 @@ def export( external_file="", export_fp16_model=False, ): + if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]: + if dist_prim_all and auto_upgrade_opset: + model_filename = decompose_program(model_filename) + deploy_backend = deploy_backend.lower() if custom_op_info is None: onnx_model_str = c_p2o.export( diff --git a/paddle2onnx/mapper/activation/activation.cc b/paddle2onnx/mapper/activation/activation.cc index 12bc578e2..eb23f3160 100644 --- a/paddle2onnx/mapper/activation/activation.cc +++ b/paddle2onnx/mapper/activation/activation.cc @@ -65,6 +65,7 @@ REGISTER_PIR_MAPPER(size, SizeMapper) REGISTER_MAPPER(softmax, SoftMaxMapper) REGISTER_PIR_MAPPER(softmax, SoftMaxMapper) REGISTER_MAPPER(softplus, ActivationMapper) +REGISTER_PIR_MAPPER(rsqrt, RsqrtMapper) REGISTER_MAPPER(softshrink, SoftShrinkMapper) REGISTER_MAPPER(softsign, ActivationMapper) REGISTER_MAPPER(sqrt, ActivationMapper) diff --git a/paddle2onnx/mapper/activation/activation.h b/paddle2onnx/mapper/activation/activation.h index 9a9e65ad3..7f67d6b1e 100644 --- a/paddle2onnx/mapper/activation/activation.h +++ b/paddle2onnx/mapper/activation/activation.h @@ -294,6 +294,11 @@ class RsqrtMapper : public Mapper { int64_t block_id, int64_t op_id) : Mapper(p, helper, block_id, op_id) {} + RsqrtMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t i, + bool c) + : Mapper(p, helper, i, c) {} void Opset7() override; }; diff --git a/paddle2onnx/mapper/tensor/elementwise.cc b/paddle2onnx/mapper/tensor/elementwise.cc index 28f07820d..929bab558 100755 --- a/paddle2onnx/mapper/tensor/elementwise.cc +++ b/paddle2onnx/mapper/tensor/elementwise.cc @@ -34,7 +34,8 @@ REGISTER_PIR_MAPPER(elementwise_mod, ElementWiseModMapper) REGISTER_PIR_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper) int32_t ElementwiseMapper::GetMinOpsetVersion(bool verbose) { - if (OpType() == "elementwise_min" || OpType() == "elementwise_max") { + if (convert_pir_op_name(OpType()) == "elementwise_min" || + convert_pir_op_name(OpType()) == "elementwise_max") { Logger(verbose, 8) << RequireOpset(8) << std::endl; return 8; } diff --git a/paddle2onnx/mapper/tensor/reduce_sum.cc b/paddle2onnx/mapper/tensor/reduce_sum.cc index 87f174c00..08380840a 100644 --- a/paddle2onnx/mapper/tensor/reduce_sum.cc +++ b/paddle2onnx/mapper/tensor/reduce_sum.cc @@ -27,7 +27,7 @@ int32_t ReduceMapperSum::GetMinOpsetVersion(bool verbose) { void ReduceMapperSum::Opset13() { auto axis_name_ = "dim"; GetAttr("keep_dim", &keep_dim_); -if (!in_pir_mode) { + if (!in_pir_mode) { GetAttr("reduce_all", &reduce_all_); GetAttr("in_dtype", &in_dtype_); GetAttr("out_dtype", &out_dtype_); @@ -47,6 +47,13 @@ if (!in_pir_mode) { } auto x_info = GetInput("X"); + auto x_name = x_info[0].name; + auto x_tpye = x_info[0].dtype; + if (x_info[0].dtype == P2ODataType::BOOL) { + x_name = helper_->AutoCast(x_name, x_tpye, P2ODataType::INT32); + x_tpye = P2ODataType::INT32; + } + std::string dims; if (IsAttrVar(axis_name_)) { auto info = GetAttrVar(axis_name_); @@ -61,7 +68,7 @@ if (!in_pir_mode) { } // Add attribute - auto reduce_node = helper_->MakeNode("ReduceSum", {x_info[0].name, dims}); + auto reduce_node = helper_->MakeNode("ReduceSum", {x_name, dims}); AddAttribute(reduce_node, "keepdims", static_cast(keep_dim_)); auto out_node_name = reduce_node->output(0); @@ -73,7 +80,6 @@ if (!in_pir_mode) { out_node_name = helper_->Reshape(out_node_name, {-1}); } auto out_info = GetOutput("Out"); - helper_->AutoCast(out_node_name, out_info[0].name, - x_info[0].dtype, out_info[0].dtype); + helper_->AutoCast(out_node_name, out_info[0].name, x_tpye, out_info[0].dtype); } } // namespace paddle2onnx diff --git a/tests/onnxbase.py b/tests/onnxbase.py index 7be69c018..f9a8ee778 100644 --- a/tests/onnxbase.py +++ b/tests/onnxbase.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import onnx from inspect import isfunction import logging from onnxruntime import InferenceSession @@ -20,7 +21,7 @@ import paddle import paddle.static as static import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o -from paddle2onnx.convert import dygraph2onnx +from paddle2onnx.convert import dygraph2onnx, decompose_program import shutil from functools import wraps @@ -231,6 +232,8 @@ def __init__( self.input_spec_shape = input_spec_shape self.input_dtype = [] self.res_fict = {} + self.dist_prim_all = False + self.auto_upgrade_opset = False if isfunction(self.func): # self._func = self.BuildFunc(self.func, **self.kwargs_dict_dygraph["params_group1"]) @@ -351,10 +354,19 @@ def _mk_onnx_res(self, ver): """ make onnx res """ + model_path = os.path.join( + self.pwd, self.name, self.name + "_" + str(ver) + ".onnx" + ) + model = onnx.load(model_path) sess = InferenceSession( - os.path.join(self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"), + model_path, providers=["CPUExecutionProvider"], ) + + input_feed = {} + if len(model.graph.input) == 0: + return sess.run(output_names=None, input_feed=input_feed) + ort_outs = sess.run(output_names=None, input_feed=self.input_feed) return ort_outs @@ -473,7 +485,10 @@ def run(self): # clip extra model_file = None if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]: - model_file = original_model_file + if self.dist_prim_all and self.auto_upgrade_opset: + model_file = decompose_program(original_model_file) + else: + model_file = original_model_file else: model_file = os.path.join(self.name, "cliped_model.pdmodel") self.clip_extra_program_only(original_model_file, model_file)