Skip to content

Commit

Permalink
[Phi] Move roi_align grad kernel and infershape from fuild to phi (#4…
Browse files Browse the repository at this point in the history
…0556)

* move roi_align_grad kernel

* move roi_align grad kernel and infershape to phi

* remove roi_align infershape
  • Loading branch information
zyfncg authored Mar 16, 2022
1 parent 44d46d0 commit 3898080
Show file tree
Hide file tree
Showing 18 changed files with 646 additions and 513 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ void BuildDygraphPhiKernelContext(
}

for (size_t i = 0; i < attr_names.size(); ++i) {
VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
Expand Down
89 changes: 9 additions & 80 deletions paddle/fluid/operators/roi_align_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/roi_align_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand All @@ -23,79 +26,6 @@ class ROIAlignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of ROIAlignOp "
"is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
platform::errors::NotFound("Input(ROIs) of ROIAlignOp "
"is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of ROIAlignOp "
"is not found."));
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");

if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(
rois_num_dims.size(), 1,
platform::errors::InvalidArgument("The size of RoisNum should be 1"
", but received size = %d",
rois_num_dims.size()));
}
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"The format of Input(X) in"
"RoIAlignOp is NCHW. And the rank of input must be 4. "
"But received rank = %d",
input_dims.size()));
PADDLE_ENFORCE_EQ(rois_dims.size(), 2, platform::errors::InvalidArgument(
"The rank of Input(ROIs) "
"in RoIAlignOp should be 2. "
"But the rank of RoIs is %d",
rois_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(rois_dims[1], 4,
platform::errors::InvalidArgument(
"The second dimension "
"of Input(ROIs) should be 4. But received the "
"dimension = %d",
rois_dims[1]));
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

PADDLE_ENFORCE_GT(pooled_height, 0,
platform::errors::InvalidArgument(
"The 'pooled_height' attribute in RoIAlignOp is "
"invalid. The height must be greater than 0. But "
"received 'pooled_height' = %d",
pooled_height));
PADDLE_ENFORCE_GT(pooled_width, 0,
platform::errors::InvalidArgument(
"The 'pooled_width' attribute in RoIAlignOp is "
"invalid. The width must be greater than 0. But "
"received 'pooled_width' = %d",
pooled_width));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
platform::errors::InvalidArgument(
"The 'spatial_scale' attribute in RoIAlignOp is "
"invalid. The scale must be greater than 0. But "
"received 'spatial_scale' = %f",
spatial_scale));

auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;

ctx->SetOutputDim("Out", out_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -221,17 +151,16 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RoiAlignGradNoNeedBufVarsInferer, "X");
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(roi_align, RoiAlignInferShapeFunctor,
PD_INFER_META(phi::RoiAlignInferMeta));

REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
ops::ROIAlignGradMaker<paddle::framework::OpDesc>,
ops::ROIAlignGradMaker<paddle::imperative::OpBase>);
ops::ROIAlignGradMaker<paddle::imperative::OpBase>,
RoiAlignInferShapeFunctor);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp,
ops::RoiAlignGradNoNeedBufVarsInferer);

REGISTER_OP_CPU_KERNEL(
roi_align_grad,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_align)
.AddCheckpoint(
R"ROC(
Expand Down
227 changes: 0 additions & 227 deletions paddle/fluid/operators/roi_align_op.cu

This file was deleted.

Loading

0 comments on commit 3898080

Please sign in to comment.