From 1586b0fb7c94594e114d775c0a872c8932fbe294 Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Tue, 19 Nov 2024 03:25:07 +0000 Subject: [PATCH 1/4] Add typing annotations to gpytorch.Module --- gpytorch/module.py | 233 +++++++++++++++++++++++---------------------- 1 file changed, 121 insertions(+), 112 deletions(-) diff --git a/gpytorch/module.py b/gpytorch/module.py index ff431a421..3c8b9bce1 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 +from __future__ import annotations + import copy import inspect import itertools import operator -from collections import OrderedDict -from typing import Union +from typing import Callable, Iterator, Mapping, MutableSet, Optional, TypeVar, Union import torch from linear_operator.operators import LinearOperator @@ -13,14 +14,64 @@ from torch.distributions import Distribution from .constraints import Interval +from .priors import Prior + + +NnModuleSelf = TypeVar("NnModuleSelf", bound=nn.Module) # TODO: replace w/ typing.Self in Python 3.11 +ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11 +RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11 + +Closure = Callable[[nn.Module], Tensor] +SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf] +SamplesDict = Mapping[str, Union[Tensor, float]] + + +class RandomModuleMixin: + def initialize(self: RandomModuleSelf, **kwargs) -> RandomModuleSelf: + """ + Set a value for a parameter + + kwargs: (param_name, value) - parameter to initialize. + Can also initialize recursively by passing in the full name of a + parameter. For example if model has attribute model.likelihood, + we can initialize the noise with either + `model.initialize(**{'likelihood.noise': 0.1})` + or + `model.likelihood.initialize(noise=0.1)`. + The former method would allow users to more easily store the + initialization values as one object. + + Value must be a Tensor + """ + for name, value in kwargs.items(): + if not isinstance(value, Tensor): + raise RuntimeError("Initialize in RandomModules can only be done with Tensor values.") + + names = name.rsplit(".") + if len(names) > 1: + mod_name, param_name = names + mod = operator.attrgetter(mod_name)(self) + else: + mod, param_name = self, name + + old_param = getattr(mod, param_name) + is_property = hasattr(type(self), name) and isinstance(getattr(type(self), name), property) + if not isinstance(old_param, torch.nn.Parameter) or is_property: + # Presumably we're calling a getter that will call initialize again on the actual parameter. + setattr(mod, param_name, value.expand(old_param.shape)) + else: + delattr(mod, param_name) + setattr(mod, param_name, value.expand(old_param.shape)) + + return self class Module(nn.Module): def __init__(self): super().__init__() - self._added_loss_terms = OrderedDict() - self._priors = OrderedDict() - self._constraints = OrderedDict() + self._added_loss_terms = {} + self._priors: dict[str, tuple[Prior, Closure, Optional[SettingClosure]]] = {} + self._constraints: dict[str, Interval] = {} self._strict_init = True self._load_strict_shapes = True @@ -40,7 +91,7 @@ def _clear_cache(self): """ pass - def _get_module_and_name(self, parameter_name): + def _get_module_and_name(self, parameter_name: str) -> tuple[nn.Module, str]: """Get module and name from full parameter name.""" module, name = parameter_name.split(".", 1) if module in self._modules: @@ -50,7 +101,7 @@ def _get_module_and_name(self, parameter_name): "Invalid parameter name {}. {} has no module {}".format(parameter_name, type(self).__name__, module) ) - def _strict(self, value): + def _strict(self, value: bool) -> None: _set_strict(self, value) def added_loss_terms(self): @@ -68,7 +119,7 @@ def hyperparameters(self): for _, param in self.named_hyperparameters(): yield param - def initialize(self, **kwargs): + def initialize(self: ModuleSelf, **kwargs) -> ModuleSelf: """ Set a value for a parameter @@ -98,7 +149,7 @@ def initialize(self, **kwargs): raise AttributeError("Unknown parameter {p} for {c}".format(p=name, c=self.__class__.__name__)) elif name not in self._parameters and name not in self._buffers: setattr(self, name, val) - elif torch.is_tensor(val): + elif isinstance(val, Tensor): constraint = self.constraint_for_parameter_name(name) if constraint is not None and constraint.enforced and not constraint.check_raw(val): raise RuntimeError( @@ -158,7 +209,7 @@ def named_hyperparameters(self): for elem in module.named_parameters(prefix=module_prefix, recurse=False): yield elem - def named_priors(self, memo=None, prefix=""): + def named_priors(self) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]: """Returns an iterator over the module's priors, yielding the name of the prior, the prior, the associated parameter names, and the transformation callable. @@ -172,7 +223,7 @@ def named_priors(self, memo=None, prefix=""): """ return _extract_named_priors(module=self, prefix="") - def named_constraints(self, memo=None, prefix=""): + def named_constraints(self) -> Iterator[tuple[str, Interval]]: return _extract_named_constraints(module=self, memo=None, prefix="") def named_variational_parameters(self): @@ -186,30 +237,22 @@ def named_variational_parameters(self): def register_added_loss_term(self, name): self._added_loss_terms[name] = None - def register_parameter(self, name, parameter): - r""" - Adds a parameter to the module. The parameter can be accessed as an attribute using the given name. - - Args: - name (str): - The name of the parameter - parameter (torch.nn.Parameter): - The parameter - """ - if "_parameters" not in self.__dict__: - raise AttributeError("Cannot assign parameter before Module.__init__() call") - super().register_parameter(name, parameter) - - def register_prior(self, name, prior, param_or_closure, setting_closure=None): + def register_prior( + self, + name: str, + prior: Prior, + param_or_closure: Union[str, Closure], + setting_closure: Optional[SettingClosure] = None, + ) -> None: """ Adds a prior to the module. The prior can be accessed as an attribute using the given name. Args: - name (str): + name: The name of the prior - prior (Prior): + prior: The prior to be registered` - param_or_closure (string or callable): + param_or_closure: Either the name of the parameter, or a closure (which upon calling evalutes a function on the module instance and one or more parameters): single parameter without a transform: `.register_prior("foo_prior", foo_prior, "foo_param")` @@ -217,33 +260,36 @@ def register_prior(self, name, prior, param_or_closure, setting_closure=None): `.register_prior("foo_prior", NormalPrior(0, 1), lambda module: torch.log(module.foo_param))` function of multiple parameters: `.register_prior("foo2_prior", foo2_prior, lambda module: f(module.param1, module.param2)))` - setting_closure (callable, optional): + setting_closure: A function taking in the module instance and a tensor in (transformed) parameter space, initializing the internal parameter representation to the proper value by applying the inverse transform. Enables setting parametres directly in the transformed space, as well as sampling parameter values from priors (see `sample_from_prior`) - """ if isinstance(param_or_closure, str): - if param_or_closure not in self._parameters and not hasattr(self, param_or_closure): + param = param_or_closure + if param not in self._parameters and not hasattr(self, param): raise AttributeError( - "Unknown parameter {name} for {module}".format( - name=param_or_closure, module=self.__class__.__name__ - ) + "Unknown parameter {name} for {module}".format(name=param, module=self.__class__.__name__) + " Make sure the parameter is registered before registering a prior." ) - def closure(module): - return getattr(module, param_or_closure) + def closure_new(module: nn.Module) -> Tensor: + return getattr(module, param) + + closure = closure_new if setting_closure is not None: raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure") - def setting_closure(module, val): - return module.initialize(**{param_or_closure: val}) + def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf: + return module.initialize(**{param: val}) + + setting_closure = setting_closure_new else: - if len(inspect.signature(param_or_closure).parameters) == 0: + closure = param_or_closure + if len(inspect.signature(closure).parameters) == 0: raise ValueError( """As of version 1.4, `param_or_closure` must operate on a module instance. For example: @@ -266,12 +312,11 @@ def setting_closure(module, val): ) """ ) - closure = param_or_closure self.add_module(name, prior) self._priors[name] = (prior, closure, setting_closure) - def register_constraint(self, param_name, constraint, replace=True): + def register_constraint(self, param_name: str, constraint: Interval, replace: bool = True) -> None: if param_name not in self._parameters: raise RuntimeError("Attempting to register constraint for nonexistent parameter.") @@ -299,7 +344,7 @@ def train(self, mode=True): self._clear_cache() return super().train(mode=mode) - def constraint_for_parameter_name(self, param_name): + def constraint_for_parameter_name(self, param_name: str) -> Interval | None: base_module = self base_name = param_name @@ -344,11 +389,11 @@ def apply_fn(module): self.apply(apply_fn) - def named_parameters_and_constraints(self): + def named_parameters_and_constraints(self) -> Iterator[tuple[str, nn.Parameter, Interval | None]]: for name, param in self.named_parameters(): yield name, param, self.constraint_for_parameter_name(name) - def sample_from_prior(self, prior_name): + def sample_from_prior(self, prior_name: str) -> None: """Sample parameter values from prior. Modifies the module's parameters in-place.""" if prior_name not in self._priors: raise RuntimeError("Unknown prior name '{}'".format(prior_name)) @@ -357,10 +402,10 @@ def sample_from_prior(self, prior_name): raise RuntimeError("Must provide inverse transform to be able to sample from prior.") setting_closure(self, prior.sample()) - def to_pyro_random_module(self): + def to_pyro_random_module(self) -> Module: return self.to_random_module() - def to_random_module(self): + def to_random_module(self) -> Module: random_module_cls = type("_Random" + self.__class__.__name__, (RandomModuleMixin, self.__class__), {}) if not isinstance(self, random_module_cls): new_module = copy.deepcopy(self) @@ -375,7 +420,7 @@ def to_random_module(self): return new_module - def pyro_sample_from_prior(self): + def pyro_sample_from_prior(self) -> Module: """ For each parameter in this Module and submodule that have defined priors, sample a value for that parameter from its corresponding prior with a pyro.sample primitive and load the resulting value in to the parameter. @@ -386,7 +431,7 @@ def pyro_sample_from_prior(self): new_module = self.to_pyro_random_module() return _pyro_sample_from_prior(module=new_module, memo=None, prefix="") - def local_load_samples(self, samples_dict, memo, prefix): + def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None: """ Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro sampling mechanism. @@ -396,13 +441,15 @@ def local_load_samples(self, samples_dict, memo, prefix): acquire an extra batch dimension corresponding to the number of samples drawn. """ self._strict(False) - for name, (prior, closure, setting_closure) in self._priors.items(): + for name, (prior, _, setting_closure) in self._priors.items(): if prior is not None and prior not in memo: memo.add(prior) + if setting_closure is None: + raise RuntimeError("Must provide setting_closure to load samples.") setting_closure(self, samples_dict[prefix + ("." if prefix else "") + name]) self._strict(True) - def pyro_load_from_samples(self, samples_dict): + def pyro_load_from_samples(self, samples_dict: SamplesDict) -> None: """ Convert this Module in to a batch Module by loading parameters from the given `samples_dict`. `samples_dict` is typically produced by a Pyro sampling mechanism. @@ -412,9 +459,9 @@ def pyro_load_from_samples(self, samples_dict): the prior to properly set the unconstrained parameter. Args: - samples_dict (dict): Dictionary mapping *prior names* to sample values. + samples_dict: Dictionary mapping *prior names* to sample values. """ - return _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="") + _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="") def update_added_loss_term(self, name, added_loss_term): from .mlls import AddedLossTerm @@ -432,29 +479,23 @@ def variational_parameters(self): def _validate_module_outputs(outputs): if isinstance(outputs, tuple): - if not all( - torch.is_tensor(output) or isinstance(output, Distribution) or isinstance(output, LinearOperator) - for output in outputs - ): + if not all(isinstance(output, (Tensor, Distribution, LinearOperator)) for output in outputs): raise RuntimeError( - "All outputs must be a Distribution, torch.Tensor, or LinearOperator. " + "All outputs must be a torch.Tensor, Distribution, or LinearOperator. " "Got {}".format([output.__class__.__name__ for output in outputs]) ) if len(outputs) == 1: outputs = outputs[0] return outputs - elif torch.is_tensor(outputs) or isinstance(outputs, Distribution) or isinstance(outputs, LinearOperator): + elif isinstance(outputs, (Tensor, Distribution, LinearOperator)): return outputs else: raise RuntimeError( - "Output must be a Distribution, torch.Tensor, or LinearOperator. Got {}".format(outputs.__class__.__name__) + "Output must be a torch.Tensor, Distribution, or LinearOperator. Got {}".format(outputs.__class__.__name__) ) -def _set_strict(module, value, memo=None): - if memo is None: - memo = set() - +def _set_strict(module: nn.Module, value: bool) -> None: if hasattr(module, "_strict_init"): module._strict_init = value @@ -462,7 +503,9 @@ def _set_strict(module, value, memo=None): _set_strict(module_, value) -def _pyro_sample_from_prior(module, memo=None, prefix=""): +def _pyro_sample_from_prior( + module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = "" +) -> NnModuleSelf: try: import pyro except ImportError: @@ -470,7 +513,7 @@ def _pyro_sample_from_prior(module, memo=None, prefix=""): if memo is None: memo = set() - if hasattr(module, "_priors"): + if isinstance(module, Module): for prior_name, (prior, closure, setting_closure) in module._priors.items(): if prior is not None and prior not in memo: if setting_closure is None: @@ -490,10 +533,12 @@ def _pyro_sample_from_prior(module, memo=None, prefix=""): return module -def _pyro_load_from_samples(module, samples_dict, memo=None, prefix=""): +def _pyro_load_from_samples( + module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = "" +) -> None: if memo is None: memo = set() - if hasattr(module, "_priors"): + if isinstance(module, Module): module.local_load_samples(samples_dict, memo, prefix) for mname, module_ in module.named_children(): @@ -515,8 +560,10 @@ def _extract_named_added_loss_terms(module, memo=None, prefix=""): yield name, strategy -def _extract_named_priors(module, prefix=""): - if hasattr(module, "_priors"): +def _extract_named_priors( + module: nn.Module, prefix: str = "" +) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]: + if isinstance(module, Module): for name, (prior, closure, inv_closure) in module._priors.items(): if prior is not None: full_name = ("." if prefix else "").join([prefix, name]) @@ -527,10 +574,12 @@ def _extract_named_priors(module, prefix=""): yield name, parent_module, prior, closure, inv_closure -def _extract_named_constraints(module, memo=None, prefix=""): +def _extract_named_constraints( + module: nn.Module, memo: Optional[MutableSet[Interval]] = None, prefix: str = "" +) -> Iterator[tuple[str, Interval]]: if memo is None: memo = set() - if hasattr(module, "_constraints"): + if isinstance(module, Module): for name, constraint in module._constraints.items(): if constraint is not None and constraint not in memo: memo.add(constraint) @@ -540,43 +589,3 @@ def _extract_named_constraints(module, memo=None, prefix=""): submodule_prefix = prefix + ("." if prefix else "") + mname for name, constraint in _extract_named_constraints(module_, memo=memo, prefix=submodule_prefix): yield name, constraint - - -class RandomModuleMixin(object): - def initialize(self, **kwargs): - """ - Set a value for a parameter - - kwargs: (param_name, value) - parameter to initialize. - Can also initialize recursively by passing in the full name of a - parameter. For example if model has attribute model.likelihood, - we can initialize the noise with either - `model.initialize(**{'likelihood.noise': 0.1})` - or - `model.likelihood.initialize(noise=0.1)`. - The former method would allow users to more easily store the - initialization values as one object. - - Value can take the form of a tensor, a float, or an int - """ - for name, value in kwargs.items(): - if not torch.is_tensor(value): - raise RuntimeError("Initialize in RandomModules can only be done with tensor values.") - - names = name.rsplit(".") - if len(names) > 1: - mod_name, param_name = names - mod = operator.attrgetter(mod_name)(self) - else: - mod, param_name = self, name - - old_param = getattr(mod, param_name) - is_property = hasattr(type(self), name) and isinstance(getattr(type(self), name), property) - if not isinstance(old_param, torch.nn.Parameter) or is_property: - # Presumably we're calling a getter that will call initialize again on the actual parameter. - setattr(mod, param_name, value.expand(old_param.shape)) - else: - delattr(mod, param_name) - setattr(mod, param_name, value.expand(old_param.shape)) - - return self From 02ac961c6e5b6a3a1280c4372f07361f294a1a96 Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Tue, 19 Nov 2024 05:50:31 +0000 Subject: [PATCH 2/4] Undo removal of Module.register_parameter() --- gpytorch/module.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gpytorch/module.py b/gpytorch/module.py index 3c8b9bce1..57550755a 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -237,6 +237,18 @@ def named_variational_parameters(self): def register_added_loss_term(self, name): self._added_loss_terms[name] = None + def register_parameter(self, name: str, parameter: Optional[nn.Parameter]) -> None: + r""" + Adds a parameter to the module. The parameter can be accessed as an attribute using the given name. + + Args: + name: + The name of the parameter + parameter: + The parameter + """ + super().register_parameter(name, parameter) + def register_prior( self, name: str, From 73a63c3d44331a3c2582ed3fe43ec4067972cbbf Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Wed, 20 Nov 2024 21:01:42 +0000 Subject: [PATCH 3/4] Fix typing annotations in gpytorch.Module --- gpytorch/module.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gpytorch/module.py b/gpytorch/module.py index 57550755a..e4bfeda4c 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -111,7 +111,7 @@ def added_loss_terms(self): def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]: raise NotImplementedError - def constraints(self): + def constraints(self) -> Iterator[Interval]: for _, constraint in self.named_constraints(): yield constraint @@ -443,7 +443,7 @@ def pyro_sample_from_prior(self) -> Module: new_module = self.to_pyro_random_module() return _pyro_sample_from_prior(module=new_module, memo=None, prefix="") - def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[str], prefix: str) -> None: + def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[Prior], prefix: str) -> None: """ Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro sampling mechanism. @@ -516,7 +516,7 @@ def _set_strict(module: nn.Module, value: bool) -> None: def _pyro_sample_from_prior( - module: NnModuleSelf, memo: Optional[MutableSet[str]] = None, prefix: str = "" + module: NnModuleSelf, memo: Optional[MutableSet[Prior]] = None, prefix: str = "" ) -> NnModuleSelf: try: import pyro @@ -546,7 +546,7 @@ def _pyro_sample_from_prior( def _pyro_load_from_samples( - module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[str]] = None, prefix: str = "" + module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[Prior]] = None, prefix: str = "" ) -> None: if memo is None: memo = set() From 3c0a274b9746ce6b1afe7fc4e9707b1e5326600a Mon Sep 17 00:00:00 2001 From: Christopher Yeh Date: Tue, 3 Dec 2024 00:56:53 +0000 Subject: [PATCH 4/4] Remove return value from SettingClosure --- gpytorch/kernels/cosine_kernel.py | 2 +- gpytorch/kernels/cylindrical_kernel.py | 35 +++++++++++++++----------- gpytorch/kernels/hamming_kernel.py | 18 ++++++------- gpytorch/kernels/kernel.py | 4 +-- gpytorch/kernels/scale_kernel.py | 2 +- gpytorch/module.py | 8 +++--- 6 files changed, 38 insertions(+), 31 deletions(-) diff --git a/gpytorch/kernels/cosine_kernel.py b/gpytorch/kernels/cosine_kernel.py index 11add6f2f..a688eba67 100644 --- a/gpytorch/kernels/cosine_kernel.py +++ b/gpytorch/kernels/cosine_kernel.py @@ -91,7 +91,7 @@ def period_length(self): @period_length.setter def period_length(self, value): - return self._set_period_length(value) + self._set_period_length(value) def _set_period_length(self, value): if not torch.is_tensor(value): diff --git a/gpytorch/kernels/cylindrical_kernel.py b/gpytorch/kernels/cylindrical_kernel.py index 48f24958c..ea66c956a 100644 --- a/gpytorch/kernels/cylindrical_kernel.py +++ b/gpytorch/kernels/cylindrical_kernel.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -from typing import Optional +from typing import Optional, Union import torch +from torch import Tensor from .. import settings from ..constraints import Interval, Positive @@ -94,39 +95,45 @@ def __init__( self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v)) @property - def angular_weights(self) -> torch.Tensor: + def angular_weights(self) -> Tensor: return self.raw_angular_weights_constraint.transform(self.raw_angular_weights) @angular_weights.setter - def angular_weights(self, value: torch.Tensor) -> None: + def angular_weights(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.tensor(value) self.initialize(raw_angular_weights=self.raw_angular_weights_constraint.inverse_transform(value)) @property - def alpha(self) -> torch.Tensor: + def alpha(self) -> Tensor: return self.raw_alpha_constraint.transform(self.raw_alpha) @alpha.setter - def alpha(self, value: torch.Tensor) -> None: - if not torch.is_tensor(value): - value = torch.tensor(value) + def alpha(self, value: Tensor) -> None: + self._set_alpha(value) + def _set_alpha(self, value: Union[Tensor, float]) -> None: + # Used by the alpha_prior + if not isinstance(value, Tensor): + value = torch.as_tensor(value).to(self.raw_alpha) self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value)) @property - def beta(self) -> torch.Tensor: + def beta(self) -> Tensor: return self.raw_beta_constraint.transform(self.raw_beta) @beta.setter - def beta(self, value: torch.Tensor) -> None: - if not torch.is_tensor(value): - value = torch.tensor(value) + def beta(self, value: Tensor) -> None: + self._set_beta(value) + def _set_beta(self, value: Union[Tensor, float]) -> None: + # Used by the beta_prior + if not isinstance(value, Tensor): + value = torch.as_tensor(value).to(self.raw_beta) self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value)) - def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, **params) -> torch.Tensor: + def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> Tensor: x1_, x2_ = x1.clone(), x2.clone() # Jitter datapoints that are exactly 0 @@ -156,12 +163,12 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params) return radial_kernel.mul(angular_kernel) - def kuma(self, x: torch.Tensor) -> torch.Tensor: + def kuma(self, x: Tensor) -> Tensor: alpha = self.alpha.view(*self.batch_shape, 1, 1) beta = self.beta.view(*self.batch_shape, 1, 1) res = 1 - (1 - x.pow(alpha) + self.eps).pow(beta) return res - def num_outputs_per_input(self, x1: torch.Tensor, x2: torch.Tensor) -> int: + def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int: return self.radial_base_kernel.num_outputs_per_input(x1, x2) diff --git a/gpytorch/kernels/hamming_kernel.py b/gpytorch/kernels/hamming_kernel.py index 6a28a2aa9..d942872b8 100644 --- a/gpytorch/kernels/hamming_kernel.py +++ b/gpytorch/kernels/hamming_kernel.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import nn, Tensor @@ -95,13 +95,13 @@ def _alpha_param(self, m: Kernel) -> Tensor: # Used by the alpha_prior return m.alpha - def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _alpha_closure(self, m: Kernel, v: Union[Tensor, float]) -> None: # Used by the alpha_prior - return m._set_alpha(v) + m._set_alpha(v) - def _set_alpha(self, value: Tensor): + def _set_alpha(self, value: Union[Tensor, float]) -> None: # Used by the alpha_prior - if not torch.is_tensor(value): + if not isinstance(value, Tensor): value = torch.as_tensor(value).to(self.raw_alpha) self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value)) @@ -117,13 +117,13 @@ def _beta_param(self, m: Kernel) -> Tensor: # Used by the beta_prior return m.beta - def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _beta_closure(self, m: Kernel, v: Union[Tensor, float]) -> None: # Used by the beta_prior - return m._set_beta(v) + m._set_beta(v) - def _set_beta(self, value: Tensor): + def _set_beta(self, value: Union[Tensor, float]) -> None: # Used by the beta_prior - if not torch.is_tensor(value): + if not isinstance(value, Tensor): value = torch.as_tensor(value).to(self.raw_beta) self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value)) diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 67e576db3..0a4c49efa 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -216,9 +216,9 @@ def _lengthscale_param(self, m: Kernel) -> Tensor: # Used by the lengthscale_prior return m.lengthscale - def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor: + def _lengthscale_closure(self, m: Kernel, v: Tensor) -> None: # Used by the lengthscale_prior - return m._set_lengthscale(v) + m._set_lengthscale(v) def _set_lengthscale(self, value: Tensor): # Used by the lengthscale_prior diff --git a/gpytorch/kernels/scale_kernel.py b/gpytorch/kernels/scale_kernel.py index 520913265..fdfadb0af 100644 --- a/gpytorch/kernels/scale_kernel.py +++ b/gpytorch/kernels/scale_kernel.py @@ -90,7 +90,7 @@ def _outputscale_param(self, m): return m.outputscale def _outputscale_closure(self, m, v): - return m._set_outputscale(v) + m._set_outputscale(v) @property def outputscale(self): diff --git a/gpytorch/module.py b/gpytorch/module.py index e4bfeda4c..e5081d878 100644 --- a/gpytorch/module.py +++ b/gpytorch/module.py @@ -21,8 +21,8 @@ ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11 RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11 -Closure = Callable[[nn.Module], Tensor] -SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], ModuleSelf] +Closure = Callable[[NnModuleSelf], Tensor] +SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], None] SamplesDict = Mapping[str, Union[Tensor, float]] @@ -294,8 +294,8 @@ def closure_new(module: nn.Module) -> Tensor: if setting_closure is not None: raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure") - def setting_closure_new(module: ModuleSelf, val: Union[Tensor, float]) -> ModuleSelf: - return module.initialize(**{param: val}) + def setting_closure_new(module: Module, val: Union[Tensor, float]) -> None: + module.initialize(**{param: val}) setting_closure = setting_closure_new