Skip to content

Commit

Permalink
Merge pull request #1932 from j-wilson/master
Browse files Browse the repository at this point in the history
LazyTensor.rmatmul
  • Loading branch information
jacobrgardner authored Mar 8, 2022
2 parents 91fb9ec + ab0fc93 commit 1acff3e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
20 changes: 20 additions & 0 deletions gpytorch/lazy/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,23 @@ def matmul(self, other):

return Matmul.apply(self.representation_tree(), other, *self.representation())

def rmatmul(self, other):
"""
Multiplies a matrix by self.
Args:
other (:obj:`torch.tensor`): Matrix or vector to multiply with. Can be either a :obj:`torch.tensor`
or a :obj:`gpytorch.lazy.LazyTensor`.
Returns:
:obj:`torch.tensor`: Tensor or LazyTensor containing the result of the matrix multiplication :math:`MK`,
where :math:`M` is the (batched) matrix input to this method, and :math:`K` is the (batched) matrix that
this :obj:`gpytorch.lazy.LazyTensor` represents.
"""
if other.ndim == 1:
return self.transpose(-1, -2).matmul(other)
return self.transpose(-1, -2).matmul(other.transpose(-1, -2)).transpose(-1, -2)

@property
def matrix_shape(self):
"""
Expand Down Expand Up @@ -2307,6 +2324,9 @@ def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTen
def __matmul__(self, other):
return self.matmul(other)

def __rmatmul__(self, other: Tensor) -> Tensor:
return self.rmatmul(other)

def __mul__(self, other):
return self.mul(other)

Expand Down
55 changes: 53 additions & 2 deletions gpytorch/test/lazy_tensor_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ def _test_matmul(self, rhs):
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"])

def _test_rmatmul(self, lhs):
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)
lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True)
evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)

res = lhs @ lazy_tensor
actual = lhs @ evaluated
self.assertAllClose(res, actual)

grad = torch.randn_like(res)
res.backward(gradient=grad)
actual.backward(gradient=grad)
for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()):
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"])

def test_add(self):
lazy_tensor = self.create_lazy_tensor()
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
Expand All @@ -72,20 +88,36 @@ def test_add(self):

def test_matmul_vec(self):
lazy_tensor = self.create_lazy_tensor()

# We skip this test if we're dealing with batch LazyTensors
# They shouldn't multiply by a vec
if lazy_tensor.ndimension() > 2:
return

rhs = torch.randn(lazy_tensor.size(-1))
return self._test_matmul(rhs)

def test_rmatmul_vec(self):
lazy_tensor = self.create_lazy_tensor()

# We skip this test if we're dealing with batch LazyTensors
# They shouldn't multiply by a vec
if lazy_tensor.ndimension() > 2:
return
else:
return self._test_matmul(rhs)

lhs = torch.randn(lazy_tensor.size(-2))
return self._test_rmatmul(lhs)

def test_matmul_matrix(self):
lazy_tensor = self.create_lazy_tensor()
rhs = torch.randn(*lazy_tensor.batch_shape, lazy_tensor.size(-1), 4)
return self._test_matmul(rhs)

def test_rmatmul_matrix(self):
lazy_tensor = self.create_lazy_tensor()
lhs = torch.randn(*lazy_tensor.batch_shape, 4, lazy_tensor.size(-2))
return self._test_rmatmul(lhs)

def test_matmul_matrix_broadcast(self):
lazy_tensor = self.create_lazy_tensor()

Expand All @@ -105,6 +137,25 @@ def test_matmul_matrix_broadcast(self):
rhs = torch.randn(*batch_shape, lazy_tensor.size(-1), 4)
self._test_matmul(rhs)

def test_rmatmul_matrix_broadcast(self):
lazy_tensor = self.create_lazy_tensor()

# Left hand size has one more batch dimension
batch_shape = torch.Size((3, *lazy_tensor.batch_shape))
lhs = torch.randn(*batch_shape, 4, lazy_tensor.size(-2))
self._test_rmatmul(lhs)

if lazy_tensor.ndimension() > 2:
# Left hand size has one fewer batch dimension
batch_shape = torch.Size(lazy_tensor.batch_shape[1:])
lhs = torch.randn(*batch_shape, 4, lazy_tensor.size(-2))
self._test_rmatmul(lhs)

# Left hand size has a singleton dimension
batch_shape = torch.Size((*lazy_tensor.batch_shape[:-1], 1))
lhs = torch.randn(*batch_shape, 4, lazy_tensor.size(-2))
self._test_rmatmul(lhs)

def test_constant_mul(self):
lazy_tensor = self.create_lazy_tensor()
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
Expand Down
9 changes: 9 additions & 0 deletions test/lazy/test_identity_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def _test_matmul(self, rhs):
actual = evaluated.matmul(rhs)
self.assertAllClose(res, actual)

def _test_rmatmul(self, lhs):
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)
lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True)
evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)

res = lhs @ lazy_tensor
actual = lhs @ evaluated
self.assertAllClose(res, actual)

def _test_inv_matmul(self, rhs, lhs=None, cholesky=False):
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)
lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True)
Expand Down
18 changes: 18 additions & 0 deletions test/lazy/test_lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ def _test_matmul(self, rhs):
lazy_tensor.x1.grad + lazy_tensor.x2.grad, lazy_tensor_copy.x1.grad + lazy_tensor_copy.x2.grad, rtol=1e-3
)

def _test_rmatmul(self, lhs):
lazy_tensor = self.create_lazy_tensor().requires_grad_(True)
lazy_tensor_copy = lazy_tensor.clone().detach_().requires_grad_(True)
evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)

res = lhs @ lazy_tensor
actual = lhs @ evaluated
self.assertAllClose(res, actual)

grad = torch.randn_like(res)
res.backward(gradient=grad)
actual.backward(gradient=grad)
for param, param_copy in zip(lazy_tensor.kernel.parameters(), lazy_tensor_copy.kernel.parameters()):
self.assertAllClose(param.grad, param_copy.grad, rtol=1e-3)
self.assertAllClose(
lazy_tensor.x1.grad + lazy_tensor.x2.grad, lazy_tensor_copy.x1.grad + lazy_tensor_copy.x2.grad, rtol=1e-3
)

def _test_inv_matmul(self, rhs, lhs=None, cholesky=False):
lazy_tensor = self.create_lazy_tensor().requires_grad_(True)
lazy_tensor_copy = lazy_tensor.clone().detach_().requires_grad_(True)
Expand Down

0 comments on commit 1acff3e

Please sign in to comment.