Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decompose pir program in paddle2onnx #1500

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle2onnx/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def arg_parser():
default=True,
help="whether enable auto_update_opset, default is True",
)
parser.add_argument(
"--enable_dist_prim_all",
type=ast.literal_eval,
default=False,
help="whether enable dist_prim_all, default is False",
)
parser.add_argument(
"--external_filename",
type=str,
Expand Down Expand Up @@ -160,6 +166,7 @@ def main():
save_file=args.save_file,
opset_version=args.opset_version,
auto_upgrade_opset=args.enable_auto_update_opset,
dist_prim_all=args.enable_dist_prim_all,
verbose=True,
enable_onnx_checker=args.enable_onnx_checker,
enable_experimental_op=True,
Expand Down
92 changes: 92 additions & 0 deletions paddle2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,96 @@

import os
import paddle
import tempfile
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
from paddle2onnx.utils import logging, paddle_jit_save_configs
from contextlib import contextmanager
from paddle.decomposition import decomp
from paddle.base.executor import global_scope


def load_model(model_filename):
"""Loads the pir model from json file."""
assert os.path.exists(
model_filename
), f"Model file {model_filename} does not exist."
if model_filename.endswith(".json"):
model_filename = model_filename[:-5]
return paddle.jit.load(model_filename)


def compare_programs(original_program, new_program):
"""Compares two pir programs' operations."""
original_ops = [op.name() for op in original_program.global_block().ops]
new_ops = [op.name() for op in new_program.global_block().ops]
return original_ops == new_ops


def save_program(program, model_file):
"""Saves the decomposed program to a file."""
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

tmp_dir = tempfile.mkdtemp()
filename = os.path.basename(model_file) + "_decompose"
filename_without_extension, _ = os.path.splitext(filename)
save_dir = os.path.join(tmp_dir, filename_without_extension)

# Find feed and fetch operations
feed, fetch = [], []
for op in program.global_block().ops:
if op.name() == "pd_op.feed":
feed.extend(op.results())
if op.name() == "pd_op.fetch" or op.name() == "builtin.shadow_output":
fetch.extend(op.operands_source())

with paddle.pir_utils.IrGuard():
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)

new_model_file = save_dir + ".json"
assert os.path.exists(
new_model_file
), f"Pir Model file {new_model_file} does not exist."
logging.info(f"Decomposed Model file path: {new_model_file}")
return new_model_file


def load_parameter(program):
params = []
opts = []
for var in program.list_vars():
if var.is_parameter or var.get_defining_op().name() == "builtin.parameter":
params.append(var)
elif var.persistable and var.get_defining_op().name() == "pd_op.data":
opts.append(var)
vars_list = params + opts
vars = [var for var in vars_list if var.persistable]

if vars is None:
return

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddle.base.libpaddle.pir.create_loaded_parameter(
vars, global_scope(), exe._default_executor
)


def decompose_program(model_filename):
"""Decomposes the given pir program."""
model = load_model(model_filename)
new_program = model.program().clone()
with decomp.prim_guard():
decomp.decompose_dist_program(new_program)

if compare_programs(model.program(), new_program):
return model_filename

# logging.info(f"Origin program: {model.program()}")
# logging.info(f"Decomposed program: {new_program}")

load_parameter(new_program)
return save_program(new_program, model_filename)


def get_old_ir_guard():
Expand All @@ -39,6 +126,7 @@ def export(
save_file=None,
opset_version=7,
auto_upgrade_opset=True,
dist_prim_all=False,
verbose=True,
enable_onnx_checker=True,
enable_experimental_op=True,
Expand All @@ -49,6 +137,10 @@ def export(
external_file="",
export_fp16_model=False,
):
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
if dist_prim_all and auto_upgrade_opset:
model_filename = decompose_program(model_filename)

deploy_backend = deploy_backend.lower()
if custom_op_info is None:
onnx_model_str = c_p2o.export(
Expand Down
1 change: 1 addition & 0 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ REGISTER_PIR_MAPPER(size, SizeMapper)
REGISTER_MAPPER(softmax, SoftMaxMapper)
REGISTER_PIR_MAPPER(softmax, SoftMaxMapper)
REGISTER_MAPPER(softplus, ActivationMapper)
REGISTER_PIR_MAPPER(rsqrt, RsqrtMapper)
REGISTER_MAPPER(softshrink, SoftShrinkMapper)
REGISTER_MAPPER(softsign, ActivationMapper)
REGISTER_MAPPER(sqrt, ActivationMapper)
Expand Down
5 changes: 5 additions & 0 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ class RsqrtMapper : public Mapper {
int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
RsqrtMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {}
void Opset7() override;
};

Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/tensor/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ REGISTER_PIR_MAPPER(elementwise_mod, ElementWiseModMapper)
REGISTER_PIR_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper)

int32_t ElementwiseMapper::GetMinOpsetVersion(bool verbose) {
if (OpType() == "elementwise_min" || OpType() == "elementwise_max") {
if (convert_pir_op_name(OpType()) == "elementwise_min" ||
convert_pir_op_name(OpType()) == "elementwise_max") {
Logger(verbose, 8) << RequireOpset(8) << std::endl;
return 8;
}
Expand Down
14 changes: 10 additions & 4 deletions paddle2onnx/mapper/tensor/reduce_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int32_t ReduceMapperSum::GetMinOpsetVersion(bool verbose) {
void ReduceMapperSum::Opset13() {
auto axis_name_ = "dim";
GetAttr("keep_dim", &keep_dim_);
if (!in_pir_mode) {
if (!in_pir_mode) {
GetAttr("reduce_all", &reduce_all_);
GetAttr("in_dtype", &in_dtype_);
GetAttr("out_dtype", &out_dtype_);
Expand All @@ -47,6 +47,13 @@ if (!in_pir_mode) {
}

auto x_info = GetInput("X");
auto x_name = x_info[0].name;
auto x_tpye = x_info[0].dtype;
if (x_info[0].dtype == P2ODataType::BOOL) {
x_name = helper_->AutoCast(x_name, x_tpye, P2ODataType::INT32);
x_tpye = P2ODataType::INT32;
}

std::string dims;
if (IsAttrVar(axis_name_)) {
auto info = GetAttrVar(axis_name_);
Expand All @@ -61,7 +68,7 @@ if (!in_pir_mode) {
}

// Add attribute
auto reduce_node = helper_->MakeNode("ReduceSum", {x_info[0].name, dims});
auto reduce_node = helper_->MakeNode("ReduceSum", {x_name, dims});
AddAttribute(reduce_node, "keepdims", static_cast<int64_t>(keep_dim_));
auto out_node_name = reduce_node->output(0);

Expand All @@ -73,7 +80,6 @@ if (!in_pir_mode) {
out_node_name = helper_->Reshape(out_node_name, {-1});
}
auto out_info = GetOutput("Out");
helper_->AutoCast(out_node_name, out_info[0].name,
x_info[0].dtype, out_info[0].dtype);
helper_->AutoCast(out_node_name, out_info[0].name, x_tpye, out_info[0].dtype);
}
} // namespace paddle2onnx
21 changes: 18 additions & 3 deletions tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx
from inspect import isfunction
import logging
from onnxruntime import InferenceSession
Expand All @@ -20,7 +21,7 @@
import paddle
import paddle.static as static
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
from paddle2onnx.convert import dygraph2onnx
from paddle2onnx.convert import dygraph2onnx, decompose_program
import shutil
from functools import wraps

Expand Down Expand Up @@ -231,6 +232,8 @@ def __init__(
self.input_spec_shape = input_spec_shape
self.input_dtype = []
self.res_fict = {}
self.dist_prim_all = False
self.auto_upgrade_opset = False

if isfunction(self.func):
# self._func = self.BuildFunc(self.func, **self.kwargs_dict_dygraph["params_group1"])
Expand Down Expand Up @@ -351,10 +354,19 @@ def _mk_onnx_res(self, ver):
"""
make onnx res
"""
model_path = os.path.join(
self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"
)
model = onnx.load(model_path)
sess = InferenceSession(
os.path.join(self.pwd, self.name, self.name + "_" + str(ver) + ".onnx"),
model_path,
providers=["CPUExecutionProvider"],
)

input_feed = {}
if len(model.graph.input) == 0:
return sess.run(output_names=None, input_feed=input_feed)

ort_outs = sess.run(output_names=None, input_feed=self.input_feed)
return ort_outs

Expand Down Expand Up @@ -473,7 +485,10 @@ def run(self):
# clip extra
model_file = None
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
model_file = original_model_file
if self.dist_prim_all and self.auto_upgrade_opset:
model_file = decompose_program(original_model_file)
else:
model_file = original_model_file
else:
model_file = os.path.join(self.name, "cliped_model.pdmodel")
self.clip_extra_program_only(original_model_file, model_file)
Expand Down
Loading