Skip to content

Commit

Permalink
[Divide by 0 Error] add lu check (#49974)
Browse files Browse the repository at this point in the history
* [Divide by 0 Error] add lu check

* [Divide by 0 Error] lu check migrate to c++
  • Loading branch information
gouzil authored Feb 1, 2023
1 parent f0811bb commit f71796b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
8 changes: 8 additions & 0 deletions paddle/phi/kernels/impl/lu_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ DenseTensor Transpose2DTo6D(const Context& dev_ctx, const DenseTensor& x) {
auto x_dim = x.dims();
auto x_vec = phi::vectorize<int>(x_dim);
int rank = x_vec.size();

for (int i = 0; i < x_dim.size(); i++) {
PADDLE_ENFORCE_LT(0,
x_dim[i],
errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
}

std::swap(x_vec[rank - 1], x_vec[rank - 2]);
std::vector<int> out_shape = x_vec;
std::vector<int> axis(rank);
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_lu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ def run_lu_static(shape, dtype):
run_lu_static(tensor_shape, dtype)


class TestLUAPIError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.dygraph.guard():
# The size of input in lu should not be 0.
def test_0_size():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [0, 0, 0]), dtype='float32'
)
paddle.linalg.lu(x, get_infos=True)

self.assertRaises(ValueError, test_0_size)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit f71796b

Please sign in to comment.