-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PT in paddle2onnx #1481
Support PT in paddle2onnx #1481
Conversation
risemeup1
commented
Jan 20, 2025
•
edited
Loading
edited
- 使用PT将旧IR转换成新IR,用户传的是infernece.pdmodel和inference.pdiparams,先转换成json格式,然后再走新IR下转ONNX逻辑,便于后续维护
- 支持多个算子转换
- 修复控制流算子bug
35cd933
to
a24dee7
Compare
5af9855
to
9071329
Compare
1272ed8
to
45eef56
Compare
paddle2onnx/mapper/nn/pool3d.h
Outdated
GetAttr("strides", &strides_); | ||
GetAttr("paddings", &pads_); | ||
GetAttr("ksize", &k_size_); | ||
if (OpType() != "max_pool3d_with_index") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成convert_pir_op_name(OpType())
paddle2onnx/mapper/onnx_helper.h
Outdated
else | ||
{ | ||
std::cout<<"dtype: "<<dtype; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除这里的打印,Assert提示中增加uint8
void BuiltinSliceMapper::Opset7() { | ||
auto input_info = GetInput(0); | ||
auto output_info = GetOutput(0); | ||
if (HasAttr("index")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index不存在的情况有吗?默认值是什么?
GaussianRandomMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i, | ||
bool c) | ||
: Mapper(p, helper, i, c) { | ||
in_pir_mode = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以删除
in_pir_mode = true;
// GetAttr("shape", &shape_);
paddle2onnx/parser/pir_parser.cc
Outdated
// print_stream << "ForwardProgram is :\n"; | ||
// pir_program_->Print(print_stream); | ||
// std::cout << "Program (fwd | bwd): \n" << print_stream.str() << | ||
// std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除注释
paddle2onnx/parser/pir_parser.cc
Outdated
if (op->name() == "pd_op.data" || op->name() == "pd_op.feed"){ | ||
std::string var_name = GenOpInputOutputName(op->name()); | ||
// std::string var_name = op->attribute<pir::StrAttribute>("name").AsString(); | ||
// inputs.push_back(GetTensorInfo(var_name, op->result(0).type())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除注释
tests/run.sh
Outdated
test_auto_scan_dist.py \ | ||
test_auto_scan_distribute_fpn_proposals1.py \ | ||
test_auto_scan_distribute_fpn_proposals_v2.py \ | ||
test_auto_scan_fill_constant_batch_size_like.py \ | ||
test_auto_scan_gaussian_random \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写错了,而且和38行重复了
tests/test_auto_scan_linspace.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加@_test_with_pir
tests/test_auto_scan_range.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加@_test_with_pir
tests/test_auto_scan_tile.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测为什么不测PT?
|