Skip to content

Commit

Permalink
Convert directory fbcode/pearl to use the Ruff Formatter
Browse files Browse the repository at this point in the history
Summary:
Converts the directory specified to use the Ruff formatter in pyfmt

ruff_dog

If this diff causes merge conflicts when rebasing, please run
`hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt`
on your diff, and amend any changes before rebasing onto latest.
That should help reduce or eliminate any merge conflicts.

allow-large-files

Reviewed By: amyreese

Differential Revision: D66370895

fbshipit-source-id: 003724df9eb7f9ac43246f86cc7e4cbdfc8ddcbe
  • Loading branch information
Thomas Polasek authored and facebook-github-bot committed Nov 25, 2024
1 parent 010d4f8 commit 5cb0d76
Show file tree
Hide file tree
Showing 32 changed files with 43 additions and 75 deletions.
7 changes: 2 additions & 5 deletions pearl/neural_networks/common/value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
This module defines several types of value neural networks.
"""


from abc import ABC
from typing import Any, List, Optional

Expand Down Expand Up @@ -86,14 +85,12 @@ def __init__(
output_channels_list: list[int],
strides: list[int],
paddings: list[int],
hidden_dims_fully_connected: None | (
list[int]
) = None, # hidden dims for fully connected layers
hidden_dims_fully_connected: None
| (list[int]) = None, # hidden dims for fully connected layers
use_batch_norm_conv: bool = False,
use_batch_norm_fully_connected: bool = False,
output_dim: int = 1, # dimension of the final output
) -> None:

assert (
len(kernel_sizes)
== len(output_channels_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
This module defines several types of actor neural networks.
"""


from typing import List, Optional, Tuple, Union

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def get_q_value_distribution(
state_batch: Tensor,
action_batch: Tensor,
) -> Tensor:

x = torch.cat([state_batch, action_batch], dim=-1)
return self.forward(x)

Expand Down Expand Up @@ -375,9 +374,7 @@ def get_q_values(
# q value of (state, action) pair of interest
state_action_values = torch.gather(
values_state_available_actions, 1, action_idx
).view(
-1
) # shape: (batch_size)
).view(-1) # shape: (batch_size)
return state_action_values


Expand All @@ -399,7 +396,6 @@ def __init__(
hidden_dims: list[int] | None,
output_dim: int = 1,
) -> None:

super().__init__()

"""
Expand Down Expand Up @@ -479,7 +475,6 @@ def __init__(
state_hidden_dims: list[int] | None = None,
action_hidden_dims: list[int] | None = None,
) -> None:

super().__init__(
state_input_dim=state_dim,
action_input_dim=action_dim,
Expand Down
6 changes: 5 additions & 1 deletion pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def act(self, exploit: bool = False) -> Action:
safe_action_space.to(self.device)

action = self.policy_learner.act(
subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6]
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Optional[Tensor]`.
subjective_state_to_be_used,
safe_action_space,
exploit=exploit, # pyre-fixme[6]
)

self._latest_action = action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def set_history_summarization_module(
self._history_summarization_module = value

def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:

# get scores for logging purpose
ucb_scores = self.get_scores(batch.state).mean()

Expand Down
1 change: 0 additions & 1 deletion pearl/policy_learners/policy_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def __init__(
action_representation_module: ActionRepresentationModule | None = None,
**options: Any,
) -> None:

super().__init__(
on_policy=on_policy,
is_action_continuous=is_action_continuous,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ def __init__(
on_policy: bool = False,
action_representation_module: ActionRepresentationModule | None = None,
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: None | (
ValueNetwork | QValueNetwork | nn.Module
) = None,
critic_network_instance: None
| (ValueNetwork | QValueNetwork | nn.Module) = None,
) -> None:
super().__init__(
on_policy=on_policy,
Expand Down
2 changes: 0 additions & 2 deletions pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

# sample a batch of actions from the actor network; shape (batch_size, action_dim)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
action_batch = self._actor.sample_action(batch.state)
Expand All @@ -116,7 +115,6 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
return loss

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

with torch.no_grad():
# sample a batch of next actions from target actor network;
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def get_next_state_values(
def _prepare_next_state_action_batch(
self, batch: TransitionBatch
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

# This function outputs tensors:
# - next_state_batch: (batch_size x action_space_size x state_dim).
# - next_available_actions_batch: (batch_size x action_space_size x action_dim).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def get_next_state_values(

next_state_action_values = self._Q.get_q_values(
next_state_batch_repeated, next_available_actions_batch
).view(
(batch_size, -1)
) # (batch_size x action_space_size)
).view((batch_size, -1)) # (batch_size x action_space_size)
# Make sure that unavailable actions' Q values are assigned to -inf
next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
}

def _value_loss(self, batch: TransitionBatch) -> torch.Tensor:

with torch.no_grad():
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
q1, q2 = self._critic_target.get_q_values(batch.state, batch.action)
Expand Down
4 changes: 1 addition & 3 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
# and the truncated lambda return for all states in the replay buffer.
next_value = self._critic(
self._history_summarization_module(next_state_in_device)
).detach()[
0
] # shape (1,)
).detach()[0] # shape (1,)
gae = torch.tensor([0.0]).to(state_values.device)
for i, transition in enumerate(reversed(replay_buffer.memory)):
original_transition_device = transition.device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def _get_next_state_quantiles(
# `get_q_values_under_risk_metric`.
next_state_action_values = self.safety_module.get_q_values_under_risk_metric(
next_state_batch_repeated, next_available_actions_batch, self._Q_target
).view(
batch_size, -1
) # shape: (batch_size, action_space_size)
).view(batch_size, -1) # shape: (batch_size, action_space_size)

# make sure that unavailable actions' Q values are assigned to -inf
next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
batch, batch_size
) * self._discount_factor * (1 - batch.terminated.float()).unsqueeze(
-1
) + batch.reward.unsqueeze(
-1
)
) + batch.reward.unsqueeze(-1)

"""
Step 3: pairwise distributional quantile loss:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def reset(self, action_space: ActionSpace) -> None:
self.scheduler.step()

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

reward_batch = batch.reward # (batch_size)
terminated_batch = batch.terminated # (batch_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
return actor_critic_loss

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

reward_batch = batch.reward # shape: (batch_size)
terminated_batch = batch.terminated # shape: (batch_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def learn(
self,
replay_buffer: ReplayBuffer,
) -> dict[str, Any]:

# We know the sampling result from SingleTransitionReplayBuffer
# is a list with a single tuple.
transitions = replay_buffer.sample(1)
Expand Down
1 change: 0 additions & 1 deletion pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def __init__(
self._behavior_policy: torch.nn.Module = behavior_policy

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

# sample a batch of actions from the actor network; shape (batch_size, action_dim)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
action_batch = self._actor.sample_action(batch.state)
Expand Down
1 change: 0 additions & 1 deletion pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def _create_transition_batch(
transitions: list[Transition],
is_action_continuous: bool,
) -> TransitionBatch:

if len(transitions) == 0:
return TransitionBatch(
state=torch.empty(0),
Expand Down
5 changes: 1 addition & 4 deletions pearl/safety_modules/reward_constrained_safety_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def constraint_lambda_update(
def cost_critic_learn_batch(
self, batch: TransitionBatch, policy_learner: PolicyLearner
) -> dict[str, Any]:

with torch.no_grad():
# sample next_action from actor's target network; shape (batch_size, action_dim)
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
Expand All @@ -179,9 +178,7 @@ def cost_critic_learn_batch(
# cost + gamma * (min{Qtarget_1(s', a from target actor network),
# Qtarget_2(s', a from target actor network)})
expected_state_action_values = (
next_q
* self.cost_discount_factor
* (1 - batch.terminated.float())
next_q * self.cost_discount_factor * (1 - batch.terminated.float())
# pyre-fixme[58]: `+` is not supported for operand types `Tensor` and
# `Optional[Tensor]`.
) + batch.cost # (batch_size)
Expand Down
2 changes: 1 addition & 1 deletion pearl/user_envs/envs/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def reset(
self,
*,
seed: int | None = None,
options: dict[str, float] | None = None
options: dict[str, float] | None = None,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> tuple[np.ndarray, dict[str, float]]:
super().reset(seed=seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def create_offline_data(
agent.reset(observation, action_space)
done = False
while not done:

# exploit is explicitly set to False as we want exploration during data collection with
# standard benchmark environments like Gym, MuJoCo etc.
action = agent.act(exploit=False)
Expand Down
6 changes: 5 additions & 1 deletion pearl/utils/functional_utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def first_item(i: Iterable[V]) -> V | None:


def find_argument(
kwarg_key: str, arg_type: type[ArgType], *args, **kwargs # pyre-ignore
kwarg_key: str,
arg_type: type[ArgType],
# pyre-fixme[2]: Parameter must be annotated.
*args,
**kwargs, # pyre-ignore
) -> ArgType | None:
"""
Finds the first argument in args and kwargs that either has type `arg_type` or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def online_learning(
info_period = {}
if number_of_steps is not None and old_total_steps // record_period < (
total_steps
) // (
record_period
): # record average info value every record_period steps
) // (record_period): # record average info value every record_period steps
for key in info_period:
info.setdefault(key, []).append(np.mean(info_period[key]))
info_period = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
target_column: int = 0,
ind_to_drop: list[int] | None = None,
) -> None:

if ind_to_drop is None:
ind_to_drop = []

Expand Down Expand Up @@ -104,11 +103,9 @@ def __init__(

# Set observation space and observation dimension
self.observation_dim: int = tensor.size()[1] - 1 # 0th index is the target
self._observation_space: Space = (
BoxSpace( # Box space with low for each dimension = -inf, high for each dimension = inf
high=torch.full((self.observation_dim,), float("inf")),
low=torch.full((self.observation_dim,), float("-inf")),
)
self._observation_space: Space = BoxSpace( # Box space with low for each dimension = -inf, high for each dimension = inf
high=torch.full((self.observation_dim,), float("inf")),
low=torch.full((self.observation_dim,), float("-inf")),
)

# Set noise to be added to reward
Expand Down Expand Up @@ -157,8 +154,9 @@ def reset(self, seed: int | None = None) -> tuple[Observation, ActionSpace]:
Provides the observation and action space to the agent.
"""
data_point = next(iter(self.dataloader_tr))
label, observation = data_point[1].to(self.device), data_point[0].to(
self.device
label, observation = (
data_point[1].to(self.device),
data_point[0].to(self.device),
)
self._observation = torch.squeeze(observation)
self._current_label = label.to(int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- one for discrete action space
- one for continuous action space
"""

import math
import random
from abc import abstractmethod
Expand Down Expand Up @@ -73,7 +74,6 @@ def observation_space(self) -> BoxSpace:
return observation_space

def reset(self, seed: int | None = None) -> tuple[torch.Tensor, ActionSpace]:

# reset (x, y) for agent position
self._agent_position = (
self._width / 2,
Expand Down
1 change: 0 additions & 1 deletion pearl/utils/replay_buffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def make_replay_buffer_class_for_specific_transition_types(
# We define a local class using the given transition types,
# and that will be returned as the result.
class ReplayBufferForGivenTransitionTypes(TensorBasedReplayBuffer):

# This statement is one reason why making this a generic class does not work;
# if this is a generic class on TransitionType, then this function call passes
# the TypeVar, rather than the value of the TypeVar, as an argument,
Expand Down
19 changes: 10 additions & 9 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
(make sure conda and related packages have been installed with
./utils/scripts/meta_only/setup_conda_pearl_on_devserver.sh)
"""

import os
import warnings
from typing import List
Expand Down Expand Up @@ -149,18 +150,18 @@ def evaluate_single(
if method["action_representation_module"].__name__ in [
"OneHotActionTensorRepresentationModule",
]:
method["action_representation_module_args"][
"max_number_actions"
] = env.action_space.n
method["action_representation_module_args"]["max_number_actions"] = (
env.action_space.n
)
if method["action_representation_module"].__name__ in [
"IdentityActionRepresentationModule"
]:
method["action_representation_module_args"][
"max_number_actions"
] = env.action_space.n
method["action_representation_module_args"][
"representation_dim"
] = env.action_space.action_dim
method["action_representation_module_args"]["max_number_actions"] = (
env.action_space.n
)
method["action_representation_module_args"]["representation_dim"] = (
env.action_space.action_dim
)
policy_learner_args["action_representation_module"] = method[
"action_representation_module"
](**method["action_representation_module_args"])
Expand Down
Loading

0 comments on commit 5cb0d76

Please sign in to comment.