Skip to content

Commit

Permalink
isort autoflake
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 26, 2022
1 parent 0e37a51 commit e794ac9
Show file tree
Hide file tree
Showing 11 changed files with 745 additions and 344 deletions.
181 changes: 90 additions & 91 deletions experiments/league.py

Large diffs are not rendered by default.

40 changes: 25 additions & 15 deletions experiments/ppo_gridnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import argparse
import os
import random
import time
import subprocess
import time
from distutils.util import strtobool

import numpy as np
Expand All @@ -13,12 +13,14 @@
import torch.nn as nn
import torch.optim as optim
from gym.spaces import MultiDiscrete
from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder
from stable_baselines3.common.vec_env import (VecEnvWrapper, VecMonitor,
VecVideoRecorder)
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv


def parse_args():
# fmt: off
Expand Down Expand Up @@ -84,7 +86,7 @@ def parse_args():
parser.add_argument('--anneal-lr', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles whether or not to use a clipped loss for the value function, as per the paper.')
parser.add_argument('--num-models', type=int, default=200,
help='the number of models saved')

Expand Down Expand Up @@ -153,7 +155,6 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
return layer



class Agent(nn.Module):
def __init__(self, envs, mapsize=16 * 16):
super(Agent, self).__init__()
Expand Down Expand Up @@ -190,7 +191,7 @@ def __init__(self, envs, mapsize=16 * 16):
nn.ReLU(),
layer_init(nn.Linear(128, 1), std=1),
)
self.register_buffer('mask_value', torch.tensor(-1e8))
self.register_buffer("mask_value", torch.tensor(-1e8))

def get_action_and_value(self, x, action=None, invalid_action_masks=None, envs=None, device=None):
hidden = self.encoder(x)
Expand Down Expand Up @@ -227,7 +228,6 @@ def get_value(self, x):
return self.critic(self.encoder(x))



if __name__ == "__main__":
args = parse_args()

Expand Down Expand Up @@ -307,7 +307,6 @@ def get_value(self, x):

## CRASH AND RESUME LOGIC:
starting_update = 1
from jpype.types import JArray, JInt

if args.prod_mode and wandb.run.resumed:
starting_update = run.summary.get("charts/update") + 1
Expand Down Expand Up @@ -410,7 +409,7 @@ def get_value(self, x):
b_values = values.reshape(-1)
b_invalid_action_masks = invalid_action_masks.reshape((-1,) + invalid_action_shape)

# Optimizaing the policy and value network
# Optimizing the policy and value network
inds = np.arange(
args.batch_size,
)
Expand Down Expand Up @@ -458,13 +457,24 @@ def get_value(self, x):

## CRASH AND RESUME LOGIC:
if args.prod_mode:
if (update-1) % args.save_frequency == 0:
if (update - 1) % args.save_frequency == 0:
if not os.path.exists(f"models/{experiment_name}"):
os.makedirs(f"models/{experiment_name}")
torch.save(agent.state_dict(), f"models/{experiment_name}/agent.pt")
torch.save(agent.state_dict(), f"models/{experiment_name}/{global_step}.pt")
wandb.save(f"models/{experiment_name}/agent.pt", base_path=f"models/{experiment_name}", policy="now")
subprocess.Popen(["python", "league.py", "--evals", f"models/{experiment_name}/{global_step}.pt", "--update-db", "false", "--cuda", "false"])
subprocess.Popen(
[
"python",
"league.py",
"--evals",
f"models/{experiment_name}/{global_step}.pt",
"--update-db",
"false",
"--cuda",
"false",
]
)
eval_queue += [f"models/{experiment_name}/{global_step}.pt"]
print(f"Evaluating models/{experiment_name}/{global_step}.pt")

Expand All @@ -484,8 +494,8 @@ def get_value(self, x):
trueskill_data = {
"name": league.loc[model_path].name,
"mu": league.loc[model_path]["mu"],
"sigma":league.loc[model_path]["sigma"],
"trueskill": league.loc[model_path]["trueskill"]
"sigma": league.loc[model_path]["sigma"],
"trueskill": league.loc[model_path]["trueskill"],
}
trueskill_df = trueskill_df.append(trueskill_data, ignore_index=True)
wandb.log({"trueskill": wandb.Table(dataframe=trueskill_df)})
Expand All @@ -494,7 +504,7 @@ def get_value(self, x):
trueskill_step_df = trueskill_step_df.append(trueskill_data, ignore_index=True)
preset_trueskill_step_df_clone = preset_trueskill_step_df.copy()
preset_trueskill_step_df_clone["step"] = model_global_step
trueskill_step_df = trueskill_step_df.append(preset_trueskill_step_df_clone, ignore_index=True)
trueskill_step_df = trueskill_step_df.append(preset_trueskill_step_df_clone, ignore_index=True)
wandb.log({"trueskill_step": wandb.Table(dataframe=trueskill_step_df)})

# TRY NOT TO MODIFY: record rewards for plotting purposes
Expand Down
16 changes: 7 additions & 9 deletions experiments/ppo_gridnet_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.spaces import MultiDiscrete
from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from ppo_gridnet import Agent, MicroRTSStatsRecorder
from stable_baselines3.common.vec_env import VecMonitor, VecVideoRecorder
from torch.utils.tensorboard import SummaryWriter

from gym_microrts import microrts_ai # noqa
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv


def parse_args():
Expand Down Expand Up @@ -101,7 +100,6 @@ def parse_args():
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic


ais = []
if args.ai:
ais = [eval(f"microrts_ai.{args.ai}")]
Expand Down Expand Up @@ -163,7 +161,7 @@ def parse_args():
# ALGO LOGIC: put action logic here
with torch.no_grad():
invalid_action_masks[step] = torch.tensor(np.array(envs.get_action_mask())).to(device)

if args.ai:
action, logproba, _, _, vs = agent.get_action_and_value(
next_obs, envs=envs, invalid_action_masks=invalid_action_masks[step], device=device
Expand All @@ -173,7 +171,7 @@ def parse_args():
p2_obs = next_obs[1::2]
p1_mask = invalid_action_masks[step][::2]
p2_mask = invalid_action_masks[step][1::2]

p1_action, _, _, _, _ = agent.get_action_and_value(
p1_obs, envs=envs, invalid_action_masks=p1_mask, device=device
)
Expand Down
Loading

0 comments on commit e794ac9

Please sign in to comment.