Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[phi] transfer unsqueeze and squeeze to phi #40596

Merged
merged 7 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 8 additions & 39 deletions paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand All @@ -140,13 +140,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

//#ifdef PADDLE_WITH_MKLDNN
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand Down Expand Up @@ -241,13 +241,13 @@ class Squeeze2Op : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand Down Expand Up @@ -287,13 +287,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

//#ifdef PADDLE_WITH_MKLDNN
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand Down Expand Up @@ -411,34 +411,3 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
30 changes: 0 additions & 30 deletions paddle/fluid/operators/squeeze_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,3 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
39 changes: 0 additions & 39 deletions paddle/fluid/operators/squeeze_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,44 +60,5 @@ class SqueezeGradKernel : public framework::OpKernel<T> {
d_x->Resize(in_dims);
}
};

template <typename DeviceContext, typename T>
class Squeeze2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *in = context.Input<framework::LoDTensor>("X");

auto &axes = context.Attr<std::vector<int>>("axes");

auto x_dims = in->dims();
auto out_dims = GetOutputShape(axes, x_dims, true);

out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};

template <typename DeviceContext, typename T>
class Squeeze2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
// auto in_dims = d_x->dims();

auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());

d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
d_x->Resize(x_dims);
}
};

} // namespace operators
} // namespace paddle
53 changes: 8 additions & 45 deletions paddle/fluid/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ limitations under the License. */
#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -251,19 +253,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
class Unsqueeze2Op : public UnsqueezeOp {
public:
using UnsqueezeOp::UnsqueezeOp;
void InferShape(framework::InferShapeContext *ctx) const override {
UnsqueezeOp::InferShape(ctx);
const auto &x_dims = ctx->GetInputDim("X");

if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
};

class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
Expand Down Expand Up @@ -339,10 +328,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2, Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
Expand All @@ -351,7 +344,8 @@ REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeInplaceInferer);
Unsqueeze2InferShapeFunctor, ops::UnsqueezeInplaceInferer);

REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
Expand Down Expand Up @@ -388,34 +382,3 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
34 changes: 0 additions & 34 deletions paddle/fluid/operators/unsqueeze_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,3 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
4 changes: 4 additions & 0 deletions paddle/phi/core/compat/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"flatten_grad",
"isinf",
"isnan",
"unsqueeze",
"unsqueeze_grad",
"squeeze",
"squeeze_grad",
"isfinite",
"matmul",
"matmul_grad",
Expand Down
37 changes: 36 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"

namespace phi {

Expand Down Expand Up @@ -1661,6 +1662,41 @@ void UnfoldInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims));
}

void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE_LE(x_dims.size(),
6,
phi::errors::InvalidArgument(
"Invalid "
"dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)"));
if (!axes.GetData().empty()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
axes.GetData().end(),
[&tmp](const int64_t& t) { tmp.push_back(t); });
auto out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
out->set_dims(out_dims);
if (x_dims[0] == out_dims[0]) {
out->share_lod(x);
}
}
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
out->set_dtype(x.dtype());
xshape->set_dtype(x.dtype());
}

void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
DataType dtype,
Expand All @@ -1671,7 +1707,6 @@ void OneHotRawInferMeta(const MetaTensor& x,
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));

auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class MetaConfig;
//
// The InferMeta Functions in this file are arranged in alphabetic order.

void UnsqueezeInferMeta(const MetaTensor& x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在这个infermeta有需要按照字典序写的要求,改个位置吧

const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out);

void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
Expand Down
Loading