Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickle tests #53

Merged
merged 34 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7f90eb8
Add multi-agent dm control dockerfile and workflow
elliottower Mar 27, 2023
b30e99a
Fix typo in dm control multiagent workflow
elliottower Mar 27, 2023
6a0e2f2
Merge remote-tracking branch 'upstream/HEAD' into dm-lab-ci
elliottower Mar 28, 2023
9e3dc4e
Add dm-lab dockerfile and workflow
elliottower Mar 28, 2023
e8f9d87
Fix typo in dm_lab dockerfile
elliottower Mar 28, 2023
ae08a50
Add shimmy[dm-lab] pip installation to match other envs
elliottower Mar 28, 2023
1bc5a35
Add pickling tests for meltingpot, openspiel, bsuite, EzPickle for op…
elliottower Mar 28, 2023
2e97a55
Add initial pickle test to all third party environments (besides gym)
elliottower Mar 29, 2023
9f0760d
Merge branch 'main' into pickle-tests
elliottower Mar 29, 2023
f0afeef
Update PZ version after 1.22.4 yank
elliottower Mar 29, 2023
706db5f
Add importorskip for dm_lab so main tests don't fail
elliottower Mar 29, 2023
dc9df9c
Try old import deepmind_lab inside of test_check_env()
elliottower Mar 29, 2023
5917bd5
Add dm-env requirement to dm-lab dockerfile (fix CI error)
elliottower Mar 30, 2023
55e13a7
Fix typo in multiagent dm control test
elliottower Mar 30, 2023
29a0e4f
Update dm-lab tests to correct action type (from int to dict)
elliottower Mar 30, 2023
86fc8bc
Fix dm control multiagent init error (recursion limit)
elliottower Mar 30, 2023
efac866
Add all dm-lab levels to test, comment out obs test (not matching)
elliottower Mar 30, 2023
45de236
Attempt to fix dm-lab seeding, fix pickling test typo
elliottower Mar 30, 2023
6be6af8
Fix typo in dm lab test
elliottower Mar 30, 2023
e6a1c88
Fix meltingpot isort issues (ignore files, works locally just not in CI)
elliottower Mar 30, 2023
ad564f8
Fix dm control to take 10x less time for seed testing (1+hrs currently)
elliottower Mar 30, 2023
320f498
Fix typo in dm lab test
elliottower Mar 30, 2023
4992a4c
Make seed warning a print statement so execution doesn't stop during …
elliottower Mar 30, 2023
084d9b2
Skip dm lab tests
elliottower Mar 30, 2023
ef0f81a
Switch dm lab tests to do lt_chasm (env used in official examples)
elliottower Mar 30, 2023
246c8e1
Skip dm lab tests again due to erros
elliottower Mar 30, 2023
9413b02
Fix typo in install script
elliottower Mar 30, 2023
60f7482
Change dm_control_multi_agent test skip reason (weakref can't be pick…
elliottower Mar 30, 2023
0240502
Fix typo in dm control test
elliottower Mar 30, 2023
492a364
Fix typo in dm control mutliagent test
elliottower Mar 30, 2023
aa0bb38
Fix typo in dm control multiagent test
elliottower Mar 30, 2023
746a823
Remove isort ignore and don't run local pre-commit hooks
elliottower Mar 31, 2023
ca39529
Remove repeated noqa ignores for file-wide ignores, remove extra imports
elliottower Mar 31, 2023
cce032f
Fix pre-commit in meltingpot test
elliottower Mar 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bin/dm_lab.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ RUN git clone https://github.com/deepmind/lab.git \
&& rm -rf lab

ENTRYPOINT ["/usr/local/shimmy/bin/docker_entrypoint"]

RUN ls
2 changes: 2 additions & 0 deletions scripts/install_dm_lab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ fi

pip3 install numpy

# TODO: fix installation issues on MacOS
# Build
if [ ! -d "lab" ]; then
git clone https://github.com/deepmind/lab.git
fi
cd lab
echo 'build --cxxopt=-std=c++17' > .bazelrc
bazel build -c opt //python/pip_package:build_pip_package
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def get_version():
"dm-control>=1.0.10",
"imageio",
"h5py>=3.7.0",
"pettingzoo>=1.22.4",
"pettingzoo>=1.22.3",
],
"dm-lab": [],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.4"],
"meltingpot": ["pettingzoo>=1.22.4"],
"dm-lab": ["dm-env>=1.6"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.3"],
"meltingpot": ["pettingzoo>=1.22.3"],
"bsuite": ["bsuite>=0.3.5"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
Expand Down
4 changes: 3 additions & 1 deletion shimmy/bsuite_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bsuite.environments import Environment
from gymnasium.core import ObsType
from gymnasium.error import UnsupportedMode
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

Expand All @@ -17,7 +18,7 @@
np.int = int # pyright: ignore[reportGeneralTypeIssues]


class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""A compatibility wrapper that converts a BSuite environment into a gymnasium environment.

Note:
Expand All @@ -33,6 +34,7 @@ def __init__(
render_mode: str | None = None,
):
"""Initialises the environment with a render mode along with render information."""
EzPickle.__init__(self, env, render_mode)
self._env = env

self.observation_space = dm_spec2gym_space(env.observation_spec())
Expand Down
6 changes: 5 additions & 1 deletion shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dm_control.rl import control
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

Expand All @@ -27,7 +28,7 @@ class EnvType(Enum):
RL_CONTROL = 1


class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""This compatibility wrapper converts a dm-control environment into a gymnasium environment.

Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics.
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
camera_id: int = 0,
):
"""Initialises the environment with a render mode along with render information."""
EzPickle.__init__(
self, env, render_mode, render_height, render_width, camera_id
)
self._env = env
self.env_type = self._find_env_type(env)

Expand Down
6 changes: 4 additions & 2 deletions shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gymnasium
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space
Expand Down Expand Up @@ -62,7 +63,7 @@ def _unravel_ma_timestep(
)


class DmControlMultiAgentCompatibilityV0(ParallelEnv):
class DmControlMultiAgentCompatibilityV0(ParallelEnv, EzPickle):
"""This compatibility wrapper converts multi-agent dm-control environments, primarily soccer, into a Pettingzoo environment.

Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments,
Expand All @@ -84,7 +85,8 @@ def __init__(
env (dm_env.Environment): dm control multi-agent environment
render_mode (Optional[str]): render_mode
"""
super().__init__()
EzPickle.__init__(self, env=env, render_mode=render_mode)
ParallelEnv.__init__(self)
self._env = env
self.render_mode = render_mode

Expand Down
4 changes: 4 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def reset(
self._env.reset(seed=seed)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had solved the seeding issue in dm-lab @jjshoots

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but I think this should be a separate PR with the other dm-lab fixes, this is just for adding pickle tests and turns out the pickle tests for dm lab don't seem to work and are blocked so can just update the seed/pickle stuff for dm-lab later

info = {}

if seed is not None:
print(
"Warning: DM-lab environments must be seeded in initialization, rather than with reset(seed)."
)
return (
self._env.observations(),
info,
Expand Down
5 changes: 3 additions & 2 deletions shimmy/meltingpot_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
and modified to modern pettingzoo API
"""
# pyright: reportOptionalSubscript=false

# isort: skip_file
elliottower marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import annotations

import functools
from typing import Optional

import gymnasium
import meltingpot.python
import numpy as np
import pygame
from gymnasium.utils.ezpickle import EzPickle
from ml_collections import config_dict
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

import meltingpot.python
import shimmy.utils.meltingpot as utils


Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
for index in range(self._num_players)
]
self.agents = [agent for agent in self.possible_agents]
self.num_cycles = 0

# Set up pygame rendering
if self.render_mode == "human":
Expand Down
5 changes: 3 additions & 2 deletions shimmy/openspiel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import pettingzoo as pz
import pyspiel
from gymnasium import spaces
from gymnasium.utils import seeding
from gymnasium.utils import EzPickle, seeding
from pettingzoo.utils.env import AgentID, ObsType


class OpenspielCompatibilityV0(pz.AECEnv):
class OpenspielCompatibilityV0(pz.AECEnv, EzPickle):
"""This compatibility wrapper converts an openspiel environment into a pettingzoo environment.

OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning
Expand All @@ -35,6 +35,7 @@ def __init__(
game (pyspiel.Game): game
render_mode (Optional[str]): render_mode
"""
EzPickle.__init__(self, game, render_mode)
super().__init__()
self.game = game
self.possible_agents = [
Expand Down
63 changes: 62 additions & 1 deletion tests/test_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the ale-py environments are correctly registered."""
import pickle
import warnings

import gymnasium as gym
Expand All @@ -7,7 +8,7 @@
from ale_py.roms import utils as rom_utils
from gymnasium.envs.registration import registry
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env
from gymnasium.utils.env_checker import check_env, data_equivalence

from shimmy.utils.envs_configs import ALL_ATARI_GAMES

Expand Down Expand Up @@ -47,3 +48,63 @@ def test_atari_envs(env_id):
assert isinstance(warning_message.message, Warning)
if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning_message.message}")


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_pickle(env_id):
"""Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_seeding(env_id):
"""Tests the seeding of the atari conversion wrapper."""
env_1 = gym.make(env_id)
env_2 = gym.make(env_id)

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()
56 changes: 56 additions & 0 deletions tests/test_bsuite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs."""
import pickle
import warnings

import bsuite
Expand Down Expand Up @@ -109,3 +110,58 @@ def test_seeding(env_id):

env_1.close()
env_2.close()


# Without EzPickle:_register_bsuite_envs.<locals>._make_bsuite_env cannot be pickled
# With EzPickle: maximum recursion limit reached
FAILING_PICKLE_ENVS = [
"bsuite/bandit_noise-v0",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have found this before. It is generally because one of the classes implemented __getattr__.
Is there a particular parameter that is missing?

I suspect the issue is that the environment only defines a variable on reset or another function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll go check that individual env and see. But what would we be able to do to fix it, besides submitting a PR to their repo? I guess we could in the compatibility wrapper specifically check if it’s that env or if an env has that specific variable not defined in init, and then do whatever modifications are required?

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, we can't fix it. The environments failing use a wrapper that contains

def __getattr__(self, attr):
    return getattr(self._env, attr)

The problem exists when _env doesn't exist in the wrapper, i.e., when a staticmethod is existed (__setstate__) then this causes an infinite loop to occur of __getattr__("static_method") -> __getattr__("_env") -> __getattr__("_env") -> ad infinitum
The second issue is that dm don't seem to be maintaining the project anymore.

The solution is simple

def __getattr__(self, attr):
    if "_env" in self.__dict__:
          return getattr(self._env, attr)
    else:
          return super().__getattribute__(attr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put a PR up by any chance? Even if they don't' end up merging it I feel like we might as well try, you seem to understand this stuff better than me though, I'm not sure I'd be able to explain it well or respond to any questions about it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"bsuite/bandit_scale-v0",
"bsuite/cartpole-v0",
"bsuite/cartpole_noise-v0",
"bsuite/cartpole_scale-v0",
"bsuite/cartpole_swingup-v0",
"bsuite/catch_noise-v0",
"bsuite/catch_scale-v0",
"bsuite/mnist_noise-v0",
"bsuite/mnist_scale-v0",
"bsuite/mountain_car_noise-v0",
"bsuite/mountain_car_scale-v0",
]

PASSING_PICKLE_ENVS = [
"bsuite/mnist-v0",
"bsuite/umbrella_length-v0",
"bsuite/discounting_chain-v0",
"bsuite/deep_sea-v0",
"bsuite/umbrella_distract-v0",
"bsuite/catch-v0",
"bsuite/memory_len-v0",
"bsuite/mountain_car-v0",
"bsuite/memory_size-v0",
"bsuite/deep_sea_stochastic-v0",
"bsuite/bandit-v0",
]


@pytest.mark.parametrize("env_id", PASSING_PICKLE_ENVS)
def test_pickle(env_id):
"""Test that pickling works."""
env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
env_2 = pickle.loads(pickle.dumps(env_1))

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()
31 changes: 31 additions & 0 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the functionality of the DmControlCompatibility Wrapper on dm_control envs."""
import pickle
import warnings
from typing import Callable

Expand Down Expand Up @@ -82,6 +83,36 @@ def test_seeding(env_id):
env_1 = gym.make(env_id)
env_2 = gym.make(env_id)

if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic):
# LQR fails this test currently.
return

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(10):
actions = env_1.action_space.sample()
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.skip(
reason="Fatal Python error: Segmentation fault (with or without EzPickle)"
)
@pytest.mark.parametrize("env_id", DM_CONTROL_ENV_IDS[0])
def test_pickle(env_id):
"""Test that dm-control seeding works."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic):
# LQR fails this test currently.
return
Expand Down
Loading