From 5dbaa9af46b80136fcbb1a000b3ec8b17208d15c Mon Sep 17 00:00:00 2001 From: Elliot Tower <32176771+elliottower@users.noreply.github.com> Date: Fri, 31 Mar 2023 12:17:31 -0400 Subject: [PATCH] Add pickle tests (#53) --- bin/dm_lab.Dockerfile | 2 + scripts/install_dm_lab.sh | 2 + setup.py | 8 +- shimmy/bsuite_compatibility.py | 4 +- shimmy/dm_control_compatibility.py | 6 +- shimmy/dm_control_multiagent_compatibility.py | 6 +- shimmy/dm_lab_compatibility.py | 4 + shimmy/meltingpot_compatibility.py | 5 +- shimmy/openspiel_compatibility.py | 5 +- tests/test_atari.py | 63 ++++++++++++- tests/test_bsuite.py | 56 +++++++++++ tests/test_dm_control.py | 31 ++++++ tests/test_dm_control_multi_agent.py | 94 +++++++++++++++++++ tests/test_dm_lab.py | 94 +++++++++++++++++-- tests/test_meltingpot.py | 59 ++++++++++-- tests/test_openspiel.py | 50 ++++++++++ 16 files changed, 459 insertions(+), 30 deletions(-) diff --git a/bin/dm_lab.Dockerfile b/bin/dm_lab.Dockerfile index 96e735a5..59772441 100644 --- a/bin/dm_lab.Dockerfile +++ b/bin/dm_lab.Dockerfile @@ -52,3 +52,5 @@ RUN git clone https://github.com/deepmind/lab.git \ && rm -rf lab ENTRYPOINT ["/usr/local/shimmy/bin/docker_entrypoint"] + +RUN ls diff --git a/scripts/install_dm_lab.sh b/scripts/install_dm_lab.sh index afd5f7b0..99804685 100644 --- a/scripts/install_dm_lab.sh +++ b/scripts/install_dm_lab.sh @@ -32,9 +32,11 @@ fi pip3 install numpy +# TODO: fix installation issues on MacOS # Build if [ ! -d "lab" ]; then git clone https://github.com/deepmind/lab.git +fi cd lab echo 'build --cxxopt=-std=c++17' > .bazelrc bazel build -c opt //python/pip_package:build_pip_package diff --git a/setup.py b/setup.py index e1beafc4..59a4dd6e 100644 --- a/setup.py +++ b/setup.py @@ -41,11 +41,11 @@ def get_version(): "dm-control>=1.0.10", "imageio", "h5py>=3.7.0", - "pettingzoo>=1.22.4", + "pettingzoo>=1.22.3", ], - "dm-lab": [], - "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.4"], - "meltingpot": ["pettingzoo>=1.22.4"], + "dm-lab": ["dm-env>=1.6"], + "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.3"], + "meltingpot": ["pettingzoo>=1.22.3"], "bsuite": ["bsuite>=0.3.5"], } extras["all"] = list({lib for libs in extras.values() for lib in libs}) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 9f3e8928..c5e2d291 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -8,6 +8,7 @@ from bsuite.environments import Environment from gymnasium.core import ObsType from gymnasium.error import UnsupportedMode +from gymnasium.utils import EzPickle from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space @@ -17,7 +18,7 @@ np.int = int # pyright: ignore[reportGeneralTypeIssues] -class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): +class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle): """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. Note: @@ -33,6 +34,7 @@ def __init__( render_mode: str | None = None, ): """Initialises the environment with a render mode along with render information.""" + EzPickle.__init__(self, env, render_mode) self._env = env self.observation_space = dm_spec2gym_space(env.observation_spec()) diff --git a/shimmy/dm_control_compatibility.py b/shimmy/dm_control_compatibility.py index 314a1d26..acaf13ba 100644 --- a/shimmy/dm_control_compatibility.py +++ b/shimmy/dm_control_compatibility.py @@ -16,6 +16,7 @@ from dm_control.rl import control from gymnasium.core import ObsType from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer +from gymnasium.utils import EzPickle from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space @@ -27,7 +28,7 @@ class EnvType(Enum): RL_CONTROL = 1 -class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): +class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle): """This compatibility wrapper converts a dm-control environment into a gymnasium environment. Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics. @@ -57,6 +58,9 @@ def __init__( camera_id: int = 0, ): """Initialises the environment with a render mode along with render information.""" + EzPickle.__init__( + self, env, render_mode, render_height, render_width, camera_id + ) self._env = env self.env_type = self._find_env_type(env) diff --git a/shimmy/dm_control_multiagent_compatibility.py b/shimmy/dm_control_multiagent_compatibility.py index 30ce7ae2..4d0301cf 100644 --- a/shimmy/dm_control_multiagent_compatibility.py +++ b/shimmy/dm_control_multiagent_compatibility.py @@ -10,6 +10,7 @@ import gymnasium import numpy as np from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer +from gymnasium.utils import EzPickle from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space @@ -62,7 +63,7 @@ def _unravel_ma_timestep( ) -class DmControlMultiAgentCompatibilityV0(ParallelEnv): +class DmControlMultiAgentCompatibilityV0(ParallelEnv, EzPickle): """This compatibility wrapper converts multi-agent dm-control environments, primarily soccer, into a Pettingzoo environment. Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, @@ -84,7 +85,8 @@ def __init__( env (dm_env.Environment): dm control multi-agent environment render_mode (Optional[str]): render_mode """ - super().__init__() + EzPickle.__init__(self, env=env, render_mode=render_mode) + ParallelEnv.__init__(self) self._env = env self.render_mode = render_mode diff --git a/shimmy/dm_lab_compatibility.py b/shimmy/dm_lab_compatibility.py index ef44f0f0..e3e7700d 100644 --- a/shimmy/dm_lab_compatibility.py +++ b/shimmy/dm_lab_compatibility.py @@ -48,6 +48,10 @@ def reset( self._env.reset(seed=seed) info = {} + if seed is not None: + print( + "Warning: DM-lab environments must be seeded in initialization, rather than with reset(seed)." + ) return ( self._env.observations(), info, diff --git a/shimmy/meltingpot_compatibility.py b/shimmy/meltingpot_compatibility.py index 6b6a930e..70edbcb4 100644 --- a/shimmy/meltingpot_compatibility.py +++ b/shimmy/meltingpot_compatibility.py @@ -5,20 +5,20 @@ and modified to modern pettingzoo API """ # pyright: reportOptionalSubscript=false - +# isort: skip_file from __future__ import annotations import functools from typing import Optional import gymnasium -import meltingpot.python import numpy as np import pygame from gymnasium.utils.ezpickle import EzPickle from ml_collections import config_dict from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv +import meltingpot.python import shimmy.utils.meltingpot as utils @@ -89,6 +89,7 @@ def __init__( for index in range(self._num_players) ] self.agents = [agent for agent in self.possible_agents] + self.num_cycles = 0 # Set up pygame rendering if self.render_mode == "human": diff --git a/shimmy/openspiel_compatibility.py b/shimmy/openspiel_compatibility.py index 839bc7b0..a83ac828 100644 --- a/shimmy/openspiel_compatibility.py +++ b/shimmy/openspiel_compatibility.py @@ -8,11 +8,11 @@ import pettingzoo as pz import pyspiel from gymnasium import spaces -from gymnasium.utils import seeding +from gymnasium.utils import EzPickle, seeding from pettingzoo.utils.env import AgentID, ObsType -class OpenspielCompatibilityV0(pz.AECEnv): +class OpenspielCompatibilityV0(pz.AECEnv, EzPickle): """This compatibility wrapper converts an openspiel environment into a pettingzoo environment. OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning @@ -35,6 +35,7 @@ def __init__( game (pyspiel.Game): game render_mode (Optional[str]): render_mode """ + EzPickle.__init__(self, game, render_mode) super().__init__() self.game = game self.possible_agents = [ diff --git a/tests/test_atari.py b/tests/test_atari.py index d5c756d7..e647fd40 100644 --- a/tests/test_atari.py +++ b/tests/test_atari.py @@ -1,4 +1,5 @@ """Tests the ale-py environments are correctly registered.""" +import pickle import warnings import gymnasium as gym @@ -7,7 +8,7 @@ from ale_py.roms import utils as rom_utils from gymnasium.envs.registration import registry from gymnasium.error import Error -from gymnasium.utils.env_checker import check_env +from gymnasium.utils.env_checker import check_env, data_equivalence from shimmy.utils.envs_configs import ALL_ATARI_GAMES @@ -47,3 +48,63 @@ def test_atari_envs(env_id): assert isinstance(warning_message.message, Warning) if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: raise Error(f"Unexpected warning: {warning_message.message}") + + +@pytest.mark.parametrize( + "env_id", + [ + env_id + for env_id, env_spec in registry.items() + if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv" + ], +) +def test_atari_pickle(env_id): + """Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants.""" + env_1 = gym.make(env_id) + env_2 = pickle.loads(pickle.dumps(env_1)) + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = int(env_1.action_space.sample()) + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() + + +@pytest.mark.parametrize( + "env_id", + [ + env_id + for env_id, env_spec in registry.items() + if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv" + ], +) +def test_atari_seeding(env_id): + """Tests the seeding of the atari conversion wrapper.""" + env_1 = gym.make(env_id) + env_2 = gym.make(env_id) + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = int(env_1.action_space.sample()) + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 820827fd..d8be2582 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -1,4 +1,5 @@ """Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs.""" +import pickle import warnings import bsuite @@ -109,3 +110,58 @@ def test_seeding(env_id): env_1.close() env_2.close() + + +# Without EzPickle:_register_bsuite_envs.._make_bsuite_env cannot be pickled +# With EzPickle: maximum recursion limit reached +FAILING_PICKLE_ENVS = [ + "bsuite/bandit_noise-v0", + "bsuite/bandit_scale-v0", + "bsuite/cartpole-v0", + "bsuite/cartpole_noise-v0", + "bsuite/cartpole_scale-v0", + "bsuite/cartpole_swingup-v0", + "bsuite/catch_noise-v0", + "bsuite/catch_scale-v0", + "bsuite/mnist_noise-v0", + "bsuite/mnist_scale-v0", + "bsuite/mountain_car_noise-v0", + "bsuite/mountain_car_scale-v0", +] + +PASSING_PICKLE_ENVS = [ + "bsuite/mnist-v0", + "bsuite/umbrella_length-v0", + "bsuite/discounting_chain-v0", + "bsuite/deep_sea-v0", + "bsuite/umbrella_distract-v0", + "bsuite/catch-v0", + "bsuite/memory_len-v0", + "bsuite/mountain_car-v0", + "bsuite/memory_size-v0", + "bsuite/deep_sea_stochastic-v0", + "bsuite/bandit-v0", +] + + +@pytest.mark.parametrize("env_id", PASSING_PICKLE_ENVS) +def test_pickle(env_id): + """Test that pickling works.""" + env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) + env_2 = pickle.loads(pickle.dumps(env_1)) + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = int(env_1.action_space.sample()) + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() diff --git a/tests/test_dm_control.py b/tests/test_dm_control.py index c050d5a8..29c90b82 100644 --- a/tests/test_dm_control.py +++ b/tests/test_dm_control.py @@ -1,4 +1,5 @@ """Tests the functionality of the DmControlCompatibility Wrapper on dm_control envs.""" +import pickle import warnings from typing import Callable @@ -82,6 +83,36 @@ def test_seeding(env_id): env_1 = gym.make(env_id) env_2 = gym.make(env_id) + if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic): + # LQR fails this test currently. + return + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(10): + actions = env_1.action_space.sample() + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() + + +@pytest.mark.skip( + reason="Fatal Python error: Segmentation fault (with or without EzPickle)" +) +@pytest.mark.parametrize("env_id", DM_CONTROL_ENV_IDS[0]) +def test_pickle(env_id): + """Test that dm-control seeding works.""" + env_1 = gym.make(env_id) + env_2 = pickle.loads(pickle.dumps(env_1)) + if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic): # LQR fails this test currently. return diff --git a/tests/test_dm_control_multi_agent.py b/tests/test_dm_control_multi_agent.py index 923f6c64..406e6b8f 100644 --- a/tests/test_dm_control_multi_agent.py +++ b/tests/test_dm_control_multi_agent.py @@ -1,7 +1,9 @@ """Tests the multi-agent dm-control soccer environment.""" +import pickle import pytest from dm_control.locomotion import soccer as dm_soccer +from gymnasium.utils.env_checker import data_equivalence from pettingzoo.test import parallel_api_test from shimmy.dm_control_multiagent_compatibility import ( @@ -32,3 +34,95 @@ def test_check_env(walker_type): parallel_api_test(env) env.close() + + +@pytest.mark.parametrize("walker_type", WALKER_TYPES) +def test_seeding(walker_type): + """Tests the seeding of the openspiel conversion wrapper.""" + # load envs + env1 = dm_soccer.load( + team_size=2, + time_limit=10.0, + disable_walker_contacts=False, + enable_field_box=True, + terminate_on_goal=False, + walker_type=walker_type, + ) + env2 = dm_soccer.load( + team_size=2, + time_limit=10.0, + disable_walker_contacts=False, + enable_field_box=True, + terminate_on_goal=False, + walker_type=walker_type, + ) + + # convert the environment + env1 = DmControlMultiAgentCompatibilityV0(env1, render_mode=None) + env2 = DmControlMultiAgentCompatibilityV0(env2, render_mode=None) + + env1.reset(seed=42) + env2.reset(seed=42) + + for agent in env1.possible_agents: + env1.action_space(agent).seed(42) + env2.action_space(agent).seed(42) + + while env1.agents: + actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents} + actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents} + + assert data_equivalence(actions1, actions2), "Incorrect action seeding" + + obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1) + obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2) + + assert not data_equivalence( + obs1, obs2 + ), "Observations are expected to be slightly different (ball position/velocity)" + assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards" + assert data_equivalence(terminations1, terminations2), "Incorrect terminations." + assert data_equivalence(truncations1, truncations2), "Incorrect truncations" + assert data_equivalence(infos1, infos2), "Incorrect infos" + env1.close() + env2.close() + + +@pytest.mark.skip(reason="Cannot pickle weakref objects used in dm_soccer envs.") +@pytest.mark.parametrize("walker_type", WALKER_TYPES) +def test_pickle(walker_type): + """Tests the seeding of the openspiel conversion wrapper.""" + env1 = dm_soccer.load( + team_size=2, + time_limit=10.0, + disable_walker_contacts=False, + enable_field_box=True, + terminate_on_goal=False, + walker_type=walker_type, + ) + env1 = DmControlMultiAgentCompatibilityV0(env1, render_mode=None) + env2 = pickle.loads(pickle.dumps(env1)) + + env1.reset(seed=42) + env2.reset(seed=42) + + for agent in env1.possible_agents: + env1.action_space(agent).seed(42) + env2.action_space(agent).seed(42) + + while env1.agents: + actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents} + actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents} + + assert data_equivalence(actions1, actions2), "Incorrect action seeding" + + obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1) + obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2) + + assert data_equivalence(obs1, obs2), "Incorrect observations" + assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards" + assert data_equivalence(terminations1, terminations2), "Incorrect terminations." + assert data_equivalence(truncations1, truncations2), "Incorrect truncations" + assert data_equivalence(infos1, infos2), "Incorrect infos" + env1.close() + env2.close() diff --git a/tests/test_dm_lab.py b/tests/test_dm_lab.py index 27b3dfb4..49cc81ce 100644 --- a/tests/test_dm_lab.py +++ b/tests/test_dm_lab.py @@ -1,25 +1,103 @@ """Tests the multi-agent dm-control soccer environment.""" +# pyright: reportUndefinedVariable=false +# flake8: noqa F821 +import pickle -import gymnasium import pytest -from gymnasium.utils.env_checker import check_env +from gymnasium.utils.env_checker import check_env, data_equivalence from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 +pytest.importorskip("deepmind_lab") + +LEVEL_NAMES = [ + "lt_chasm", + "lt_hallway_slope", + "lt_horseshoe_color", + "lt_space_bounce_hard", + "nav_maze_random_goal_01", + "nav_maze_random_goal_02", + "nav_maze_random_goal_03", + "nav_maze_static_01", + "nav_maze_static_02", + "nav_maze_static_03", + "seekavoid_arena_01", + "stairway_to_melon", +] -@pytest.mark.skip(reason="no way of currently testing this") -def test_check_env(): - """Check that environment pass the gym check_env.""" - import deepmind_lab +@pytest.mark.skip("DM lab tests are not currently possible.") +@pytest.mark.parametrize("level_name", LEVEL_NAMES[0]) +def test_check_env(level_name): + """Check that environment pass the gym check_env.""" observations = ["RGBD"] config = {"width": "640", "height": "480", "botCount": "2"} renderer = "hardware" - env = deepmind_lab.Lab("lt_chasm", observations, config=config, renderer=renderer) - + env = deepmind_lab.Lab(level_name, observations, config=config, renderer=renderer) env = DmLabCompatibilityV0(env) check_env(env) env.close() + + +@pytest.mark.skip("DM lab seed tests are not currently possible.") +@pytest.mark.parametrize("level_name", LEVEL_NAMES[0]) +def test_seeding(level_name): + """Checks that the environment can be properly seeded.""" + observations = ["RGBD"] + config = {"width": "640", "height": "480", "botCount": "2", "random_seed": "42"} + renderer = "hardware" + + env_1 = deepmind_lab.Lab(level_name, observations, config=config, renderer=renderer) + env_1 = DmLabCompatibilityV0(env_1) + + env_2 = deepmind_lab.Lab(level_name, observations, config=config, renderer=renderer) + env_2 = DmLabCompatibilityV0(env_2) + + obs_1, info_1 = env_1.reset() + obs_2, info_2 = env_2.reset() + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = env_1.action_space.sample() + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + # assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() + + +@pytest.mark.skip("DM lab pickle tests are not currently possible.") +@pytest.mark.parametrize("level_name", LEVEL_NAMES[0]) +def test_pickle(level_name): + """Checks that the environment can be saved and loaded by pickling.""" + observations = ["RGBD"] + config = {"width": "640", "height": "480", "botCount": "2", "random_seed": "42"} + renderer = "hardware" + + env_1 = deepmind_lab.Lab(level_name, observations, config=config, renderer=renderer) + env_1 = DmLabCompatibilityV0(env_1) + + env_2 = pickle.loads(pickle.dumps(env_1)) + + obs_1, info_1 = env_1.reset() + obs_2, info_2 = env_2.reset() + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = env_1.action_space.sample() + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + # assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() diff --git a/tests/test_meltingpot.py b/tests/test_meltingpot.py index 01f17d5d..dde624e4 100644 --- a/tests/test_meltingpot.py +++ b/tests/test_meltingpot.py @@ -1,16 +1,19 @@ """Tests the functionality of the MeltingPotCompatibility wrapper on meltingpot substrates.""" +# pyright: reportUndefinedVariable=false +# flake8: noqa F821 E402 +import pickle + import pytest from gymnasium.utils.env_checker import data_equivalence from pettingzoo.test import parallel_api_test pytest.importorskip("meltingpot") -import meltingpot # noqa: E402 -import meltingpot.python # noqa: E402 -from meltingpot.python.configs.substrates import SUBSTRATES # noqa: E402 -from ml_collections import config_dict # noqa: E402 +import meltingpot.python +from meltingpot.python.configs.substrates import SUBSTRATES +from ml_collections import config_dict -from shimmy.meltingpot_compatibility import MeltingPotCompatibilityV0 # noqa: E402 +from shimmy.meltingpot_compatibility import MeltingPotCompatibilityV0 @pytest.mark.skip( @@ -26,10 +29,9 @@ def test_seeding(substrate_name): env1.reset(seed=42) env2.reset(seed=42) - a_space1 = env1.action_space(env1.agents[0]) - a_space1.seed(42) - a_space2 = env2.action_space(env2.agents[0]) - a_space2.seed(42) + for agent in env1.possible_agents: + env1.action_space(agent).seed(42) + env2.action_space(agent).seed(42) while env1.agents: actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents} @@ -45,6 +47,8 @@ def test_seeding(substrate_name): assert data_equivalence(terminations1, terminations2), "Incorrect terminations." assert data_equivalence(truncations1, truncations2), "Incorrect truncations" assert data_equivalence(infos1, infos2), "Incorrect infos" + env1.close() + env2.close() @pytest.mark.parametrize("substrate_name", SUBSTRATES) @@ -59,6 +63,7 @@ def test_substrate(substrate_name): while env.agents: actions = {agent: env.action_space(agent).sample() for agent in env.agents} observations, rewards, terminations, truncations, infos = env.step(actions) + env.close() def test_custom_substrate(): @@ -89,6 +94,7 @@ def test_custom_substrate(): actions = {agent: env.action_space(agent).sample() for agent in env.agents} env.step(actions) env.render() + env.close() @pytest.mark.parametrize("substrate_name", SUBSTRATES) @@ -101,3 +107,38 @@ def test_rendering(substrate_name): actions = {agent: env.action_space(agent).sample() for agent in env.agents} env.step(actions) env.render() + + +@pytest.mark.skip( + reason="Melting Pot environments are stochastic and do not currently support seeding." +) +@pytest.mark.parametrize("substrate_name", SUBSTRATES) +def test_pickle(substrate_name): + """Test that environments can be saved and loaded with pickle.""" + # load and convert the envs + env1 = MeltingPotCompatibilityV0(substrate_name=substrate_name, render_mode=None) + env2 = pickle.loads(pickle.dumps(env1)) + + env1.reset(seed=42) + env2.reset(seed=42) + + for agent in env1.possible_agents: + env1.action_space(agent).seed(42) + env2.action_space(agent).seed(42) + + while env1.agents: + actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents} + actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents} + + assert data_equivalence(actions1, actions2), "Incorrect action seeding" + + obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1) + obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2) + + assert data_equivalence(obs1, obs2), "Incorrect observations" + assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards" + assert data_equivalence(terminations1, terminations2), "Incorrect terminations." + assert data_equivalence(truncations1, truncations2), "Incorrect truncations" + assert data_equivalence(infos1, infos2), "Incorrect infos" + env1.close() + env2.close() diff --git a/tests/test_openspiel.py b/tests/test_openspiel.py index c66254ce..9e870b4d 100644 --- a/tests/test_openspiel.py +++ b/tests/test_openspiel.py @@ -1,4 +1,6 @@ """Tests the functionality of the OpenspielWrapper on openspiel envs.""" +import pickle + import numpy as np import pyspiel import pytest @@ -177,3 +179,51 @@ def test_seeding(game): env1.step(action1) env2.step(action2) + env1.close() + env2.close() + + +@pytest.mark.parametrize("game", _PASSING_GAMES) +def test_pickle(game): + """Tests the seeding of the openspiel conversion wrapper.""" + env1 = pyspiel.load_game(game) + env1 = OpenspielCompatibilityV0(env1, render_mode=None) + + env2 = pickle.loads(pickle.dumps(env1)) + + assert data_equivalence( + env1.reset(seed=42), env2.reset(seed=42) + ), "Incorrect return on reset()" + + agent1 = env1.agent_selection + agent2 = env2.agent_selection + assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}" + + a_space1 = env1.action_space(agent1) + a_space1.seed(42) + a_space2 = env2.action_space(agent2) + a_space2.seed(42) + + for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()): + assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}" + + obs1, rew1, term1, trunc1, info1 = env1.last() + obs2, rew2, term2, trunc2, info2 = env2.last() + + assert data_equivalence(obs1, obs2), f"Incorrect observations: {obs1} {obs2}" + assert data_equivalence(rew1, rew2), f"Incorrect rewards: {rew1} {rew2}" + assert data_equivalence(term1, term2), f"Incorrect terms: {term1} {term2}" + assert data_equivalence(trunc1, trunc2), f"Incorrect truncs: {trunc1} {trunc2}" + assert data_equivalence(info1, info2), f"Incorrect info: {info1} {info2}" + + action1 = a_space1.sample(mask=info1["action_mask"]) + action2 = a_space2.sample(mask=info2["action_mask"]) + + assert data_equivalence( + action1, action2 + ), f"Incorrect actions: {action1} {action2}" + + env1.step(action1) + env2.step(action2) + env1.close() + env2.close()