-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add axis for mul_op
and rowwise_add_op
#3888
Changes from all commits
e76fa85
86655cb
af0264a
69fbc54
d71396b
e168fc4
256d6a3
f2a66ff
823bdd6
3d62c6d
0c13660
5aacd64
d7c8bdc
b744430
1d9a4d2
f6e72c9
b6a4666
856611c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel { | |
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto dim1 = ctx.Input<Tensor>("Y")->dims(); | ||
PADDLE_ENFORCE_EQ(dim0.size(), 2, | ||
"input X(%s) should be a tensor with 2 dims, a matrix", | ||
ctx.op().Input("X")); | ||
PADDLE_ENFORCE_EQ(dim1.size(), 2, | ||
"input Y(%s) should be a tensor with 2 dims, a matrix", | ||
ctx.op().Input("Y")); | ||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto y_dims = ctx.Input<Tensor>("Y")->dims(); | ||
int x_num_col_dims = Attr<int>("x_num_col_dims"); | ||
int y_num_col_dims = Attr<int>("y_num_col_dims"); | ||
|
||
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, | ||
"The rank of input tensor X(%s) should be larger than " | ||
"`mul_op`'s `x_num_col_dims`.", | ||
ctx.op().Input("X")); | ||
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, | ||
"The rank of input tensor Y(%s) should be larger than " | ||
"`mul_op`'s `y_num_col_dims`.", | ||
ctx.op().Input("Y")); | ||
|
||
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); | ||
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
dim0[1], dim1[0], | ||
x_mat_dims[1], y_mat_dims[0], | ||
"First matrix's width must be equal with second matrix's height."); | ||
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]}); | ||
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); | ||
} | ||
}; | ||
|
||
|
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { | |
AddInput("X", "The first input of mul op"); | ||
AddInput("Y", "The second input of mul op"); | ||
AddOutput("Out", "The output of mul op"); | ||
AddAttr<int>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a very useful syntax in AddAttr<int>("x_num_col_dims", R"DOC(mul_op can take ...
....
)DOC");
See http://en.cppreference.com/w/cpp/language/string_literal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, Thank you! |
||
"x_num_col_dims", | ||
R"DOC(mul_op can take tensors with more than two dimensions as input `X`, | ||
in that case, tensors will be reshaped to a matrix. The matrix's first | ||
dimension(column length) will be the product of tensor's last | ||
`num_col_dims` dimensions, and the matrix's second dimension(row length) | ||
will be the product of tensor's first `rank - num_col_dims` dimensions. | ||
)DOC") | ||
.SetDefault(1) | ||
.EqualGreaterThan(1); | ||
AddAttr<int>( | ||
"y_num_col_dims", | ||
R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, | ||
in that case, tensors will be reshaped to a matrix. Just like input `X`. | ||
)DOC") | ||
.SetDefault(1) | ||
.EqualGreaterThan(1); | ||
AddComment(R"DOC( | ||
Two Element Mul Operator. | ||
|
||
|
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel { | |
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); | ||
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
PADDLE_ENFORCE(x_dims[0] == out_dims[0], | ||
"Out@GRAD M X N must equal to X dims 0, M "); | ||
PADDLE_ENFORCE(y_dims[1] == out_dims[1], | ||
"Out@GRAD M X N must equal to Y dims 1, N "); | ||
|
||
auto x_mat_dims = | ||
framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims")); | ||
auto y_mat_dims = | ||
framework::flatten_to_2d(y_dims, Attr<int>("y_num_col_dims")); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
x_mat_dims[0], out_dims[0], | ||
"The first dimension of Out@GRAD must equal to the first dimension of " | ||
"the first operand."); | ||
PADDLE_ENFORCE_EQ( | ||
y_mat_dims[1], out_dims[1], | ||
"The second dimension of Out@GRAD must equal to the second " | ||
"dimension of the second operand."); | ||
|
||
if (x_grad) x_grad->Resize(x_dims); | ||
if (y_grad) y_grad->Resize(y_dims); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel { | |
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto dim1 = ctx.Input<Tensor>("b")->dims(); | ||
|
||
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); | ||
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); | ||
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); | ||
PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1"); | ||
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims()); | ||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto b_dims = ctx.Input<Tensor>("b")->dims(); | ||
PADDLE_ENFORCE_GT( | ||
x_dims.size(), b_dims.size(), | ||
"The rank of input `X` must be larger than the one of input `b`."); | ||
|
||
int num_col_dims = x_dims.size() - b_dims.size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting implementation here. So the rowwise_add's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It makes sure that |
||
|
||
PADDLE_ENFORCE_EQ( | ||
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, | ||
"The width of two operands must be same"); | ||
PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); | ||
ctx.Output<Tensor>("Out")->Resize(x_dims); | ||
} | ||
}; | ||
|
||
|
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { | |
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | ||
"Input(Out@GRAD) should not be null"); | ||
auto dims0 = ctx.Input<Tensor>("X")->dims(); | ||
auto dims1 = ctx.Input<Tensor>("b")->dims(); | ||
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") | ||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto b_dims = ctx.Input<Tensor>("b")->dims(); | ||
PADDLE_ENFORCE_GT( | ||
x_dims.size(), b_dims.size(), | ||
"The rank of input `X` must be larger than the one of input `b`."); | ||
|
||
int num_col_dims = x_dims.size() - b_dims.size(); | ||
PADDLE_ENFORCE_EQ( | ||
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, | ||
"The width of two operands must be same"); | ||
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto *db = ctx.Output<Tensor>(framework::GradVarName("b")); | ||
if (dx) dx->Resize(dims0); | ||
if (db) db->Resize(dims1); | ||
if (dx) dx->Resize(x_dims); | ||
if (db) db->Resize(b_dims); | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
is not needed inclass method
. It will be compiler's choice whether inline or not.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a class method. It's a global function.