Fix issues with indexing broadcasted LazyEvaluatedKernelTensors #1971
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When batch-evaluating the posterior of a model (i.e.when, in eval mode, passing in a tensor with a batch shape that has more dimensions than the batch shape of the training data), then, if the kernel is lazily evaluated, this leads to obscure indexing errors. The issue is that the raw parameters of the lazy evaluated kernel have different batch shapes than the batch shape of the covariance matrix.
Specifically, if train_batch_shape is the batch shape of the kernel (unbroadcasted, as used for computing the train-train covariance)
and the posterior is evaluated with some batch shape that has more dimensions, e.g.
eval_batch_shape = add_batch_shape + train_batch_shape, then under the hood when the
LazyEvaluatedKernelTensoris indexed, we use a batch index meant for the
eval_batch_shapeto index the underlyig ernel parameters (e.g. lengthscales) who only have the leading
train_batch_shape, which casues indexing issues. This doesn't happen in the case of eager evaluations of the kernel, since the parameters are appropriately broadcasted to the new larger
eval_batch_shapebefore indexing happens (i.e. when computing the kernel the parameters are automatically "broadcasted up" to the shape that's required to produce the kernel with the
eval_batch_shape`.It's not entirely clear how to best fix this. The path pursued with this PR is to check pre-indexing whether the dimension of the batch shapes are different, and then automatically modify the indices that are passed to the
__getitem__
function of kernel (and thus subsequently to the various lazy tensors that the kernel consists of). It appears to work but seems quite dicey and easy to break, so open for alternative suggestions.There likely will be similar issues if the posterior of a batchd model evaluated on a non-batch input is indexed while lazilyt evaluated, but that's something for an separate PR.