Skip to content

Commit

Permalink
expand covar matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak committed Jan 21, 2025
1 parent 9eaecdc commit b37a85b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
Expand Down
3 changes: 3 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,11 @@ def test_multivariate_normal_expand(self, cuda=False):
self.assertEqual(expanded.batch_shape, torch.Size([2]))
self.assertEqual(expanded.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(expanded.mean, mean.expand(2, -1)))
self.assertEqual(expanded.mean.shape, torch.Size([2, 3]))
self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1)))
self.assertEqual(expanded.covariance_matrix.shape, torch.Size([2, 3, 3]))
self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1)))
self.assertEqual(expanded.scale_tril.shape, torch.Size([2, 3, 3]))


if __name__ == "__main__":
Expand Down

0 comments on commit b37a85b

Please sign in to comment.