-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathPPO.py
120 lines (92 loc) · 3.91 KB
/
PPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
import sys
import gym
sys.path.append("./")
from base_net.model import *
from torch import nn, optim
import torch.nn.functional as F
from torch.distributions import Categorical
class PPO(nn.Module):
def __init__(self, args):
super(PPO, self).__init__()
self.input_size, self.output_size, self.device, self.actor_lr, self.critic_lr = args
self.actor = Policy_net(args = (self.input_size, self.output_size))
self.critic = Q_net(args = (self.input_size, 1))
self.buffer = ReplayBuffer(args = (10000))
self.optimizer_actor = optim.Adam(self.actor.parameters(), lr = self.actor_lr)
self.optimizer_critic = optim.Adam(self.critic.parameters(), lr = self.critic_lr)
def get_policy_op(self, inputs):
policy_op = self.actor(inputs)
softmax_op = F.softmax(policy_op, -1)
return softmax_op
def select_action(self, inputs):
action_prob = self.get_policy_op(inputs)
action = Categorical(action_prob)
action = action.sample().item()
return action, action_prob.detach().cpu().numpy()[action]
def save_trans(self, transition):
self.buffer.save_trans(transition)
def to_tensor(self, items):
s, a, r, s_next, a_prob, done = items
s = torch.FloatTensor(s).to(self.device)
a = torch.LongTensor(a).to(self.device)
r = torch.FloatTensor(r).to(self.device)
s_next = torch.FloatTensor(s_next).to(self.device)
a_prob = torch.FloatTensor(a_prob).to(self.device)
done = torch.FloatTensor(done).to(self.device)
return s, a.unsqueeze(-1), r.unsqueeze(-1), s_next, a_prob.unsqueeze(-1), done.unsqueeze(-1)
def train(self, gamma = 0.98, batch_size = 32, k_iters = 3, epsilon_clip = 0.1, lmbda = 0.95):
s, a, r, s_next, a_prob, done = self.to_tensor(self.buffer.sample_all_data())
for i in range(k_iters):
td_target = r + gamma * self.critic(s_next) * (1 - done)
td_error = td_target - self.critic(s)
td_error = td_error.detach().cpu().numpy()
advantage_ls = []
advantage = 0.
for error in td_error[::-1]:
advantage = gamma * lmbda * advantage + error[0]
advantage_ls.append([advantage])
advantage_ls.reverse()
advantage = torch.FloatTensor(advantage_ls).to(self.device)
policy_op = self.get_policy_op(s)
policy_op = policy_op.gather(-1, a)
ratio = torch.exp(torch.log(policy_op) - torch.log(a_prob))
sur1 = ratio * advantage.detach()
sur2 = torch.clamp(ratio, 1 - epsilon_clip, 1 + epsilon_clip) * advantage.detach()
loss_actor = (- torch.min(sur1, sur2)).mean()
self.optimizer_actor.zero_grad()
loss_actor.backward()
self.optimizer_actor.step()
loss_critic = (F.smooth_l1_loss(self.critic(s), td_target.detach())).mean()
self.optimizer_critic.zero_grad()
loss_critic.backward()
self.optimizer_critic.step()
self.buffer.clear()
'''
PPO test
'''
if __name__ == "__main__":
# hyper param
lr = 1e-3
render = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
env = gym.make("CartPole-v1")
model = PPO(args = (4, 2, device, lr, lr)).to(device)
score = 0.
for epo_i in range(10000):
obs = env.reset()
score = 0.
for step in range(200):
if render:
env.render()
a, a_prob = model.select_action(torch.FloatTensor(obs).to(device))
obs_next, r, done, info = env.step(a)
model.save_trans((obs, a, r, obs_next, a_prob, done))
obs = obs_next
score += r
if done:
break
model.train()
print("Epoch: {} score: {}".format(epo_i, score))
env.close()