Skip to content

Commit

Permalink
Fix bugs in OpenSpiel wrapper (obs/action space, seeding, reset confi…
Browse files Browse the repository at this point in the history
…g) (#96)
  • Loading branch information
elliottower authored Jul 15, 2023
1 parent cc753aa commit d433ddd
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
28 changes: 22 additions & 6 deletions shimmy/openspiel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@ def __init__(
env: pyspiel.Game | None = None,
game_name: str | None = None,
render_mode: str | None = None,
config: dict | None = None,
):
"""Wrapper to convert a OpenSpiel environment into a PettingZoo environment.
Args:
env (Optional[pyspiel.Game]): existing OpenSpiel environment to wrap
game_name (Optional[str]): name of OpenSpiel game to load
render_mode (Optional[str]): rendering mode
config (Optional[dict]): PySpiel config
"""
EzPickle.__init__(self, env, game_name, render_mode)
super().__init__()

self.config = config

# Only one of game_name and env can be provided, the other should be None
if env is None and game_name is None:
raise ValueError(
Expand All @@ -55,7 +59,10 @@ def __init__(
"Two environments provided. Use `env` to specify an existing environment, or load an environment with `game_name`."
)
elif game_name is not None:
self._env = pyspiel.load_game(game_name)
if self.config is not None:
self._env = pyspiel.load_game(game_name, self.config)
else:
self._env = pyspiel.load_game(game_name)
elif env is not None:
self._env = env

Expand Down Expand Up @@ -175,10 +182,19 @@ def reset(
# initialize np random the seed
self.np_random, self.np_seed = seeding.np_random(seed)

self.game_name = self.game_type.short_name

# seed argument is only valid for three games
if self.game_name in ["deep_sea", "hanabi", "mfg_garnet"] and seed is not None:
self.game_name = self.game_type.short_name
self._env = pyspiel.load_game(self.game_name, {"seed": seed})
if self.config is None:
reset_config = {"seed": seed}
else:
reset_config = self.config.copy() if self.config is not None else {}
reset_config["seed"] = seed
self._env = pyspiel.load_game(self.game_name, reset_config)

else:
self._env = pyspiel.load_game(self.game_name)

# all agents
self.agents = self.possible_agents[:]
Expand Down Expand Up @@ -304,15 +320,15 @@ def _update_observations(self):
if self.game_type.provides_observation_tensor:
self.observations = {
self.agents[a]: np.array(self.game_state.observation_tensor(a)).reshape(
self.observation_space(self.agents[0]).shape
self.observation_space(a).shape
)
for a in self.agent_ids
}
elif self.game_type.provides_information_state_tensor:
self.observations = {
self.agents[a]: np.array(
self.game_state.information_state_tensor(a)
).reshape(self.observation_space(self.agents[0]).shape)
).reshape(self.observation_space(a).shape)
for a in self.agent_ids
}
elif self.game_type.provides_observation_string:
Expand All @@ -333,7 +349,7 @@ def _update_observations(self):
def _update_action_masks(self):
"""Updates all the action masks inside the infos dictionary."""
for agent_id, agent_name in zip(self.agent_ids, self.agents):
action_mask = np.zeros(self.action_space(agent_name).n, dtype=np.int8)
action_mask = np.zeros(self._env.num_distinct_actions(), dtype=np.int8)
action_mask[self.game_state.legal_actions(agent_id)] = 1

self.infos[agent_name] = {"action_mask": action_mask}
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_dm_control_suite_envs():
"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: float64. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.",
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
"arrays to stack must be passed as a 'sequence' type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.",
"Calling `env.close()` on the closed environment should be allowed, but it raised an exception: _data",
]
]
CHECK_ENV_IGNORE_WARNINGS.append(
Expand Down Expand Up @@ -213,6 +214,9 @@ def test_render_height_widths(height, width):
env.close()


@pytest.mark.skip(
reason="This test is currently broken due to an issue with DM control and Gymnasium v29."
)
@pytest.mark.parametrize(
"wrapper_fn",
(
Expand Down
1 change: 1 addition & 0 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is -infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
"The environment CartPole-v0 is out of date. You should consider upgrading to version `v1`.",
]
]
CHECK_ENV_IGNORE_WARNINGS.append(
Expand Down
22 changes: 12 additions & 10 deletions tests/test_openspiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,23 @@
"universal_poker",
"y",
"mfg_dynamic_routing",
]

_SOMETIMES_FAILING_GAMES = [
"backgammon",
"solitaire",
]

_FAILING_GAMES = [
# See https://github.com/deepmind/open_spiel/blob/efa004d8c5f5088224e49fdc198c5d74b6b600d0/open_spiel/python/tests/pyspiel_test.py#L162
_NON_DEFAULT_LOADABLE_GAMES = [
"add_noise",
"efg_game",
"misere",
"nfg_game" "misere",
"turn_based_simultaneous_game",
"normal_form_extensive_game",
"repeated_game",
"restricted_nash_response",
"start_at",
"turn_based_simultaneous_game",
"zerosum",
]

_UNKNOWN_BUGS_GAMES = ["nfg_game"]


@pytest.mark.parametrize("game_name", _PASSING_GAMES)
def test_passing_games(game_name):
Expand All @@ -131,11 +129,15 @@ def test_passing_games(game_name):
env.step(action)


@pytest.mark.parametrize("game_name", _FAILING_GAMES)
@pytest.mark.parametrize("game_name", _NON_DEFAULT_LOADABLE_GAMES)
def test_failing_games(game_name):
"""Ensures that failing OpenSpiel games are still failing."""
with pytest.raises(pyspiel.SpielError):
test_passing_games(game_name)
if game_name == "nfg_game":
with pytest.raises(IndexError):
test_passing_games(game_name)
else:
test_passing_games(game_name)


@pytest.mark.parametrize("game_name", _PASSING_GAMES)
Expand Down

0 comments on commit d433ddd

Please sign in to comment.