Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeOps regressions from #2296. #2413

Merged
merged 2 commits into from
Nov 13, 2023
Merged

Fix KeOps regressions from #2296. #2413

merged 2 commits into from
Nov 13, 2023

Conversation

gpleiss
Copy link
Member

@gpleiss gpleiss commented Sep 21, 2023

KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general.

The fixes:

  • KeOpsKernels now only define a forward function, that will be used both when we want to use KeOps and when we want to bypass it.
  • KeOpsKernels now use a _lazify_inputs helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors.
  • The KeOps wrapping happens unless we want to bypass KeOps, which occurs when either (1) the matrix is small (below Cholesky size) or (2) when the use has turned off the gpytorch.settings.use_keops option (NEW IN THIS PR).

Why this is beneficial:

  • KeOps kernels now follow the same API as non-KeOps kernels (define a forward method)
  • The user now only has to define one forward method, that works in both the keops and non-keops cases
  • The diagonal call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator).

Other changes:

  • Fix stability issues with the keops MaternKernel. (There were some NaN issues)
  • Introduce a gpytorch.settings.use_keops feature flag.
  • Clean up KeOPs notebook

[Fixes #2363]

KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
class KeOpsKernel(Kernel):
def __init__(self, *args: Any, **kwargs: Any):
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel")
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we being softer about this failure? If someone has gone to the trouble to import and try to instantiate a KeOpsKernel instead of the standard kernel, doesn't the hard stop make sense?

Copy link
Member

@jacobrgardner jacobrgardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm other than the question about whether not having keops installed should be a hard failure

@gpleiss gpleiss enabled auto-merge (squash) November 13, 2023 18:41
@gpleiss gpleiss merged commit efb6121 into master Nov 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] KeOps LazyTensors are no longer compatible with linear-operator
2 participants