Skip to content

Commit

Permalink
Add pickling tests, adapt all envs to be picklable (#928)
Browse files Browse the repository at this point in the history
Co-authored-by: Ariel Kwiatkowski <[email protected]>
  • Loading branch information
elliottower and RedTachyon authored Apr 21, 2023
1 parent c6e19ad commit 6aca84c
Show file tree
Hide file tree
Showing 50 changed files with 369 additions and 165 deletions.
20 changes: 10 additions & 10 deletions pettingzoo/atari/base_atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def __init__(
"""
EzPickle.__init__(
self,
game,
num_players,
mode_num,
seed,
obs_type,
full_action_space,
env_name,
max_cycles,
render_mode,
auto_rom_install_path,
game=game,
num_players=num_players,
mode_num=mode_num,
seed=seed,
obs_type=obs_type,
full_action_space=full_action_space,
env_name=env_name,
max_cycles=max_cycles,
render_mode=render_mode,
auto_rom_install_path=auto_rom_install_path,
)

assert obs_type in (
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/basketball_pong/basketball_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(num_players=2, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/boxing/boxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/combat_plane/combat_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)

avaliable_versions = {
"bi-plane": 15,
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/combat_tank/combat_tank.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@
import warnings
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/double_dunk/double_dunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/flag_capture/flag_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/foozpong/foozpong.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(num_players=4, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/ice_hockey/ice_hockey.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/joust/joust.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/mario_bros/mario_bros.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/maze_craze/maze_craze.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@
import warnings
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)

avaliable_versions = {
"robbers": 2,
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/othello/othello.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/pong/pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)

avaliable_2p_versions = {
"classic": 4,
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/quadrapong/quadrapong.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/space_invaders/space_invaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/space_war/space_war.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/surround/surround.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/tennis/tennis.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/video_checkers/video_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/volleyball_pong/volleyball_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(num_players=4, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/warlords/warlords.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
6 changes: 5 additions & 1 deletion pettingzoo/atari/wizard_of_wor/wizard_of_wor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
from pettingzoo.atari.base_atari_env import (
BaseAtariEnv,
base_env_wrapper_fn,
parallel_wrapper_fn,
)


def raw_env(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,20 @@ def __init__(
):
EzPickle.__init__(
self,
spawn_rate,
num_archers,
num_knights,
max_zombies,
max_arrows,
killable_knights,
killable_archers,
pad_observation,
line_death,
max_cycles,
vector_state,
use_typemasks,
sequence_space,
render_mode,
spawn_rate=spawn_rate,
num_archers=num_archers,
num_knights=num_knights,
max_zombies=max_zombies,
max_arrows=max_arrows,
killable_knights=killable_knights,
killable_archers=killable_archers,
pad_observation=pad_observation,
line_death=line_death,
max_cycles=max_cycles,
vector_state=vector_state,
use_typemasks=use_typemasks,
sequence_space=sequence_space,
render_mode=render_mode,
)
# variable state space
self.sequence_space = sequence_space
Expand Down
Loading

0 comments on commit 6aca84c

Please sign in to comment.