Skip to content

Commit

Permalink
update gridnet example
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jul 5, 2021
1 parent 640a6b5 commit c09f3ed
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 451 deletions.
54 changes: 17 additions & 37 deletions experiments/ppo_gridnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def parse_args():
help="the entity (team) of wandb's project")

# Algorithm specific arguments
parser.add_argument('--partial-obs', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, the game will have partial observability')
parser.add_argument('--n-minibatch', type=int, default=4,
help='the number of mini batch')
parser.add_argument('--num-bot-envs', type=int, default=24,
parser.add_argument('--num-bot-envs', type=int, default=0,
help='the number of bot game environment; 16 bot envs measn 16 games')
parser.add_argument('--num-selfplay-envs', type=int, default=0,
parser.add_argument('--num-selfplay-envs', type=int, default=24,
help='the number of self play envs; 16 self play envs means 8 games')
parser.add_argument('--num-steps', type=int, default=256,
help='the number of steps per game environment')
Expand Down Expand Up @@ -93,10 +95,6 @@ def parse_args():


class MicroRTSStatsRecorder(VecEnvWrapper):
def __init__(self, env, gamma):
super().__init__(env)
self.gamma = gamma

def reset(self):
obs = self.venv.reset()
self.raw_rewards = [[] for _ in range(self.num_envs)]
Expand Down Expand Up @@ -161,12 +159,14 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
return layer


class Encoder(nn.Module):
def __init__(self, input_channels):
super().__init__()
self._encoder = nn.Sequential(
class Agent(nn.Module):
def __init__(self, envs, mapsize=16 * 16):
super(Agent, self).__init__()
self.mapsize = mapsize
h, w, c = envs.observation_space.shape
self.encoder = nn.Sequential(
Transpose((0, 3, 1, 2)),
layer_init(nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)),
layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
nn.MaxPool2d(3, stride=2, padding=1),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
Expand All @@ -178,37 +178,16 @@ def __init__(self, input_channels):
layer_init(nn.Conv2d(128, 256, kernel_size=3, padding=1)),
nn.MaxPool2d(3, stride=2, padding=1),
)

def forward(self, x):
return self._encoder(x)


class Decoder(nn.Module):
def __init__(self, output_channels):
super().__init__()

self.deconv = nn.Sequential(
self.actor = nn.Sequential(
layer_init(nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)),
nn.ReLU(),
layer_init(nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)),
nn.ReLU(),
layer_init(nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)),
nn.ReLU(),
layer_init(nn.ConvTranspose2d(32, output_channels, 3, stride=2, padding=1, output_padding=1)),
layer_init(nn.ConvTranspose2d(32, 78, 3, stride=2, padding=1, output_padding=1)),
Transpose((0, 2, 3, 1)),
)

def forward(self, x):
return self.deconv(x)


class Agent(nn.Module):
def __init__(self, mapsize=16 * 16):
super(Agent, self).__init__()
self.mapsize = mapsize
h, w, c = envs.observation_space.shape
self.encoder = Encoder(c)
self.actor = Decoder(78)
self.critic = nn.Sequential(
nn.Flatten(),
layer_init(nn.Linear(256, 128), std=1),
Expand Down Expand Up @@ -282,24 +261,25 @@ def get_value(self, x):
envs = MicroRTSGridModeVecEnv(
num_selfplay_envs=args.num_selfplay_envs,
num_bot_envs=args.num_bot_envs,
partial_obs=args.partial_obs,
max_steps=2000,
render_theme=2,
ai2s=[microrts_ai.coacAI for _ in range(args.num_bot_envs - 6)]
ai2s=[microrts_ai.randomAI for _ in range(args.num_bot_envs - 6)]
+ [microrts_ai.randomBiasedAI for _ in range(min(args.num_bot_envs, 2))]
+ [microrts_ai.lightRushAI for _ in range(min(args.num_bot_envs, 2))]
+ [microrts_ai.workerRushAI for _ in range(min(args.num_bot_envs, 2))],
map_path="maps/16x16/basesWorkers16x16.xml",
reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]),
)
envs = MicroRTSStatsRecorder(envs, args.gamma)
envs = MicroRTSStatsRecorder(envs)
envs = VecMonitor(envs)
if args.capture_video:
envs = VecVideoRecorder(
envs, f"videos/{experiment_name}", record_video_trigger=lambda x: x % 1000000 == 0, video_length=2000
)
assert isinstance(envs.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"

agent = Agent().to(device)
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
if args.anneal_lr:
# /~https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
Expand Down
Loading

0 comments on commit c09f3ed

Please sign in to comment.