diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 4bf91ebeb..adf65225e 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -158,7 +158,7 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal: ) super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril) # Set the covar matrix, since it is always available for GPyTorch MVN. - new.covariance_matrix = self.covariance_matrix + new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:]) return new def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor: diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index ed54b50ec..a344c92fd 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -343,8 +343,11 @@ def test_multivariate_normal_expand(self, cuda=False): self.assertEqual(expanded.batch_shape, torch.Size([2])) self.assertEqual(expanded.event_shape, mvn.event_shape) self.assertTrue(torch.equal(expanded.mean, mean.expand(2, -1))) + self.assertEqual(expanded.mean.shape, torch.Size([2, 3])) self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1))) + self.assertEqual(expanded.covariance_matrix.shape, torch.Size([2, 3, 3])) self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1))) + self.assertEqual(expanded.scale_tril.shape, torch.Size([2, 3, 3])) if __name__ == "__main__":