Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu replay buffer #187

Merged
merged 5 commits into from
Dec 10, 2020
Merged

Feature/cpu replay buffer #187

merged 5 commits into from
Dec 10, 2020

Conversation

benblack769
Copy link
Collaborator

This gives the replay buffers a store_device parameter which allows it to store the replay buffer on a different device than it expects input and output to be. The main use case is if you want the replay buffer to be stored in CPU memory while the network and training lives in GPU memory. I measured the performance hit of doing this to be about 10-15% on DQN.

Right now there is no parameter for store_device on the presets, though perhaps there should be.

Copy link
Owner

@cpnota cpnota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good overall, this should be very helpful on many machines! The design seems correct. I left a few comments in line, and also have a few high-level comments:

  1. Could you extend the unit tests to test some of this functionality? In particular, state_test.py and replay_buffer_test.py. FrameStack also really should have unit tests, but I won't hold you to that as they were already missing.
  2. It might be useful to add docstrings in certain places in order to explain the usage.
  3. As you commented, for this to be useful, the presets need to be modified. My thought is that a hyperparameter, something like cpu_replay_buffer=False, could be added. One of the recent PRs added the ability to set hyperparameters from the command line.

self.from_device = value.device
if self.to_device is None:
self.to_device = device
if self.from_device != value.device or self.to_device != device:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the usage a little confusing... device only has one valid value if to_device is set, but can be set to anything if to_device is none. Is there any reason not to choose one usage or the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I greatly simplified this logic.

@@ -291,7 +304,7 @@ def as_output(self, tensor):
return tensor.view((*self.shape, *tensor.shape[1:]))

def apply_mask(self, tensor):
return tensor * self.mask.unsqueeze(-1)
return tensor * self.mask.unsqueeze(-1) # pylint: disable=no-member
Copy link
Owner

@cpnota cpnota Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need the pylint disable because we changed the linter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -23,14 +23,17 @@ def update_priorities(self, indexes, td_errors):
# Adapted from:
# /~https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
class ExperienceReplayBuffer(ReplayBuffer):
def __init__(self, size, device=torch.device('cpu')):
def __init__(self, size, device='cpu', store_device='cpu'):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps store_device could default to None, in which case store_device would just be set to device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

if isinstance(state, StateArray):
return state.update('observation', torch.cat(self._frames, dim=1))
return state.update('observation', torch.cat(self._frames, dim=0))


class ToCache:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be good to add a docstring explaining the purpose of this? It probably would not be immediately obvious to somebody reading the code for the first time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a docstring

@benblack769
Copy link
Collaborator Author

  1. Added unit tests to state_test.py and replay_buffer_test.py
  2. Added docstrings to the cache (and changed the name)
  3. Haven't changed the presets yet.

@cpnota cpnota merged commit d2dc3ab into develop Dec 10, 2020
@benblack769
Copy link
Collaborator Author

Thanks

@cpnota cpnota deleted the feature/cpu_replay_buffer branch April 12, 2022 21:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants