Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update testing #946

Merged
merged 39 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8601497
Update hanabi to use shimmy[openspiel] instead of Hanabi Learning Env
elliottower Apr 2, 2023
fb2fc0f
Fix render test to use action_space.sample() for action mask
elliottower Apr 2, 2023
57d330c
Remove hanabi learning env req, shimmy as 'all' req
elliottower Apr 2, 2023
3c3971b
Remove unnecessary check for info[action_mask], fix old citation
elliottower Apr 4, 2023
21f55ab
Fix typos and clean up code from PR feedback
elliottower Apr 5, 2023
b3fe068
Merge branch 'master' into hanabi-shimmy
elliottower Apr 5, 2023
07009c8
Update api_test.py
elliottower Apr 7, 2023
4ebc2b6
Temporary fix for rendering not implemented error
elliottower Apr 7, 2023
91b880e
Pre-commit
elliottower Apr 7, 2023
9141084
Clean up action masking code per jet's suggestions
elliottower Apr 7, 2023
a7e037f
fix typo in observation dict handling, fix dtype mismatch error
elliottower Apr 7, 2023
f043896
Pre-commit
elliottower Apr 7, 2023
1512ced
Re-write seed test to match gymnasium/shimmy seed tests
elliottower Apr 8, 2023
b968529
Used same function names and functionality as old seed_test
elliottower Apr 8, 2023
c0fa964
Fixed seed test call for hanabi (unsupported)
elliottower Apr 8, 2023
8785eb6
Attempt to fix generated_agents variable env tests
elliottower Apr 8, 2023
b368d31
Fix typos and add temp code to debug issue
elliottower Apr 10, 2023
b1c17a2
Fix variable env test and add obs seeding to check_env_deterministic
elliottower Apr 11, 2023
e55cd71
Merge branch 'Farama-Foundation:master' into hanabi-shimmy
elliottower Apr 11, 2023
b18a16f
Remove teporary fix in seed test code
elliottower Apr 11, 2023
ae13b4e
Disable failing waterworld tests, add comment explaining bug
elliottower Apr 12, 2023
6145de3
Merge branch 'hanabi-shimmy' of github.com:elliottower/PettingZoo int…
elliottower Apr 12, 2023
a55b1f3
Revert hanabi and documentation change, testing only branch
elliottower Apr 13, 2023
73ed380
Change hanabi v5 references to v4
elliottower Apr 13, 2023
9a569ca
Merge branch 'Farama-Foundation:master' into testing
elliottower Apr 13, 2023
68a4ad9
Revert accidental hanabi related changes
elliottower Apr 13, 2023
7605e9b
Exclude waterworld (fails more rigorous seed test)
elliottower Apr 13, 2023
cf1cde6
Update variable env test, revert changed seed for seed test
elliottower Apr 13, 2023
5196a75
fix test
jjshoots Apr 17, 2023
2028309
Merge pull request #1 from jjshoots/fix_test
elliottower Apr 17, 2023
bc2b179
Merge branch 'Farama-Foundation:master' into testing
elliottower Apr 17, 2023
cb53283
Remove skip from passing tests, refactor for easier debugging
elliottower Apr 17, 2023
ee72291
Ensure asserts are done for last step of environment in seed_test
elliottower Apr 17, 2023
e284956
revert rlcard float32 change, match other envs
elliottower Apr 17, 2023
a05eec0
Revert change for texas holdem, rlcard envs use float32
elliottower Apr 17, 2023
00166ef
fix test
jjshoots Apr 18, 2023
965e843
pre commit
jjshoots Apr 18, 2023
138a055
Merge pull request #2 from jjshoots/fix_test
elliottower Apr 18, 2023
aebad1f
Update all_parameter_combs_test.py
elliottower Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, num_players=2, render_mode=None):
),
[100, 100],
),
dtype=np.float64,
dtype=np.float32,
),
"action_mask": spaces.Box(
low=0, high=1, shape=(self.env.num_actions,), dtype=np.int8
Expand Down
34 changes: 19 additions & 15 deletions pettingzoo/test/api_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import re
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -316,8 +315,6 @@ def play_test(env, observation_0, num_cycles):
"""
env.reset()

terminated = {agent: False for agent in env.agents}
truncated = {agent: False for agent in env.agents}
live_agents = set(env.agents[:])
has_finished = set()
generated_agents = set()
Expand All @@ -333,10 +330,13 @@ def play_test(env, observation_0, num_cycles):
prev_observe, reward, terminated, truncated, info = env.last()
if terminated or truncated:
action = None
elif isinstance(prev_observe, dict) and "action_mask" in prev_observe:
action = random.choice(np.flatnonzero(prev_observe["action_mask"]).tolist())
else:
action = env.action_space(agent).sample()
mask = (
prev_observe.get("action_mask")
if isinstance(prev_observe, dict)
else None
)
action = env.action_space(agent).sample(mask)

if agent not in live_agents:
live_agents.add(agent)
Expand Down Expand Up @@ -382,15 +382,17 @@ def play_test(env, observation_0, num_cycles):
if not env.agents:
break

if isinstance(env.observation_space(agent), gymnasium.spaces.Box):
assert env.observation_space(agent).dtype == prev_observe.dtype
assert env.observation_space(agent).contains(
prev_observe
), "Out of bounds observation: " + str(prev_observe)

assert env.observation_space(agent).contains(
prev_observe
), "Agent's observation is outside of it's observation space"
if isinstance(env.observation_space(agent), gymnasium.spaces.Box):
assert env.observation_space(agent).dtype == prev_observe.dtype
elif isinstance(env.observation_space(agent), gymnasium.spaces.Dict):
assert (
env.observation_space(agent)["observation"].dtype
== prev_observe["observation"].dtype
)
test_observation(prev_observe, observation_0, str(env.unwrapped))
if not isinstance(env.infos[env.agent_selection], dict):
warnings.warn(
Expand All @@ -407,10 +409,9 @@ def play_test(env, observation_0, num_cycles):
obs, reward, terminated, truncated, info = env.last()
if terminated or truncated:
action = None
elif isinstance(obs, dict) and "action_mask" in obs:
action = random.choice(np.flatnonzero(obs["action_mask"]).tolist())
else:
action = env.action_space(agent).sample()
mask = obs.get("action_mask") if isinstance(obs, dict) else None
action = env.action_space(agent).sample(mask)
assert isinstance(terminated, bool), "terminated from last is not True or False"
assert isinstance(truncated, bool), "terminated from last is not True or False"
assert (
Expand Down Expand Up @@ -439,7 +440,7 @@ def test_action_flexibility(env):
if terminated or truncated:
action = None
elif isinstance(obs, dict) and "action_mask" in obs:
action = random.choice(np.flatnonzero(obs["action_mask"]).tolist())
action = env.action_space(agent).sample(obs["action_mask"])
else:
action = 0
env.step(action)
Expand Down Expand Up @@ -481,6 +482,9 @@ def progress_report(msg):

env.reset()
observation_0, *_ = env.last()
if isinstance(observation_0, dict) and "observation" in observation_0:
observation_0 = observation_0["observation"]

test_observation(observation_0, observation_0, str(env.unwrapped))

non_observe, *_ = env.last(observe=False)
Expand Down
13 changes: 13 additions & 0 deletions pettingzoo/test/example_envs/generated_agents_env_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,25 @@ def reset(self, seed=None, options=None):
self.truncations = {}
self.infos = {}
self.num_steps = 0

self._obs_spaces = {}
self._act_spaces = {}
self.types = []
self._agent_counters = {}
for i in range(3):
self.add_type()
for i in range(5):
self.add_agent(self.np_random.choice(self.types))

self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()

# seed observation and action spaces
for i, agent in enumerate(self.agents):
self.observation_space(agent).seed(seed)
for i, agent in enumerate(self.agents):
self.action_space(agent).seed(seed)

def seed(self, seed=None):
self.np_random, _ = gymnasium.utils.seeding.np_random(seed)

Expand Down
19 changes: 18 additions & 1 deletion pettingzoo/test/example_envs/generated_agents_parallel_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,27 @@ def add_agent(self, type):
def reset(self, seed=None, options=None):
if seed is not None:
self.seed(seed=seed)
self.agents = []
self.num_steps = 0

# Reset spaces and types
self._obs_spaces = {}
self._act_spaces = {}
self.types = []
self._agent_counters = {}
for i in range(3):
self.add_type()

# Add agents
self.agents = []
for i in range(5):
self.add_agent(self.np_random.choice(self.types))

# seed observation and action spaces
for i, agent in enumerate(self.agents):
self.observation_space(agent).seed(seed)
for i, agent in enumerate(self.agents):
self.action_space(agent).seed(seed)

return {agent: self.observe(agent) for agent in self.agents}

def seed(self, seed=None):
Expand Down
4 changes: 1 addition & 3 deletions pettingzoo/test/render_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import random

import numpy as np


Expand All @@ -14,7 +12,7 @@ def collect_render_results(env):
if terminated or truncated:
action = None
elif isinstance(obs, dict) and "action_mask" in obs:
action = random.choice(np.flatnonzero(obs["action_mask"]).tolist())
action = env.action_space(agent).sample(obs["action_mask"])
else:
action = env.action_space(agent).sample()
env.step(action)
Expand Down
181 changes: 100 additions & 81 deletions pettingzoo/test/seed_test.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,128 @@
import hashlib
import pickle
import random
import warnings

import numpy as np

from pettingzoo.utils import parallel_to_aec


def hash(val):
val = pickle.dumps(val)
hasher = hashlib.md5()
hasher.update(val)
return hasher.hexdigest()


def calc_hash(new_env, rand_issue, max_env_iters):
cur_hashes = []
sampler = random.Random(42)
for i in range(3):
new_env.reset(seed=i)
for j in range(rand_issue + 1):
random.randint(0, 1000)
np.random.normal(size=100)
for agent in new_env.agent_iter(max_env_iters):
obs, rew, terminated, truncated, info = new_env.last()
if terminated or truncated:
action = None
elif isinstance(obs, dict) and "action_mask" in obs:
action = sampler.choice(np.flatnonzero(obs["action_mask"]).tolist())
else:
action = new_env.action_space(agent).sample()
new_env.step(action)
cur_hashes.append(agent)
cur_hashes.append(hash_obsevation(obs))
cur_hashes.append(float(rew))

return hash(tuple(cur_hashes))
from gymnasium.utils.env_checker import data_equivalence


def seed_action_spaces(env):
if hasattr(env, "possible_agents"):
for i, agent in enumerate(env.possible_agents):
if hasattr(env, "agents"):
for i, agent in enumerate(env.agents):
env.action_space(agent).seed(42 + i)


def seed_observation_spaces(env):
if hasattr(env, "agents"):
for i, agent in enumerate(env.agents):
env.observation_space(agent).seed(42 + i)


def check_environment_deterministic(env1, env2, num_cycles):
"""
env1 and env2 should be seeded environments
"""Check that two AEC environments execute the same way."""

returns a bool: true if env1 and env2 execute the same way
"""
env1.reset(seed=42)
env2.reset(seed=42)

# seeds action space so that actions are deterministic
# seed action spaces to ensure sampled actions are the same
seed_action_spaces(env1)
seed_action_spaces(env2)

num_agents = max(1, len(getattr(env1, "possible_agents", [])))
# seed observation spaces to ensure first observation is the same
seed_observation_spaces(env1)
seed_observation_spaces(env2)

iter = 0
max_env_iters = num_cycles * len(env1.agents)

for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()):
assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}"

obs1, reward1, termination1, truncation1, info1 = env1.last()
obs2, reward2, termination2, truncation2, info2 = env2.last()

if termination1 or truncation1 or termination2 or truncation2:
break

# checks deterministic behavior if seed is set
hashes = []
num_seeds = 2
max_env_iters = num_cycles * num_agents
envs = [env1, env2]
for x in range(num_seeds):
hashes.append(calc_hash(envs[x], x, max_env_iters))
assert data_equivalence(obs1, obs2), "Incorrect observation"
assert data_equivalence(reward1, reward2), "Incorrect reward."
assert data_equivalence(termination1, termination2), "Incorrect termination."
assert data_equivalence(truncation1, truncation2), "Incorrect truncation."
assert data_equivalence(info1, info2), "Incorrect info."

return all(hashes[0] == h for h in hashes)
mask1 = obs1.get("action_mask") if isinstance(obs1, dict) else None
mask2 = obs2.get("action_mask") if isinstance(obs2, dict) else None
assert data_equivalence(mask1, mask2), f"Incorrect action mask: {mask1} {mask2}"

action1 = env1.action_space(agent1).sample(mask1)
action2 = env2.action_space(agent2).sample(mask2)

def hash_obsevation(obs):
try:
val = hash(obs.tobytes())
return val
except AttributeError:
try:
return hash(obs)
except TypeError:
warnings.warn("Observation not an int or an Numpy array")
return 0
assert data_equivalence(
action1, action2
), f"Incorrect actions: {action1} {action2}"

env1.step(action1)
env2.step(action2)

def test_environment_reset_deterministic(env1, num_cycles):
iter += 1

if iter >= max_env_iters:
break

env1.close()
env2.close()


def check_environment_deterministic_parallel(env1, env2, num_cycles):
"""Check that two parallel environments execute the same way."""
env1.reset(seed=42)
env2.reset(seed=42)

# seed action spaces to ensure sampled actions are the same
seed_action_spaces(env1)
hash1 = calc_hash(env1, 1, num_cycles)
seed_action_spaces(env2)

# seed observation spaces to ensure first observation is the same
seed_observation_spaces(env1)
seed_observation_spaces(env2)

iter = 0
max_env_iters = num_cycles * len(env1.agents)

env1.reset(seed=42)
env2.reset(seed=42)

seed_action_spaces(env1)
hash2 = calc_hash(env1, 2, num_cycles)
assert hash1 == hash2, "environments kept state after and reset(seed)"
seed_action_spaces(env2)

while env1.agents:
actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents}
actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents}

assert data_equivalence(actions1, actions2), "Incorrect action seeding"

obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1)
obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2)

iter += 1

assert data_equivalence(obs1, obs2), "Incorrect observations"
assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards"
assert data_equivalence(terminations1, terminations2), "Incorrect terminations."
assert data_equivalence(truncations1, truncations2), "Incorrect truncations"
assert data_equivalence(infos1, infos2), "Incorrect infos"

if iter >= max_env_iters or any(terminations1) or any(truncations1):
break

env1.close()
env2.close()


def seed_test(env_constructor, num_cycles=10, test_kept_state=True):
def seed_test(env_constructor, num_cycles=10):
env1 = env_constructor()
if test_kept_state:
test_environment_reset_deterministic(env1, num_cycles)
env2 = env_constructor()

assert check_environment_deterministic(
env1, env2, num_cycles
), "The environment gives different results on multiple runs when initialized with the same seed. This is usually a sign that you are using np.random or random modules directly, which uses a global random state."
check_environment_deterministic(env1, env2, 500)


def parallel_seed_test(parallel_env_fn, num_cycles=10, test_kept_state=True):
def aec_env_fn():
parallel_env = parallel_env_fn()
env = parallel_to_aec(parallel_env)
return env
def parallel_seed_test(parallel_env_fn):
env1 = parallel_env_fn()
env2 = parallel_env_fn()

seed_test(aec_env_fn, num_cycles, test_kept_state)
check_environment_deterministic_parallel(env1, env2, 500)
Loading