Skip to content

Commit

Permalink
Merge pull request #2249 from saitcakmak/fix_zip
Browse files Browse the repository at this point in the history
Fix silently ignored arguments in IndependentModelList
  • Loading branch information
saitcakmak authored Jan 13, 2023
2 parents 0fc5408 + 4b07b5e commit 41a8386
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 26 deletions.
3 changes: 2 additions & 1 deletion gpytorch/kernels/lcm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn import ModuleList

from ..priors import Prior
from ..utils.generic import length_safe_zip
from .kernel import Kernel
from .multitask_kernel import MultitaskKernel

Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
self.covar_module_list = ModuleList(
[
MultitaskKernel(base_kernel, num_tasks=num_tasks, rank=r, task_covar_prior=task_covar_prior)
for base_kernel, r in zip(base_kernels, rank)
for base_kernel, r in length_safe_zip(base_kernels, rank)
]
)

Expand Down
14 changes: 8 additions & 6 deletions gpytorch/likelihoods/likelihood_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import ModuleList

from gpytorch.likelihoods import Likelihood
from gpytorch.utils.generic import length_safe_zip


def _get_tuple_args_(*args):
Expand All @@ -21,7 +22,7 @@ def __init__(self, *likelihoods):
def expected_log_prob(self, *args, **kwargs):
return [
likelihood.expected_log_prob(*args_, **kwargs)
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
]

def forward(self, *args, **kwargs):
Expand All @@ -30,18 +31,18 @@ def forward(self, *args, **kwargs):
# if noise kwarg is passed, assume it's an iterable of noise tensors
return [
likelihood.forward(*args_, {**kwargs, "noise": noise_})
for likelihood, args_, noise_ in zip(self.likelihoods, _get_tuple_args_(*args), noise)
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
]
else:
return [
likelihood.forward(*args_, **kwargs)
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
]

def pyro_sample_output(self, *args, **kwargs):
return [
likelihood.pyro_sample_output(*args_, **kwargs)
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
]

def __call__(self, *args, **kwargs):
Expand All @@ -50,9 +51,10 @@ def __call__(self, *args, **kwargs):
# if noise kwarg is passed, assume it's an iterable of noise tensors
return [
likelihood(*args_, {**kwargs, "noise": noise_})
for likelihood, args_, noise_ in zip(self.likelihoods, _get_tuple_args_(*args), noise)
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
]
else:
return [
likelihood(*args_, **kwargs) for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
likelihood(*args_, **kwargs)
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
]
5 changes: 3 additions & 2 deletions gpytorch/mlls/sum_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import ModuleList

from gpytorch.mlls import ExactMarginalLogLikelihood, MarginalLogLikelihood
from gpytorch.utils.generic import length_safe_zip


class SumMarginalLogLikelihood(MarginalLogLikelihood):
Expand Down Expand Up @@ -30,10 +31,10 @@ def forward(self, outputs, targets, *params):
(e.g. parameters in case of heteroskedastic likelihoods)
"""
if len(params) == 0:
sum_mll = sum(mll(output, target) for mll, output, target in zip(self.mlls, outputs, targets))
sum_mll = sum(mll(output, target) for mll, output, target in length_safe_zip(self.mlls, outputs, targets))
else:
sum_mll = sum(
mll(output, target, *iparams)
for mll, output, target, iparams in zip(self.mlls, outputs, targets, params)
for mll, output, target, iparams in length_safe_zip(self.mlls, outputs, targets, params)
)
return sum_mll.div_(len(self.mlls))
28 changes: 20 additions & 8 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import settings
from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from ..utils.generic import length_safe_zip
from ..utils.warnings import GPInputWarning
from .exact_prediction_strategies import prediction_strategy
from .gp import GP
Expand Down Expand Up @@ -113,7 +114,7 @@ def set_train_data(self, inputs=None, targets=None, strict=True):
inputs = (inputs,)
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
if strict:
for input_, t_input in zip(inputs, self.train_inputs or (None,)):
for input_, t_input in length_safe_zip(inputs, self.train_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
Expand Down Expand Up @@ -200,10 +201,16 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[-1:])

full_inputs = [
torch.cat([train_input, input.expand(input_batch_shape + input.shape[-2:])], dim=-2)
for train_input, input in zip(train_inputs, inputs)
torch.cat(
[train_input, input.expand(input_batch_shape + input.shape[-2:])],
dim=-2,
)
for train_input, input in length_safe_zip(train_inputs, inputs)
]
full_targets = torch.cat([train_targets, targets.expand(target_batch_shape + targets.shape[-1:])], dim=-1)
full_targets = torch.cat(
[train_targets, targets.expand(target_batch_shape + targets.shape[-1:])],
dim=-1,
)

try:
fantasy_kwargs = {"noise": kwargs.pop("noise")}
Expand Down Expand Up @@ -253,7 +260,9 @@ def __call__(self, *args, **kwargs):
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
if not all(torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)):
if not all(
torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
):
raise RuntimeError("You must train on the training inputs!")
res = super().__call__(*inputs, **kwargs)
return res
Expand All @@ -270,7 +279,7 @@ def __call__(self, *args, **kwargs):
# Posterior mode
else:
if settings.debug.on():
if all(torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)):
if all(torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)):
warnings.warn(
"The input matches the stored training data. Did you forget to call model.train()?",
GPInputWarning,
Expand All @@ -291,7 +300,7 @@ def __call__(self, *args, **kwargs):
# Concatenate the input to the training input
full_inputs = []
batch_shape = train_inputs[0].shape[:-2]
for train_input, input in zip(train_inputs, inputs):
for train_input, input in length_safe_zip(train_inputs, inputs):
# Make sure the batch shapes agree for training/test data
if batch_shape != train_input.shape[:-2]:
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
Expand All @@ -317,7 +326,10 @@ def __call__(self, *args, **kwargs):

# Make the prediction
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(full_mean, full_covar)
(
predictive_mean,
predictive_covar,
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)

# Reshape predictive mean to match the appropriate event shape
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
Expand Down
16 changes: 12 additions & 4 deletions gpytorch/models/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from gpytorch.likelihoods import LikelihoodList
from gpytorch.models import GP
from gpytorch.utils.generic import length_safe_zip


class AbstractModelList(GP, ABC):
Expand Down Expand Up @@ -46,7 +47,9 @@ def likelihood_i(self, i, *args, **kwargs):
return self.likelihood.likelihoods[i](*args, **kwargs)

def forward(self, *args, **kwargs):
return [model.forward(*args_, **kwargs) for model, args_ in zip(self.models, _get_tensor_args(*args))]
return [
model.forward(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
]

def get_fantasy_model(self, inputs, targets, **kwargs):
"""
Expand All @@ -72,14 +75,19 @@ def get_fantasy_model(self, inputs, targets, **kwargs):

fantasy_models = [
model.get_fantasy_model(*inputs_, *targets_, **kwargs_)
for model, inputs_, targets_, kwargs_ in zip(
self.models, _get_tensor_args(*inputs), _get_tensor_args(*targets), kwargs
for model, inputs_, targets_, kwargs_ in length_safe_zip(
self.models,
_get_tensor_args(*inputs),
_get_tensor_args(*targets),
kwargs,
)
]
return self.__class__(*fantasy_models)

def __call__(self, *args, **kwargs):
return [model.__call__(*args_, **kwargs) for model, args_ in zip(self.models, _get_tensor_args(*args))]
return [
model.__call__(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
]

@property
def train_inputs(self):
Expand Down
3 changes: 2 additions & 1 deletion gpytorch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import linear_operator

from . import deprecation, errors, grid, interpolation, quadrature, transforms, warnings
from . import deprecation, errors, generic, grid, interpolation, quadrature, transforms, warnings
from .memoize import cached
from .nearest_neighbors import NNUtil

__all__ = [
"cached",
"deprecation",
"errors",
"generic",
"grid",
"interpolation",
"quadrature",
Expand Down
17 changes: 17 additions & 0 deletions gpytorch/utils/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3


def length_safe_zip(*args):
"""Python's `zip(...)` with checks to ensure the arguments have
the same number of elements.
NOTE: This converts all args that do not define "__len__" to a list.
"""
args = [a if hasattr(a, "__len__") else list(a) for a in args]
if len({len(a) for a in args}) > 1:
raise ValueError(
"Expected the lengths of all arguments to be equal. Got lengths "
f"{[len(a) for a in args]} for args {args}. Did you pass in "
"fewer inputs than expected?"
)
return zip(*args)
10 changes: 6 additions & 4 deletions test/models/test_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def test_forward_eval(self):
models = [self.create_model() for _ in range(2)]
model = IndependentModelList(*models)
model.eval()
model(torch.rand(3))
with self.assertRaises(ValueError):
model(torch.rand(3))
model(torch.rand(3), torch.rand(3))

def test_forward_eval_fixed_noise(self):
models = [self.create_model(fixed_noise=True) for _ in range(2)]
model = IndependentModelList(*models)
model.eval()
model(torch.rand(3))
model(torch.rand(3), torch.rand(3))

def test_get_fantasy_model(self):
models = [self.create_model() for _ in range(2)]
Expand All @@ -39,7 +41,7 @@ def test_get_fantasy_model(self):
fant_x = [torch.randn(2), torch.randn(3)]
fant_y = [torch.randn(2), torch.randn(3)]
fmodel = model.get_fantasy_model(fant_x, fant_y)
fmodel(torch.randn(4))
fmodel(torch.randn(4), torch.randn(4))

def test_get_fantasy_model_fixed_noise(self):
models = [self.create_model(fixed_noise=True) for _ in range(2)]
Expand All @@ -50,7 +52,7 @@ def test_get_fantasy_model_fixed_noise(self):
fant_y = [torch.randn(2), torch.randn(3)]
fant_noise = [0.1 * torch.ones(2), 0.1 * torch.ones(3)]
fmodel = model.get_fantasy_model(fant_x, fant_y, noise=fant_noise)
fmodel(torch.randn(4))
fmodel(torch.randn(4), torch.randn(4))


if __name__ == "__main__":
Expand Down

0 comments on commit 41a8386

Please sign in to comment.