Skip to content

Commit

Permalink
remove raw pad3d infershape function
Browse files Browse the repository at this point in the history
  • Loading branch information
MingMingShangTian committed Mar 18, 2022
1 parent a209149 commit c243a39
Showing 1 changed file with 0 additions and 70 deletions.
70 changes: 0 additions & 70 deletions paddle/fluid/operators/pad3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,76 +30,6 @@ class Pad3dOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad3d");

auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dim.size(), 5,
platform::errors::InvalidArgument(
"The size of Input(X)'s dimension should be equal to "
"5, but received %d. ",
x_dim.size()));

std::vector<int64_t> out_dims(x_dim.size());
auto data_format = ctx->Attrs().Get<std::string>("data_format");
out_dims[0] = x_dim[0];
if (ctx->HasInput("Paddings")) {
auto paddings_dim = ctx->GetInputDim("Paddings");
PADDLE_ENFORCE_EQ(paddings_dim.size(), 1,
platform::errors::InvalidArgument(
"Size of Input(Paddings)'s dimension should be "
"equal to 1, but received %d.",
paddings_dim.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(paddings_dim[0], 6,
platform::errors::InvalidArgument(
"Shape of Input(Paddings) should be equal to "
"[6], but received [%d].",
paddings_dim[0]));
}
out_dims[1] = x_dim[1];
out_dims[2] = x_dim[2];
out_dims[3] = x_dim[3];
} else {
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(
paddings.size(), 6,
platform::errors::InvalidArgument(
"Size of paddings should be equal to 4, but received %d.",
static_cast<int>(paddings.size())));
if (data_format == "NCDHW") {
out_dims[1] = x_dim[1]; // channel
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
? x_dim[2]
: (x_dim[2] + paddings[4] + paddings[5]); // depth

out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
? x_dim[3]
: (x_dim[3] + paddings[2] + paddings[3]); // height

out_dims[4] = ((!ctx->IsRuntime()) && (x_dim[4] < 0))
? x_dim[4]
: (x_dim[4] + paddings[0] + paddings[1]); // width
} else { // NDHWC
out_dims[4] = x_dim[4]; // channel

out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0))
? x_dim[1]
: (x_dim[1] + paddings[4] + paddings[5]); // depth
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
? x_dim[2]
: (x_dim[2] + paddings[2] + paddings[3]); // height
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
? x_dim[3]
: (x_dim[3] + paddings[0] + paddings[1]); // width
}
}

ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down

1 comment on commit c243a39

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.