Skip to content

Commit

Permalink
Merge pull request #1516 from zhanghonggeng/ignore_test_1
Browse files Browse the repository at this point in the history
support conv3d_transpose  uniform & fix some tests
  • Loading branch information
risemeup1 authored Feb 25, 2025
2 parents a9c03e1 + b482888 commit 1e29c15
Show file tree
Hide file tree
Showing 16 changed files with 468 additions and 273 deletions.
84 changes: 84 additions & 0 deletions paddle2onnx/mapper/nn/conv3d_transpose.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/nn/conv3d_transpose.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(conv3d_transpose, Conv3dTransposeMapper)

int32_t Conv3dTransposeMapper::GetMinOpsetVersion(bool verbose) {
if (data_format_ != "NCHW" && data_format_ != "NHWC") {
Error() << "[ERROR] only support NCDHW or NDHWC format for "
"conv3d_transpose "
<< std::endl;
return -1;
}
return 7;
}

void Conv3dTransposeMapper::Opset7() {
auto kernel_info = GetInput("filter");
auto input_info = GetInput("x");
auto output_info = GetOutput("out");

auto input = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
if (data_format_ == "NHWC") {
input = helper_->Transpose(input, {0, 4, 1, 2, 3}); // NDHWC -> NCDHW
}

auto kernel = helper_->AutoCast(
kernel_info[0].name, kernel_info[0].dtype, P2ODataType::FP32);

auto node = helper_->MakeNode("ConvTranspose", {input, kernel});

std::vector<int64_t> kernel_shape = {kernel_info[0].shape[2],
kernel_info[0].shape[3],
kernel_info[0].shape[4]};
AddAttribute(node, "kernel_shape", kernel_shape);

AddAttribute(node, "dilations", dilations_);
AddAttribute(node, "strides", strides_);
AddAttribute(node, "group", groups_);

if (padding_algorithm_ == "SAME") {
AddAttribute(node, "auto_pad", "SAME_UPPER");
} else if (padding_algorithm_ == "VALID") {
AddAttribute(node, "auto_pad", "VALID");
} else {
std::vector<int64_t> paddings;
if (paddings_.size() == 3) {
paddings.insert(paddings.begin(), paddings_.begin(), paddings_.end());
paddings.insert(paddings.begin(), paddings_.begin(), paddings_.end());
} else {
std::vector<int64_t> index = {0, 2, 4, 1, 3, 5};
for (auto& i : index) {
paddings.push_back(paddings_[i]);
}
}
AddAttribute(node, "pads", paddings);
}

if (!output_padding_.empty()) {
AddAttribute(node, "output_padding", output_padding_);
}

auto output = node->output(0);
if (data_format_ == "NHWC") {
output = helper_->Transpose(output, {0, 2, 3, 4, 1}); // NCDHW -> NDHWC
}
helper_->AutoCast(
output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
}
} // namespace paddle2onnx
59 changes: 59 additions & 0 deletions paddle2onnx/mapper/nn/conv3d_transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class Conv3dTransposeMapper : public Mapper {
public:
Conv3dTransposeMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t i,
bool c)
: Mapper(p, helper, i, c) {
GetAttr("groups", &groups_);
GetAttr("dilations", &dilations_);
GetAttr("strides", &strides_);
GetAttr("paddings", &paddings_);
GetAttr("padding_algorithm", &padding_algorithm_);
GetAttr("data_format", &data_format_);

if (HasAttr("output_padding")) {
GetAttr("output_padding", &output_padding_);
}
if (HasAttr("output_size")) {
GetAttr("output_size", &output_size_);
}
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
std::vector<int64_t> dilations_;
std::vector<int64_t> strides_;
std::vector<int64_t> paddings_;
std::vector<int64_t> output_padding_;
std::vector<int64_t> output_size_;
std::string padding_algorithm_;
std::string data_format_;
int64_t groups_;
};

} // namespace paddle2onnx
55 changes: 55 additions & 0 deletions paddle2onnx/mapper/tensor/uniform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/uniform.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(uniform, UniformMapper)

int32_t UniformMapper::GetMinOpsetVersion(bool verbose) { return 7; }

void UniformMapper::Opset7() {
auto output_info = GetOutput("out");
auto shape_info = GetInput("shape");
auto min_info = GetInput("min");
auto max_info = GetInput("max");

if (min_info[0].Rank() != 0 || max_info[0].Rank() != 0) {
Error() << "[ERROR] min/max must be scalar tensors for op "
"uniform "
<< std::endl;
}
std::vector<float> min_val{0.0f}, max_val{1.0f};
bool is_min_const =
helper_->TryGetTensorValue<float>(min_info[0].name, &min_val);
bool is_max_const =
helper_->TryGetTensorValue<float>(max_info[0].name, &max_val);

std::vector<int64_t> shape_values;
helper_->TryGetTensorValue<int64_t>(shape_info[0].name, &shape_values);

auto onnx_dtype = GetOnnxDtype(dtype_);

auto random_node =
helper_->MakeNode("RandomUniform", {}, {output_info[0].name});

AddAttribute(random_node, "shape", shape_values);
AddAttribute(random_node, "low", min_val[0]);
AddAttribute(random_node, "high", max_val[0]);
AddAttribute(random_node, "dtype", static_cast<int64_t>(onnx_dtype));
if (seed_ != 0) {
AddAttribute(random_node, "seed", static_cast<float>(seed_));
}
}
} // namespace paddle2onnx
39 changes: 39 additions & 0 deletions paddle2onnx/mapper/tensor/uniform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class UniformMapper : public Mapper {
public:
UniformMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i, bool c)
: Mapper(p, helper, i, c) {
GetAttr("dtype", &dtype_);
GetAttr("seed", &seed_);
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
int64_t dtype_;
int64_t seed_;
};

} // namespace paddle2onnx
9 changes: 2 additions & 7 deletions tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
import paddle2onnx
import paddle.static as static
from paddle2onnx.convert import dygraph2onnx, decompose_program
from paddle2onnx.convert import dygraph2onnx
import shutil
from functools import wraps

Expand Down Expand Up @@ -232,8 +232,6 @@ def __init__(
self.input_spec_shape = input_spec_shape
self.input_dtype = []
self.res_fict = {}
self.dist_prim_all = False
self.auto_upgrade_opset = False

if isfunction(self.func):
# self._func = self.BuildFunc(self.func, **self.kwargs_dict_dygraph["params_group1"])
Expand Down Expand Up @@ -497,10 +495,7 @@ def run(self):
# clip extra
model_file = None
if paddle.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]:
if self.dist_prim_all and self.auto_upgrade_opset:
model_file = decompose_program(original_model_file)
else:
model_file = original_model_file
model_file = original_model_file
else:
model_file = os.path.join(self.name, "cliped_model.pdmodel")
self.clip_extra_program_only(original_model_file, model_file)
Expand Down
12 changes: 0 additions & 12 deletions tests/run.bat
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,16 @@ set ignore=test_auto_scan_multiclass_nms.py
set ignore=!ignore! test_auto_scan_roi_align.py
set ignore=!ignore! test_auto_scan_pool_adaptive_max_ops.py
set ignore=!ignore! test_auto_scan_pad2d.py
set ignore=!ignore! test_auto_scan_roll.py
set ignore=!ignore! test_auto_scan_unfold.py
set ignore=!ignore! test_auto_scan_uniform_random_batch_size_like.py
set ignore=!ignore! test_auto_scan_uniform_random.py
set ignore=!ignore! test_auto_scan_dist.py
set ignore=!ignore! test_auto_scan_distribute_fpn_proposals1.py
set ignore=!ignore! test_auto_scan_distribute_fpn_proposals_v2.py
set ignore=!ignore! test_auto_scan_fill_constant_batch_size_like.py
set ignore=!ignore! test_auto_scan_generate_proposals.py
set ignore=!ignore! test_uniform.py
set ignore=!ignore! test_ceil.py
set ignore=!ignore! test_deform_conv2d.py
set ignore=!ignore! test_floor_divide.py
set ignore=!ignore! test_has_nan.py
set ignore=!ignore! test_median.py
set ignore=!ignore! test_nn_Conv3DTranspose.py
set ignore=!ignore! test_nn_GroupNorm.py
set ignore=!ignore! test_nn_InstanceNorm3D.py
set ignore=!ignore! test_nn_Upsample.py
set ignore=!ignore! test_normalize.py
set ignore=!ignore! test_scatter_nd_add.py
set ignore=!ignore! test_unsqueeze.py
set ignore=!ignore! test_quantize_model.py
set ignore=!ignore! test_quantize_model_minist.py
Expand All @@ -73,7 +62,6 @@ set ignore=!ignore! test_auto_scan_conv2d.py
set ignore=!ignore! test_auto_scan_conv2d_transpose.py
set ignore=!ignore! test_auto_scan_conv3d.py
set ignore=!ignore! test_auto_scan_grid_sampler.py
set ignore=!ignore! test_auto_scan_set_value.py
set ignore=!ignore! test_auto_scan_dequantize_linear.py
set ignore=!ignore! test_auto_scan_gaussian_random.py
set ignore=!ignore! test_auto_scan_partial_ops.py
Expand Down
15 changes: 1 addition & 14 deletions tests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,19 @@ ignore="test_auto_scan_multiclass_nms.py
test_auto_scan_roi_align.py \ # need to be rewrite
test_auto_scan_pool_adaptive_max_ops.py \
test_auto_scan_pad2d.py \
test_auto_scan_roll.py \
test_auto_scan_unfold.py \
test_auto_scan_uniform_random_batch_size_like.py \
test_auto_scan_uniform_random.py \
test_auto_scan_gaussian_random.py \
test_auto_scan_dist.py \
test_auto_scan_distribute_fpn_proposals1.py \
test_auto_scan_distribute_fpn_proposals_v2.py \
test_auto_scan_fill_constant_batch_size_like.py \
test_auto_scan_unary_ops.py \
test_auto_scan_generate_proposals.py \
test_uniform.py \
test_ceil.py \
test_deform_conv2d.py \
test_floor_divide.py \
test_has_nan.py \
test_median.py \
test_nn_Conv3DTranspose.py \
test_nn_GroupNorm.py \
test_nn_InstanceNorm3D.py \
test_nn_Upsample.py \
test_normalize.py \
test_hardtanh.py \
test_nn_GRU.py \
test_scatter_nd_add.py \
test_quantize_model.py \
test_quantize_model_minist.py \
test_auto_scan_partial_ops.py \
Expand All @@ -64,8 +52,7 @@ ignore="test_auto_scan_multiclass_nms.py
test_resnet_fp16.py \
test_empty.py \
test_auto_scan_pool_max_ops.py \
test_auto_scan_fill_constant.py \
test_auto_scan_set_value.py"
test_auto_scan_fill_constant.py"
bug=0

# Install Python Packet
Expand Down
4 changes: 2 additions & 2 deletions tests/test_auto_scan_roll.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import hypothesis.strategies as st
import unittest
import paddle
from onnxbase import _test_only_pir
from onnxbase import _test_with_pir


class Net(BaseNet):
Expand Down Expand Up @@ -112,7 +112,7 @@ def sample_convert_config(self, draw):

return (config, models)

@_test_only_pir
@_test_with_pir
def test(self):
self.run_and_statis(max_examples=80)

Expand Down
Loading

0 comments on commit 1e29c15

Please sign in to comment.