-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/cpu replay buffer #187
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good overall, this should be very helpful on many machines! The design seems correct. I left a few comments in line, and also have a few high-level comments:
- Could you extend the unit tests to test some of this functionality? In particular,
state_test.py
andreplay_buffer_test.py
.FrameStack
also really should have unit tests, but I won't hold you to that as they were already missing. - It might be useful to add docstrings in certain places in order to explain the usage.
- As you commented, for this to be useful, the presets need to be modified. My thought is that a hyperparameter, something like
cpu_replay_buffer=False
, could be added. One of the recent PRs added the ability to set hyperparameters from the command line.
all/bodies/vision.py
Outdated
self.from_device = value.device | ||
if self.to_device is None: | ||
self.to_device = device | ||
if self.from_device != value.device or self.to_device != device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the usage a little confusing... device
only has one valid value if to_device
is set, but can be set to anything if to_device
is none. Is there any reason not to choose one usage or the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I greatly simplified this logic.
all/core/state.py
Outdated
@@ -291,7 +304,7 @@ def as_output(self, tensor): | |||
return tensor.view((*self.shape, *tensor.shape[1:])) | |||
|
|||
def apply_mask(self, tensor): | |||
return tensor * self.mask.unsqueeze(-1) | |||
return tensor * self.mask.unsqueeze(-1) # pylint: disable=no-member |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need the pylint disable because we changed the linter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
all/memory/replay_buffer.py
Outdated
@@ -23,14 +23,17 @@ def update_priorities(self, indexes, td_errors): | |||
# Adapted from: | |||
# /~https://github.com/Shmuma/ptan/blob/master/ptan/experience.py | |||
class ExperienceReplayBuffer(ReplayBuffer): | |||
def __init__(self, size, device=torch.device('cpu')): | |||
def __init__(self, size, device='cpu', store_device='cpu'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps store_device
could default to None
, in which case store_device
would just be set to device
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
all/bodies/vision.py
Outdated
if isinstance(state, StateArray): | ||
return state.update('observation', torch.cat(self._frames, dim=1)) | ||
return state.update('observation', torch.cat(self._frames, dim=0)) | ||
|
||
|
||
class ToCache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would be good to add a docstring explaining the purpose of this? It probably would not be immediately obvious to somebody reading the code for the first time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a docstring
|
Thanks |
This gives the replay buffers a
store_device
parameter which allows it to store the replay buffer on a different device than it expects input and output to be. The main use case is if you want the replay buffer to be stored in CPU memory while the network and training lives in GPU memory. I measured the performance hit of doing this to be about 10-15% on DQN.Right now there is no parameter for
store_device
on the presets, though perhaps there should be.