Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpytorch.module: fix typing annotation #2611

Merged
merged 6 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def period_length(self):

@period_length.setter
def period_length(self, value):
return self._set_period_length(value)
self._set_period_length(value)

def _set_period_length(self, value):
if not torch.is_tensor(value):
Expand Down
35 changes: 21 additions & 14 deletions gpytorch/kernels/cylindrical_kernel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3

from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor

from .. import settings
from ..constraints import Interval, Positive
Expand Down Expand Up @@ -94,39 +95,45 @@ def __init__(
self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v))

@property
def angular_weights(self) -> torch.Tensor:
def angular_weights(self) -> Tensor:
return self.raw_angular_weights_constraint.transform(self.raw_angular_weights)

@angular_weights.setter
def angular_weights(self, value: torch.Tensor) -> None:
def angular_weights(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)

self.initialize(raw_angular_weights=self.raw_angular_weights_constraint.inverse_transform(value))

@property
def alpha(self) -> torch.Tensor:
def alpha(self) -> Tensor:
return self.raw_alpha_constraint.transform(self.raw_alpha)

@alpha.setter
def alpha(self, value: torch.Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)
def alpha(self, value: Tensor) -> None:
self._set_alpha(value)

def _set_alpha(self, value: Union[Tensor, float]) -> None:
# Used by the alpha_prior
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_alpha)
self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))

@property
def beta(self) -> torch.Tensor:
def beta(self) -> Tensor:
return self.raw_beta_constraint.transform(self.raw_beta)

@beta.setter
def beta(self, value: torch.Tensor) -> None:
if not torch.is_tensor(value):
value = torch.tensor(value)
def beta(self, value: Tensor) -> None:
self._set_beta(value)

def _set_beta(self, value: Union[Tensor, float]) -> None:
# Used by the beta_prior
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_beta)
self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))

def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, **params) -> torch.Tensor:
def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> Tensor:

x1_, x2_ = x1.clone(), x2.clone()
# Jitter datapoints that are exactly 0
Expand Down Expand Up @@ -156,12 +163,12 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
return radial_kernel.mul(angular_kernel)

def kuma(self, x: torch.Tensor) -> torch.Tensor:
def kuma(self, x: Tensor) -> Tensor:
alpha = self.alpha.view(*self.batch_shape, 1, 1)
beta = self.beta.view(*self.batch_shape, 1, 1)

res = 1 - (1 - x.pow(alpha) + self.eps).pow(beta)
return res

def num_outputs_per_input(self, x1: torch.Tensor, x2: torch.Tensor) -> int:
def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int:
return self.radial_base_kernel.num_outputs_per_input(x1, x2)
18 changes: 9 additions & 9 deletions gpytorch/kernels/hamming_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -95,13 +95,13 @@ def _alpha_param(self, m: Kernel) -> Tensor:
# Used by the alpha_prior
return m.alpha

def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _alpha_closure(self, m: Kernel, v: Union[Tensor, float]) -> None:
# Used by the alpha_prior
return m._set_alpha(v)
m._set_alpha(v)

def _set_alpha(self, value: Tensor):
def _set_alpha(self, value: Union[Tensor, float]) -> None:
# Used by the alpha_prior
if not torch.is_tensor(value):
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_alpha)
self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))

Expand All @@ -117,13 +117,13 @@ def _beta_param(self, m: Kernel) -> Tensor:
# Used by the beta_prior
return m.beta

def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _beta_closure(self, m: Kernel, v: Union[Tensor, float]) -> None:
# Used by the beta_prior
return m._set_beta(v)
m._set_beta(v)

def _set_beta(self, value: Tensor):
def _set_beta(self, value: Union[Tensor, float]) -> None:
# Used by the beta_prior
if not torch.is_tensor(value):
if not isinstance(value, Tensor):
value = torch.as_tensor(value).to(self.raw_beta)
self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))

Expand Down
4 changes: 2 additions & 2 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def _lengthscale_param(self, m: Kernel) -> Tensor:
# Used by the lengthscale_prior
return m.lengthscale

def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
def _lengthscale_closure(self, m: Kernel, v: Tensor) -> None:
# Used by the lengthscale_prior
return m._set_lengthscale(v)
m._set_lengthscale(v)

def _set_lengthscale(self, value: Tensor):
# Used by the lengthscale_prior
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/scale_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _outputscale_param(self, m):
return m.outputscale

def _outputscale_closure(self, m, v):
return m._set_outputscale(v)
m._set_outputscale(v)

@property
def outputscale(self):
Expand Down
16 changes: 8 additions & 8 deletions gpytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11
RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11

Closure = Callable[[nn.Module], Tensor]
SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf]
Closure = Callable[[NnModuleSelf], Tensor]
SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], None]
SamplesDict = Mapping[str, Union[Tensor, float]]


Expand Down Expand Up @@ -111,7 +111,7 @@ def added_loss_terms(self):
def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
raise NotImplementedError

def constraints(self):
def constraints(self) -> Iterator[Interval]:
for _, constraint in self.named_constraints():
yield constraint

Expand Down Expand Up @@ -294,8 +294,8 @@ def closure_new(module: nn.Module) -> Tensor:
if setting_closure is not None:
raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure")

def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf:
return module.initialize(**{param: val})
def setting_closure_new(module: Module, val: Union[Tensor, float]) -> None:
module.initialize(**{param: val})

setting_closure = setting_closure_new

Expand Down Expand Up @@ -443,7 +443,7 @@ def pyro_sample_from_prior(self) -> Module:
new_module = self.to_pyro_random_module()
return _pyro_sample_from_prior(module=new_module, memo=None, prefix="")

def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None:
def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[Prior], prefix: str) -> None:
"""
Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro
sampling mechanism.
Expand Down Expand Up @@ -516,7 +516,7 @@ def _set_strict(module: nn.Module, value: bool) -> None:


def _pyro_sample_from_prior(
module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = ""
module: NnModuleSelf, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> NnModuleSelf:
try:
import pyro
Expand Down Expand Up @@ -546,7 +546,7 @@ def _pyro_sample_from_prior(


def _pyro_load_from_samples(
module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = ""
module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> None:
if memo is None:
memo = set()
Expand Down
Loading