Skip to content

Commit

Permalink
[Cherey-Pick]Support 0D for slogdet (#53087)
Browse files Browse the repository at this point in the history
* add 0D output support for inalg.slogdet,test=allcase

* fix zerom dime test error test=allcase

* fix test error test=allcase

* add static backward test, test=allcase
  • Loading branch information
GGBond8488 authored Apr 20, 2023
1 parent d131e67 commit 3f5058e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {1};
output_dim_vec = {};
}
output_dim_vec.insert(output_dim_vec.begin(),
2); // make the output dims as same as numpy
Expand Down
45 changes: 45 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,27 @@ def body(i, x):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))

def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
out.retain_grads()
out.backward()

self.assertTrue(out.shape, [2])
self.assertTrue(x.grad.shape, [3, 3])

# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
out1.retain_grads()
out1.backward()

self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -3609,6 +3630,30 @@ def test_broadcast_tensors(self):
self.assertEqual(out1.shape, (2, 3))
self.assertEqual(out2.shape, (2, 3))

@prog_scope()
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
paddle.static.append_backward(out.sum())

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (3, 3))

# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
paddle.static.append_backward(out1.sum())

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down

0 comments on commit 3f5058e

Please sign in to comment.