diff --git a/pearl/neural_networks/common/epistemic_neural_networks.py b/pearl/neural_networks/common/epistemic_neural_networks.py index 362c303..11472c5 100644 --- a/pearl/neural_networks/common/epistemic_neural_networks.py +++ b/pearl/neural_networks/common/epistemic_neural_networks.py @@ -172,6 +172,7 @@ def generate_params_buffers(self) -> None: """ Generate parameters and buffers for the priornet. """ + # pyre-fixme[6]: For 1st argument expected `List[Module]` but got `ModuleList`. self.params, self.buffers = torch.func.stack_module_state(self.models) def call_single_model( diff --git a/pearl/neural_networks/common/utils.py b/pearl/neural_networks/common/utils.py index ecfac83..15481cd 100644 --- a/pearl/neural_networks/common/utils.py +++ b/pearl/neural_networks/common/utils.py @@ -226,6 +226,8 @@ def wrapper( ) -> torch.Tensor: return torch.func.functional_call(models[0], (params, buffers), data) + # pyre-fixme[6]: For 1st argument expected `List[Module]` but got + # `Union[List[Module], ModuleList]`. params, buffers = stack_module_state(models) values = torch.vmap(wrapper)(params, buffers, features).view( (-1, batch_size)