Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickling tests, adapt all envs to be picklable #928

Merged
merged 24 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4001ab1
Add pickling tests for all envs in API test, adapt waterworld to be p…
elliottower Mar 29, 2023
e9476db
Refactor pickle test to be separate test (not in API test)
elliottower Mar 30, 2023
27a90f4
Update init __version__ to 1.23.0, tutorials rollback to 1.22.3
elliottower Mar 30, 2023
af6902a
Update CleanRL tutorial requirements.txt after testing
elliottower Mar 30, 2023
ad111a0
Add pickle test file (not committed by mistake)
elliottower Mar 31, 2023
d29cd32
Merge branch 'master' into pickle-test
elliottower Apr 6, 2023
32ed2e1
Remove unnecessary changes, check waterworld/mpe
elliottower Apr 8, 2023
bdb1b14
Merge remote-tracking branch 'upstream/master' into pickle-test
elliottower Apr 10, 2023
473a797
Use keyword args for EzPickle, replace all relative imports with global
elliottower Apr 11, 2023
a2c420c
Fix typo with KAZ ezpickle
elliottower Apr 11, 2023
fd65b83
Fix typo with KAZ ezpickle
elliottower Apr 11, 2023
4409525
Merge branch 'Farama-Foundation:master' into pickle-test
elliottower Apr 13, 2023
010a4e7
Fix super() calls with ezpickle
elliottower Apr 18, 2023
8b5205a
test removing simple_comm_v2 from failing envs
elliottower Apr 18, 2023
44a6186
fix SimpleEnv init, add absolute import
elliottower Apr 18, 2023
808086b
Re-remove simple world comm v2 from pickle test (fails)
elliottower Apr 18, 2023
a80a36f
remove pickle human env and fully remove simple world comm test
elliottower Apr 21, 2023
e56961f
Fix typo in failing env names
elliottower Apr 21, 2023
4759c09
Fix typos in param combs test
elliottower Apr 21, 2023
8d6237d
Update simple_world_comm.py
RedTachyon Apr 21, 2023
a96e439
Update simple_world_comm.py
RedTachyon Apr 21, 2023
e822070
Reactivate tests
RedTachyon Apr 21, 2023
27a6cc1
Merge branch 'Farama-Foundation:master' into pickle-test
elliottower Apr 21, 2023
1beceaa
Fix action space seeding in pickle test
elliottower Apr 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pettingzoo/mpe/simple/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"""

import numpy as np
from gymnasium.utils import EzPickle

from pettingzoo.utils.conversions import parallel_wrapper_fn

Expand All @@ -51,8 +52,9 @@
from .._mpe_utils.simple_env import SimpleEnv, make_env
elliottower marked this conversation as resolved.
Show resolved Hide resolved


class raw_env(SimpleEnv):
class raw_env(SimpleEnv, EzPickle):
def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None):
EzPickle.__init__(self, max_cycles, continuous_actions, render_mode)
elliottower marked this conversation as resolved.
Show resolved Hide resolved
scenario = Scenario()
world = scenario.make_world()
super().__init__(
elliottower marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
5 changes: 4 additions & 1 deletion pettingzoo/sisl/waterworld/waterworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@

"""

from gymnasium.utils import EzPickle

from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector, wrappers
from pettingzoo.utils.conversions import parallel_wrapper_fn
Expand All @@ -154,7 +156,7 @@ def env(**kwargs):
parallel_env = parallel_wrapper_fn(env)


class raw_env(AECEnv):
class raw_env(AECEnv, EzPickle):
metadata = {
"render_modes": ["human", "rgb_array"],
"name": "waterworld_v4",
Expand All @@ -163,6 +165,7 @@ class raw_env(AECEnv):
}

def __init__(self, *args, **kwargs):
EzPickle.__init__(self, *args, **kwargs)
super().__init__()
self.env = _env(*args, **kwargs)

Expand Down
141 changes: 141 additions & 0 deletions test/pickle_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pickle

import pytest
from gymnasium.utils.env_checker import data_equivalence

from .all_modules import all_environments

ALL_ENVS = list(all_environments.items())
FAILING_ENV_NAMES = ["mpe/simple_world_comm_v2"]
elliottower marked this conversation as resolved.
Show resolved Hide resolved
PASSING_ENVS = [
(name, env_module)
for (name, env_module) in ALL_ENVS
if name not in FAILING_ENV_NAMES
]


@pytest.mark.parametrize(("name", "env_module"), PASSING_ENVS)
def test_pickle_env(name, env_module):
env1 = env_module.env(render_mode=None)
env2 = pickle.loads(pickle.dumps(env1))

env1.reset(seed=42)
env2.reset(seed=42)

agent1 = env1.agents[0]
agent2 = env2.agents[0]

a_space1 = env1.action_space(agent1)
a_space1.seed(42)
a_space2 = env2.action_space(agent2)
a_space2.seed(42)

iter = 0
for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()):
if iter > 10:
break
assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}"

obs1, rew1, term1, trunc1, info1 = env1.last()
obs2, rew2, term2, trunc2, info2 = env2.last()

if name == "mpe/simple_world_comm_v2":
print("Test")

if term1 or term2 or trunc1 or trunc2:
break

assert data_equivalence(obs1, obs2), f"Incorrect observations: {obs1} {obs2}"
assert data_equivalence(rew1, rew2), f"Incorrect rewards: {rew1} {rew2}"
assert data_equivalence(term1, term2), f"Incorrect terms: {term1} {term2}"
assert data_equivalence(trunc1, trunc2), f"Incorrect truncs: {trunc1} {trunc2}"
assert data_equivalence(info1, info2), f"Incorrect info: {info1} {info2}"

mask = None
if "action_mask" in info1:
mask = info1["action_mask"]

if isinstance(obs1, dict) and "action_mask" in obs1:
mask = obs1["action_mask"]

action1 = a_space1.sample(mask=mask)
action2 = a_space2.sample(mask=mask)

assert data_equivalence(
action1, action2
), f"Incorrect actions: {action1} {action2}"

env1.step(action1)
env2.step(action2)
iter += 1
env1.close()
env2.close()


ALL_ENVS = list(all_environments.items())
FAILING_ENV_NAMES = []
PASSING_ENVS = [
(name, env_module)
for (name, env_module) in ALL_ENVS
if name not in FAILING_ENV_NAMES
]


@pytest.mark.skip(
reason="pickling pygame rendered envs does not work well, video system will not be initialized."
elliottower marked this conversation as resolved.
Show resolved Hide resolved
)
@pytest.mark.parametrize(("name", "env_module"), PASSING_ENVS)
def test_pickle_env_human(name, env_module):
env1 = env_module.env(render_mode="human")
env2 = pickle.loads(pickle.dumps(env1))

env1.reset(seed=42)
env2.reset(seed=42)

agent1 = env1.agents[0]
agent2 = env2.agents[0]

a_space1 = env1.action_space(agent1)
a_space1.seed(42)
a_space2 = env2.action_space(agent2)
a_space2.seed(42)

iter = 0
for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()):
if iter > 5:
break
assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}"

obs1, rew1, term1, trunc1, info1 = env1.last()
obs2, rew2, term2, trunc2, info2 = env2.last()

assert data_equivalence(obs1, obs2), f"Incorrect observations: {obs1} {obs2}"
assert data_equivalence(rew1, rew2), f"Incorrect rewards: {rew1} {rew2}"
assert data_equivalence(term1, term2), f"Incorrect terms: {term1} {term2}"
assert data_equivalence(trunc1, trunc2), f"Incorrect truncs: {trunc1} {trunc2}"
assert data_equivalence(info1, info2), f"Incorrect info: {info1} {info2}"

if name == "mpe/simple_world_comm_v2":
print("Test")
mask = None
if "action_mask" in info1:
mask = info1["action_mask"]

if isinstance(obs1, dict) and "action_mask" in obs1:
mask = obs1["action_mask"]

action1 = a_space1.sample(mask=mask)
action2 = a_space2.sample(mask=mask)

assert data_equivalence(
action1, action2
), f"Incorrect actions: {action1} {action2}"

if term1 or term2 or trunc1 or trunc2:
break

env1.step(action1)
env2.step(action2)
iter += 1
env1.close()
env2.close()