Skip to content

Commit

Permalink
Store reply buffers in CPU only
Browse files Browse the repository at this point in the history
Summary:
So far, every tensor added to a tensor-based replay buffer was moved to the active device, leading the entire replay buffer to be stored in it. For large replay buffers this will be too much for the GPU.

Now we store the replay buffer in CPU and move batches sampled from it to the active device.

Reviewed By: yiwan-rl

Differential Revision: D54625530

fbshipit-source-id: 9afe8e65c33077621a527b3a0cc061244854119f
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed May 8, 2024
1 parent ea3f757 commit 7fa88fc
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self.replay_buffer.is_action_continuous = (
self.policy_learner.is_action_continuous
)
self.replay_buffer.device = self.device
self.replay_buffer.device_for_batches = self.device

# check that all components of the agent are compatible with each other
pearl_agent_compatibility_check(
Expand Down
26 changes: 25 additions & 1 deletion pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
torch.cat(action_list)
)

# Transitions in the reply buffer memory are in the CPU
# (only sampled batches are moved to the used device, kept in replay_buffer.device)
# To use it in expressions involving the models,
# we must move them to the device being used first.
history_summary_batch = history_summary_batch.to(
replay_buffer.device_for_batches
)
action_representation_batch = action_representation_batch.to(
replay_buffer.device_for_batches
)

state_values = self._critic(history_summary_batch).detach()
action_probs = (
self._actor.get_action_prob(
Expand All @@ -193,16 +204,28 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
.detach()
.unsqueeze(-1)
)

# Transitions in the reply buffer memory are in the CPU
# (only sampled batches are moved to the used device,
# kept in replay_buffer.device_for_batches)
# To use it in expressions involving the critic,
# we must move them to the device being used first.
next_state = replay_buffer.memory[-1].next_state
assert next_state is not None
next_state_in_device = next_state.to(replay_buffer.device_for_batches)

# Obtain the value of the most recent state stored in the replay buffer.
# This value is used to compute the generalized advantage estimation (gae)
# and the truncated lambda return for all states in the replay buffer.
next_value = self._critic(
self._history_summarization_module(replay_buffer.memory[-1].next_state)
self._history_summarization_module(next_state_in_device)
).detach()[
0
] # shape (1,)
gae = torch.tensor([0.0]).to(state_values.device)
for i, transition in enumerate(reversed(replay_buffer.memory)):
original_transition_device = transition.device
transition.to(state_values.device)
td_error = (
transition.reward
+ self._discount_factor * next_value * (~transition.terminated)
Expand All @@ -222,3 +245,4 @@ def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
# action probabilities from the current policy
transition.action_probs = action_probs[i]
next_value = state_values[i]
transition.to(original_transition_device)
20 changes: 18 additions & 2 deletions pearl/policy_learners/sequential_decision_making/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,25 @@ def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]:
assert type(replay_buffer) is OnPolicyReplayBuffer
assert len(replay_buffer.memory) > 0
# compute return for all states in the buffer

# Transitions in the reply buffer memory are in the CPU
# (only sampled batches are moved to the used device,
# kept in replay_buffer.device_for_batches)
# To use it in expressions involving the critic,
# we must move them to the device being used first.
next_state = replay_buffer.memory[-1].next_state
terminated = replay_buffer.memory[-1].terminated
assert next_state is not None
assert terminated is not None
next_state_in_device = next_state.to(replay_buffer.device_for_batches)
terminated_in_device = terminated.to(replay_buffer.device_for_batches)

cum_reward = self._critic(
self._history_summarization_module(replay_buffer.memory[-1].next_state)
).detach() * (~replay_buffer.memory[-1].terminated)
self._history_summarization_module(next_state_in_device)
).detach() * (~terminated_in_device)

# move cum_reward to CPU to process CPU-stored transitions
cum_reward = cum_reward.cpu()
for transition in reversed(replay_buffer.memory):
cum_reward += transition.reward
assert isinstance(transition, OnPolicyTransition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def push(
self.memory.append(
Transition(
state=self._process_single_state(state),
action=action,
action=self._process_single_action(action),
reward=self._process_single_reward(reward),
).to(self.device)
)
)

def sample(self, batch_size: int) -> TransitionBatch:
Expand All @@ -68,4 +68,4 @@ def sample(self, batch_size: int) -> TransitionBatch:
state=torch.cat([x.state for x in samples]),
action=torch.stack([x.action for x in samples]),
reward=torch.cat([x.reward for x in samples]),
).to(self.device)
).to(self.device_for_batches)
30 changes: 20 additions & 10 deletions pearl/replay_buffers/examples/single_transition_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pearl.api.state import SubjectiveState

from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.utils.device import get_default_device


# Preferred to define inside class but that is not working. Pending discussion.
Expand All @@ -32,16 +33,23 @@
]


def to_default_device_if_tensor(obj: object) -> object:
if isinstance(obj, torch.Tensor):
return obj.to(get_default_device())
else:
return obj


class SingleTransitionReplayBuffer(ReplayBuffer):
def __init__(self) -> None:
self._transition: Optional[SingleTransition] = None

@property
def device(self) -> torch.device:
def device_for_batches(self) -> torch.device:
raise ValueError("SingleTransitionReplayBuffer does not have a device.")

@device.setter
def device(self, new_device: torch.device) -> None:
@device_for_batches.setter
def device_for_batches(self, new_device_for_batches: torch.device) -> None:
pass

def push(
Expand All @@ -56,16 +64,18 @@ def push(
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
self._transition = (
state,
action,
reward,
next_state,
# TODO: we use pyre-ignore here because tabular Q learning does not use tensors
# like other policy learners. It should be converted to do so.
self._transition = ( # pyre-ignore
to_default_device_if_tensor(state),
to_default_device_if_tensor(action),
to_default_device_if_tensor(reward),
to_default_device_if_tensor(next_state),
curr_available_actions,
next_available_actions,
terminated,
to_default_device_if_tensor(terminated),
max_number_actions,
cost,
to_default_device_if_tensor(cost),
)

def sample(self, batch_size: int) -> List[SingleTransition]:
Expand Down
21 changes: 18 additions & 3 deletions pearl/replay_buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,34 @@


class ReplayBuffer(ABC):
"""
A base class for all replay buffers.
Replay buffers store transitions collected from an agent's experience,
and batches of those transitions can be sampled to train the agent.
They are stored in the CPU since they may grow quite large,
but contain a property `device` which specifies where
batches are stored.
"""

def __init__(self) -> None:
super().__init__()
self._is_action_continuous: bool = False
self._has_cost_available: bool = False

@property
@abstractmethod
def device(self) -> torch.device:
def device_for_batches(self) -> torch.device:
"""
The device on which _batches_ are stored
(the replay buffer is always stored in the CPU).
"""
pass

@device.setter
@device_for_batches.setter
@abstractmethod
def device(self, new_device: torch.device) -> None:
def device_for_batches(self, new_device_for_batches: torch.device) -> None:
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ def push(
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
cost=self._process_single_cost(cost),
).to(self.device)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def push(
next_available_actions=self.cache.next_available_actions,
next_unavailable_actions_mask=self.cache.next_unavailable_actions_mask,
terminated=self.cache.terminated,
).to(self.device)
)
)
if not terminated:
# save current push into cache
Expand All @@ -91,7 +91,7 @@ def push(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
).to(self.device)
)
else:
# for terminal state, push directly
self.memory.append(
Expand All @@ -107,5 +107,5 @@ def push(
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
terminated=self._process_single_terminated(terminated),
).to(self.device)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def push(
next_available_actions=None,
next_unavailable_actions_mask=None,
terminated=self._process_single_terminated(terminated),
).to(self.device)
)
)

def _create_transition_batch(
Expand Down Expand Up @@ -146,4 +146,4 @@ def helper(
transition_batch, **on_policy_attrs
)

return transition_batch
return transition_batch.to(self.device_for_batches)
31 changes: 15 additions & 16 deletions pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,38 @@ def __init__(
self._has_next_action = has_next_action
self._has_next_available_actions = has_next_available_actions
self.has_cost_available = has_cost_available
self._device: torch.device = get_default_device()
self._device_for_batches: torch.device = get_default_device()

@property
def device(self) -> torch.device:
return self._device
def device_for_batches(self) -> torch.device:
return self._device_for_batches

@device.setter
def device(self, value: torch.device) -> None:
self._device = value
@device_for_batches.setter
def device_for_batches(self, new_device_for_batches: torch.device) -> None:
self._device_for_batches = new_device_for_batches

def _process_single_state(self, state: SubjectiveState) -> torch.Tensor:
if isinstance(state, torch.Tensor):
return state.clone().detach().to(self._device).unsqueeze(0)
return state.to(get_default_device()).clone().detach().unsqueeze(0)
else:
return torch.tensor(state, device=self._device).unsqueeze(0)
return torch.tensor(state).unsqueeze(0)

def _process_single_action(self, action: Action) -> torch.Tensor:
if isinstance(action, torch.Tensor):
return action.clone().detach().to(self._device).unsqueeze(0)
return action.to(get_default_device()).clone().detach().unsqueeze(0)
else:
return torch.tensor(action, device=self._device).unsqueeze(0)
return torch.tensor(action).unsqueeze(0)

def _process_single_reward(self, reward: Reward) -> torch.Tensor:
return torch.tensor([reward], device=self._device)
return torch.tensor([reward])

def _process_single_cost(self, cost: Optional[float]) -> Optional[torch.Tensor]:
if cost is None:
return None
return torch.tensor([cost], device=self._device)
return torch.tensor([cost])

def _process_single_terminated(self, terminated: bool) -> torch.Tensor:
return torch.tensor([terminated], device=self._device) # (1,)
return torch.tensor([terminated]) # (1,)

"""
This function is only used for discrete action space.
Expand Down Expand Up @@ -116,7 +116,6 @@ def _create_action_tensor_and_mask(

available_actions_tensor_with_padding = torch.zeros(
(1, max_number_actions, available_action_space.action_dim),
device=self._device,
dtype=torch.float32,
) # (1 x action_space_size x action_dim)
available_actions_tensor = available_action_space.actions_batch
Expand All @@ -125,7 +124,7 @@ def _create_action_tensor_and_mask(
)

unavailable_actions_mask = torch.zeros(
(1, max_number_actions), device=self._device
(1, max_number_actions)
) # (1 x action_space_size)
unavailable_actions_mask[0, available_action_space.n :] = 1
unavailable_actions_mask = unavailable_actions_mask.bool()
Expand Down Expand Up @@ -253,4 +252,4 @@ def _create_transition_batch(
next_unavailable_actions_mask=next_unavailable_actions_mask_batch,
terminated=terminated_batch,
cost=cost_batch,
).to(self.device)
).to(self.device_for_batches)
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def offline_learning(
set_seed(seed=seed)

# move replay buffer to device of the offline agent
data_buffer.device = offline_agent.device
data_buffer.device_for_batches = offline_agent.device

# training loop
for i in range(training_epochs):
Expand Down
Loading

0 comments on commit 7fa88fc

Please sign in to comment.