-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainDDQN.py
50 lines (42 loc) · 1.14 KB
/
TrainDDQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from absl import app
from sls import Env, QLRunner
from sls.agents import *
_CONFIG = dict(
episodes=1000,
screen_size=16,
minimap_size=16,
visualize=False,
train=True,
agent=DDQNAgent,
load_path='./models/...',
num_scores_average=50,
discount_factor=0.8, # best 0.8
sarsa=False,
exploration='epsilon_greedy',
file_format='.h5'
)
def main(unused_argv):
agent = _CONFIG['agent'](
train=_CONFIG['train'],
screen_size=_CONFIG['screen_size'],
discount_factor=_CONFIG['discount_factor'],
exploration=_CONFIG['exploration']
)
env = Env(
screen_size=_CONFIG['screen_size'],
minimap_size=_CONFIG['minimap_size'],
visualize=_CONFIG['visualize']
)
runner = QLRunner(
agent=agent,
env=env,
train=_CONFIG['train'],
load_path=_CONFIG['load_path'],
num_scores_average=_CONFIG['num_scores_average'],
sarsa=_CONFIG['sarsa'],
exploration=_CONFIG['exploration'],
file_format=_CONFIG['file_format']
)
runner.run(episodes=_CONFIG['episodes'])
if __name__ == "__main__":
app.run(main)