diff --git a/gpytorch/test/base_keops_test_case.py b/gpytorch/test/base_keops_test_case.py index 9b8cbb13b..e63206b8d 100644 --- a/gpytorch/test/base_keops_test_case.py +++ b/gpytorch/test/base_keops_test_case.py @@ -46,6 +46,7 @@ def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs): d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2) d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(d1 - d2), 1e-4) + self.assertTrue(torch.equal(k1.diag(), d1)) if use_keops: self.assertTrue(keops_mock.called)