-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathagent.py
110 lines (98 loc) · 4.38 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch.nn.functional import mse_loss
from all.agents import Agent
class ModelBasedDQN(Agent):
"""
Model Based DQN.
This is a simplified model predictive control algorithm based on DQN.
The purpose of this agent is to demonstrate how the autonomous-learning-library
can be used to build new types of agents from scratch, while reusuing many
of the useful features of the library. This agent selects actions by predicting
future states conditioned on each possible action choosing the action
with the highest expected return based on this prediction. It trains on a replay
buffer in a style similar to DQN.
Args:
f (FeatureNetwork): Shared feature layers.
v (VNetwork): State-value function head.
r (QNetwork): Reward prediction head.
g (Approximation): Transition model.
replay_buffer (ReplayBuffer): Experience replay buffer.
discount_factor (float): Discount factor for future rewards.
minibatch_size (int): The number of experiences to sample in each training update.
replay_start_size (int): Number of experiences in replay buffer when training begins.
"""
def __init__(self,
f, # shared feature representation
v, # state-value head
r, # reward prediction head
g, # transition model head
replay_buffer,
discount_factor=0.99,
minibatch_size=32,
replay_start_size=5000,
):
# objects
self.f = f
self.v = v
self.r = r
self.g = g
self.replay_buffer = replay_buffer
# hyperparameters
self.discount_factor = discount_factor
self.minibatch_size = minibatch_size
self.replay_start_size = replay_start_size
# private
self._state = None
self._action = None
def act(self, state):
self.replay_buffer.store(self._state, self._action, state)
self._train()
self._state = state
self._action = self._choose_action(state)
return self._action
def _choose_action(self, state):
"""
Choose the best action in the current state using our model predictions.
Note that every call below uses .no_grad(), which puts torch in no_grad mode.
"""
features = self.f.no_grad(state)
predicted_rewards = self.r.no_grad(features)
predicted_next_states = self.g.no_grad(features)
predicted_next_values = self.v.no_grad(self.f.no_grad(predicted_next_states))
predicted_returns = predicted_rewards + self.discount_factor * predicted_next_values
return torch.argmax(predicted_returns, dim=-1)
def _train(self):
"""Update the agent."""
if len(self.replay_buffer) > self.replay_start_size:
# sample transitions from buffer
(states, actions, rewards, next_states, _) = self.replay_buffer.sample(self.minibatch_size)
# forward pass
features = self.f(states)
predicted_values = self.v(features)
predicted_rewards = self.r(features, actions)
predicted_next_states = self.g(features, actions)
# compute target value
target_values = rewards + self.discount_factor * self.v.target(self.f.target(next_states))
# compute losses
value_loss = mse_loss(predicted_values, target_values)
reward_loss = mse_loss(predicted_rewards, rewards)
generator_loss = mse_loss(predicted_next_states.observation, next_states.observation.float())
# backward passes
self.v.reinforce(value_loss)
self.r.reinforce(reward_loss)
self.g.reinforce(generator_loss)
self.f.reinforce()
class ModelBasedTestAgent(Agent):
def __init__(self, f, v, r, g, discount_factor=0.99):
self.f = f
self.v = v
self.r = r
self.g = g
self.discount_factor = discount_factor
def act(self, state):
features = self.f.eval(state)
predicted_rewards = self.r.eval(features)
predicted_next_states = self.g.eval(features)
predicted_next_values = self.v.eval(self.f.eval(predicted_next_states))
predicted_returns = predicted_rewards + self.discount_factor * predicted_next_values
return torch.argmax(predicted_returns, dim=-1)