-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/cpu replay buffer #187
Changes from 1 commit
5c287a2
2d7ff07
7f6aa09
033491e
52d22f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,23 +9,55 @@ def __init__(self, agent, size=4, lazy=False): | |
self._frames = [] | ||
self._size = size | ||
self._lazy = lazy | ||
self._to_cache = ToCache() | ||
|
||
def process_state(self, state): | ||
if not self._frames: | ||
self._frames = [state.observation] * self._size | ||
else: | ||
self._frames = self._frames[1:] + [state.observation] | ||
if self._lazy: | ||
return LazyState.from_state(state, self._frames) | ||
return LazyState.from_state(state, self._frames, self._to_cache) | ||
if isinstance(state, StateArray): | ||
return state.update('observation', torch.cat(self._frames, dim=1)) | ||
return state.update('observation', torch.cat(self._frames, dim=0)) | ||
|
||
|
||
class ToCache: | ||
def __init__(self, from_device=None, to_device=None, max_size=16): | ||
self.from_device = from_device | ||
self.to_device = to_device | ||
self.max_size = max_size | ||
self.cache_data = [] | ||
|
||
def convert(self, value, device): | ||
if self.from_device is None: | ||
self.from_device = value.device | ||
if self.to_device is None: | ||
self.to_device = device | ||
if self.from_device != value.device or self.to_device != device: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find the usage a little confusing... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I greatly simplified this logic. |
||
raise ValueError("bad devices to convert lazystate, must be internally consistent") | ||
|
||
cached = None | ||
for el in self.cache_data: | ||
if el[0] is value: | ||
cached = el[1] | ||
break | ||
if cached is not None: | ||
new_v = cached | ||
else: | ||
new_v = value.to(device) | ||
self.cache_data.append((value, new_v)) | ||
if len(self.cache_data) > self.max_size: | ||
self.cache_data.pop(0) | ||
return new_v | ||
|
||
|
||
class LazyState(State): | ||
@classmethod | ||
def from_state(cls, state, frames): | ||
def from_state(cls, state, frames, to_cache): | ||
state = LazyState(state, device=state.device) | ||
state.to_cache = to_cache | ||
state['observation'] = frames | ||
return state | ||
|
||
|
@@ -36,3 +68,17 @@ def __getitem__(self, key): | |
return v | ||
return torch.cat(dict.__getitem__(self, key), dim=0) | ||
return super().__getitem__(key) | ||
|
||
def to(self, device): | ||
if device == self.device: | ||
return self | ||
x = {} | ||
for key, value in self.items(): | ||
if key == 'observation': | ||
x[key] = [self.to_cache.convert(v, device) for v in value] | ||
# x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device) | ||
elif torch.is_tensor(value): | ||
x[key] = value.to(device) | ||
else: | ||
x[key] = value | ||
return LazyState(x, device=device) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ class State(dict): | |
device (string): | ||
The torch device on which component tensors are stored. | ||
""" | ||
|
||
def __init__(self, x, device='cpu', **kwargs): | ||
if not isinstance(x, dict): | ||
x = {'observation': x} | ||
|
@@ -68,11 +67,15 @@ def array(cls, list_of_states): | |
for key in list_of_states[0].keys(): | ||
v = list_of_states[0][key] | ||
try: | ||
if torch.is_tensor(v): | ||
if isinstance(v, list) and len(v) > 0 and torch.is_tensor(v[0]): | ||
x[key] = torch.stack([torch.stack(state[key]) for state in list_of_states]) | ||
elif torch.is_tensor(v): | ||
x[key] = torch.stack([state[key] for state in list_of_states]) | ||
else: | ||
x[key] = torch.tensor([state[key] for state in list_of_states], device=device) | ||
except BaseException: | ||
except ValueError: | ||
pass | ||
except KeyError: | ||
pass | ||
return StateArray(x, shape, device=device) | ||
|
||
|
@@ -188,6 +191,17 @@ def from_gym(cls, state, device='cpu', dtype=np.float32): | |
x[key] = info[key] | ||
return State(x, device=device) | ||
|
||
def to(self, device): | ||
if device == self.device: | ||
return self | ||
x = {} | ||
for key, value in self.items(): | ||
if torch.is_tensor(value): | ||
x[key] = value.to(device) | ||
else: | ||
x[key] = value | ||
return type(self)(x, device=device, shape=self._shape) | ||
|
||
@property | ||
def observation(self): | ||
"""A tensor containing the current observation.""" | ||
|
@@ -246,7 +260,6 @@ class StateArray(State): | |
device (string): | ||
The torch device on which component tensors are stored. | ||
""" | ||
|
||
def __init__(self, x, shape, device='cpu', **kwargs): | ||
if not isinstance(x, dict): | ||
x = {'observation': x} | ||
|
@@ -291,7 +304,7 @@ def as_output(self, tensor): | |
return tensor.view((*self.shape, *tensor.shape[1:])) | ||
|
||
def apply_mask(self, tensor): | ||
return tensor * self.mask.unsqueeze(-1) | ||
return tensor * self.mask.unsqueeze(-1) # pylint: disable=no-member | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need the pylint disable because we changed the linter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
|
||
def flatten(self): | ||
""" | ||
|
@@ -350,7 +363,7 @@ def __getitem__(self, key): | |
for (k, v) in self.items(): | ||
try: | ||
d[k] = v[key] | ||
except BaseException: | ||
except KeyError: | ||
pass | ||
return self.__class__(d, shape, device=self.device) | ||
try: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,14 +23,17 @@ def update_priorities(self, indexes, td_errors): | |
# Adapted from: | ||
# /~https://github.com/Shmuma/ptan/blob/master/ptan/experience.py | ||
class ExperienceReplayBuffer(ReplayBuffer): | ||
def __init__(self, size, device=torch.device('cpu')): | ||
def __init__(self, size, device='cpu', store_device='cpu'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed |
||
self.buffer = [] | ||
self.capacity = int(size) | ||
self.pos = 0 | ||
self.device = device | ||
self.device = torch.device(device) | ||
self.store_device = torch.device(store_device) | ||
|
||
def store(self, state, action, next_state): | ||
if state is not None and not state.done: | ||
state = state.to(self.store_device) | ||
next_state = next_state.to(self.store_device) | ||
self._add((state, action, next_state)) | ||
|
||
def sample(self, batch_size): | ||
|
@@ -49,12 +52,12 @@ def _add(self, sample): | |
self.pos = (self.pos + 1) % self.capacity | ||
|
||
def _reshape(self, minibatch, weights): | ||
states = State.array([sample[0] for sample in minibatch]) | ||
states = State.array([sample[0] for sample in minibatch]).to(self.device) | ||
if torch.is_tensor(minibatch[0][1]): | ||
actions = torch.stack([sample[1] for sample in minibatch]) | ||
actions = torch.stack([sample[1] for sample in minibatch]).to(self.device) | ||
else: | ||
actions = torch.tensor([sample[1] for sample in minibatch], device=self.device) | ||
next_states = State.array([sample[2] for sample in minibatch]) | ||
next_states = State.array([sample[2] for sample in minibatch]).to(self.device) | ||
return (states, actions, next_states.reward, next_states, weights) | ||
|
||
def __len__(self): | ||
|
@@ -145,7 +148,6 @@ def _sample_proportional(self, batch_size): | |
|
||
class NStepReplayBuffer(ReplayBuffer): | ||
'''Converts any ReplayBuffer into an NStepReplayBuffer''' | ||
|
||
def __init__( | ||
self, | ||
steps, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would be good to add a docstring explaining the purpose of this? It probably would not be immediately obvious to somebody reading the code for the first time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a docstring