Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu replay buffer #187

Merged
merged 5 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
added cpu replay buffer code
  • Loading branch information
benblack769 committed Nov 12, 2020
commit 5c287a2935aae14ee96cbee3ffe736460f731e9f
50 changes: 48 additions & 2 deletions all/bodies/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,55 @@ def __init__(self, agent, size=4, lazy=False):
self._frames = []
self._size = size
self._lazy = lazy
self._to_cache = ToCache()

def process_state(self, state):
if not self._frames:
self._frames = [state.observation] * self._size
else:
self._frames = self._frames[1:] + [state.observation]
if self._lazy:
return LazyState.from_state(state, self._frames)
return LazyState.from_state(state, self._frames, self._to_cache)
if isinstance(state, StateArray):
return state.update('observation', torch.cat(self._frames, dim=1))
return state.update('observation', torch.cat(self._frames, dim=0))


class ToCache:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be good to add a docstring explaining the purpose of this? It probably would not be immediately obvious to somebody reading the code for the first time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a docstring

def __init__(self, from_device=None, to_device=None, max_size=16):
self.from_device = from_device
self.to_device = to_device
self.max_size = max_size
self.cache_data = []

def convert(self, value, device):
if self.from_device is None:
self.from_device = value.device
if self.to_device is None:
self.to_device = device
if self.from_device != value.device or self.to_device != device:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the usage a little confusing... device only has one valid value if to_device is set, but can be set to anything if to_device is none. Is there any reason not to choose one usage or the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I greatly simplified this logic.

raise ValueError("bad devices to convert lazystate, must be internally consistent")

cached = None
for el in self.cache_data:
if el[0] is value:
cached = el[1]
break
if cached is not None:
new_v = cached
else:
new_v = value.to(device)
self.cache_data.append((value, new_v))
if len(self.cache_data) > self.max_size:
self.cache_data.pop(0)
return new_v


class LazyState(State):
@classmethod
def from_state(cls, state, frames):
def from_state(cls, state, frames, to_cache):
state = LazyState(state, device=state.device)
state.to_cache = to_cache
state['observation'] = frames
return state

Expand All @@ -36,3 +68,17 @@ def __getitem__(self, key):
return v
return torch.cat(dict.__getitem__(self, key), dim=0)
return super().__getitem__(key)

def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if key == 'observation':
x[key] = [self.to_cache.convert(v, device) for v in value]
# x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device)
elif torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
return LazyState(x, device=device)
25 changes: 19 additions & 6 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class State(dict):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down Expand Up @@ -68,11 +67,15 @@ def array(cls, list_of_states):
for key in list_of_states[0].keys():
v = list_of_states[0][key]
try:
if torch.is_tensor(v):
if isinstance(v, list) and len(v) > 0 and torch.is_tensor(v[0]):
x[key] = torch.stack([torch.stack(state[key]) for state in list_of_states])
elif torch.is_tensor(v):
x[key] = torch.stack([state[key] for state in list_of_states])
else:
x[key] = torch.tensor([state[key] for state in list_of_states], device=device)
except BaseException:
except ValueError:
pass
except KeyError:
pass
return StateArray(x, shape, device=device)

Expand Down Expand Up @@ -188,6 +191,17 @@ def from_gym(cls, state, device='cpu', dtype=np.float32):
x[key] = info[key]
return State(x, device=device)

def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
return type(self)(x, device=device, shape=self._shape)

@property
def observation(self):
"""A tensor containing the current observation."""
Expand Down Expand Up @@ -246,7 +260,6 @@ class StateArray(State):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, shape, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down Expand Up @@ -291,7 +304,7 @@ def as_output(self, tensor):
return tensor.view((*self.shape, *tensor.shape[1:]))

def apply_mask(self, tensor):
return tensor * self.mask.unsqueeze(-1)
return tensor * self.mask.unsqueeze(-1) # pylint: disable=no-member
Copy link
Owner

@cpnota cpnota Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need the pylint disable because we changed the linter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


def flatten(self):
"""
Expand Down Expand Up @@ -350,7 +363,7 @@ def __getitem__(self, key):
for (k, v) in self.items():
try:
d[k] = v[key]
except BaseException:
except KeyError:
pass
return self.__class__(d, shape, device=self.device)
try:
Expand Down
14 changes: 8 additions & 6 deletions all/memory/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ def update_priorities(self, indexes, td_errors):
# Adapted from:
# /~https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
class ExperienceReplayBuffer(ReplayBuffer):
def __init__(self, size, device=torch.device('cpu')):
def __init__(self, size, device='cpu', store_device='cpu'):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps store_device could default to None, in which case store_device would just be set to device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

self.buffer = []
self.capacity = int(size)
self.pos = 0
self.device = device
self.device = torch.device(device)
self.store_device = torch.device(store_device)

def store(self, state, action, next_state):
if state is not None and not state.done:
state = state.to(self.store_device)
next_state = next_state.to(self.store_device)
self._add((state, action, next_state))

def sample(self, batch_size):
Expand All @@ -49,12 +52,12 @@ def _add(self, sample):
self.pos = (self.pos + 1) % self.capacity

def _reshape(self, minibatch, weights):
states = State.array([sample[0] for sample in minibatch])
states = State.array([sample[0] for sample in minibatch]).to(self.device)
if torch.is_tensor(minibatch[0][1]):
actions = torch.stack([sample[1] for sample in minibatch])
actions = torch.stack([sample[1] for sample in minibatch]).to(self.device)
else:
actions = torch.tensor([sample[1] for sample in minibatch], device=self.device)
next_states = State.array([sample[2] for sample in minibatch])
next_states = State.array([sample[2] for sample in minibatch]).to(self.device)
return (states, actions, next_states.reward, next_states, weights)

def __len__(self):
Expand Down Expand Up @@ -145,7 +148,6 @@ def _sample_proportional(self, batch_size):

class NStepReplayBuffer(ReplayBuffer):
'''Converts any ReplayBuffer into an NStepReplayBuffer'''

def __init__(
self,
steps,
Expand Down