diff --git a/experiments/ppo_gridnet_large.py b/experiments/ppo_gridnet_large.py index 2499babe..ccade4a7 100644 --- a/experiments/ppo_gridnet_large.py +++ b/experiments/ppo_gridnet_large.py @@ -203,6 +203,14 @@ def __init__(self, envs, mapsize=16 * 16): self.register_buffer("mask_value", torch.tensor(-1e8)) def get_action_and_value(self, x, action=None, invalid_action_masks=None, envs=None, device=None): + """ + :return: + (1) action (shape = [num_envs, width*height, 7], where 7 = dimensionality of per-unit action) + (2) log probability of action (shape = [num_envs]) + (3) entropy (shape = [num_envs]) + (4) invalid action masks + (5) Critic's prediction + """ hidden = self.encoder(x) logits = self.actor(hidden) grid_logits = logits.reshape(-1, envs.action_plane_space.nvec.sum()) diff --git a/gym_microrts/envs/vec_env.py b/gym_microrts/envs/vec_env.py index 244e56b4..92094065 100644 --- a/gym_microrts/envs/vec_env.py +++ b/gym_microrts/envs/vec_env.py @@ -279,7 +279,13 @@ def close(self): jpype.shutdownJVM() def get_action_mask(self): + """ + :return: Mask for action types and action parameters, + of shape [num_envs, map height * width, action types + params] + """ + # action_mask shape: [num_envs, map height, map width, 1 + action types + params] action_mask = np.array(self.vec_client.getMasks(0)) + # self.source_unit_mask shape: [num_envs, map height * map width * 1] self.source_unit_mask = action_mask[:, :, :, 0].reshape(self.num_envs, -1) action_type_and_parameter_mask = action_mask[:, :, :, 1:].reshape(self.num_envs, self.height * self.width, -1) return action_type_and_parameter_mask diff --git a/gym_microrts/microrts b/gym_microrts/microrts index 77411e7d..515ceff9 160000 --- a/gym_microrts/microrts +++ b/gym_microrts/microrts @@ -1 +1 @@ -Subproject commit 77411e7d133820cd199a91382474e0f1bb3b7316 +Subproject commit 515ceff955611ad32a726756bb0c96782978126d