From 6aca84cba127bde3eebcf800d1993afd1676ec5b Mon Sep 17 00:00:00 2001 From: Elliot Tower <32176771+elliottower@users.noreply.github.com> Date: Fri, 21 Apr 2023 17:15:41 -0400 Subject: [PATCH] Add pickling tests, adapt all envs to be picklable (#928) Co-authored-by: Ariel Kwiatkowski --- pettingzoo/atari/base_atari_env.py | 20 +++--- .../atari/basketball_pong/basketball_pong.py | 6 +- pettingzoo/atari/boxing/boxing.py | 6 +- pettingzoo/atari/combat_plane/combat_plane.py | 6 +- pettingzoo/atari/combat_tank/combat_tank.py | 6 +- pettingzoo/atari/double_dunk/double_dunk.py | 6 +- .../entombed_competitive.py | 6 +- .../entombed_cooperative.py | 6 +- pettingzoo/atari/flag_capture/flag_capture.py | 6 +- pettingzoo/atari/foozpong/foozpong.py | 6 +- pettingzoo/atari/ice_hockey/ice_hockey.py | 6 +- pettingzoo/atari/joust/joust.py | 6 +- pettingzoo/atari/mario_bros/mario_bros.py | 6 +- pettingzoo/atari/maze_craze/maze_craze.py | 6 +- pettingzoo/atari/othello/othello.py | 6 +- pettingzoo/atari/pong/pong.py | 6 +- pettingzoo/atari/quadrapong/quadrapong.py | 6 +- .../atari/space_invaders/space_invaders.py | 6 +- pettingzoo/atari/space_war/space_war.py | 6 +- pettingzoo/atari/surround/surround.py | 6 +- pettingzoo/atari/tennis/tennis.py | 6 +- .../atari/video_checkers/video_checkers.py | 6 +- .../atari/volleyball_pong/volleyball_pong.py | 6 +- pettingzoo/atari/warlords/warlords.py | 6 +- .../atari/wizard_of_wor/wizard_of_wor.py | 6 +- .../knights_archers_zombies.py | 28 ++++----- pettingzoo/butterfly/pistonball/pistonball.py | 20 +++--- pettingzoo/classic/hanabi/hanabi.py | 18 +++--- pettingzoo/classic/rlcard_envs/gin_rummy.py | 7 ++- pettingzoo/mpe/simple/simple.py | 19 ++++-- .../mpe/simple_adversary/simple_adversary.py | 18 ++++-- pettingzoo/mpe/simple_crypto/simple_crypto.py | 17 +++-- pettingzoo/mpe/simple_push/simple_push.py | 17 +++-- .../mpe/simple_reference/simple_reference.py | 18 +++--- .../simple_speaker_listener.py | 17 +++-- pettingzoo/mpe/simple_spread/simple_spread.py | 17 +++-- pettingzoo/mpe/simple_tag/simple_tag.py | 22 +++---- .../simple_world_comm/simple_world_comm.py | 25 ++++---- .../sisl/multiwalker/multiwalker_base.py | 2 +- .../sisl/pursuit/utils/discrete_agent.py | 2 +- pettingzoo/sisl/waterworld/waterworld.py | 7 ++- pettingzoo/utils/__init__.py | 18 +++--- pettingzoo/utils/wrappers/__init__.py | 14 ++--- .../utils/wrappers/assert_out_of_bounds.py | 2 +- pettingzoo/utils/wrappers/base_parallel.py | 2 +- pettingzoo/utils/wrappers/capture_stdout.py | 4 +- .../utils/wrappers/clip_out_of_bounds.py | 4 +- pettingzoo/utils/wrappers/order_enforcing.py | 6 +- .../utils/wrappers/terminate_illegal.py | 4 +- test/pickle_test.py | 62 +++++++++++++++++++ 50 files changed, 369 insertions(+), 165 deletions(-) create mode 100644 test/pickle_test.py diff --git a/pettingzoo/atari/base_atari_env.py b/pettingzoo/atari/base_atari_env.py index 3e834d7c4..c609b942c 100644 --- a/pettingzoo/atari/base_atari_env.py +++ b/pettingzoo/atari/base_atari_env.py @@ -49,16 +49,16 @@ def __init__( """ EzPickle.__init__( self, - game, - num_players, - mode_num, - seed, - obs_type, - full_action_space, - env_name, - max_cycles, - render_mode, - auto_rom_install_path, + game=game, + num_players=num_players, + mode_num=mode_num, + seed=seed, + obs_type=obs_type, + full_action_space=full_action_space, + env_name=env_name, + max_cycles=max_cycles, + render_mode=render_mode, + auto_rom_install_path=auto_rom_install_path, ) assert obs_type in ( diff --git a/pettingzoo/atari/basketball_pong/basketball_pong.py b/pettingzoo/atari/basketball_pong/basketball_pong.py index 13490d431..34efa22cd 100644 --- a/pettingzoo/atari/basketball_pong/basketball_pong.py +++ b/pettingzoo/atari/basketball_pong/basketball_pong.py @@ -68,7 +68,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(num_players=2, **kwargs): diff --git a/pettingzoo/atari/boxing/boxing.py b/pettingzoo/atari/boxing/boxing.py index 6750bff12..55fcb0c56 100644 --- a/pettingzoo/atari/boxing/boxing.py +++ b/pettingzoo/atari/boxing/boxing.py @@ -78,7 +78,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/combat_plane/combat_plane.py b/pettingzoo/atari/combat_plane/combat_plane.py index 7cebebfc2..ba56abfb6 100644 --- a/pettingzoo/atari/combat_plane/combat_plane.py +++ b/pettingzoo/atari/combat_plane/combat_plane.py @@ -88,7 +88,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) avaliable_versions = { "bi-plane": 15, diff --git a/pettingzoo/atari/combat_tank/combat_tank.py b/pettingzoo/atari/combat_tank/combat_tank.py index c4d12df24..7870604db 100644 --- a/pettingzoo/atari/combat_tank/combat_tank.py +++ b/pettingzoo/atari/combat_tank/combat_tank.py @@ -86,7 +86,11 @@ import warnings from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs): diff --git a/pettingzoo/atari/double_dunk/double_dunk.py b/pettingzoo/atari/double_dunk/double_dunk.py index 97362c77a..88e7e3fb3 100644 --- a/pettingzoo/atari/double_dunk/double_dunk.py +++ b/pettingzoo/atari/double_dunk/double_dunk.py @@ -78,7 +78,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/entombed_competitive/entombed_competitive.py b/pettingzoo/atari/entombed_competitive/entombed_competitive.py index f5fc69e27..08213bf91 100644 --- a/pettingzoo/atari/entombed_competitive/entombed_competitive.py +++ b/pettingzoo/atari/entombed_competitive/entombed_competitive.py @@ -75,7 +75,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/entombed_cooperative/entombed_cooperative.py b/pettingzoo/atari/entombed_cooperative/entombed_cooperative.py index 450f5788c..86602bd8c 100644 --- a/pettingzoo/atari/entombed_cooperative/entombed_cooperative.py +++ b/pettingzoo/atari/entombed_cooperative/entombed_cooperative.py @@ -83,7 +83,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/flag_capture/flag_capture.py b/pettingzoo/atari/flag_capture/flag_capture.py index def0ee405..018c86967 100644 --- a/pettingzoo/atari/flag_capture/flag_capture.py +++ b/pettingzoo/atari/flag_capture/flag_capture.py @@ -72,7 +72,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/foozpong/foozpong.py b/pettingzoo/atari/foozpong/foozpong.py index 295bc2b21..acc1630f3 100644 --- a/pettingzoo/atari/foozpong/foozpong.py +++ b/pettingzoo/atari/foozpong/foozpong.py @@ -73,7 +73,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(num_players=4, **kwargs): diff --git a/pettingzoo/atari/ice_hockey/ice_hockey.py b/pettingzoo/atari/ice_hockey/ice_hockey.py index ec65b6331..e164b1933 100644 --- a/pettingzoo/atari/ice_hockey/ice_hockey.py +++ b/pettingzoo/atari/ice_hockey/ice_hockey.py @@ -71,7 +71,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/joust/joust.py b/pettingzoo/atari/joust/joust.py index 20f4441fd..c67debe32 100644 --- a/pettingzoo/atari/joust/joust.py +++ b/pettingzoo/atari/joust/joust.py @@ -75,7 +75,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/mario_bros/mario_bros.py b/pettingzoo/atari/mario_bros/mario_bros.py index 58521a485..3b8c662d0 100644 --- a/pettingzoo/atari/mario_bros/mario_bros.py +++ b/pettingzoo/atari/mario_bros/mario_bros.py @@ -79,7 +79,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/maze_craze/maze_craze.py b/pettingzoo/atari/maze_craze/maze_craze.py index b567adfe3..e9ea11a51 100644 --- a/pettingzoo/atari/maze_craze/maze_craze.py +++ b/pettingzoo/atari/maze_craze/maze_craze.py @@ -88,7 +88,11 @@ import warnings from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) avaliable_versions = { "robbers": 2, diff --git a/pettingzoo/atari/othello/othello.py b/pettingzoo/atari/othello/othello.py index a14ab6825..2959c1105 100644 --- a/pettingzoo/atari/othello/othello.py +++ b/pettingzoo/atari/othello/othello.py @@ -75,7 +75,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/pong/pong.py b/pettingzoo/atari/pong/pong.py index cea18c64c..1d5ad17f8 100644 --- a/pettingzoo/atari/pong/pong.py +++ b/pettingzoo/atari/pong/pong.py @@ -70,7 +70,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) avaliable_2p_versions = { "classic": 4, diff --git a/pettingzoo/atari/quadrapong/quadrapong.py b/pettingzoo/atari/quadrapong/quadrapong.py index 75f0da427..ca626ad59 100644 --- a/pettingzoo/atari/quadrapong/quadrapong.py +++ b/pettingzoo/atari/quadrapong/quadrapong.py @@ -65,7 +65,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/space_invaders/space_invaders.py b/pettingzoo/atari/space_invaders/space_invaders.py index 83891720b..7ce34d6d7 100644 --- a/pettingzoo/atari/space_invaders/space_invaders.py +++ b/pettingzoo/atari/space_invaders/space_invaders.py @@ -78,7 +78,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env( diff --git a/pettingzoo/atari/space_war/space_war.py b/pettingzoo/atari/space_war/space_war.py index 65e016cd1..e66515e92 100644 --- a/pettingzoo/atari/space_war/space_war.py +++ b/pettingzoo/atari/space_war/space_war.py @@ -72,7 +72,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/surround/surround.py b/pettingzoo/atari/surround/surround.py index 6a42c7596..878888843 100644 --- a/pettingzoo/atari/surround/surround.py +++ b/pettingzoo/atari/surround/surround.py @@ -60,7 +60,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/tennis/tennis.py b/pettingzoo/atari/tennis/tennis.py index b64164c11..e4e61c671 100644 --- a/pettingzoo/atari/tennis/tennis.py +++ b/pettingzoo/atari/tennis/tennis.py @@ -74,7 +74,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/video_checkers/video_checkers.py b/pettingzoo/atari/video_checkers/video_checkers.py index 135469fc7..cf780dc20 100644 --- a/pettingzoo/atari/video_checkers/video_checkers.py +++ b/pettingzoo/atari/video_checkers/video_checkers.py @@ -64,7 +64,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/volleyball_pong/volleyball_pong.py b/pettingzoo/atari/volleyball_pong/volleyball_pong.py index f80857d48..e7e1894f3 100644 --- a/pettingzoo/atari/volleyball_pong/volleyball_pong.py +++ b/pettingzoo/atari/volleyball_pong/volleyball_pong.py @@ -73,7 +73,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(num_players=4, **kwargs): diff --git a/pettingzoo/atari/warlords/warlords.py b/pettingzoo/atari/warlords/warlords.py index de2ecbcd4..9c44960e9 100644 --- a/pettingzoo/atari/warlords/warlords.py +++ b/pettingzoo/atari/warlords/warlords.py @@ -60,7 +60,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/atari/wizard_of_wor/wizard_of_wor.py b/pettingzoo/atari/wizard_of_wor/wizard_of_wor.py index 4b8506722..16ea0b04e 100644 --- a/pettingzoo/atari/wizard_of_wor/wizard_of_wor.py +++ b/pettingzoo/atari/wizard_of_wor/wizard_of_wor.py @@ -66,7 +66,11 @@ import os from glob import glob -from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +from pettingzoo.atari.base_atari_env import ( + BaseAtariEnv, + base_env_wrapper_fn, + parallel_wrapper_fn, +) def raw_env(**kwargs): diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index 93968a890..e89e191af 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -238,20 +238,20 @@ def __init__( ): EzPickle.__init__( self, - spawn_rate, - num_archers, - num_knights, - max_zombies, - max_arrows, - killable_knights, - killable_archers, - pad_observation, - line_death, - max_cycles, - vector_state, - use_typemasks, - sequence_space, - render_mode, + spawn_rate=spawn_rate, + num_archers=num_archers, + num_knights=num_knights, + max_zombies=max_zombies, + max_arrows=max_arrows, + killable_knights=killable_knights, + killable_archers=killable_archers, + pad_observation=pad_observation, + line_death=line_death, + max_cycles=max_cycles, + vector_state=vector_state, + use_typemasks=use_typemasks, + sequence_space=sequence_space, + render_mode=render_mode, ) # variable state space self.sequence_space = sequence_space diff --git a/pettingzoo/butterfly/pistonball/pistonball.py b/pettingzoo/butterfly/pistonball/pistonball.py index bfc38159c..5bb43445c 100644 --- a/pettingzoo/butterfly/pistonball/pistonball.py +++ b/pettingzoo/butterfly/pistonball/pistonball.py @@ -145,16 +145,16 @@ def __init__( ): EzPickle.__init__( self, - n_pistons, - time_penalty, - continuous, - random_drop, - random_rotate, - ball_mass, - ball_friction, - ball_elasticity, - max_cycles, - render_mode, + n_pistons=n_pistons, + time_penalty=time_penalty, + continuous=continuous, + random_drop=random_drop, + random_rotate=random_rotate, + ball_mass=ball_mass, + ball_friction=ball_friction, + ball_elasticity=ball_elasticity, + max_cycles=max_cycles, + render_mode=render_mode, ) self.dt = 1.0 / FPS self.n_pistons = n_pistons diff --git a/pettingzoo/classic/hanabi/hanabi.py b/pettingzoo/classic/hanabi/hanabi.py index 08c218cb3..e770871f1 100644 --- a/pettingzoo/classic/hanabi/hanabi.py +++ b/pettingzoo/classic/hanabi/hanabi.py @@ -287,15 +287,15 @@ def __init__( """ EzPickle.__init__( self, - colors, - ranks, - players, - hand_size, - max_information_tokens, - max_life_tokens, - observation_type, - random_start_player, - render_mode, + colors=colors, + ranks=ranks, + players=players, + hand_size=hand_size, + max_information_tokens=max_information_tokens, + max_life_tokens=max_life_tokens, + observation_type=observation_type, + random_start_player=random_start_player, + render_mode=render_mode, ) # ToDo: Starts diff --git a/pettingzoo/classic/rlcard_envs/gin_rummy.py b/pettingzoo/classic/rlcard_envs/gin_rummy.py index c1d02f550..b660eb70f 100644 --- a/pettingzoo/classic/rlcard_envs/gin_rummy.py +++ b/pettingzoo/classic/rlcard_envs/gin_rummy.py @@ -152,7 +152,12 @@ def __init__( opponents_hand_visible=False, render_mode=None, ): - EzPickle.__init__(self, knock_reward, gin_reward, render_mode) + EzPickle.__init__( + self, + knock_reward=knock_reward, + gin_reward=gin_reward, + render_mode=render_mode, + ) self._opponents_hand_visible = opponents_hand_visible num_planes = 5 if self._opponents_hand_visible else 4 RLCardBase.__init__(self, "gin-rummy", 2, (num_planes, 52)) diff --git a/pettingzoo/mpe/simple/simple.py b/pettingzoo/mpe/simple/simple.py index 1de8f17c2..457bb3f81 100644 --- a/pettingzoo/mpe/simple/simple.py +++ b/pettingzoo/mpe/simple/simple.py @@ -43,19 +43,26 @@ """ import numpy as np +from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - -class raw_env(SimpleEnv): +class raw_env(SimpleEnv, EzPickle): def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): + EzPickle.__init__( + self, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, + ) scenario = Scenario() world = scenario.make_world() - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_adversary/simple_adversary.py b/pettingzoo/mpe/simple_adversary/simple_adversary.py index 2624640cb..d68e6cdef 100644 --- a/pettingzoo/mpe/simple_adversary/simple_adversary.py +++ b/pettingzoo/mpe/simple_adversary/simple_adversary.py @@ -55,19 +55,25 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__(self, N=2, max_cycles=25, continuous_actions=False, render_mode=None): - EzPickle.__init__(self, N, max_cycles, continuous_actions, render_mode) + EzPickle.__init__( + self, + N=N, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, + ) scenario = Scenario() world = scenario.make_world(N) - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_crypto/simple_crypto.py b/pettingzoo/mpe/simple_crypto/simple_crypto.py index 26885f53b..9234943be 100644 --- a/pettingzoo/mpe/simple_crypto/simple_crypto.py +++ b/pettingzoo/mpe/simple_crypto/simple_crypto.py @@ -59,12 +59,11 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - """Simple crypto environment. Scenario: @@ -75,10 +74,16 @@ class raw_env(SimpleEnv, EzPickle): def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): - EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) + EzPickle.__init__( + self, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, + ) scenario = Scenario() world = scenario.make_world() - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_push/simple_push.py b/pettingzoo/mpe/simple_push/simple_push.py index fe73cf5f1..dee55cf7e 100644 --- a/pettingzoo/mpe/simple_push/simple_push.py +++ b/pettingzoo/mpe/simple_push/simple_push.py @@ -50,19 +50,24 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): - EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) + EzPickle.__init__( + self, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, + ) scenario = Scenario() world = scenario.make_world() - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_reference/simple_reference.py b/pettingzoo/mpe/simple_reference/simple_reference.py index 0f03a8373..0acb83098 100644 --- a/pettingzoo/mpe/simple_reference/simple_reference.py +++ b/pettingzoo/mpe/simple_reference/simple_reference.py @@ -56,12 +56,11 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__( @@ -69,17 +68,18 @@ def __init__( ): EzPickle.__init__( self, - local_ratio, - max_cycles, - continuous_actions, - render_mode, + local_ratio=local_ratio, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, ) assert ( 0.0 <= local_ratio <= 1.0 ), "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world() - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py b/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py index c1977f2ee..c1e268fb9 100644 --- a/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py +++ b/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py @@ -51,19 +51,24 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): - EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) + EzPickle.__init__( + self, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, + ) scenario = Scenario() world = scenario.make_world() - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_spread/simple_spread.py b/pettingzoo/mpe/simple_spread/simple_spread.py index c6d0b4c15..e5a3ccbfa 100644 --- a/pettingzoo/mpe/simple_spread/simple_spread.py +++ b/pettingzoo/mpe/simple_spread/simple_spread.py @@ -54,12 +54,11 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__( @@ -71,14 +70,20 @@ def __init__( render_mode=None, ): EzPickle.__init__( - self, N, local_ratio, max_cycles, continuous_actions, render_mode + self, + N=N, + local_ratio=local_ratio, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, ) assert ( 0.0 <= local_ratio <= 1.0 ), "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world(N) - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_tag/simple_tag.py b/pettingzoo/mpe/simple_tag/simple_tag.py index cee12fd88..f6f078dce 100644 --- a/pettingzoo/mpe/simple_tag/simple_tag.py +++ b/pettingzoo/mpe/simple_tag/simple_tag.py @@ -65,12 +65,11 @@ def bound(x): import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__( @@ -84,16 +83,17 @@ def __init__( ): EzPickle.__init__( self, - num_good, - num_adversaries, - num_obstacles, - max_cycles, - continuous_actions, - render_mode, + num_good=num_good, + num_adversaries=num_adversaries, + num_obstacles=num_obstacles, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + render_mode=render_mode, ) scenario = Scenario() world = scenario.make_world(num_good, num_adversaries, num_obstacles) - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/mpe/simple_world_comm/simple_world_comm.py b/pettingzoo/mpe/simple_world_comm/simple_world_comm.py index 7ddbe3444..94e3a7f4f 100644 --- a/pettingzoo/mpe/simple_world_comm/simple_world_comm.py +++ b/pettingzoo/mpe/simple_world_comm/simple_world_comm.py @@ -76,12 +76,11 @@ import numpy as np from gymnasium.utils import EzPickle +from pettingzoo.mpe._mpe_utils.core import Agent, Landmark, World +from pettingzoo.mpe._mpe_utils.scenario import BaseScenario +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env from pettingzoo.utils.conversions import parallel_wrapper_fn -from .._mpe_utils.core import Agent, Landmark, World -from .._mpe_utils.scenario import BaseScenario -from .._mpe_utils.simple_env import SimpleEnv, make_env - class raw_env(SimpleEnv, EzPickle): def __init__( @@ -97,19 +96,21 @@ def __init__( ): EzPickle.__init__( self, - num_good, - num_adversaries, - num_obstacles, - max_cycles, - num_forests, - continuous_actions, - render_mode, + num_good=num_good, + num_adversaries=num_adversaries, + num_obstacles=num_obstacles, + num_food=num_food, + max_cycles=max_cycles, + num_forests=num_forests, + continuous_actions=continuous_actions, + render_mode=render_mode, ) scenario = Scenario() world = scenario.make_world( num_good, num_adversaries, num_obstacles, num_food, num_forests ) - super().__init__( + SimpleEnv.__init__( + self, scenario=scenario, world=world, render_mode=render_mode, diff --git a/pettingzoo/sisl/multiwalker/multiwalker_base.py b/pettingzoo/sisl/multiwalker/multiwalker_base.py index ec21d6d77..8f3165e60 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker_base.py +++ b/pettingzoo/sisl/multiwalker/multiwalker_base.py @@ -16,7 +16,7 @@ from gymnasium.utils import seeding from pygame import gfxdraw -from .._utils import Agent +from pettingzoo.sisl._utils import Agent MAX_AGENTS = 40 diff --git a/pettingzoo/sisl/pursuit/utils/discrete_agent.py b/pettingzoo/sisl/pursuit/utils/discrete_agent.py index b22e805e5..1c8b9fb0f 100644 --- a/pettingzoo/sisl/pursuit/utils/discrete_agent.py +++ b/pettingzoo/sisl/pursuit/utils/discrete_agent.py @@ -1,7 +1,7 @@ import numpy as np from gymnasium import spaces -from ..._utils import Agent +from pettingzoo.sisl._utils import Agent ################################################################# # Implements the Single 2D Agent Dynamics diff --git a/pettingzoo/sisl/waterworld/waterworld.py b/pettingzoo/sisl/waterworld/waterworld.py index 3564c1592..f0e25f007 100755 --- a/pettingzoo/sisl/waterworld/waterworld.py +++ b/pettingzoo/sisl/waterworld/waterworld.py @@ -136,6 +136,8 @@ """ +from gymnasium.utils import EzPickle + from pettingzoo import AECEnv from pettingzoo.utils import agent_selector, wrappers from pettingzoo.utils.conversions import parallel_wrapper_fn @@ -154,7 +156,7 @@ def env(**kwargs): parallel_env = parallel_wrapper_fn(env) -class raw_env(AECEnv): +class raw_env(AECEnv, EzPickle): metadata = { "render_modes": ["human", "rgb_array"], "name": "waterworld_v4", @@ -163,7 +165,8 @@ class raw_env(AECEnv): } def __init__(self, *args, **kwargs): - super().__init__() + EzPickle.__init__(self, *args, **kwargs) + AECEnv.__init__(self) self.env = _env(*args, **kwargs) self.agents = ["pursuer_" + str(r) for r in range(self.env.num_agents)] diff --git a/pettingzoo/utils/__init__.py b/pettingzoo/utils/__init__.py index 18801e8fe..af9445539 100644 --- a/pettingzoo/utils/__init__.py +++ b/pettingzoo/utils/__init__.py @@ -1,10 +1,14 @@ -from .agent_selector import agent_selector -from .average_total_reward import average_total_reward -from .conversions import aec_to_parallel, parallel_to_aec, turn_based_aec_to_parallel -from .env import AECEnv, ParallelEnv -from .random_demo import random_demo -from .save_observation import save_observation -from .wrappers import ( +from pettingzoo.utils.agent_selector import agent_selector +from pettingzoo.utils.average_total_reward import average_total_reward +from pettingzoo.utils.conversions import ( + aec_to_parallel, + parallel_to_aec, + turn_based_aec_to_parallel, +) +from pettingzoo.utils.env import AECEnv, ParallelEnv +from pettingzoo.utils.random_demo import random_demo +from pettingzoo.utils.save_observation import save_observation +from pettingzoo.utils.wrappers import ( AssertOutOfBoundsWrapper, BaseParallelWrapper, BaseWrapper, diff --git a/pettingzoo/utils/wrappers/__init__.py b/pettingzoo/utils/wrappers/__init__.py index c09dd2b5b..494babb3c 100644 --- a/pettingzoo/utils/wrappers/__init__.py +++ b/pettingzoo/utils/wrappers/__init__.py @@ -1,7 +1,7 @@ -from .assert_out_of_bounds import AssertOutOfBoundsWrapper -from .base import BaseWrapper -from .base_parallel import BaseParallelWrapper -from .capture_stdout import CaptureStdoutWrapper -from .clip_out_of_bounds import ClipOutOfBoundsWrapper -from .order_enforcing import OrderEnforcingWrapper -from .terminate_illegal import TerminateIllegalWrapper +from pettingzoo.utils.wrappers.assert_out_of_bounds import AssertOutOfBoundsWrapper +from pettingzoo.utils.wrappers.base import BaseWrapper +from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper +from pettingzoo.utils.wrappers.capture_stdout import CaptureStdoutWrapper +from pettingzoo.utils.wrappers.clip_out_of_bounds import ClipOutOfBoundsWrapper +from pettingzoo.utils.wrappers.order_enforcing import OrderEnforcingWrapper +from pettingzoo.utils.wrappers.terminate_illegal import TerminateIllegalWrapper diff --git a/pettingzoo/utils/wrappers/assert_out_of_bounds.py b/pettingzoo/utils/wrappers/assert_out_of_bounds.py index 125a03533..a394334fc 100644 --- a/pettingzoo/utils/wrappers/assert_out_of_bounds.py +++ b/pettingzoo/utils/wrappers/assert_out_of_bounds.py @@ -1,6 +1,6 @@ from gymnasium.spaces import Discrete -from .base import BaseWrapper +from pettingzoo.utils.wrappers.base import BaseWrapper class AssertOutOfBoundsWrapper(BaseWrapper): diff --git a/pettingzoo/utils/wrappers/base_parallel.py b/pettingzoo/utils/wrappers/base_parallel.py index 599d323b4..332208869 100644 --- a/pettingzoo/utils/wrappers/base_parallel.py +++ b/pettingzoo/utils/wrappers/base_parallel.py @@ -2,7 +2,7 @@ from gymnasium.utils import seeding -from ..env import ParallelEnv +from pettingzoo.utils.env import ParallelEnv class BaseParallelWrapper(ParallelEnv): diff --git a/pettingzoo/utils/wrappers/capture_stdout.py b/pettingzoo/utils/wrappers/capture_stdout.py index 47d843be4..2d7273efa 100644 --- a/pettingzoo/utils/wrappers/capture_stdout.py +++ b/pettingzoo/utils/wrappers/capture_stdout.py @@ -1,5 +1,5 @@ -from ..capture_stdout import capture_stdout -from .base import BaseWrapper +from pettingzoo.utils.capture_stdout import capture_stdout +from pettingzoo.utils.wrappers.base import BaseWrapper class CaptureStdoutWrapper(BaseWrapper): diff --git a/pettingzoo/utils/wrappers/clip_out_of_bounds.py b/pettingzoo/utils/wrappers/clip_out_of_bounds.py index 2fb953117..ea1b596a3 100644 --- a/pettingzoo/utils/wrappers/clip_out_of_bounds.py +++ b/pettingzoo/utils/wrappers/clip_out_of_bounds.py @@ -1,8 +1,8 @@ import numpy as np from gymnasium.spaces import Box -from ..env_logger import EnvLogger -from .base import BaseWrapper +from pettingzoo.utils.env_logger import EnvLogger +from pettingzoo.utils.wrappers.base import BaseWrapper class ClipOutOfBoundsWrapper(BaseWrapper): diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index dacda2d53..9f0b1a822 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -1,6 +1,6 @@ -from ..env import AECIterable, AECIterator -from ..env_logger import EnvLogger -from .base import BaseWrapper +from pettingzoo.utils.env import AECIterable, AECIterator +from pettingzoo.utils.env_logger import EnvLogger +from pettingzoo.utils.wrappers.base import BaseWrapper class OrderEnforcingWrapper(BaseWrapper): diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index 5a2c0a97d..85e69894c 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -1,5 +1,5 @@ -from ..env_logger import EnvLogger -from .base import BaseWrapper +from pettingzoo.utils.env_logger import EnvLogger +from pettingzoo.utils.wrappers.base import BaseWrapper class TerminateIllegalWrapper(BaseWrapper): diff --git a/test/pickle_test.py b/test/pickle_test.py new file mode 100644 index 000000000..f95d7ea0b --- /dev/null +++ b/test/pickle_test.py @@ -0,0 +1,62 @@ +import pickle + +import pytest +from gymnasium.utils.env_checker import data_equivalence + +from pettingzoo.test.seed_test import seed_action_spaces, seed_observation_spaces + +from .all_modules import all_environments + +ALL_ENVS = list(all_environments.items()) + + +@pytest.mark.parametrize(("name", "env_module"), ALL_ENVS) +def test_pickle_env(name, env_module): + env1 = env_module.env(render_mode=None) + env2 = pickle.loads(pickle.dumps(env1)) + + env1.reset(seed=42) + env2.reset(seed=42) + + seed_action_spaces(env1) + seed_action_spaces(env2) + seed_observation_spaces(env1) + seed_observation_spaces(env2) + + iter = 0 + for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()): + if iter > 10: + break + assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}" + + obs1, rew1, term1, trunc1, info1 = env1.last() + obs2, rew2, term2, trunc2, info2 = env2.last() + + if term1 or term2 or trunc1 or trunc2: + break + + assert data_equivalence(obs1, obs2), f"Incorrect observations: {obs1} {obs2}" + assert data_equivalence(rew1, rew2), f"Incorrect rewards: {rew1} {rew2}" + assert data_equivalence(term1, term2), f"Incorrect terms: {term1} {term2}" + assert data_equivalence(trunc1, trunc2), f"Incorrect truncs: {trunc1} {trunc2}" + assert data_equivalence(info1, info2), f"Incorrect info: {info1} {info2}" + + mask = None + if "action_mask" in info1: + mask = info1["action_mask"] + + if isinstance(obs1, dict) and "action_mask" in obs1: + mask = obs1["action_mask"] + + action1 = env1.action_space(agent1).sample(mask=mask) + action2 = env2.action_space(agent2).sample(mask=mask) + + assert data_equivalence( + action1, action2 + ), f"Incorrect actions: {action1} {action2}" + + env1.step(action1) + env2.step(action2) + iter += 1 + env1.close() + env2.close()