forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 724
/
Copy pathtest_replay_buffer.py
72 lines (58 loc) · 2.44 KB
/
test_replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from stable_baselines.common.buffers import ReplayBuffer, PrioritizedReplayBuffer
def test_extend_uniform():
nvals = 16
states = [np.random.rand(2, 2) for _ in range(nvals)]
actions = [np.random.rand(2) for _ in range(nvals)]
rewards = [np.random.rand() for _ in range(nvals)]
newstate = [np.random.rand(2, 2) for _ in range(nvals)]
done = [np.random.randint(0, 2) for _ in range(nvals)]
size = 32
baseline = ReplayBuffer(size)
ext = ReplayBuffer(size)
for data in zip(states, actions, rewards, newstate, done):
baseline.add(*data)
states, actions, rewards, newstates, done = map(
np.array, [states, actions, rewards, newstate, done])
ext.extend(states, actions, rewards, newstates, done)
assert len(baseline) == len(ext)
# Check buffers have same values
for i in range(nvals):
for j in range(5):
condition = (baseline.storage[i][j] == ext.storage[i][j])
if isinstance(condition, np.ndarray):
# for obs, obs_t1
assert np.all(condition)
else:
# for done, reward action
assert condition
def test_extend_prioritized():
nvals = 16
states = [np.random.rand(2, 2) for _ in range(nvals)]
actions = [np.random.rand(2) for _ in range(nvals)]
rewards = [np.random.rand() for _ in range(nvals)]
newstate = [np.random.rand(2, 2) for _ in range(nvals)]
done = [np.random.randint(0, 2) for _ in range(nvals)]
size = 32
alpha = 0.99
baseline = PrioritizedReplayBuffer(size, alpha)
ext = PrioritizedReplayBuffer(size, alpha)
for data in zip(states, actions, rewards, newstate, done):
baseline.add(*data)
states, actions, rewards, newstates, done = map(
np.array, [states, actions, rewards, newstate, done])
ext.extend(states, actions, rewards, newstates, done)
assert len(baseline) == len(ext)
# Check buffers have same values
for i in range(nvals):
for j in range(5):
condition = (baseline.storage[i][j] == ext.storage[i][j])
if isinstance(condition, np.ndarray):
# for obs, obs_t1
assert np.all(condition)
else:
# for done, reward action
assert condition
# assert priorities
assert (baseline._it_min._value == ext._it_min._value).all()
assert (baseline._it_sum._value == ext._it_sum._value).all()