Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support some ops in pir #1416

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@
namespace paddle2onnx {

REGISTER_MAPPER(abs, ActivationMapper)
REGISTER_PIR_MAPPER(abs, ActivationMapper)
REGISTER_MAPPER(acos, ActivationMapper)
REGISTER_MAPPER(asin, ActivationMapper)
REGISTER_MAPPER(atan, ActivationMapper)
REGISTER_MAPPER(brelu, BReluMapper)
REGISTER_MAPPER(ceil, ActivationMapper)
REGISTER_MAPPER(cos, ActivationMapper)
REGISTER_PIR_MAPPER(cos, ActivationMapper)
REGISTER_MAPPER(elu, EluMapper)
REGISTER_MAPPER(erf, ActivationMapper)
REGISTER_MAPPER(exp, ActivationMapper)
REGISTER_PIR_MAPPER(exp, ActivationMapper)
REGISTER_MAPPER(floor, ActivationMapper)
REGISTER_PIR_MAPPER(floor, ActivationMapper)
REGISTER_MAPPER(gelu, GeluMapper)
REGISTER_PIR_MAPPER(gelu, GeluMapper)
REGISTER_MAPPER(leaky_relu, LeakyReluMapper)
REGISTER_PIR_MAPPER(leaky_relu, LeakyReluMapper)
REGISTER_MAPPER(log, ActivationMapper)
REGISTER_PIR_MAPPER(log, ActivationMapper)
REGISTER_MAPPER(log10, Log10Mapper)
REGISTER_MAPPER(log1p, Log1PMapper)
REGISTER_MAPPER(log2, Log2Mapper)
Expand All @@ -45,13 +52,17 @@ REGISTER_MAPPER(rsqrt, RsqrtMapper)
REGISTER_MAPPER(sel, ActivationMapper)
REGISTER_MAPPER(selu, SeluMapper)
REGISTER_MAPPER(silu, SiluMapper)
REGISTER_PIR_MAPPER(silu, SiluMapper)
REGISTER_MAPPER(sin, ActivationMapper)
REGISTER_PIR_MAPPER(sin, ActivationMapper)
REGISTER_MAPPER(size, SizeMapper)
REGISTER_MAPPER(softmax, SoftMaxMapper)
REGISTER_PIR_MAPPER(softmax, SoftMaxMapper)
REGISTER_MAPPER(softplus, ActivationMapper)
REGISTER_MAPPER(softshrink, SoftShrinkMapper)
REGISTER_MAPPER(softsign, ActivationMapper)
REGISTER_MAPPER(sqrt, ActivationMapper)
REGISTER_PIR_MAPPER(sqrt, ActivationMapper)
REGISTER_MAPPER(square, SquareMapper)
REGISTER_MAPPER(tan, ActivationMapper)
REGISTER_MAPPER(tanh, ActivationMapper)
Expand Down Expand Up @@ -85,7 +96,9 @@ void ActivationMapper::Opset7() {
auto output_info = GetOutput("Out");
auto iter = op_mapper_.find(convert_pir_op_name(OpType()));
Assert(op_mapper_.end() != iter,
"Cannot find " + convert_pir_op_name(OpType()) + " in activation op_mapper.");
"Cannot find " +
convert_pir_op_name(OpType()) +
" in activation op_mapper.");
if (convert_pir_op_name(OpType()) == "erf") {
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
Expand Down Expand Up @@ -367,7 +380,9 @@ void ThresholdedReluMapper::Opset10() {
void Log1PMapper::Opset7() {
auto x_info = GetInput("X");
auto out_info = GetOutput("Out");
auto one = helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), float(1.0));
auto one = helper_->Constant({},
GetOnnxDtype(x_info[0].dtype),
static_cast<float>(1.0));
auto input = helper_->MakeNode("Add", {x_info[0].name, one})->output(0);
helper_->MakeNode("Log", {input}, {out_info[0].name});
}
Expand Down
37 changes: 37 additions & 0 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ class LeakyReluMapper : public Mapper {
GetAttr("alpha", &alpha_);
}

LeakyReluMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
GetAttr("alpha", &alpha_);
}

void Opset7() override;

private:
Expand All @@ -136,6 +145,14 @@ class GeluMapper : public Mapper {
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}

GeluMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
}

int32_t GetMinOpsetVersion(bool verbose) override {
Logger(verbose, 9) << RequireOpset(9) << std::endl;
return 9;
Expand All @@ -158,6 +175,19 @@ class SoftMaxMapper : public Mapper {
}
}

SoftMaxMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
if (HasAttr("axis")) {
GetAttr("axis", &axis_);
} else {
axis_ = -1;
}
}

void Opset7() override;
void Opset13() override;

Expand Down Expand Up @@ -354,6 +384,13 @@ class SiluMapper : public Mapper {
int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SiluMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
in_pir_mode = true;
}
void Opset7() override;
};

Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/activation/sigmoid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

namespace paddle2onnx {
REGISTER_MAPPER(sigmoid, SigmoidMapper)
REGISTER_PIR_MAPPER(sigmoid, SigmoidMapper)

void SigmoidMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
helper_->MakeNode("Sigmoid", {input_info[0].name}, {output_info[0].name});
}
}
}
14 changes: 10 additions & 4 deletions paddle2onnx/mapper/activation/sigmoid.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,26 @@
// limitations under the License.
#pragma once


#include "paddle2onnx/mapper/mapper.h"

#include <cmath>
#include <map>
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class SigmoidMapper : public Mapper {
public:
SigmoidMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SigmoidMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t op_id,
bool c)
: Mapper(p, helper, op_id, c) {
in_pir_mode = true;
}
void Opset7() override;
};
}
} // namespace paddle2onnx
14 changes: 10 additions & 4 deletions paddle2onnx/mapper/activation/swish.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace paddle2onnx {
REGISTER_MAPPER(swish, SwishMapper)
REGISTER_PIR_MAPPER(swish, SwishMapper)

void SwishMapper::Opset7() {
auto input_info = GetInput("X");
Expand All @@ -25,13 +26,18 @@ void SwishMapper::Opset7() {
if (HasAttr("beta")) {
float temp_beta = 1.0;
GetAttr("beta", &temp_beta);
std::string beta_node = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), temp_beta);
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
std::string beta_node = helper_->Constant({},
GetOnnxDtype(input_info[0].dtype),
temp_beta);
auto beta_x_node = helper_->MakeNode("Mul",
{input_info[0].name, beta_node});
sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
} else {
sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
}

helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)}, {output_info[0].name});
helper_->MakeNode("Mul",
{input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
}
} // namespace paddle2onnx
13 changes: 10 additions & 3 deletions paddle2onnx/mapper/activation/swish.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,26 @@
#pragma once


#include "paddle2onnx/mapper/mapper.h"

#include <cmath>
#include <map>
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class SwishMapper : public Mapper {
public:
SwishMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SwishMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t op_id,
bool c)
: Mapper(p, helper, op_id, c) {
in_pir_mode = true;
}
void Opset7() override;
};
}
} // namespace paddle2onnx
67 changes: 45 additions & 22 deletions paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ inline std::string convert_pir_op_name(const std::string pir_op_name) {
{"matmul", "matmul_v2"},
// {"relu", "relu6"},
{"batch_norm_", "batch_norm"},
{"assign_value_", "assign_value"},
{"flatten", "flatten_contiguous_range"},
{"add", "elementwise_add"}};
std::string op_name = pir_op_name;
Expand Down Expand Up @@ -94,7 +95,8 @@ class ModelExporter {
bool* save_external = nullptr,
bool export_fp16_model = false,
std::vector<std::string> disable_fp16_op_types = {});
std::string Run(PaddlePirParser& pir_parser,
std::string Run(PaddlePirParser
& pir_parser,
int opset_version = 9,
bool auto_upgrade_opset = true,
bool verbose = false,
Expand Down Expand Up @@ -132,34 +134,46 @@ class ModelExporter {
//
void ExportInputOutputs(
const PaddleParser& parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs);

void ExportInputOutputs(
const PaddlePirParser& pir_parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs);

void ExportParameters(
const PaddleParser& parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
void ExportParameters(
const PaddlePirParser& pir_parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
// Process dumplicate tensor names in paddle model
std::set<std::string> tensor_names_;
void ProcessGraphDumplicateNames(
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& nodes,
std::map<std::string, QuantizeInfo>& quantize_info);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& nodes,
std::map<std::string, QuantizeInfo>
& quantize_info);
// Update constant node in parameters. When process quantize model, the
// weight dtype may be int8, it should be convet to float32 and use this
// function to update converted params.
void UpdateParameters(
const std::map<std::string, Weight>& params,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
//
std::map<std::string, std::pair<int32_t, int32_t>> sub_block_map_;
ONNX_NAMESPACE::GraphProto ExportConditionalBlock(
Expand All @@ -168,21 +182,30 @@ class ModelExporter {
int32_t op_id,
const std::string& output_names);

ONNX_NAMESPACE::GraphProto ExportIfBlock(PaddlePirParser& pir_parser,
pir::Block& block);
ONNX_NAMESPACE::GraphProto ExportIfBlock(PaddlePirParser
& pir_parser,
pir::Block
& block);

ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddleParser& parser,
int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs);
ONNX_NAMESPACE::GraphProto ExportBlock(
PaddlePirParser& pir_parser,
PaddlePirParser
& pir_parser,
pir::Block* block,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs,
bool if_in_subblock);

void ExportOp(const PaddleParser& parser,
Expand Down
2 changes: 2 additions & 0 deletions paddle2onnx/mapper/nn/conv2d_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

namespace paddle2onnx {
REGISTER_MAPPER(conv2d_transpose, Conv2dTransposeMapper)
REGISTER_PIR_MAPPER(conv2d_transpose, Conv2dTransposeMapper)
REGISTER_MAPPER(depthwise_conv2d_transpose, Conv2dTransposeMapper)
REGISTER_PIR_MAPPER(depthwise_conv2d_transpose, Conv2dTransposeMapper)

int32_t Conv2dTransposeMapper::GetMinOpsetVersion(bool verbose) {
// NHWC is not supported
Expand Down
Loading