Skip to content

Commit

Permalink
add the shape check for the matmul (#35791)
Browse files Browse the repository at this point in the history
* add the shape check for the matmul

* remove the test case for the linear
  • Loading branch information
wawltor authored Sep 24, 2021
1 parent 4e7bd9c commit 8e19d1b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/matmul_v2_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,14 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto* Out = ctx.Output<Tensor>("Out");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
PADDLE_ENFORCE_NE(framework::product(X->dims()), 0,
platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0,
platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<DeviceContext, T>(X, Y, Out, trans_x, trans_y, ctx);
}
};
Expand Down
9 changes: 0 additions & 9 deletions python/paddle/fluid/tests/unittests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,6 @@ def test_error(self, place=paddle.CPUPlace()):
np.testing.assert_array_almost_equal(res_f, res_nn)
np.testing.assert_array_almost_equal(res_nn, res_np)

def test_error_dummy_input(self, place=paddle.CPUPlace()):
with self.assertRaises(RuntimeError):
x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32')
weight = paddle.zeros([4, 4, 4], dtype='float32')
bias = paddle.to_tensor([], dtype='float32')
paddle.nn.functional.linear(x, weight, bias=bias)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8e19d1b

Please sign in to comment.