Skip to content

Commit

Permalink
add add_n for the 0d tensor (#49854)
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor authored Jan 16, 2023
1 parent 8fdb908 commit 65b0181
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
9 changes: 7 additions & 2 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}

bool is_all_0d_tensor = true;
phi::DDim in_dim({0});
for (size_t i = 0; i < x.size(); ++i) {
auto x_dim = x[i]->dims();
Expand All @@ -313,6 +313,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
if (x_dim.size() == 0) {
continue;
}
is_all_0d_tensor = false;
if (phi::product(in_dim) == 0) {
in_dim = x_dim;
} else {
Expand Down Expand Up @@ -360,7 +361,11 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
}
}
}
out->set_dims(in_dim);
if (is_all_0d_tensor) {
out->set_dims(make_ddim({}));
} else {
out->set_dims(in_dim);
}
out->share_lod(*x[0]);
}

Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,31 @@ def test_floor_divide(self):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))

def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False

out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])

out1.backward()
out2.backward()

self.assertEqual(x1.grad.shape, [])
self.assertTrue(x1.grad.numpy() == 1)
self.assertEqual(x2.grad.shape, [])
self.assertTrue(x2.grad.numpy() == 1)
self.assertEqual(x3.grad.shape, [])
self.assertTrue(x3.grad.numpy() == 1)
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])

def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False
Expand Down Expand Up @@ -1534,6 +1559,46 @@ def test_floor_divide(self):
np.testing.assert_array_equal(out3_1, out3_2)
np.testing.assert_array_equal(out3_2, np.asarray(1))

@prog_scope()
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False

out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])

paddle.static.append_backward(out1.sum())
paddle.static.append_backward(out2.sum())

prog = paddle.static.default_main_program()
block = prog.global_block()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
x3.grad_name,
out1.grad_name,
out2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[2], 1)
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
self.assertEqual(res[4].shape, ())
self.assertEqual(res[4], 1)
self.assertEqual(res[5].shape, ())
self.assertEqual(res[6].shape, ())

@prog_scope()
def test_reshape_list(self):
x1 = paddle.rand([])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,28 @@ def test_floor_divide(self):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))

def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False

out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])

out1.retain_grads()
out2.retain_grads()

out1.backward()
out2.backward()

self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])

def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False
Expand Down

0 comments on commit 65b0181

Please sign in to comment.