Skip to content

Commit

Permalink
Feature/cpu replay buffer (#187)
Browse files Browse the repository at this point in the history
* added cpu replay buffer code

* fixed lazystate

* added tests and docstrings, changed cpu buffer params

* fix whitespace

Co-authored-by: Chris Nota <cpnota@gmail.com>
  • Loading branch information
benblack769 and cpnota authored Dec 10, 2020
1 parent 8f65a70 commit d2dc3ab
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 18 deletions.
58 changes: 55 additions & 3 deletions all/bodies/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,50 @@ def __init__(self, agent, size=4, lazy=False):
self._frames = []
self._size = size
self._lazy = lazy
self._to_cache = TensorDeviceCache()

def process_state(self, state):
if not self._frames:
self._frames = [state.observation] * self._size
else:
self._frames = self._frames[1:] + [state.observation]
if self._lazy:
return LazyState.from_state(state, self._frames)
return LazyState.from_state(state, self._frames, self._to_cache)
if isinstance(state, StateArray):
return state.update('observation', torch.cat(self._frames, dim=1))
return state.update('observation', torch.cat(self._frames, dim=0))


class TensorDeviceCache:
'''
To efficiently implement device trasfer of lazy states, this class
caches the transfered tensor so that it is not copied multiple times.
'''
def __init__(self, max_size=16):
self.max_size = max_size
self.cache_data = []

def convert(self, value, device):
cached = None
for el in self.cache_data:
if el[0] is value:
cached = el[1]
break
if cached is not None and cached.device == torch.device(device):
new_v = cached
else:
new_v = value.to(device)
self.cache_data.append((value, new_v))
if len(self.cache_data) > self.max_size:
self.cache_data.pop(0)
return new_v


class LazyState(State):
@classmethod
def from_state(cls, state, frames):
state = LazyState(state, device=state.device)
def from_state(cls, state, frames, to_cache):
state = LazyState(state, device=frames[0].device)
state.to_cache = to_cache
state['observation'] = frames
return state

Expand All @@ -36,3 +63,28 @@ def __getitem__(self, key):
return v
return torch.cat(dict.__getitem__(self, key), dim=0)
return super().__getitem__(key)

def update(self, key, value):
x = {}
for k in self.keys():
if not k == key:
x[k] = super().__getitem__(k)
x[key] = value
state = LazyState(x, device=self.device)
state.to_cache = self.to_cache
return state

def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if key == 'observation':
x[key] = [self.to_cache.convert(v, device) for v in value]
# x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device)
elif torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
state = LazyState.from_state(x, x['observation'], self.to_cache)
return state
23 changes: 18 additions & 5 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class State(dict):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down Expand Up @@ -68,11 +67,15 @@ def array(cls, list_of_states):
for key in list_of_states[0].keys():
v = list_of_states[0][key]
try:
if torch.is_tensor(v):
if isinstance(v, list) and len(v) > 0 and torch.is_tensor(v[0]):
x[key] = torch.stack([torch.stack(state[key]) for state in list_of_states])
elif torch.is_tensor(v):
x[key] = torch.stack([state[key] for state in list_of_states])
else:
x[key] = torch.tensor([state[key] for state in list_of_states], device=device)
except BaseException:
except ValueError:
pass
except KeyError:
pass
return StateArray(x, shape, device=device)

Expand Down Expand Up @@ -188,6 +191,17 @@ def from_gym(cls, state, device='cpu', dtype=np.float32):
x[key] = info[key]
return State(x, device=device)

def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
return type(self)(x, device=device, shape=self._shape)

@property
def observation(self):
"""A tensor containing the current observation."""
Expand Down Expand Up @@ -246,7 +260,6 @@ class StateArray(State):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, shape, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down Expand Up @@ -350,7 +363,7 @@ def __getitem__(self, key):
for (k, v) in self.items():
try:
d[k] = v[key]
except BaseException:
except KeyError:
pass
return self.__class__(d, shape, device=self.device)
try:
Expand Down
7 changes: 7 additions & 0 deletions all/core/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def test_apply_done(self):
self.assertEqual(output.shape, (5, 3))
self.assertEqual(output.sum().item(), 0)

def test_to_device(self):
observation = torch.randn(3, 4)
state = State(observation, device=torch.device('cpu'))
state_cpu = state.to("cpu")
self.assertTrue(torch.equal(state['observation'], state_cpu['observation']))
self.assertFalse(state is state_cpu)


class StateArrayTest(unittest.TestCase):
def test_constructor_defaults(self):
Expand Down
21 changes: 13 additions & 8 deletions all/memory/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def update_priorities(self, indexes, td_errors):
# Adapted from:
# /~https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
class ExperienceReplayBuffer(ReplayBuffer):
def __init__(self, size, device=torch.device('cpu')):
def __init__(self, size, device='cpu', store_device=None):
self.buffer = []
self.capacity = int(size)
self.pos = 0
self.device = device
self.device = torch.device(device)
if store_device is None:
store_device = self.device
self.store_device = torch.device(store_device)

def store(self, state, action, next_state):
if state is not None and not state.done:
state = state.to(self.store_device)
next_state = next_state.to(self.store_device)
self._add((state, action, next_state))

def sample(self, batch_size):
Expand All @@ -49,12 +54,12 @@ def _add(self, sample):
self.pos = (self.pos + 1) % self.capacity

def _reshape(self, minibatch, weights):
states = State.array([sample[0] for sample in minibatch])
states = State.array([sample[0] for sample in minibatch]).to(self.device)
if torch.is_tensor(minibatch[0][1]):
actions = torch.stack([sample[1] for sample in minibatch])
actions = torch.stack([sample[1] for sample in minibatch]).to(self.device)
else:
actions = torch.tensor([sample[1] for sample in minibatch], device=self.device)
next_states = State.array([sample[2] for sample in minibatch])
next_states = State.array([sample[2] for sample in minibatch]).to(self.device)
return (states, actions, next_states.reward, next_states, weights)

def __len__(self):
Expand All @@ -71,9 +76,10 @@ def __init__(
alpha=0.6,
beta=0.4,
epsilon=1e-5,
device=torch.device('cpu')
device=torch.device('cpu'),
store_device=None
):
super().__init__(buffer_size, device=device)
super().__init__(buffer_size, device=device, store_device=store_device)

assert alpha >= 0
self._alpha = alpha
Expand Down Expand Up @@ -145,7 +151,6 @@ def _sample_proportional(self, batch_size):

class NStepReplayBuffer(ReplayBuffer):
'''Converts any ReplayBuffer into an NStepReplayBuffer'''

def __init__(
self,
steps,
Expand Down
17 changes: 15 additions & 2 deletions all/memory/replay_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@


class TestExperienceReplayBuffer(unittest.TestCase):
def setUp(self):
def test_run(self):
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
self.replay_buffer = ExperienceReplayBuffer(5)

def test_run(self):
states = torch.arange(0, 20)
actions = torch.arange(0, 20).view((-1, 1))
rewards = torch.arange(0, 20)
Expand Down Expand Up @@ -51,6 +50,20 @@ def test_run(self):
)
np.testing.assert_array_equal(expected_weights, np.vstack(actual_weights))

def test_store_device(self):
if torch.cuda.is_available():
self.replay_buffer = ExperienceReplayBuffer(5, device='cuda', store_device='cpu')

states = torch.arange(0, 20).to('cuda')
actions = torch.arange(0, 20).view((-1, 1)).to('cuda')
rewards = torch.arange(0, 20).to('cuda')
state = State(states[0])
next_state = State(states[1], reward=rewards[1])
self.replay_buffer.store(state, actions[0], next_state)
sample = self.replay_buffer.sample(3)
self.assertEqual(sample[0].device, torch.device('cuda'))
self.assertEqual(self.replay_buffer.buffer[0][0].device, torch.device('cpu'))


class TestPrioritizedReplayBuffer(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit d2dc3ab

Please sign in to comment.