Skip to content

Commit

Permalink
Merge pull request #1500 from zhanghonggeng/support_decompose_program
Browse files Browse the repository at this point in the history
Support decompose pir program in paddle2onnx
  • Loading branch information
risemeup1 authored Feb 14, 2025
2 parents 400040d + 12339cd commit 5735b0a
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 8 deletions.
7 changes: 7 additions & 0 deletions paddle2onnx/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def arg_parser():
default=True,
help="whether enable auto_update_opset, default is True",
)
parser.add_argument(
"--enable_dist_prim_all",
type=ast.literal_eval,
default=False,
help="whether enable dist_prim_all, default is False",
)
parser.add_argument(
"--external_filename",
type=str,
Expand Down Expand Up @@ -160,6 +166,7 @@ def main():
save_file=args.save_file,
opset_version=args.opset_version,
auto_upgrade_opset=args.enable_auto_update_opset,
dist_prim_all=args.enable_dist_prim_all,
verbose=True,
enable_onnx_checker=args.enable_onnx_checker,
enable_experimental_op=True,
Expand Down
92 changes: 92 additions & 0 deletions paddle2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,96 @@

import os
import paddle
import tempfile
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
from paddle2onnx.utils import logging, paddle_jit_save_configs
from contextlib import contextmanager
from paddle.decomposition import decomp
from paddle.base.executor import global_scope


def load_model(model_filename):
"""Loads the pir model from json file."""
assert os.path.exists(
model_filename
), f"Model file {model_filename} does not exist."
if model_filename.endswith(".json"):
model_filename = model_filename[:-5]
return paddle.jit.load(model_filename)


def compare_programs(original_program, new_program):
"""Compares two pir programs' operations."""
original_ops = [op.name() for op in original_program.global_block().ops]
new_ops = [op.name() for op in new_program.global_block().ops]
return original_ops == new_ops


def save_program(program, model_file):
"""Saves the decomposed program to a file."""
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

tmp_dir = tempfile.mkdtemp()
filename = os.path.basename(model_file) + "_decompose"
filename_without_extension, _ = os.path.splitext(filename)
save_dir = os.path.join(tmp_dir, filename_without_extension)

# Find feed and fetch operations
feed, fetch = [], []
for op in program.global_block().ops:
if op.name() == "pd_op.feed":
feed.extend(op.results())
if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output":
fetch.extend(op.operands_source())

with paddle.pir_utils.IrGuard():
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)

new_model_file = save_dir + ".json"
assert os.path.exists(
new_model_file
), f"Pir Model file {new_model_file} does not exist."
logging.info(f"Decomposed Model file path: {new_model_file}")
return new_model_file


def load_parameter(program):
params = []
opts = []
for var in program.list_vars():
if var.is_parameter or var.get_defining_op().name() == "builtin.parameter":
params.append(var)
elif var.persistable and var.get_defining_op().name() == "pd_op.data":
opts.append(var)
vars_list = params + opts
vars = [var for var in vars_list if var.persistable]

if vars is None:
return

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddle.base.libpaddle.pir.create_loaded_parameter(
vars, global_scope(), exe._default_executor
)


def decompose_program(model_filename):
"""Decomposes the given pir program."""
model = load_model(model_filename)
new_program = model.program().clone()
with decomp.prim_guard():
decomp.decompose_dist_program(new_program)

if compare_programs(model.program(), new_program):
return model_filename

# logging.info(f"Origin program: {model.program()}")
# logging.info(f"Decomposed program: {new_program}")

load_parameter(new_program)
return save_program(new_program, model_filename)


def get_old_ir_guard():
Expand All @@ -39,6 +126,7 @@ def export(
save_file=None,
opset_version=7,
auto_upgrade_opset=True,
dist_prim_all=False,
verbose=True,
enable_onnx_checker=True,
enable_experimental_op=True,
Expand All @@ -49,6 +137,10 @@ def export(
external_file="",
export_fp16_model=False,
):
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
if dist_prim_all and auto_upgrade_opset:
model_filename = decompose_program(model_filename)

deploy_backend = deploy_backend.lower()
if custom_op_info is None:
onnx_model_str = c_p2o.export(
Expand Down
1 change: 1 addition & 0 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ REGISTER_PIR_MAPPER(size, SizeMapper)
REGISTER_MAPPER(softmax, SoftMaxMapper)
REGISTER_PIR_MAPPER(softmax, SoftMaxMapper)
REGISTER_MAPPER(softplus, ActivationMapper)
REGISTER_PIR_MAPPER(rsqrt, RsqrtMapper)
REGISTER_MAPPER(softshrink, SoftShrinkMapper)
REGISTER_MAPPER(softsign, ActivationMapper)
REGISTER_MAPPER(sqrt, ActivationMapper)
Expand Down
5 changes: 5 additions & 0 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ class RsqrtMapper : public Mapper {
int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
RsqrtMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {}
void Opset7() override;
};

Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/tensor/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ REGISTER_PIR_MAPPER(elementwise_mod, ElementWiseModMapper)
REGISTER_PIR_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper)

int32_t ElementwiseMapper::GetMinOpsetVersion(bool verbose) {
if (OpType() == "elementwise_min" || OpType() == "elementwise_max") {
if (convert_pir_op_name(OpType()) == "elementwise_min" ||
convert_pir_op_name(OpType()) == "elementwise_max") {
Logger(verbose, 8) << RequireOpset(8) << std::endl;
return 8;
}
Expand Down
14 changes: 10 additions & 4 deletions paddle2onnx/mapper/tensor/reduce_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int32_t ReduceMapperSum::GetMinOpsetVersion(bool verbose) {
void ReduceMapperSum::Opset13() {
auto axis_name_ = "dim";
GetAttr("keep_dim", &keep_dim_);
if (!in_pir_mode) {
if (!in_pir_mode) {
GetAttr("reduce_all", &reduce_all_);
GetAttr("in_dtype", &in_dtype_);
GetAttr("out_dtype", &out_dtype_);
Expand All @@ -47,6 +47,13 @@ if (!in_pir_mode) {
}

auto x_info = GetInput("X");
auto x_name = x_info[0].name;
auto x_tpye = x_info[0].dtype;
if (x_info[0].dtype == P2ODataType::BOOL) {
x_name = helper_->AutoCast(x_name, x_tpye, P2ODataType::INT32);
x_tpye = P2ODataType::INT32;
}

std::string dims;
if (IsAttrVar(axis_name_)) {
auto info = GetAttrVar(axis_name_);
Expand All @@ -61,7 +68,7 @@ if (!in_pir_mode) {
}

// Add attribute
auto reduce_node = helper_->MakeNode("ReduceSum", {x_info[0].name, dims});
auto reduce_node = helper_->MakeNode("ReduceSum", {x_name, dims});
AddAttribute(reduce_node, "keepdims", static_cast<int64_t>(keep_dim_));
auto out_node_name = reduce_node->output(0);

Expand All @@ -73,7 +80,6 @@ if (!in_pir_mode) {
out_node_name = helper_->Reshape(out_node_name, {-1});
}
auto out_info = GetOutput("Out");
helper_->AutoCast(out_node_name, out_info[0].name,
x_info[0].dtype, out_info[0].dtype);
helper_->AutoCast(out_node_name, out_info[0].name, x_tpye, out_info[0].dtype);
}
} // namespace paddle2onnx
21 changes: 18 additions & 3 deletions tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx
from inspect import isfunction
import logging
from onnxruntime import InferenceSession
Expand All @@ -20,7 +21,7 @@
import paddle
import paddle.static as static
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
from paddle2onnx.convert import dygraph2onnx
from paddle2onnx.convert import dygraph2onnx, decompose_program
import shutil
from functools import wraps

Expand Down Expand Up @@ -231,6 +232,8 @@ def __init__(
self.input_spec_shape = input_spec_shape
self.input_dtype = []
self.res_fict = {}
self.dist_prim_all = False
self.auto_upgrade_opset = False

if isfunction(self.func):
# self._func = self.BuildFunc(self.func, **self.kwargs_dict_dygraph["params_group1"])
Expand Down Expand Up @@ -351,10 +354,19 @@ def _mk_onnx_res(self, ver):
"""
make onnx res
"""
model_path = os.path.join(
self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"
)
model = onnx.load(model_path)
sess = InferenceSession(
os.path.join(self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"),
model_path,
providers=["CPUExecutionProvider"],
)

input_feed = {}
if len(model.graph.input) == 0:
return sess.run(output_names=None, input_feed=input_feed)

ort_outs = sess.run(output_names=None, input_feed=self.input_feed)
return ort_outs

Expand Down Expand Up @@ -473,7 +485,10 @@ def run(self):
# clip extra
model_file = None
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
model_file = original_model_file
if self.dist_prim_all and self.auto_upgrade_opset:
model_file = decompose_program(original_model_file)
else:
model_file = original_model_file
else:
model_file = os.path.join(self.name, "cliped_model.pdmodel")
self.clip_extra_program_only(original_model_file, model_file)
Expand Down

0 comments on commit 5735b0a

Please sign in to comment.