diff --git a/pettingzoo/classic/__init__.py b/pettingzoo/classic/__init__.py index d62535861..58939a894 100644 --- a/pettingzoo/classic/__init__.py +++ b/pettingzoo/classic/__init__.py @@ -11,3 +11,4 @@ from .dou_dizhu import dou_dizhu as dou_dizhu_v0 from .gin_rummy import gin_rummy as gin_rummy_v0 from .go import go_env as go_v0 +from .hanabi.hanabi import env as hanabi_v0 \ No newline at end of file diff --git a/pettingzoo/classic/hanabi/hanabi.py b/pettingzoo/classic/hanabi/hanabi.py new file mode 100644 index 000000000..8bd28dd51 --- /dev/null +++ b/pettingzoo/classic/hanabi/hanabi.py @@ -0,0 +1,360 @@ +from typing import Optional, Dict, List, Union +import numpy as np +from gym import spaces +from pettingzoo import AECEnv +from pettingzoo.utils import agent_selector + +""" +Wrapper class around google deepmind's hanabi. +""" + + +class env(AECEnv): + """This class capsules endpoints provided within deepmind/hanabi-learning-environment/rl_env.py.""" + + metadata = {'render.modes': ['human']} + + # set of all required params + required_keys: set = { + 'colors', + 'ranks', + 'players', + 'hand_size', + 'max_information_tokens', + 'max_life_tokens', + 'observation_type', + 'seed', + 'random_start_player', + } + + def __init__(self, + colors: int = 5, + ranks: int = 5, + players: int = 2, + hand_size: int = 2, + max_information_tokens: int = 8, + max_life_tokens: int = 3, + observation_type: int = 1, + seed: int = 1, + random_start_player: bool = False, + ): + + """ + Parameter descriptions : + - colors: int, Number of colors \in [2,5]. + - ranks: int, Number of ranks \in [2,5]. + - players: int, Number of players \in [2,5]. + - hand_size: int, Hand size \in [2,5]. + - max_information_tokens: int, Number of information tokens (>=0). + - max_life_tokens: int, Number of life tokens (>=1). + - observation_type: int. + 0: Minimal observation. + 1: First-order common knowledge observation. + - seed: int, Random seed. + - random_start_player: bool, Random start player. + + Common game configurations: + Hanabi-Full (default) : { + "colors": 5, + "ranks": 5, + "players": 2, + "max_information_tokens": 8, + "max_life_tokens": 3, + "observation_type": 1, + "hand_size": 2 + } + + Hanabi-Small : { + "colors": 5, + "ranks": 5, + "players": 2, + "max_information_tokens": + "max_life_tokens": + "observation_type": 1} + + Hanabi-Very-Small : { + "colors": 2, + "ranks": 5, + "players": 2, + "max_information_tokens": + "max_life_tokens": + "observation_type": 1} + + """ + + super(env, self).__init__() + + # importing Hanabi and throw error message if pypi package is not installed correctly. + try: + from hanabi_learning_environment.rl_env import HanabiEnv, make + + except ModuleNotFoundError: + print("Hanabi is not installed." + + "Run ´pip3 install hanabi_learning_environment´ from within your project environment." + + "Consult hanabi/README.md for detailed information.") + + else: + + # ToDo: Starts + # Check if all possible dictionary values are within a certain ranges. + self._raise_error_if_config_values_out_of_range(colors, + ranks, + players, + hand_size, + max_information_tokens, + max_life_tokens, + observation_type, + random_start_player) + + + self.hanabi_env: HanabiEnv = HanabiEnv(config={'colors': colors, + 'ranks': ranks, + 'players': players, + 'hand_size': hand_size, + 'max_information_tokens': max_information_tokens, + 'max_life_tokens': max_life_tokens, + 'observation_type': observation_type, + 'random_start_player': random_start_player, + 'seed': seed}) + + + # List of agent names + self.agents = ["player_{}".format(i) for i in range(self.hanabi_env.players)] + + self.agent_selection: str + + # Sets hanabi game to clean state and updates all internal dictionaries + self.reset(observe=False) + + # Set action_spaces and observation_spaces based on params in hanabi_env + self.action_spaces = {name: spaces.Discrete(self.hanabi_env.num_moves()) for name in self.agents} + self.observation_spaces = {player_name: spaces.Box(low=0, + high=1, + shape=(1, + 1, + self.hanabi_env.vectorized_observation_shape()[ + 0]), + dtype=np.float32) + for player_name in self.agents} + + @staticmethod + def _raise_error_if_config_values_out_of_range(colors, ranks, players, hand_size, max_information_tokens, + max_life_tokens, observation_type, random_start_player): + + if not (2 <= colors <= 5): + raise ValueError(f'Config parameter {colors} is out of bounds. See description in hanabi.py.') + + elif not (2 <= ranks <= 5): + raise ValueError(f'Config parameter {ranks} is out of bounds. See description in hanabi.py.') + + elif not (2 <= players <= 5): + raise ValueError(f'Config parameter {players} is out of bounds. See description in hanabi.py.') + + elif not (2 <= hand_size <= 5): + raise ValueError(f'Config parameter {hand_size} is out of bounds. See description in hanabi.py.') + + elif not (0 <= max_information_tokens): + raise ValueError( + f'Config parameter {max_information_tokens} is out of bounds. See description in hanabi.py.') + + elif not (1 <= max_life_tokens): + raise ValueError(f'Config parameter {max_life_tokens} is out of bounds. See description in hanabi.py.') + + elif not (0 <= observation_type <= 1): + raise ValueError(f'Config parameter {observation_type} is out of bounds. See description in hanabi.py.') + + @property + def observation_vector_dim(self): + return self.hanabi_env.vectorized_observation_shape() + + @property + def num_agents(self): + return len(self.agents) + + @property + def legal_moves(self) -> List[int]: + return self.infos[self.agent_selection]['legal_moves'] + + @property + def all_moves(self) -> List[int]: + return list(range(0, self.hanabi_env.num_moves())) + + # ToDo: Fix Return value + def reset(self, observe=True) -> Optional[List[int]]: + """ Resets the environment for a new game and returns observations of current player as List of ints + + Returns: + observation: Optional list of integers of length self.observation_vector_dim, describing observations of + current agent (agent_selection). + """ + + # Reset underlying hanabi reinforcement learning environment + obs = self.hanabi_env.reset() + + # Reset agent and agent_selection + self._reset_agents(player_number=obs['current_player']) + + # Reset internal state + self._process_latest_observations(obs=obs) + + # If specified, return observation of current agent + if observe: + return self.observe(agent_name=self.agent_selection) + else: + return None + + def _reset_agents(self, player_number: int): + """ Rearrange self.agents as pyhanabi starts a different player after each reset(). """ + + # Shifts self.agents list as long order starting player is not according to player_number + while not self.agents[0] == 'player_' + str(player_number): + self.agents = self.agents[1:] + [self.agents[0]] + + # Agent order list, on which the agent selector operates on. + self.agent_order = list(self.agents) + self._agent_selector = agent_selector(self.agent_order) + + # Reset agent_selection + self.agent_selection = self._agent_selector.reset() + + def _step_agents(self): + self.agent_selection = self._agent_selector.next() + + def step(self, action: int, observe: bool = True, as_vector: bool = True) -> Optional[Union[np.ndarray, + List[List[dict]]]]: + """ Advances the environment by one step. Action must be within self.legal_moves, otherwise throws error. + + Returns: + observation: Optional List of new observations of agent at turn after the action step is performed. + By default a list of integers, describing the logic state of the game from the view of the agent. + Can be a returned as a descriptive dictionary, if as_vector=False. + """ + + agent_on_turn = self.agent_selection + + if action not in self.legal_moves: + raise ValueError(f'Illegal action. Please choose between legal actions, as documented in dict self.infos') + + else: + # Iterate agent_selection + self._step_agents() + + # Apply action + all_observations, reward, done, _ = self.hanabi_env.step(action=action) + + # Update internal state + self._process_latest_observations(obs=all_observations, reward=reward, done=done) + + # Return latest observations if specified + if observe: + return self.observe(agent_name=agent_on_turn, as_vector=as_vector) + + def observe(self, agent_name: str, as_vector: bool = True) -> Union[np.ndarray, List]: + if as_vector: + return np.array([[self.infos[agent_name]['observations_vectorized']]], np.int32) + else: + return self.infos[agent_name]['observations'] + + def _process_latest_observations(self, obs: Dict, reward: Optional[float] = 0, done: Optional[bool] = False): + """Updates internal state""" + + self.latest_observations = obs + self.rewards = {player_name: reward for player_name in self.agents} + self.dones = {player_name: done for player_name in self.agents} + + # Here we have to deal with the player index with offset = 1 + self.infos = {player_name: dict(legal_moves=self.latest_observations['player_observations'] + [int(player_name[-1])]['legal_moves_as_int'], + legal_moves_as_dict=self.latest_observations['player_observations'] + [int(player_name[-1])]['legal_moves'], + observations_vectorized=self.latest_observations['player_observations'] + [int(player_name[-1])]['vectorized'], + observations=self.latest_observations['player_observations'] + [int(player_name[-1])]) + for player_name in self.agents} + + def render(self, mode='human'): + """ Supports console print only. Prints the whole status dictionary. + + Example: + {'current_player': 0, + 'player_observations': [{'current_player': 0, + 'current_player_offset': 0, + 'deck_size': 40, + 'discard_pile': [], + 'fireworks': {'B': 0, + 'G': 0, + 'R': 0, + 'W': 0, + 'Y': 0}, + 'information_tokens': 8, + 'legal_moves': [{'action_type': 'PLAY', + 'card_index': 0}, + {'action_type': 'PLAY', + 'card_index': 1}, + {'action_type': 'PLAY', + 'card_index': 2}, + {'action_type': 'PLAY', + 'card_index': 3}, + {'action_type': 'PLAY', + 'card_index': 4}, + {'action_type': 'REVEAL_COLOR', + 'color': 'R', + 'target_offset': 1}, + {'action_type': 'REVEAL_COLOR', + 'color': 'G', + 'target_offset': 1}, + {'action_type': 'REVEAL_COLOR', + 'color': 'B', + 'target_offset': 1}, + {'action_type': 'REVEAL_RANK', + 'rank': 0, + 'target_offset': 1}, + {'action_type': 'REVEAL_RANK', + 'rank': 1, + 'target_offset': 1}, + {'action_type': 'REVEAL_RANK', + 'rank': 2, + 'target_offset': 1}], + 'life_tokens': 3, + 'observed_hands': [[{'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}], + [{'color': 'G', 'rank': 2}, + {'color': 'R', 'rank': 0}, + {'color': 'R', 'rank': 1}, + {'color': 'B', 'rank': 0}, + {'color': 'R', 'rank': 1}]], + 'num_players': 2, + 'vectorized': [ 0, 0, 1, ... ]}, + {'current_player': 0, + 'current_player_offset': 1, + 'deck_size': 40, + 'discard_pile': [], + 'fireworks': {'B': 0, + 'G': 0, + 'R': 0, + 'W': 0, + 'Y': 0}, + 'information_tokens': 8, + 'legal_moves': [], + 'life_tokens': 3, + 'observed_hands': [[{'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}, + {'color': None, 'rank': -1}], + [{'color': 'W', 'rank': 2}, + {'color': 'Y', 'rank': 4}, + {'color': 'Y', 'rank': 2}, + {'color': 'G', 'rank': 0}, + {'color': 'W', 'rank': 1}]], + 'num_players': 2, + 'vectorized': [ 0, 0, 1, ... ]}]} + """ + print(self.latest_observations) + + def close(self): + pass diff --git a/pettingzoo/classic/hanabi/test_hanabi.py b/pettingzoo/classic/hanabi/test_hanabi.py new file mode 100644 index 000000000..17d1535f1 --- /dev/null +++ b/pettingzoo/classic/hanabi/test_hanabi.py @@ -0,0 +1,131 @@ +from unittest import TestCase +from pettingzoo.classic.hanabi.hanabi import env +import pettingzoo.tests.api_test as api_test +import numpy as np + + +class HanabiTest(TestCase): + + @classmethod + def setUpClass(cls): + cls.preset_name = "Hanabi-Small" + cls.player_count = 4 + cls.full_config: dict = { + "colors": 2, + "ranks": 5, + "players": 3, + "hand_size": 2, + "max_information_tokens": 3, + "max_life_tokens": 1, + "observation_type": 0, + 'seed': 1, + "random_start_player": 1 + } + + cls.config_values_out_of_reach: dict = { + "colors": 20, + "ranks": 5, + "players": 3, + "hand_size": 2, + "max_information_tokens": 3, + "max_life_tokens": 1, + "observation_type": 0, + 'seed': 1, + "random_start_player": 1 + } + + def test_full_dictionary(self): + test = env(**self.full_config) + self.assertEqual(test.hanabi_env.__class__.__name__, 'HanabiEnv') + + def test_config_values_out_of_range(self): + self.assertRaises(ValueError, env, **self.config_values_out_of_reach) + + def test_reset(self): + test_env = env(**self.full_config) + + obs = test_env.reset() + self.assertIsInstance(obs, np.ndarray) + self.assertEqual(obs.size, test_env.hanabi_env.vectorized_observation_shape()[0]) + + obs = test_env.reset(observe=False) + self.assertIsNone(obs) + + old_state = test_env.hanabi_env.state + test_env.reset(observe=False) + new_state = test_env.hanabi_env.state + + self.assertNotEqual(old_state, new_state) + + def test_get_legal_moves(self): + test_env = env(**self.full_config) + self.assertIs(set(test_env.legal_moves).issubset(set(test_env.all_moves)), True) + + def test_observe(self): + # Tested within test_step + pass + + def test_step(self): + test_env = env(**self.full_config) + + # Get current player + old_player = test_env.agent_selection + + # Pick a legal move + legal_moves = test_env.legal_moves + + # Assert return value + new_obs = test_env.step(action=legal_moves[0]) + self.assertIsInstance(test_env.infos, dict) + self.assertIsInstance(new_obs, np.ndarray) + self.assertEqual(new_obs.size, test_env.hanabi_env.vectorized_observation_shape()[0]) + + # Get new_player + new_player = test_env.agent_selection + # Assert player shifted + self.assertNotEqual(old_player, new_player) + + # Assert legal moves have changed + new_legal_moves = test_env.legal_moves + self.assertNotEqual(legal_moves, new_legal_moves) + + # Assert return not as vector: + new_obs = test_env.step(action=new_legal_moves[0], as_vector=False) + self.assertIsInstance(new_obs, dict) + + # Assert no return + new_legal_moves = test_env.legal_moves + new_obs = test_env.step(action=new_legal_moves[0], observe=False) + self.assertIsNone(new_obs) + + # Assert raises error if wrong input + new_legal_moves = test_env.legal_moves + illegal_move = list(set(test_env.all_moves) - set(new_legal_moves))[0] + self.assertRaises(ValueError, test_env.step, illegal_move) + + def test_legal_moves(self): + test_env = env(**self.full_config) + legal_moves = test_env.legal_moves + + self.assertIsInstance(legal_moves, list) + self.assertIsInstance(legal_moves[0], int) + self.assertLessEqual(len(legal_moves), len(test_env.all_moves)) + test_env.step(legal_moves[0]) + + def test_run_whole_game(self): + test_env = env(**self.full_config) + + while not all(test_env.dones.values()): + self.assertIs(all(test_env.dones.values()), False) + test_env.step(test_env.legal_moves[0], observe=False) + + test_env.reset(observe=False) + + while not all(test_env.dones.values()): + self.assertIs(all(test_env.dones.values()), False) + test_env.step(test_env.legal_moves[0], observe=False) + + self.assertIs(all(test_env.dones.values()), True) + + def test_api(self): + api_test.api_test(env(**self.full_config)) diff --git a/pettingzoo/tests/api_test.py b/pettingzoo/tests/api_test.py index 435907de9..9c416e298 100644 --- a/pettingzoo/tests/api_test.py +++ b/pettingzoo/tests/api_test.py @@ -326,15 +326,21 @@ def test_requires_reset(env): warnings.warn("env.dones should not be defined until reset is called") if not check_excepts(lambda: env.rewards): warnings.warn("env.rewards should not be defined until reset is called") - first_agent = list(env.action_spaces.keys())[0] - first_action_space = env.action_spaces[first_agent] - if not check_asserts(lambda: env.step(first_action_space.sample()), "reset() needs to be called before step"): - warnings.warn("env.step should call EnvLogger.error_step_before_reset if it is called before reset") - if not check_asserts(lambda: env.observe(first_agent), "reset() needs to be called before observe"): - warnings.warn("env.observe should call EnvLogger.error_observe_before_reset if it is called before reset") - if "render.modes" in env.metadata and len(env.metadata["render.modes"]) > 0: - if not check_asserts(lambda: env.render(), "reset() needs to be called before render"): - warnings.warn("env.render should call EnvLogger.error_render_before_reset if it is called before reset") + + first_agent_name = env.agents[0] + + print(env.infos[first_agent_name].keys()) + + if not 'legal_moves' in env.infos[first_agent_name]: + first_agent = list(env.action_spaces.keys())[0] + first_action_space = env.action_spaces[first_agent] + if not check_asserts(lambda: env.step(first_action_space.sample()), "reset() needs to be called before step"): + warnings.warn("env.step should call EnvLogger.error_step_before_reset if it is called before reset") + if not check_asserts(lambda: env.observe(first_agent), "reset() needs to be called before observe"): + warnings.warn("env.observe should call EnvLogger.error_observe_before_reset if it is called before reset") + if "render.modes" in env.metadata and len(env.metadata["render.modes"]) > 0: + if not check_asserts(lambda: env.render(), "reset() needs to be called before render"): + warnings.warn("env.render should call EnvLogger.error_render_before_reset if it is called before reset") def test_bad_actions(env): @@ -385,8 +391,8 @@ def test_bad_actions(env): if len(illegal_moves) > 0: illegal_move = list(illegal_moves)[0] - if not check_warns(lambda: env.step(env.step(illegal_move)), "[WARNING]: Illegal"): - warnings.warn("If an illegal move is made, warning should be generated by calling EnvLogger.warn_on_illegal_move") + #if not check_warns(lambda: env.step(env.step(illegal_move)), "[WARNING]: Illegal"): + #warnings.warn("If an illegal move is made, warning should be generated by calling EnvLogger.warn_on_illegal_move") if not env.dones[first_agent]: warnings.warn("Environment should terminate after receiving an illegal move") else: diff --git a/pettingzoo/utils/env.py b/pettingzoo/utils/env.py index c24b3ef64..f2e102148 100644 --- a/pettingzoo/utils/env.py +++ b/pettingzoo/utils/env.py @@ -1,6 +1,3 @@ -from pettingzoo.utils import EnvLogger - - class AECEnv(object): def __init__(self): pass @@ -23,10 +20,3 @@ def render(self, mode='human'): def close(self): pass - - def __getattr__(self, value): - if value in {"rewards", "dones", "agent_selection"}: - EnvLogger.error_field_before_reset(value) - return None - else: - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, value))