diff --git a/.gitignore b/.gitignore index e925889d6..29e3d2cf4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,7 @@ __pycache__/ *.swp .DS_Store .vscode/ -saved_observations/ \ No newline at end of file +saved_observations/ +build/ +dist/ +PettingZoo.egg-info/ \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..bc70680d6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +recursive-include pettingzoo * +global-exclude __pycache__ +global-exclude *.pyc \ No newline at end of file diff --git a/README.md b/README.md index f3f859870..d3d0990f8 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,9 @@ PettingZoo includes the following sets of games: * mpe: A set of simple nongraphical communication tasks, originally from https://github.com/openai/multiagent-particle-envs * sisl: 3 cooperative environments, originally from https://github.com/sisl/MADRL -To install a set of games, use `pip3 install pettingzoo[atari]` (or whichever set of games you want). +To install, use `pip install pettingzoo` -We support Python 3.5, 3.6, 3.7 and 3.8 +We support Python 3.6, 3.7 and 3.8 ## Initializing Environments @@ -155,7 +155,7 @@ from pettingzoo.utils import random_demo random_demo(env) ``` -### Observation Saver +### Observation Saving If the agents in a game make observations that are images, the observations can be saved to an image file. This function takes in the environment, along with a specified agent. If no agent is specified, the current selected agent for the environment is chosen. If all_agents is passed in as True, then the observations of all agents in the environment is saved. By default the images are saved to the current working directory, in a folder matching the environment name. The saved image will match the name of the observing agent. If save_dir is passed in, a new folder is created where images will be saved to. @@ -180,14 +180,10 @@ Our cooperative games have leaderboards for best total (summed over all agents) The following environments are under active development: * atari/* (Ben) -* classic/checkers (Ben) -* classic/go (Luis) +* classic/backgammon (Caroline) +* classic/checkers (Caroline) * classic/hanabi (Clemens) +* classic/shogi (Caroline) * gamma/prospector (Yashas) * magent/* (Mario) * robotics/* (Yiling) -* classic/backgammon (Caroline) - -Development has not yet started on: - -* classic/shogi (python-shogi) diff --git a/docs/gamma.md b/docs/gamma.md index ef88fd38b..b0f572f3d 100644 --- a/docs/gamma.md +++ b/docs/gamma.md @@ -42,10 +42,10 @@ Move the left paddle using the 'W' and 'S' keys. Move the right paddle using 'UP ``` cooperative_pong.env(ball_speed=18, left_paddle_speed=25, -right_paddle_speed=25, is_cake_paddle=True, max_frames=900, bounce_randomness=False) +right_paddle_speed=25, cake_paddle=True, max_frames=900, bounce_randomness=False) ``` -The speed of the ball (`ball_speed` )is held constant throughout the game, while the initial direction of the ball is randomized when `reset()` method is called. The speed of left and right paddles are controlled by `left_paddle_speed` and `right_paddle_speed` respectively. If `is_cake_paddle` is `True`, the right paddle has the shape of a 4-tiered wedding cake. `done` of all agents are set to `True` after `max_frames` number of frames elapse. If `bounce_randomness` is `True`, each collision of the ball with the paddles adds a small random angle to the direction of the ball, while the speed of the ball remains unchanged. +The speed of the ball (`ball_speed` )is held constant throughout the game, while the initial direction of the ball is randomized when `reset()` method is called. The speed of left and right paddles are controlled by `left_paddle_speed` and `right_paddle_speed` respectively. If `cake_paddle` is `True`, the right paddle has the shape of a 4-tiered wedding cake. `done` of all agents are set to `True` after `max_frames` number of frames elapse. If `bounce_randomness` is `True`, each collision of the ball with the paddles adds a small random angle to the direction of the ball, while the speed of the ball remains unchanged. Leaderboard: @@ -67,7 +67,7 @@ Leaderboard: *AEC diagram* -Zombies walk from the top border of the screen down to the bottom border in unpredictable paths. The agents you control are knights and archers (default 2 knights and 2 archers) that are initially positioned at the bottom border of the screen. Each agent can rotate clockwise or counter-clockwise and move forward or backward. Each agent can also attack to kill zombies. When a knight attacks, it swings a mace in an arc in front of its current heading direction. When an archer attacks, it fires an arrow in a straight line in the direction of the archer's heading. The game ends when all agents die (collide with a zombie) or a zombie reaches the bottom screen border. An agent gets a reward when it kills a zombie. Each agent observes the environment as a square region around itself, with its own body in the center of the square. The observation is represented as a 512x512 image around the agent. +Zombies walk from the top border of the screen down to the bottom border in unpredictable paths. The agents you control are knights and archers (default 2 knights and 2 archers) that are initially positioned at the bottom border of the screen. Each agent can rotate clockwise or counter-clockwise and move forward or backward. Each agent can also attack to kill zombies. When a knight attacks, it swings a mace in an arc in front of its current heading direction. When an archer attacks, it fires an arrow in a straight line in the direction of the archer's heading. The game ends when all agents die (collide with a zombie) or a zombie reaches the bottom screen border. A knight is rewarded 1 point when its mace hits and kills a zombie. An archer is rewarded 1 point when one of their arrows hits and kills a zombie. Each agent observes the environment as a square region around itself, with its own body in the center of the square. The observation is represented as a 512x512 pixel image around the agent, or in other words, a 16x16 agent sized space around the agent. Manual Control: @@ -80,7 +80,7 @@ Press 'M' key to spawn a new knight. ``` knights_archers_zombies.env(spawn_rate=20, knights=2, archers=2, -killable_knights=True, killable_archers=True, line_death=True, pad_observation=True, max_frames=900) +killable_knights=True, killable_archers=True, black_death=True, line_death=True, pad_observation=True, max_frames=900) ``` *about arguments* @@ -96,7 +96,9 @@ killable_knights: if set to False, knight agents cannot be killed by zombies. killable_archers: if set to False, archer agents cannot be killed by zombies. -line_death: +black_death: if set to True, agents who die will observe only black. If False, dead agents do not have reward, done, info or observations and are removed from agent list. + +line_death: if set to False, agents do not die when they touch the top or bottom border. If True, agents die as soon as they touch the top or bottom border. pad_observation: if agents are near edge of environment, their observation cannot form a 40x40 grid. If this is set to True, the observation is padded with black. ``` @@ -122,7 +124,7 @@ Leaderboard: *AEC diagram* This is a simple physics based cooperative game where the goal is to move the ball to the left wall of the game border by activating any of the twenty vertically moving pistons. Pistons can only see themselves, and the two pistons next to them. -Thus, pistons must learn highly coordinated emergent behavior to achieve an optimal policy for the environment. Each agent get's a reward that is a combination of how much the ball moved left overall, and how much the ball moved left if it was close to the piston (i.e. movement it contributed to). Balancing the ratio between these appears to be critical to learning this environment, and as such is an environment parameter. If the ball moves to the left, a positive global reward is applied. If the ball moves to the right then a negative global reward is applied. Additionally, pistons that are within a radius of the ball are given a local reward. +Thus, pistons must learn highly coordinated emergent behavior to achieve an optimal policy for the environment. Each agent get's a reward that is a combination of how much the ball moved left overall, and how much the ball moved left if it was close to the piston (i.e. movement it contributed to). Balancing the ratio between these appears to be critical to learning this environment, and as such is an environment parameter. The local reward applied is 0.5 times the change in the ball's x-position. Additionally, the global reward is change in x-position divided by the starting position, times 100. For each piston, the reward is .02 * local_reward + 0.08 * global_reward. The local reward is applied to pistons surrounding the ball while the global reward is provided to all pistons. Pistonball uses the chipmunk physics engine, and are thus the physics are about as realistic as Angry Birds. @@ -167,9 +169,9 @@ Continuous Leaderboard: ### Prison -| Actions | Agents | Manual Control | Action Shape | Action Values | Observation Shape | Observation Values | Num States | -|---------|--------|----------------|--------------|---------------|-------------------|--------------------|------------| -| Either | 8 | Yes | (1,) | [0, 2] | (100, 300, 3) | (0, 255) | ? | +| Actions | Agents | Manual Control | Action Shape | Action Values | Observation Shape | Observation Values | Num States | +|---------|--------|----------------|--------------|---------------|----------------------|------------------------|------------| +| Either | 8 | Yes | (1,) | [0, 2] | (100, 300, 3) or (1,)| (0, 255) or (-300, 300)| ? | `from pettingzoo.gamma import prison_v0` @@ -181,6 +183,10 @@ Continuous Leaderboard: In prison, 8 aliens locked in identical prison cells are controlled by the user. They cannot communicate with each other in any way, and can only pace in their cell. Every time they touch one end of the cell and then the other, they get a reward of 1. Due to the fully independent nature of these agents and the simplicity of the task, this is an environment primarily intended for debugging purposes- it's multiple individual purely single agent tasks. To make this debugging tool as compatible with as many methods as possible, it can accept both discrete and continuous actions and the observation can be automatically turned into a number representing position of the alien from the left of it's cell instead of the normal graphical output. +Manual Control: + +Select different aliens with 'W', 'A', 'S' or 'D'. Move the selected alien left with 'J' and right with 'K'. + Arguments: ``` diff --git a/docs/mpe.md b/docs/mpe.md index 7fd20e2cb..149e206f4 100644 --- a/docs/mpe.md +++ b/docs/mpe.md @@ -159,7 +159,7 @@ max_frames: number of frames (a step for each agent) until game terminates *AEC diagram* -In this environment, there are 2 good agents (Alice and Bob) and 1 adversary (Eve). Alice must sent a private 1 bit message to Bob over a public channel. Alice and Bob are rewarded if Bob reconstructs the message, but are negatively rewarded if Eve reconstruct the message. Eve is rewarded based on how well it can reconstruct the signal. Alice and Bob have a private key (randomly generated at beginning of each episode), which they must learn to use to encrypt the message. +In this environment, there are 2 good agents (Alice and Bob) and 1 adversary (Eve). Alice must sent a private 1 bit message to Bob over a public channel. Alice and Bob are rewarded +2 if Bob reconstructs the message, but are rewarded -2 if Eve reconstruct the message (that adds to 0 if both teams recontruct the bit). Eve is rewarded -2 based if it cannot reconstruct the signal, zero if it can. Alice and Bob have a private key (randomly generated at beginning of each episode), which they must learn to use to encrypt the message. Alice observation space: `[message, private_key]` diff --git a/docs/sisl.md b/docs/sisl.md index 359de0442..f5250757e 100644 --- a/docs/sisl.md +++ b/docs/sisl.md @@ -39,7 +39,7 @@ Please additionally cite: *AEC diagram* -A package is placed on top of (by default) 3 pairs of robot legs which you control. The robots must learn to move the package as far as possible to the right. Each walker gets a reward of 1 for moving the package forward, and a reward of -100 for dropping the package. Each walker exerts force on two joints in their two legs, giving a continuous action space represented as a 4 element vector. Each walker observes via a 32 element vector, containing simulated noisy lidar data about the environment and information about neighboring walkers. The environment runs for 500 frames by default. +A package is placed on top of (by default) 3 pairs of robot legs which you control. The robots must learn to move the package as far as possible to the right. A positive reward is awarded to each walker, which is the change in the package distance summed with 130 times the change in the walker's position. A walker is given a reward of -100 if they fall and a reward of -100 for each fallen walker in the environment. If the global reward mechanic is chosen, the mean of all rewards is given to each agent. Each walker exerts force on two joints in their two legs, giving a continuous action space represented as a 4 element vector. Each walker observes via a 32 element vector, containing simulated noisy lidar data about the environment and information about neighboring walkers. The environment runs for 500 frames by default. ``` multiwalker.env(n_walkers=3, position_noise=1e-3, angle_noise=1e-3, reward_mech='local', @@ -94,6 +94,11 @@ Add Gupta et al and DDPG paper results too By default there are 30 blue evaders and 8 red pursuer agents, in a 16 x 16 grid with an obstacle in the center, shown in white. The evaders move randomly, and the pursuers are controlled. Every time the pursuers fully surround an evader, each of the surrounding agents receives a reward of 5, and the evader is removed from the environment. Pursuers also receive a reward of 0.01 every time they touch an evader. The pursuers have a discrete action space of up, down, left, right and stay. Each pursuer observes a 7 x 7 grid centered around itself, depicted by the orange boxes surrounding the red pursuer agents. The enviroment runs for 500 frames by default. Observation shape takes the full form of `(3, obs_range, obs_range)`. +Manual Control: + +Select different pursuers with 'J' and 'K'. The selected pursuer can be moved with the arrow keys. + + ``` pursuit.env(max_frames=500, xs=16, ys=16, reward_mech='local', n_evaders=30, n_pursuers=8, obs_range=7, layer_norm=10, n_catch=2, random_opponents=False, max_opponents=10, diff --git a/pettingzoo/__init__.py b/pettingzoo/__init__.py index 53efe715b..e27d5aac8 100644 --- a/pettingzoo/__init__.py +++ b/pettingzoo/__init__.py @@ -1,7 +1,2 @@ from pettingzoo.utils.env import AECEnv import pettingzoo.utils -import pettingzoo.gamma -import pettingzoo.sisl -import pettingzoo.classic -import pettingzoo.tests -import pettingzoo.magent diff --git a/pettingzoo/classic/checkers/checkers.py b/pettingzoo/classic/checkers/checkers.py index 16de81819..9104c7510 100644 --- a/pettingzoo/classic/checkers/checkers.py +++ b/pettingzoo/classic/checkers/checkers.py @@ -14,7 +14,7 @@ class env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self): - super(env, self).__init__() + super().__init__() self.ch = CheckersRules() self.num_agents = 2 diff --git a/pettingzoo/classic/chess/chess_env.py b/pettingzoo/classic/chess/chess_env.py index 543a73c05..0d99d5a49 100644 --- a/pettingzoo/classic/chess/chess_env.py +++ b/pettingzoo/classic/chess/chess_env.py @@ -6,14 +6,24 @@ import warnings from pettingzoo.utils.agent_selector import agent_selector from pettingzoo.utils.env_logger import EnvLogger +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(): + env = raw_env() + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human', 'ascii']} def __init__(self): - super(env, self).__init__() + super().__init__() self.board = chess.Board() @@ -25,20 +35,15 @@ def __init__(self): self.action_spaces = {name: spaces.Discrete(8 * 8 * 73) for name in self.agents} self.observation_spaces = {name: spaces.Box(low=0, high=1, shape=(8, 8, 20), dtype=np.float32) for name in self.agents} - # self.rewards = None - # self.dones = None - # self.infos = None - # - # self.agent_selection = None + self.rewards = None + self.dones = None + self.infos = {name: {} for name in self.agents} - self.has_reset = False - self.has_rendered = False + self.agent_selection = None self.num_agents = len(self.agents) def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() return chess_utils.get_observation(self.board, self.agents.index(agent)) def reset(self, observe=True): @@ -66,49 +71,32 @@ def set_game_result(self, result_val): self.infos[name] = {'legal_moves': []} def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() - backup_policy = "game terminating with current player losing" - act_space = self.action_spaces[self.agent_selection] - if np.isnan(action).any(): - EnvLogger.warn_action_is_NaN(backup_policy) - if not act_space.contains(action): - EnvLogger.warn_action_out_of_bound(action, act_space, backup_policy) - current_agent = self.agent_selection current_index = self.agents.index(current_agent) self.agent_selection = next_agent = self._agent_selector.next() - old_legal_moves = self.infos[current_agent]['legal_moves'] + chosen_move = chess_utils.action_to_move(self.board, action, current_index) + assert chosen_move in self.board.legal_moves + self.board.push(chosen_move) - if action not in old_legal_moves: - EnvLogger.warn_on_illegal_move() - player_loses_val = -1 if current_index == 0 else 1 - self.set_game_result(player_loses_val) - self.rewards[next_agent] = 0 - else: - chosen_move = chess_utils.action_to_move(self.board, action, current_index) - assert chosen_move in self.board.legal_moves - self.board.push(chosen_move) - - next_legal_moves = chess_utils.legal_moves(self.board) + next_legal_moves = chess_utils.legal_moves(self.board) - is_stale_or_checkmate = not any(next_legal_moves) + is_stale_or_checkmate = not any(next_legal_moves) - # claim draw is set to be true to allign with normal tournament rules - is_repetition = self.board.is_repetition(3) - is_50_move_rule = self.board.can_claim_fifty_moves() - is_claimable_draw = is_repetition or is_50_move_rule - game_over = is_claimable_draw or is_stale_or_checkmate + # claim draw is set to be true to allign with normal tournament rules + is_repetition = self.board.is_repetition(3) + is_50_move_rule = self.board.can_claim_fifty_moves() + is_claimable_draw = is_repetition or is_50_move_rule + game_over = is_claimable_draw or is_stale_or_checkmate - if game_over: - result = self.board.result(claim_draw=True) - result_val = chess_utils.result_to_int(result) - self.set_game_result(result_val) - else: - self.infos[current_agent] = {'legal_moves': []} - self.infos[next_agent] = {'legal_moves': next_legal_moves} - assert len(self.infos[next_agent]['legal_moves']) + if game_over: + result = self.board.result(claim_draw=True) + result_val = chess_utils.result_to_int(result) + self.set_game_result(result_val) + else: + self.infos[current_agent] = {'legal_moves': []} + self.infos[next_agent] = {'legal_moves': next_legal_moves} + assert len(self.infos[next_agent]['legal_moves']) if observe: next_observation = self.observe(next_agent) @@ -117,10 +105,7 @@ def step(self, action, observe=True): return next_observation def render(self, mode='human'): - self.has_rendered = True print(self.board) def close(self): - if not self.has_rendered: - EnvLogger.warn_close_unrendered_env() - self.has_rendered = False + pass diff --git a/pettingzoo/classic/connect_four/__init__.py b/pettingzoo/classic/connect_four/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/connect_four/connect_four.py b/pettingzoo/classic/connect_four/connect_four.py index 27c780274..5d01b42cd 100644 --- a/pettingzoo/classic/connect_four/connect_four.py +++ b/pettingzoo/classic/connect_four/connect_four.py @@ -4,13 +4,24 @@ import warnings from .manual_control import manual_control +from pettingzoo.utils import wrappers +from pettingzoo.utils.agent_selector import agent_selector -class env(AECEnv): - metadata = {'render.modes': ['ansi']} +def env(): + env = raw_env() + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): + metadata = {'render.modes': ['human']} def __init__(self): - super(env, self).__init__() + super().__init__() # 6 rows x 7 columns # blank space = 0 # agent 0 -- 1 @@ -23,14 +34,8 @@ def __init__(self): self.agent_order = list(self.agents) - self.action_spaces = {i: spaces.Discrete(7) for i in range(2)} - self.observation_spaces = {i: spaces.Box(low=0, high=2, shape=(6, 7), dtype=np.int8) for i in range(2)} - - self.rewards = {i: 0 for i in range(self.num_agents)} - self.dones = {i: False for i in range(self.num_agents)} - self.infos = {i: {'legal_moves': list(range(7))} for i in range(self.num_agents)} - - self.agent_selection = 0 + self.action_spaces = {i: spaces.Discrete(7) for i in self.agents} + self.observation_spaces = {i: spaces.Box(low=0, high=1, shape=(6, 7, 2), dtype=np.int8) for i in self.agents} # Key # ---- @@ -46,44 +51,45 @@ def __init__(self): # [2, 0, 0, 0, 1, 1, 0], # [1, 1, 2, 1, 0, 1, 0]], dtype=int8) def observe(self, agent): - return np.array(self.board).reshape(6, 7) + board_vals = np.array(self.board).reshape(6, 7) + cur_player = self.agents.index(self.agent_selection) + opp_player = (cur_player + 1) % 2 + + cur_p_board = np.equal(board_vals, cur_player + 1) + opp_p_board = np.equal(board_vals, opp_player + 1) + return np.stack([cur_p_board, opp_p_board], axis=2).astype(np.int8) # action in this case is a value from 0 to 6 indicating position to move on the flat representation of the connect4 board def step(self, action, observe=True): - # check if input action is a valid move (0 == empty spot) - if(self.board[0:7][action] == 0): - # valid move - for i in list(filter(lambda x: x % 7 == action, list(range(41, -1, -1)))): - if self.board[i] == 0: - self.board[i] = self.agent_selection + 1 - break - - next_agent = 1 if (self.agent_selection == 0) else 0 - - # update infos with valid moves - self.infos[self.agent_selection]['legal_moves'] = [i for i in range(7) if self.board[i] == 0] - self.infos[next_agent]['legal_moves'] = [i for i in range(7) if self.board[i] == 0] - - winner = self.check_for_winner() - - # check if there is a winner - if winner: - self.rewards[self.agent_selection] += 1 - self.rewards[next_agent] -= 1 - self.dones = {i: True for i in range(self.num_agents)} - # check if there is a tie - elif all(x in [1, 2] for x in self.board): - # once either play wins or there is a draw, game over, both players are done - self.dones = {i: True for i in range(self.num_agents)} - else: - # no winner yet - self.agent_selection = next_agent - + # assert valid move + assert (self.board[0:7][action] == 0), "played illegal move." + + piece = self.agents.index(self.agent_selection) + 1 + for i in list(filter(lambda x: x % 7 == action, list(range(41, -1, -1)))): + if self.board[i] == 0: + self.board[i] = piece + break + + next_agent = self._agent_selector.next() + + # update infos with valid moves + self.infos[self.agent_selection]['legal_moves'] = [i for i in range(7) if self.board[i] == 0] + self.infos[next_agent]['legal_moves'] = [i for i in range(7) if self.board[i] == 0] + + winner = self.check_for_winner() + + # check if there is a winner + if winner: + self.rewards[self.agent_selection] += 1 + self.rewards[next_agent] -= 1 + self.dones = {i: True for i in self.agents} + # check if there is a tie + elif all(x in [1, 2] for x in self.board): + # once either play wins or there is a draw, game over, both players are done + self.dones = {i: True for i in self.agents} else: - # invalid move, end game - self.rewards[self.agent_selection] -= 1 - self.dones = {i: True for i in range(self.num_agents)} - warnings.warn("Bad connect four move made, game terminating with current player losing. env.infos[player]['legal_moves'] contains a list of all legal moves that can be chosen.") + # no winner yet + self.agent_selection = next_agent if observe: return self.observe(self.agent_selection) @@ -94,26 +100,29 @@ def reset(self, observe=True): # reset environment self.board = [0] * (6 * 7) - self.rewards = {i: 0 for i in range(self.num_agents)} - self.dones = {i: False for i in range(self.num_agents)} - self.infos = {i: {'legal_moves': list(range(7))} for i in range(self.num_agents)} + self.rewards = {i: 0 for i in self.agents} + self.dones = {i: False for i in self.agents} + self.infos = {i: {'legal_moves': list(range(7))} for i in self.agents} + + self._agent_selector = agent_selector(self.agents) + + self.agent_selection = self._agent_selector.reset() - # selects the first agent - self.agent_selection = 0 if observe: return self.observe(self.agent_selection) else: return def render(self, mode='ansi'): - print(str(self.observe(self.agent_selection))) + print("{}'s turn'".format(self.agent_selection)) + print(str(np.array(self.board).reshape(6, 7))) def close(self): pass def check_for_winner(self): board = np.array(self.board).reshape(6, 7) - piece = self.agent_selection + 1 + piece = self.agents.index(self.agent_selection) + 1 # Check horizontal locations for win column_count = 7 diff --git a/pettingzoo/classic/dou_dizhu/__init__.py b/pettingzoo/classic/dou_dizhu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/dou_dizhu/dou_dizhu.py b/pettingzoo/classic/dou_dizhu/dou_dizhu.py index 1f04fd2f6..bc229aea5 100644 --- a/pettingzoo/classic/dou_dizhu/dou_dizhu.py +++ b/pettingzoo/classic/dou_dizhu/dou_dizhu.py @@ -5,18 +5,25 @@ import rlcard import random import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() - if seed is not None: - np.random.seed(seed) - random.seed(seed) - self.env = rlcard.make('doudizhu', **kwargs) + def __init__(self, seed=None): + super().__init__() + self.env = rlcard.make('doudizhu', config={"seed": seed}) self.agents = ['landlord_0', 'peasant_0', 'peasant_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -83,7 +90,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_selection = self._agent_selector.reset() self.rewards = self._convert_to_dict(np.array([0.0, 0.0, 0.0])) self.dones = self._convert_to_dict([False for _ in range(self.num_agents)]) diff --git a/pettingzoo/classic/gin_rummy/__init__.py b/pettingzoo/classic/gin_rummy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/gin_rummy/gin_rummy.py b/pettingzoo/classic/gin_rummy/gin_rummy.py index 9b8f7f1cc..7de44f7f3 100644 --- a/pettingzoo/classic/gin_rummy/gin_rummy.py +++ b/pettingzoo/classic/gin_rummy/gin_rummy.py @@ -10,20 +10,27 @@ from rlcard.games.gin_rummy.utils.action_event import KnockAction, GinAction import rlcard.games.gin_rummy.utils.melding as melding import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, knock_reward: float = 0.5, gin_reward: float = 1.0, **kwargs): - super(env, self).__init__() - if seed is not None: - np.random.seed(seed) - random.seed(seed) + def __init__(self, seed=None, knock_reward: float = 0.5, gin_reward: float = 1.0): + super().__init__() self._knock_reward = knock_reward self._gin_reward = gin_reward - self.env = rlcard.make('gin-rummy', **kwargs) + self.env = rlcard.make('gin-rummy', config={"seed": seed}) self.agents = ['player_0', 'player_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -120,7 +127,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = [self._int_to_name(agent) for agent in [player_id, 0 if player_id == 1 else 1]] self._agent_selector.reinit(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/classic/go/__init__.py b/pettingzoo/classic/go/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/go/go_env.py b/pettingzoo/classic/go/go_env.py index a1b9bc363..0a261755b 100644 --- a/pettingzoo/classic/go/go_env.py +++ b/pettingzoo/classic/go/go_env.py @@ -5,16 +5,27 @@ from . import go from . import coords import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(): + env = raw_env() + pass_move = env._N * env._N + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, pass_move, "passing turn with action {}".format(pass_move)) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self, board_size: int = 19, komi: float = 7.5): # board_size: a int, representing the board size (board has a board_size x board_size shape) # komi: a float, representing points given to the second player. - super(env, self).__init__() + super().__init__() self._overwrite_go_global_variables(board_size=board_size) self._komi = komi diff --git a/pettingzoo/classic/leduc_holdem/__init__.py b/pettingzoo/classic/leduc_holdem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/leduc_holdem/leduc_holdem.py b/pettingzoo/classic/leduc_holdem/leduc_holdem.py index 6e7782ee7..755f1a878 100644 --- a/pettingzoo/classic/leduc_holdem/leduc_holdem.py +++ b/pettingzoo/classic/leduc_holdem/leduc_holdem.py @@ -6,18 +6,28 @@ import random from rlcard.utils.utils import print_card import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() + def __init__(self, seed=None): + super().__init__() if seed is not None: np.random.seed(seed) random.seed(seed) - self.env = rlcard.make('leduc-holdem', **kwargs) + self.env = rlcard.make('leduc-holdem', config={"seed": seed}) self.agents = ['player_0', 'player_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -78,7 +88,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = [self._int_to_name(agent) for agent in [player_id, 0 if player_id == 1 else 1]] self._agent_selector.reinit(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/classic/mahjong/__init__.py b/pettingzoo/classic/mahjong/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/mahjong/mahjong.py b/pettingzoo/classic/mahjong/mahjong.py index 803378fe3..3ee3fa50e 100644 --- a/pettingzoo/classic/mahjong/mahjong.py +++ b/pettingzoo/classic/mahjong/mahjong.py @@ -5,18 +5,25 @@ import rlcard import random import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() - if seed is not None: - np.random.seed(seed) - random.seed(seed) - self.env = rlcard.make('mahjong', **kwargs) + def __init__(self, seed=None): + super().__init__() + self.env = rlcard.make('mahjong', config={"seed": seed}) self.agents = ['player_0', 'player_1', 'player_2', 'player_3'] self.num_agents = len(self.agents) self.has_reset = False @@ -97,7 +104,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = self.agents self._agent_selector = agent_selector(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/classic/rps/rps.py b/pettingzoo/classic/rps/rps.py index bfc15f1fe..7fbe93075 100644 --- a/pettingzoo/classic/rps/rps.py +++ b/pettingzoo/classic/rps/rps.py @@ -2,6 +2,7 @@ import numpy as np from pettingzoo import AECEnv from pettingzoo.utils import agent_selector +from pettingzoo.utils import wrappers rock = 0 paper = 1 @@ -11,7 +12,15 @@ NUM_ITERS = 100 -class env(AECEnv): +def env(): + env = raw_env() + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): """Two-player environment for rock paper scissors. The observation is simply the last opponent action.""" @@ -56,11 +65,7 @@ def reset(self, observe=True): def step(self, action, observe=True): agent = self.agent_selection - if np.isnan(action): - action = 0 - elif not self.action_spaces[agent].contains(action): - raise Exception('Action for agent {} must be in Discrete({}).' - 'It is currently {}'.format(agent, self.action_spaces[agent].n, action)) + self.state[self.agent_selection] = action # collect reward if it is the last agent to act diff --git a/pettingzoo/classic/rpsls/rpsls.py b/pettingzoo/classic/rpsls/rpsls.py index 0c5fdaa68..ba4fe9d01 100644 --- a/pettingzoo/classic/rpsls/rpsls.py +++ b/pettingzoo/classic/rpsls/rpsls.py @@ -2,6 +2,7 @@ import numpy as np from pettingzoo import AECEnv from pettingzoo.utils import agent_selector +from pettingzoo.utils import wrappers rock = 0 paper = 1 @@ -13,7 +14,15 @@ NUM_ITERS = 100 -class env(AECEnv): +def env(): + env = raw_env() + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): """Two-player environment for rock paper scissors lizard spock. The observation is simply the last opponent action.""" @@ -37,12 +46,12 @@ def reinit(self): self.rewards = {agent: 0 for agent in self.agents} self.dones = {agent: False for agent in self.agents} self.infos = {agent: {} for agent in self.agents} - self.state = {} + self.state = {agent: none for agent in self.agents} self.observations = {agent: none for agent in self.agents} self.num_moves = 0 def render(self, mode="human"): - print("Current state: Agent1: {} , Agent2: {}".format(MOVES[self.state[0]], MOVES[self.state[1]])) + print("Current state: Agent1: {} , Agent2: {}".format(MOVES[self.state[self.agents[0]]], MOVES[self.state[self.agents[1]]])) def observe(self, agent): # observation of one agent is the previous state of the other @@ -58,12 +67,8 @@ def reset(self, observe=True): def step(self, action, observe=True): agent = self.agent_selection - if np.isnan(action): - action = 0 - elif not self.action_spaces[agent].contains(action): - raise Exception('Action for agent {} must be in Discrete({}).' - 'It is currently {}'.format(agent, self.action_spaces[agent].n, action)) - self.state[self.agent_name_mapping[self.agent_selection]] = action + + self.state[self.agent_selection] = action # collect reward if it is the last agent to act if self._agent_selector.is_last(): @@ -97,17 +102,16 @@ def step(self, action, observe=True): (spock, scissors): (1, -1), (spock, lizard): (-1, 1), (spock, spock): (0, 0), - }[(self.state[0], self.state[1])] + }[(self.state[self.agents[0]], self.state[self.agents[0]])] self.num_moves += 1 self.dones = {agent: self.num_moves >= NUM_ITERS for agent in self.agents} # observe the current state for i in self.agents: - self.observations[i] = self.state[1 - self.agent_name_mapping[i]] - self.state[1 - self.agent_name_mapping[i]] = none + self.observations[i] = self.state[self.agents[1 - self.agent_name_mapping[i]]] else: - self.state[1 - self.agent_name_mapping[agent]] = none + self.state[self.agents[1 - self.agent_name_mapping[agent]]] = none self.agent_selection = self._agent_selector.next() if observe: diff --git a/pettingzoo/classic/texas_holdem/__init__.py b/pettingzoo/classic/texas_holdem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/texas_holdem/texas_holdem.py b/pettingzoo/classic/texas_holdem/texas_holdem.py index 33610a52e..e3a8042f1 100644 --- a/pettingzoo/classic/texas_holdem/texas_holdem.py +++ b/pettingzoo/classic/texas_holdem/texas_holdem.py @@ -6,18 +6,28 @@ import random from rlcard.utils.utils import print_card import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() + def __init__(self, seed=None): + super().__init__() if seed is not None: np.random.seed(seed) random.seed(seed) - self.env = rlcard.make('limit-holdem', **kwargs) + self.env = rlcard.make('limit-holdem', config={"seed": seed},) self.agents = ['player_0', 'player_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -80,7 +90,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = [self._int_to_name(agent) for agent in [player_id, 0 if player_id == 1 else 1]] self._agent_selector.reinit(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/classic/texas_holdem_no_limit/__init__.py b/pettingzoo/classic/texas_holdem_no_limit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/texas_holdem_no_limit/texas_holdem_no_limit.py b/pettingzoo/classic/texas_holdem_no_limit/texas_holdem_no_limit.py index b5c8014b9..df6a218e0 100644 --- a/pettingzoo/classic/texas_holdem_no_limit/texas_holdem_no_limit.py +++ b/pettingzoo/classic/texas_holdem_no_limit/texas_holdem_no_limit.py @@ -6,18 +6,28 @@ import random from rlcard.utils.utils import print_card import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() + def __init__(self, seed=None): + super().__init__() if seed is not None: np.random.seed(seed) random.seed(seed) - self.env = rlcard.make('no-limit-holdem', **kwargs) + self.env = rlcard.make('no-limit-holdem', config={"seed": seed}) self.agents = ['player_0', 'player_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -80,7 +90,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = [self._int_to_name(agent) for agent in [player_id, 0 if player_id == 1 else 1]] self._agent_selector.reinit(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/classic/tictactoe/tictactoe.py b/pettingzoo/classic/tictactoe/tictactoe.py index d18e1fbaf..d8385a10b 100644 --- a/pettingzoo/classic/tictactoe/tictactoe.py +++ b/pettingzoo/classic/tictactoe/tictactoe.py @@ -5,15 +5,25 @@ import warnings from .manual_control import manual_control +from pettingzoo.utils import wrappers from .board import Board -class env(AECEnv): +def env(): + env = raw_env() + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self): - super(env, self).__init__() + super().__init__() self.board = Board() self.num_agents = 2 @@ -22,7 +32,7 @@ def __init__(self): self.agent_order = list(self.agents) self.action_spaces = {i: spaces.Discrete(9) for i in self.agents} - self.observation_spaces = {i: spaces.Box(low=0, high=2, shape=(3, 3), dtype=np.int8) for i in self.agents} + self.observation_spaces = {i: spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8) for i in self.agents} self.rewards = {i: 0 for i in self.agents} self.dones = {i: False for i in self.agents} @@ -42,47 +52,45 @@ def __init__(self): # [1,2,1] # [2,1,0]] def observe(self, agent): - # return observation of an agent - s = np.array(self.board.squares) - return s.reshape(3, 3).T + board_vals = np.array(self.board.squares).reshape(3, 3) + cur_player = self.agents.index(self.agent_selection) + opp_player = (cur_player + 1) % 2 + + cur_p_board = np.equal(board_vals, cur_player + 1) + opp_p_board = np.equal(board_vals, opp_player + 1) + return np.stack([cur_p_board, opp_p_board], axis=2).astype(np.int8) # action in this case is a value from 0 to 8 indicating position to move on tictactoe board def step(self, action, observe=True): # check if input action is a valid move (0 == empty spot) - if(self.board.squares[action] == 0): - # play turn - self.board.play_turn(self.agents.index(self.agent_selection), action) - - # update infos - # list of valid actions (indexes in board) - # next_agent = self.agents[(self.agents.index(self.agent_selection) + 1) % len(self.agents)] - next_agent = self._agent_selector.next() - self.infos[self.agent_selection]['legal_moves'] = [i for i in range(len(self.board.squares)) if self.board.squares[i] == 0] - self.infos[next_agent]['legal_moves'] = [i for i in range(len(self.board.squares)) if self.board.squares[i] == 0] - - if self.board.check_game_over(): - winner = self.board.check_for_winner() - - if winner == -1: - # tie - pass - elif winner == 1: - # agent 0 won - self.rewards[self.agents[0]] += 1 - self.rewards[self.agents[1]] -= 1 - else: - # agent 1 won - self.rewards[self.agents[1]] += 1 - self.rewards[self.agents[0]] -= 1 - - # once either play wins or there is a draw, game over, both players are done - self.dones = {i: True for i in self.agents} + assert (self.board.squares[action] == 0), "played illegal move" + # play turn + self.board.play_turn(self.agents.index(self.agent_selection), action) + + # update infos + # list of valid actions (indexes in board) + # next_agent = self.agents[(self.agents.index(self.agent_selection) + 1) % len(self.agents)] + next_agent = self._agent_selector.next() + self.infos[self.agent_selection]['legal_moves'] = [i for i in range(len(self.board.squares)) if self.board.squares[i] == 0] + self.infos[next_agent]['legal_moves'] = [i for i in range(len(self.board.squares)) if self.board.squares[i] == 0] + + if self.board.check_game_over(): + winner = self.board.check_for_winner() + + if winner == -1: + # tie + pass + elif winner == 1: + # agent 0 won + self.rewards[self.agents[0]] += 1 + self.rewards[self.agents[1]] -= 1 + else: + # agent 1 won + self.rewards[self.agents[1]] += 1 + self.rewards[self.agents[0]] -= 1 - else: - # invalid move, end game - self.rewards[self.agent_selection] -= 1 + # once either play wins or there is a draw, game over, both players are done self.dones = {i: True for i in self.agents} - warnings.warn("Bad tictactoe move made, game terminating with current player losing. env.infos[player]['legal_moves'] contains a list of all legal moves that can be chosen.") # Switch selection to next agents self.agent_selection = next_agent diff --git a/pettingzoo/classic/uno/__init__.py b/pettingzoo/classic/uno/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/classic/uno/uno.py b/pettingzoo/classic/uno/uno.py index 2854a301e..de0d38e32 100644 --- a/pettingzoo/classic/uno/uno.py +++ b/pettingzoo/classic/uno/uno.py @@ -6,18 +6,28 @@ import random from rlcard.games.uno.card import UnoCard import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NaNRandomWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, **kwargs): - super(env, self).__init__() + def __init__(self, seed=None): + super().__init__() if seed is not None: np.random.seed(seed) random.seed(seed) - self.env = rlcard.make('uno', **kwargs) + self.env = rlcard.make('uno', config={"seed": seed}) self.agents = ['player_0', 'player_1'] self.num_agents = len(self.agents) self.has_reset = False @@ -87,7 +97,7 @@ def step(self, action, observe=True): def reset(self, observe=True): self.has_reset = True - obs, player_id = self.env.init_game() + obs, player_id = self.env.reset() self.agent_order = [self.agents[agent] for agent in [player_id, 0 if player_id == 1 else 1]] self._agent_selector.reinit(self.agent_order) self.agent_selection = self._agent_selector.reset() diff --git a/pettingzoo/gamma/__init__.py b/pettingzoo/gamma/__init__.py index 071410b52..4ea823557 100644 --- a/pettingzoo/gamma/__init__.py +++ b/pettingzoo/gamma/__init__.py @@ -2,3 +2,4 @@ from .pistonball import pistonball as pistonball_v0 from .cooperative_pong import cooperative_pong as cooperative_pong_v0 from .prison import prison as prison_v0 +from .prospector import prospector as prospector_v0 diff --git a/pettingzoo/gamma/cooperative_pong/cooperative_pong.py b/pettingzoo/gamma/cooperative_pong/cooperative_pong.py index edbc94bec..56f812488 100755 --- a/pettingzoo/gamma/cooperative_pong/cooperative_pong.py +++ b/pettingzoo/gamma/cooperative_pong/cooperative_pong.py @@ -1,9 +1,11 @@ import os import numpy as np import gym +from gym.utils import seeding from .cake_paddle import CakePaddle from .manual_control import manual_control from pettingzoo import AECEnv +from pettingzoo.utils import wrappers from pettingzoo.utils.agent_selector import agent_selector os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide' import pygame @@ -28,7 +30,7 @@ def original_obs_shape(screen_width, screen_height): return (int(screen_height / KERNEL_WINDOW_LENGTH), int(screen_width / (2 * KERNEL_WINDOW_LENGTH)), 1) -def get_valid_angle(): +def get_valid_angle(randomizer): # generates an angle in [0, 2*np.pi) that \ # excludes (90 +- ver_deg_range), (270 +- ver_deg_range), (0 +- hor_deg_range), (180 +- hor_deg_range) # (65, 115), (245, 295), (170, 190), (0, 10), (350, 360) @@ -45,14 +47,14 @@ def get_valid_angle(): angle = 0 while ((angle > a1 and angle < b1) or (angle > a2 and angle < b2) or (angle > c1 and angle < d1) or (angle > c2) or (angle < d2)): - angle = 2 * np.pi * np.random.rand() + angle = 2 * np.pi * randomizer.rand() return angle -def get_small_random_value(): +def get_small_random_value(randomizer): # generates a small random value between [0, 1/100) - return (1 / 100) * np.random.rand() + return (1 / 100) * randomizer.rand() class PaddleSprite(pygame.sprite.Sprite): @@ -129,7 +131,7 @@ def process_collision(self, b_rect, dx, dy, b_speed, paddle_type): class BallSprite(pygame.sprite.Sprite): - def __init__(self, dims, speed, bounce_randomness=False): # def __init__(self, image, speed): + def __init__(self, randomizer, dims, speed, bounce_randomness=False): # def __init__(self, image, speed): # self.surf = get_image(image) self.surf = pygame.Surface(dims) self.rect = self.surf.get_rect() @@ -138,6 +140,7 @@ def __init__(self, dims, speed, bounce_randomness=False): # def __init__(self, self.bounce_randomness = bounce_randomness self.done = False self.hit = False + self.randomizer = randomizer def update2(self, area, p0, p1): (speed_x, speed_y) = self.speed @@ -174,7 +177,7 @@ def move_single_axis(self, dx, dy, area, p0, p1): # add some randomness r_val = 0 if self.bounce_randomness: - r_val = get_small_random_value() + r_val = get_small_random_value(self.randomizer) # ball in left half of screen if self.rect.center[0] < area.center[0]: @@ -199,7 +202,7 @@ class CooperativePong(gym.Env): metadata = {'render.modes': ['human']} # ball_speed = [3,3], left_paddle_speed = 3, right_paddle_speed = 3 - def __init__(self, ball_speed=18, left_paddle_speed=25, right_paddle_speed=25, is_cake_paddle=True, max_frames=900, bounce_randomness=False): + def __init__(self, randomizer, ball_speed=18, left_paddle_speed=25, right_paddle_speed=25, cake_paddle=True, max_frames=900, bounce_randomness=False): super(CooperativePong, self).__init__() pygame.init() @@ -228,7 +231,7 @@ def __init__(self, ball_speed=18, left_paddle_speed=25, right_paddle_speed=25, i # paddles self.p0 = PaddleSprite((20, 80), left_paddle_speed) - if is_cake_paddle: + if cake_paddle: self.p1 = CakePaddle(right_paddle_speed) else: self.p1 = PaddleSprite((20, 100), right_paddle_speed) @@ -236,7 +239,8 @@ def __init__(self, ball_speed=18, left_paddle_speed=25, right_paddle_speed=25, i self.agents = ["paddle_0", "paddle_1"] # list(range(self.num_agents)) # ball - self.ball = BallSprite((20, 20), ball_speed, bounce_randomness) + self.ball = BallSprite(randomizer, (20, 20), ball_speed, bounce_randomness) + self.randomizer = randomizer self.reinit() @@ -252,7 +256,7 @@ def reset(self): # reset ball and paddle init conditions self.ball.rect.center = self.area.center # set the direction to an angle between [0, 2*np.pi) - angle = get_valid_angle() + angle = get_valid_angle(self.randomizer) # angle = deg_to_rad(89) self.ball.speed = [int(self.ball.speed_val * np.cos(angle)), int(self.ball.speed_val * np.sin(angle))] @@ -349,17 +353,27 @@ def step(self, action, agent): self.dones[ag] = self.done self.infos[ag] = {} - pygame.event.pump() + if self.renderOn: + pygame.event.pump() self.draw() -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, 0, "doing nothing") + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): # class env(MultiAgentEnv): metadata = {'render.modes': ['human']} - def __init__(self, **kwargs): - super(env, self).__init__() - self.env = CooperativePong(**kwargs) + def __init__(self, seed=None, **kwargs): + super().__init__() + self.randomizer, seed = seeding.np_random(seed) + self.env = CooperativePong(self.randomizer, **kwargs) self.agents = self.env.agents self.num_agents = len(self.agents) diff --git a/pettingzoo/gamma/knights_archers_zombies/__init__.py b/pettingzoo/gamma/knights_archers_zombies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/gamma/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/gamma/knights_archers_zombies/knights_archers_zombies.py index 5d020d5b0..0348f6cc5 100644 --- a/pettingzoo/gamma/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/gamma/knights_archers_zombies/knights_archers_zombies.py @@ -10,12 +10,11 @@ from .manual_control import manual_control import numpy as np from skimage import measure -import matplotlib.pyplot as plt from pettingzoo import AECEnv from pettingzoo.utils import agent_selector -from pettingzoo.utils import EnvLogger from gym.spaces import Box, Discrete from gym.utils import seeding +from pettingzoo.utils import wrappers def get_image(path): @@ -25,11 +24,20 @@ def get_image(path): return image -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.AssertOutOfBoundsWrapper(env) + default_val = 1 + env = wrappers.NanNoOpWrapper(env, default_val, "setting action to 1") + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} - def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killable_knights=True, killable_archers=True, pad_observation=True, max_frames=900): + def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killable_knights=True, killable_archers=True, pad_observation=True, black_death=True, line_death=False, max_frames=900): # Game Constants self.ZOMBIE_SPAWN = spawn_rate self.FPS = 90 @@ -40,6 +48,8 @@ def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killa self.pad_observation = pad_observation self.killable_knights = killable_knights self.killable_archers = killable_archers + self.black_death = black_death + self.line_death = line_death self.has_reset = False self.np_random, seed = seeding.np_random(seed) @@ -68,6 +78,9 @@ def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killa self.archer_list = pygame.sprite.Group() self.knight_list = pygame.sprite.Group() + # If black_death, this represents agents to remove at end of cycle + self.kill_list = [] + self.num_archers = num_archers self.num_knights = num_knights @@ -91,7 +104,8 @@ def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killa self.agents = [] for i in range(num_archers): - self.archer_dict["archer{0}".format(self.archer_player_num)] = Archer() + name = "archer_" + str(i) + self.archer_dict["archer{0}".format(self.archer_player_num)] = Archer(agent_name=name) self.archer_dict["archer{0}".format(self.archer_player_num)].offset(i * 50, 0) self.archer_list.add(self.archer_dict["archer{0}".format(self.archer_player_num)]) self.all_sprites.add(self.archer_dict["archer{0}".format(self.archer_player_num)]) @@ -100,7 +114,8 @@ def __init__(self, seed=None, spawn_rate=20, num_archers=2, num_knights=2, killa self.archer_player_num += 1 for i in range(num_knights): - self.knight_dict["knight{0}".format(self.knight_player_num)] = Knight() + name = "knight_" + str(i) + self.knight_dict["knight{0}".format(self.knight_player_num)] = Knight(agent_name=name) self.knight_dict["knight{0}".format(self.knight_player_num)].offset(i * 50, 0) self.knight_list.add(self.knight_dict["knight{0}".format(self.knight_player_num)]) self.all_sprites.add(self.knight_dict["knight{0}".format(self.knight_player_num)]) @@ -251,7 +266,7 @@ def sword_stab(self, sword_list, all_sprites): # Spawning Zombies at Random Location at every 100 iterations def spawn_zombie(self, zombie_spawn_rate, zombie_list, all_sprites): zombie_spawn_rate += 1 - zombie = Zombie() + zombie = Zombie(self.np_random) if zombie_spawn_rate >= self.ZOMBIE_SPAWN: zombie.rect.x = self.np_random.randint(0, self.WIDTH) @@ -273,6 +288,8 @@ def zombie_knight(self, zombie_list, knight_list, all_sprites, knight_killed, sw all_sprites.remove(knight) sword_killed = True knight_killed = True + if knight.agent_name not in self.kill_list: + self.kill_list.append(knight.agent_name) return zombie_list, knight_list, all_sprites, knight_killed, sword_list, sword_killed # Kill the Sword when Knight dies @@ -294,6 +311,8 @@ def zombie_archer(self, zombie_list, archer_list, all_sprites, archer_killed): archer_list.remove(archer) all_sprites.remove(archer) archer_killed = True + if archer.agent_name not in self.kill_list: + self.kill_list.append(archer.agent_name) return zombie_list, archer_list, all_sprites, archer_killed # Zombie Kills the Sword @@ -345,8 +364,6 @@ def zombie_all_players(self, knight_list, archer_list, run): return run def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() screen = pygame.surfarray.pixels3d(self.WINDOW) i = self.agent_name_mapping[agent] @@ -389,15 +406,7 @@ def observe(self, agent): return cropped def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() agent = self.agent_selection - if action is None or np.isnan(action): - EnvLogger.warn_action_is_NaN(backup_policy="setting action to 1") - action = 1 - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting action to 1") - action = 1 if self.render_on: self.clock.tick(self.FPS) # FPS else: @@ -420,8 +429,19 @@ def step(self, action, observe=True): # Reset Environment if event.key == pygame.K_BACKSPACE: self.reset(observe=False) + agent_name = self.agent_list[self.agent_name_mapping[agent]] - agent_name.update(action) + action = action + 1 + out_of_bounds = agent_name.update(action) + + if self.line_death and out_of_bounds: + agent_name.alive = False + if agent_name in self.archer_list: + self.archer_list.remove(agent_name) + else: + self.knight_list.remove(agent_name) + self.all_sprites.remove(agent_name) + self.kill_list.append(agent_name.agent_name) sp = self.spawnPlayers(action, self.knight_player_num, self.archer_player_num, self.knight_list, self.archer_list, self.all_sprites, self.knight_dict, self.archer_dict) # Knight @@ -434,6 +454,7 @@ def step(self, action, observe=True): self.sword_spawn_rate, self.knight_killed, self.knight_dict, self.knight_list, self.knight_player_num, self.all_sprites, self.sword_dict, self.sword_list = sw.spawnSword() # Arrow self.arrow_spawn_rate, self.archer_killed, self.archer_dict, self.archer_list, self.archer_player_num, self.all_sprites, self.arrow_dict, self.arrow_list = sw.spawnArrow() + if self._agent_selector.is_last(): # Spawning Zombies at Random Location at every 100 iterations self.zombie_spawn_rate, self.zombie_list, self.all_sprites = self.spawn_zombie(self.zombie_spawn_rate, self.zombie_list, self.all_sprites) @@ -482,9 +503,28 @@ def step(self, action, observe=True): self.check_game_end() self.frames += 1 - self.agent_selection = self._agent_selector.next() + self.rewards[agent] = agent_name.score self.dones[agent] = not self.run or self.frames >= self.max_frames + + if self._agent_selector.is_last() and not self.black_death: + # self.agents must be recreated + for k in self.kill_list: + print("Killed ", k) + self.agents.remove(k) + self.dones.pop(k, None) + self.rewards.pop(k, None) + self.infos.pop(k, None) + + # reinit agent_order from agents + self.agent_order = self.agents[:] + self._agent_selector.reinit(self.agent_order) + self.num_agents = len(self.agents) + + # reset the kill list + self.kill_list = [] + + self.agent_selection = self._agent_selector.next() if observe: return self.observe(self.agent_selection) @@ -495,20 +535,15 @@ def enable_render(self): self.reset() def render(self, mode="human"): - if not self.has_reset: - EnvLogger.error_render_before_reset() - else: - if not self.render_on: - # sets self.render_on to true and initializes display - self.enable_render() - pygame.display.flip() + if not self.render_on: + # sets self.render_on to true and initializes display + self.enable_render() + pygame.display.flip() def close(self): if not self.closed: self.closed = True - if not self.render_on: - EnvLogger.warn_close_unrendered_env() - else: + if self.render_on: # self.WINDOW = pygame.display.set_mode([self.WIDTH, self.HEIGHT]) self.WINDOW = pygame.Surface((self.WIDTH, self.HEIGHT)) self.render_on = False @@ -556,7 +591,8 @@ def reinit(self): self.agents = [] for i in range(self.num_archers): - self.archer_dict["archer{0}".format(self.archer_player_num)] = Archer() + name = "archer_" + str(i) + self.archer_dict["archer{0}".format(self.archer_player_num)] = Archer(agent_name=name) self.archer_dict["archer{0}".format(self.archer_player_num)].offset(i * 50, 0) self.archer_list.add(self.archer_dict["archer{0}".format(self.archer_player_num)]) self.all_sprites.add(self.archer_dict["archer{0}".format(self.archer_player_num)]) @@ -565,7 +601,8 @@ def reinit(self): self.archer_player_num += 1 for i in range(self.num_knights): - self.knight_dict["knight{0}".format(self.knight_player_num)] = Knight() + name = "knight_" + str(i) + self.knight_dict["knight{0}".format(self.knight_player_num)] = Knight(agent_name=name) self.knight_dict["knight{0}".format(self.knight_player_num)].offset(i * 50, 0) self.knight_list.add(self.knight_dict["knight{0}".format(self.knight_player_num)]) self.all_sprites.add(self.knight_dict["knight{0}".format(self.knight_player_num)]) diff --git a/pettingzoo/gamma/knights_archers_zombies/manual_control.py b/pettingzoo/gamma/knights_archers_zombies/manual_control.py index 4bbce8034..7915bc7c8 100644 --- a/pettingzoo/gamma/knights_archers_zombies/manual_control.py +++ b/pettingzoo/gamma/knights_archers_zombies/manual_control.py @@ -15,10 +15,10 @@ def manual_control(**kwargs): while not done: # while frame_count < frame_limit: # Uncomment this if you want the game to run for fame_limit amount of frames instead of ending by normal game conditions (useful for testing purposes) - agents = env.agent_list + agents = env.agents frame_count += 1 - actions = [6 for x in range(len(env.agents))] # If you want to do manual input - + actions = [5 for x in range(len(env.agents))] # If you want to do manual input + # 5 is do nothing, 0 is up, 1 is down, 2 is turn CW, 3 is CCW, 4 is attack for event in pygame.event.get(): if event.type == pygame.KEYDOWN: if event.key == pygame.K_ESCAPE: @@ -37,19 +37,18 @@ def manual_control(**kwargs): if cur_agent > len(agents) - 1: cur_agent = 0 if event.key == pygame.K_q: - actions[cur_agent] = 3 + actions[cur_agent] = 2 if event.key == pygame.K_e: - actions[cur_agent] = 4 + actions[cur_agent] = 3 if event.key == pygame.K_w: - actions[cur_agent] = 1 + actions[cur_agent] = 0 if event.key == pygame.K_s: - actions[cur_agent] = 2 + actions[cur_agent] = 1 if event.key == pygame.K_f: - actions[cur_agent] = 5 + actions[cur_agent] = 4 if quit_game: break - for a in actions: env.step(a) env.render() diff --git a/pettingzoo/gamma/knights_archers_zombies/src/__init__.py b/pettingzoo/gamma/knights_archers_zombies/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/gamma/knights_archers_zombies/src/players.py b/pettingzoo/gamma/knights_archers_zombies/src/players.py index 796c71b30..d1bc6d1c7 100644 --- a/pettingzoo/gamma/knights_archers_zombies/src/players.py +++ b/pettingzoo/gamma/knights_archers_zombies/src/players.py @@ -7,16 +7,15 @@ HEIGHT = 720 ARCHER_SPEED = 25 KNIGHT_SPEED = 25 -ARCHER_X, ARCHER_Y = 400, 710 -KNIGHT_X, KNIGHT_Y = 800, 710 +ARCHER_X, ARCHER_Y = 400, 610 +KNIGHT_X, KNIGHT_Y = 800, 610 ANGLE_RATE = 10 class Archer(pygame.sprite.Sprite): - def __init__(self): + def __init__(self, agent_name): super().__init__() - # rand_x = random.randint(20, 1260) img_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'img')) self.image = pygame.image.load(os.path.join(img_path, 'archer.png')) self.rect = self.image.get_rect(center=(ARCHER_X, ARCHER_Y)) @@ -30,8 +29,10 @@ def __init__(self): self.score = 0 self.is_archer = True self.is_knight = False + self.agent_name = agent_name def update(self, action): + went_out_of_bounds = False if not self.attacking: move_angle = math.radians(self.angle + 90) @@ -54,12 +55,16 @@ def update(self, action): pass # Clamp to stay inside the screen + if self.rect.y < 0 or self.rect.y > (HEIGHT - 40): + went_out_of_bounds = True + self.rect.x = max(min(self.rect.x, WIDTH - 132), 100) self.rect.y = max(min(self.rect.y, HEIGHT - 40), 0) self.direction = pygame.Vector2(0, -1).rotate(-self.angle) self.image = pygame.transform.rotate(self.org_image, self.angle) self.rect = self.image.get_rect(center=self.rect.center) + return went_out_of_bounds def offset(self, x_offset, y_offset): self.rect.x += x_offset @@ -70,7 +75,7 @@ def is_done(self): class Knight(pygame.sprite.Sprite): - def __init__(self): + def __init__(self, agent_name): super().__init__() img_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'img')) self.image = pygame.image.load(os.path.join(img_path, 'knight.png')) @@ -88,10 +93,11 @@ def __init__(self): self.score = 0 self.is_archer = False self.is_knight = True + self.agent_name = agent_name def update(self, action): self.action = action - + went_out_of_bounds = False if not self.attacking: move_angle = math.radians(self.angle + 90) # Up and Down movement @@ -112,6 +118,9 @@ def update(self, action): pass # Clamp to stay inside the screen + if self.rect.y < 0 or self.rect.y > (HEIGHT - 40): + went_out_of_bounds = True + self.rect.x = max(min(self.rect.x, WIDTH - 132), 100) self.rect.y = max(min(self.rect.y, HEIGHT - 40), 0) @@ -119,6 +128,8 @@ def update(self, action): self.image = pygame.transform.rotate(self.org_image, self.angle) self.rect = self.image.get_rect(center=self.rect.center) + return went_out_of_bounds + def offset(self, x_offset, y_offset): self.rect.x += x_offset self.rect.y += y_offset diff --git a/pettingzoo/gamma/knights_archers_zombies/src/zombie.py b/pettingzoo/gamma/knights_archers_zombies/src/zombie.py index 4ae3140dd..029340944 100644 --- a/pettingzoo/gamma/knights_archers_zombies/src/zombie.py +++ b/pettingzoo/gamma/knights_archers_zombies/src/zombie.py @@ -1,5 +1,4 @@ import pygame -import random import os ZOMBIE_Y_SPEED = 5 @@ -9,14 +8,15 @@ class Zombie(pygame.sprite.Sprite): - def __init__(self): + def __init__(self, randomizer): super().__init__() img_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'img')) self.image = pygame.image.load(os.path.join(img_path, 'zombie.png')) self.rect = self.image.get_rect(center=(50, 50)) + self.randomizer = randomizer def update(self): - rand_x = random.randint(0, 10) + rand_x = self.randomizer.randint(0, 10) # Wobbling in X-Y Direction self.rect.y += ZOMBIE_Y_SPEED diff --git a/pettingzoo/gamma/pistonball/pistonball.py b/pettingzoo/gamma/pistonball/pistonball.py index 99915e777..5536092fe 100755 --- a/pettingzoo/gamma/pistonball/pistonball.py +++ b/pettingzoo/gamma/pistonball/pistonball.py @@ -12,8 +12,8 @@ from gym.utils import seeding from pettingzoo import AECEnv from pettingzoo.utils import agent_selector -from pettingzoo.utils import EnvLogger from .manual_control import manual_control +from pettingzoo.utils import wrappers _image_library = {} @@ -25,12 +25,21 @@ def get_image(path): return image -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + default_val = np.zeros((1,)) if env.continuous else 1 + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, default_val, "setting action to {}".format(default_val)) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self, seed=None, local_ratio=0.02, continuous=False, random_drop=True, starting_angular_momentum=True, ball_mass=0.75, ball_friction=0.3, ball_elasticity=1.5, max_frames=900): - super(env, self).__init__() + super().__init__() self.agents = ["piston_" + str(r) for r in range(20)] self.agent_name_mapping = dict(zip(self.agents, list(range(20)))) self.agent_order = self.agents[:] @@ -106,8 +115,6 @@ def __init__(self, seed=None, local_ratio=0.02, continuous=False, random_drop=Tr self.closed = False def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() observation = pygame.surfarray.pixels3d(self.screen) i = self.agent_name_mapping[agent] x_low = 40 * i @@ -125,9 +132,7 @@ def enable_render(self): def close(self): if not self.closed: self.closed = True - if not self.renderOn: - EnvLogger.warn_close_unrendered_env() - else: + if self.renderOn: self.screen = pygame.Surface((960, 560)) self.renderOn = False pygame.event.pump() @@ -246,25 +251,13 @@ def get_local_reward(self, prev_position, curr_position): return local_reward * self.local_reward_weight def render(self, mode="human"): - if not self.has_reset: - EnvLogger.error_render_before_reset() - else: - if not self.renderOn: - # sets self.renderOn to true and initializes display - self.enable_render() - pygame.display.flip() + if not self.renderOn: + # sets self.renderOn to true and initializes display + self.enable_render() + pygame.display.flip() def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() agent = self.agent_selection - if action is None or np.isnan(action): - action = 1 - EnvLogger.warn_action_is_NaN(backup_policy="setting action to 1") - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting action to 1") - action = 1 - if self.continuous: self.move_piston(self.pistonList[self.agent_name_mapping[agent]], action) else: diff --git a/pettingzoo/gamma/prison/prison.py b/pettingzoo/gamma/prison/prison.py index e9ea4f17b..4416f495d 100644 --- a/pettingzoo/gamma/prison/prison.py +++ b/pettingzoo/gamma/prison/prison.py @@ -5,7 +5,7 @@ import numpy as np from gym import spaces from .manual_control import manual_control -from pettingzoo.utils import EnvLogger +from pettingzoo.utils import wrappers from gym.utils import seeding os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide' @@ -71,10 +71,18 @@ def update_sprite(self, movement): self.last_sprite_movement = 0 -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, 0, "setting action to 0") + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): def __init__(self, seed=None, continuous=False, vector_observation=False, max_frames=900, num_floors=4, synchronized_start=False, identical_aliens=False, random_aliens=False): - # super(env, self).__init__() + # super().__init__() self.num_agents = 2 * num_floors self.agents = ["prisoner_" + str(s) for s in range(0, self.num_agents)] self.agent_order = self.agents[:] @@ -237,8 +245,6 @@ def close(self): if self.rendering: pygame.event.pump() pygame.display.quit() - else: - EnvLogger.warn_close_unrendered_env() pygame.quit() def draw(self): @@ -249,8 +255,6 @@ def draw(self): self.screen.blit(self.prisoners[p].get_sprite(), self.prisoners[p].position) def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() if self.vector_obs: p = self.prisoners[agent] x = p.position[0] @@ -290,17 +294,8 @@ def reset(self, observe=True): return self.observe(self.agent_selection) def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() # move prisoners, -1 = move left, 0 = do nothing and 1 is move right agent = self.agent_selection - # if not continuous, input must be normalized - if None in [action] or np.isnan(action): - EnvLogger.warn_action_is_NaN(backup_policy="setting action to 0") - action = np.zeros_like(self.action_spaces[agent].sample()) - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting action to zero") - action = np.zeros_like(self.action_spaces[agent].sample()) reward = 0 if self.continuous: reward = self.move_prisoner(agent, action) @@ -336,22 +331,19 @@ def step(self, action, observe=True): return observation def render(self, mode='human'): - if not self.has_reset: - EnvLogger.error_render_before_reset() - else: - if not self.rendering: - pygame.display.init() - old_screen = self.screen - self.screen = pygame.display.set_mode((750, 50 + 150 * self.num_floors)) - self.screen.blit(old_screen, (0, 0)) - self.screen.blit(self.background, (0, 0)) - if self.num_floors > 4: - min_rows = self.num_floors - 4 - for k in range(min_rows): - h = 650 + 150 * k - self.screen.blit(self.background_append, (0, h)) - self.rendering = True - pygame.display.flip() + if not self.rendering: + pygame.display.init() + old_screen = self.screen + self.screen = pygame.display.set_mode((750, 50 + 150 * self.num_floors)) + self.screen.blit(old_screen, (0, 0)) + self.screen.blit(self.background, (0, 0)) + if self.num_floors > 4: + min_rows = self.num_floors - 4 + for k in range(min_rows): + h = 650 + 150 * k + self.screen.blit(self.background_append, (0, h)) + self.rendering = True + pygame.display.flip() # Sprites other than bunny and tank purchased from https://nebelstern.itch.io/futura-seven # Tank and bunny sprites commissioned from https://www.fiverr.com/jeimansutrisman diff --git a/pettingzoo/gamma/prospector/constants.py b/pettingzoo/gamma/prospector/constants.py index 3092b589f..9584be667 100644 --- a/pettingzoo/gamma/prospector/constants.py +++ b/pettingzoo/gamma/prospector/constants.py @@ -15,9 +15,12 @@ OBSERVATION_SIDE_LENGTH = 5 * AGENT_DIAMETER OBSERVATION_SHAPE = (OBSERVATION_SIDE_LENGTH, OBSERVATION_SIDE_LENGTH, 3) - MAX_SPRITE_ROTATION = math.pi / 4 +NUM_PROSPECTORS = 4 +NUM_BANKERS = 3 +NUM_AGENTS = NUM_PROSPECTORS + NUM_BANKERS + PROSPECTOR_SPEED = 150 BANKER_SPEED = 100 BANKER_HANDOFF_TOLERANCE = math.pi / 4 @@ -44,9 +47,14 @@ ) FENCE_INFO = [ - ("left", [0, 0], [0, 0],FENCE_VERT_VERTICES), # left boundary + ("left", [0, 0], [0, 0], FENCE_VERT_VERTICES), # left boundary ("top", [0, 0], [0, 0], FENCE_HORIZ_VERTICES), # top boundary - ("right", [SCREEN_WIDTH - FENCE_WIDTH, 0], [SCREEN_WIDTH - (FENCE_WIDTH + FENCE_COLLISION_BUFFER), 0], FENCE_VERT_VERTICES), + ( + "right", + [SCREEN_WIDTH - FENCE_WIDTH, 0], + [SCREEN_WIDTH - (FENCE_WIDTH + FENCE_COLLISION_BUFFER), 0], + FENCE_VERT_VERTICES, + ), ] BANK_SIZE = BANK_WIDTH, BANK_HEIGHT = 184, 100 @@ -65,11 +73,11 @@ ] WATER_INFO = [ - (0, SCREEN_HEIGHT - WATER_HEIGHT), # position - ( # vertices + (0, SCREEN_HEIGHT - WATER_HEIGHT), # position + ( # vertices (0, 0), (SCREEN_WIDTH, 0), (SCREEN_WIDTH, WATER_HEIGHT), (0, WATER_HEIGHT), ), -] \ No newline at end of file +] diff --git a/pettingzoo/gamma/prospector/data/bankers/0-big.png b/pettingzoo/gamma/prospector/data/bankers/0-big.png new file mode 100644 index 000000000..8dfa18e45 Binary files /dev/null and b/pettingzoo/gamma/prospector/data/bankers/0-big.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/0.png b/pettingzoo/gamma/prospector/data/bankers/0.png new file mode 100644 index 000000000..6cbffc11d Binary files /dev/null and b/pettingzoo/gamma/prospector/data/bankers/0.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/1-big.png b/pettingzoo/gamma/prospector/data/bankers/1-big.png index 8dfa18e45..d1df4bf99 100644 Binary files a/pettingzoo/gamma/prospector/data/bankers/1-big.png and b/pettingzoo/gamma/prospector/data/bankers/1-big.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/1.png b/pettingzoo/gamma/prospector/data/bankers/1.png index 6cbffc11d..180f38670 100644 Binary files a/pettingzoo/gamma/prospector/data/bankers/1.png and b/pettingzoo/gamma/prospector/data/bankers/1.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/2-big.png b/pettingzoo/gamma/prospector/data/bankers/2-big.png index d1df4bf99..031cd44b1 100644 Binary files a/pettingzoo/gamma/prospector/data/bankers/2-big.png and b/pettingzoo/gamma/prospector/data/bankers/2-big.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/2.png b/pettingzoo/gamma/prospector/data/bankers/2.png index 180f38670..70084f4a2 100644 Binary files a/pettingzoo/gamma/prospector/data/bankers/2.png and b/pettingzoo/gamma/prospector/data/bankers/2.png differ diff --git a/pettingzoo/gamma/prospector/data/bankers/3-big.png b/pettingzoo/gamma/prospector/data/bankers/3-big.png deleted file mode 100644 index 031cd44b1..000000000 Binary files a/pettingzoo/gamma/prospector/data/bankers/3-big.png and /dev/null differ diff --git a/pettingzoo/gamma/prospector/data/bankers/3.png b/pettingzoo/gamma/prospector/data/bankers/3.png deleted file mode 100644 index 70084f4a2..000000000 Binary files a/pettingzoo/gamma/prospector/data/bankers/3.png and /dev/null differ diff --git a/pettingzoo/gamma/prospector/manual_control.py b/pettingzoo/gamma/prospector/manual_control.py new file mode 100644 index 000000000..799a7c874 --- /dev/null +++ b/pettingzoo/gamma/prospector/manual_control.py @@ -0,0 +1,72 @@ +import pygame +import numpy as np +from . import constants as const + + +def manual_control(**kwargs): + from .prospector import env as _env + + env = _env(**kwargs) + env.reset() + default_scalar = 0.8 + + while True: + agent_actions = np.array( + [[0, 0, 0] for _ in range(const.NUM_PROSPECTORS)] + + [[0, 0, 0] for _ in range(const.NUM_BANKERS)] + ) + num_actions = 0 + agent = 0 + for event in pygame.event.get(): + # Use left/right arrow keys to switch between agents + # Use WASD to control bankers + # Use WASD and QE to control prospectors + # Note: QE while selecting a banker has no effect. + if event.type == pygame.KEYDOWN: + # Agent selection + if event.key == pygame.K_LEFT: + agent = (agent - 1) % const.NUM_AGENTS + elif event.key == pygame.K_RIGHT: + agent = (agent + 1) % const.NUM_AGENTS + # Forward/backward or up/down movement + elif event.key == pygame.K_w: + num_actions += 1 + agent_actions[agent][0] = default_scalar + elif event.key == pygame.K_s: + num_actions += 1 + agent_actions[agent][0] = -default_scalar + # left/right movement + elif event.key == pygame.K_a: + num_actions += 1 + agent_actions[agent][1] = -default_scalar + elif event.key == pygame.K_d: + num_actions += 1 + agent_actions[agent][1] = default_scalar + # rotation + elif event.key == pygame.K_q: + if 0 <= agent <= 3: + num_actions += 1 + agent_actions[agent][2] = default_scalar + elif event.key == pygame.K_e: + if 0 <= agent <= 3: + num_actions += 1 + agent_actions[agent][2] = -default_scalar + elif event.key == pygame.K_ESCAPE: + test_done = True + actions = dict(zip(env.agents, agent_actions)) + test_done = False + for i in env.agents: + reward, done, info = env.last() + if done: + test_done = True + action = actions[i] + env.step(action, observe=False) + env.render() + + if test_done: + break + env.close() + + +if __name__ == "__main__": + manual_control() diff --git a/pettingzoo/gamma/prospector/prospector.py b/pettingzoo/gamma/prospector/prospector.py index a7bd00ade..b24ad1fc7 100644 --- a/pettingzoo/gamma/prospector/prospector.py +++ b/pettingzoo/gamma/prospector/prospector.py @@ -1,450 +1,795 @@ -import pygame import pygame as pg -import os -import random -import gym +import pymunk as pm +from pymunk import Vec2d +from gym import spaces +from gym.utils import seeding import numpy as np -# Define some colors +from pettingzoo import AECEnv +from pettingzoo.utils import agent_selector +from pettingzoo.utils import wrappers +from . import constants as const +from . import utils +from .manual_control import manual_control -PLAYER_SPEED = 30.0 -PLAYER_ROT_SPEED = 20.0 -PLAYER_IMG = "manBlue_gun.png" -PLAYER_HIT_RECT = pg.Rect(0, 0, 35, 35) -TILESIZE = 2 +import math +import os +from enum import IntEnum, auto +import itertools as it + + +class CollisionTypes(IntEnum): + PROSPECTOR = auto() + BOUNDARY = auto() + WATER = auto() + BANK = auto() + GOLD = auto() + BANKER = auto() + + +class Prospector(pg.sprite.Sprite): + def __init__(self, pos, space, num, *sprite_groups): + super().__init__(sprite_groups) + # self.image = load_image(['prospec.png']) + self.image = utils.load_image(["prospector-pickaxe-big.png"]) + self.image = pg.transform.scale( + self.image, (int(const.AGENT_RADIUS * 2), int(const.AGENT_RADIUS * 2)) + ) -BLACK = (0, 0, 0) -WHITE = (255, 255, 255) -GREEN = (0, 255, 0) -RED = (255, 0, 0) + self.id = num -_image_library = {} -vec = pygame.math.Vector2 + self.rect = self.image.get_rect(topleft=pos) + self.orig_image = self.image + # Create the physics body and shape of this object. + # moment = pm.moment_for_poly(mass, vertices) -def get_image(path): - global _image_library - image = _image_library.get(path) - if image is None: - canonicalized_path = path.replace("/", os.sep).replace("\\", os.sep) - image = pygame.image.load(canonicalized_path) - _image_library[path] = image - return image + moment = pm.moment_for_circle(1, 0, self.rect.width / 2) + self.body = pm.Body(1, moment) + self.body.nugget = None + self.body.sprite_type = "prospector" + # self.shape = pm.Poly(self.body, vertices, radius=3) -def get_small_random_value(): - # generates a small random value between [0, 1/100) - return (1 / 100) * np.random.rand() + self.shape = pm.Circle(self.body, const.AGENT_RADIUS) + self.shape.elasticity = 0.0 + self.shape.collision_type = CollisionTypes.PROSPECTOR + self.body.position = utils.flipy(pos) + # Add them to the Pymunk space. + self.space = space + self.space.add(self.body, self.shape) -class Block(pygame.sprite.Sprite): - """ - This class represents the ball. - It derives from the "Sprite" class in Pygame. - """ + def reset(self, pos): + self.body.angle = 0 + self.image = pg.transform.rotozoom(self.orig_image, 0, 1) + self.rect = self.image.get_rect(topleft=pos) + self.body.position = utils.flipy(pos) + self.body.velocity = Vec2d(0.0, 0.0) - def __init__(self, color=(250, 250, 0), width=20, height=20): - """ Constructor. Pass in the color of the block, - and its size. """ + @property + def center(self): + return self.rect.x + const.AGENT_RADIUS, self.rect.y + const.AGENT_RADIUS - # Call the parent class (Sprite) constructor - super().__init__() + def update(self, action): + # forward/backward action + y_vel = action[0] * const.PROSPECTOR_SPEED + # left/right action + x_vel = action[1] * const.PROSPECTOR_SPEED - # Create an image of the block, and fill it with a color. - # This could also be an image loaded from the disk. - self.image = pygame.Surface([width, height]) - self.image.fill(color) - # Fetch the rectangle object that has the dimensions of the image - # image. - # Update the position of this object by setting the values - # of rect.x and rect.y - self.rect = self.image.get_rect() - self.dim = self.rect.size - - def reset_pos(self): - self.rect.y = random.randrange(0, 20) - self.rect.x = random.randrange(0, self.screen_width[0]) - # print(self.rect.y, self.rect.x) - - def update(self, pos, corner): - """ - rect corner numbers - 1-----2 - | | - 4-----3 - """ - self.rect.bottomright = (pos.x, pos.y) - - -class agent1(pygame.sprite.Sprite): - """ - This class represents the ball - """ - - def __init__(self, _screen_width, x, y, speed=20): - """ Constructor. Pass in the color of the block, - and its x and y position. """ - # Call the parent class (Sprite) constructor - super().__init__() - - self.image = get_image("agent1.jpg") - self.base_image = get_image("agent1.jpg") - self.screen_width = _screen_width - self.rect = self.image.get_rect() - self.dim = self.rect.size + delta_angle = action[2] * const.MAX_SPRITE_ROTATION + + self.body.angle += delta_angle + self.body.angular_velocity = 0 + + self.body.velocity = Vec2d(x_vel, y_vel).rotated(self.body.angle) - self.speed_val = speed - self.vel = vec( - int(self.speed_val * np.cos(np.pi / 4)), - int(self.speed_val * np.sin(np.pi / 4)), + self.rect.center = utils.flipy(self.body.position) + self.image = pg.transform.rotozoom( + self.orig_image, math.degrees(self.body.angle), 1 ) - self.rot = 0 - self.bounce_randomness = 1 - self.pos = vec(x, y) * TILESIZE - self.collision = [False] * 9 - - def handle_keyboard_input(self): - self.rot_speed = 0 - self.vel = vec(0, 0) - keys = pygame.key.get_pressed() - if keys[pg.K_LEFT] or keys[pg.K_a]: - print("Left") - self.rot_speed = PLAYER_ROT_SPEED - if keys[pg.K_RIGHT] or keys[pg.K_d]: - self.rot_speed = -PLAYER_ROT_SPEED - print("Right") - if keys[pg.K_UP] or keys[pg.K_w]: - self.vel = vec(PLAYER_SPEED, 0).rotate(-self.rot) - print("Up") - if keys[pg.K_DOWN] or keys[pg.K_s]: - self.vel = vec(-PLAYER_SPEED / 2, 0).rotate(-self.rot) - print("down") - - def reset_pos(self): - self.rect.y = random.randrange(350, 600) - self.rect.x = random.randrange(0, self.screen_width[0]) - print(self.rect.y, self.rect.x) - - def update(self, pos, area, p1): - """ Called each frame. """ - self.handle_keyboard_input() - self.rot = (self.rot + self.rot_speed) % 360 - self.image = pg.transform.rotate(self.base_image, self.rot) - self.image = pygame.transform.scale(self.image, (50, 50)) - self.rect = self.base_image.get_rect() - self.rect.center = self.pos - self.pos += self.vel - # print(':::::::::',self.rect.x,self.rect.y) - if self.vel[0] != 0: - done_x = self.move_single_axis(self.vel[0], 0, area, p1) - if self.vel[1] != 0: - done_y = self.move_single_axis(0, self.vel[1], area, p1) - - def check_collision(self, rect): - self.collision[0] = rect.collidepoint(self.rect.topleft) - self.collision[1] = rect.collidepoint(self.rect.topright) - self.collision[2] = rect.collidepoint(self.rect.bottomleft) - self.collision[3] = rect.collidepoint(self.rect.bottomright) - - self.collision[4] = rect.collidepoint(self.rect.midleft) - self.collision[5] = rect.collidepoint(self.rect.midright) - self.collision[6] = rect.collidepoint(self.rect.midtop) - self.collision[7] = rect.collidepoint(self.rect.midbottom) - self.collision[8] = rect.collidepoint(self.rect.center) - - if True in self.collision: - return self.collision.index(True) - else: + self.rect = self.image.get_rect(center=self.rect.center) + + curr_vel = self.body.velocity + + if self.body.nugget is not None: + self.body.nugget.update(self.body.position, self.body.angle, False) + + self.body.velocity = curr_vel + + def convert_img(self): + self.image = self.image.convert_alpha() + + def __str__(self): + return "prospector_%s" % self.id + + def __repr__(self): + return self.__str__() + + +class Banker(pg.sprite.Sprite): + def __init__(self, pos, space, num, *sprite_groups): + super().__init__(sprite_groups) + self.image = utils.load_image(["bankers", "%s-big.png" % num]) + self.image = pg.transform.scale( + self.image, (int(const.AGENT_RADIUS * 2), int(const.AGENT_RADIUS * 2)) + ) + + self.id = num + + self.rect = self.image.get_rect(topleft=pos) + self.orig_image = self.image + + moment = pm.moment_for_circle(1, 0, self.rect.width / 2) + + self.body = pm.Body(1, moment) + self.body.nugget = None + self.body.sprite_type = "banker" + self.body.nugget_offset = None + + self.shape = pm.Circle(self.body, const.AGENT_RADIUS) + self.shape.collision_type = CollisionTypes.BANKER + + self.body.position = utils.flipy(pos) + # Add them to the Pymunk space. + self.space = space + self.space.add(self.body, self.shape) + + @property + def center(self): + return self.rect.x + const.AGENT_RADIUS, self.rect.y + const.AGENT_RADIUS + + def reset(self, pos): + self.body.angle = 0 + self.image = pg.transform.rotozoom(self.orig_image, 0, 1) + self.rect = self.image.get_rect(topleft=pos) + self.body.position = utils.flipy(pos) + self.body.velocity = Vec2d(0.0, 0.0) + + def update(self, action): + # up/down action + y_vel = action[0] * const.BANKER_SPEED + # left/right action + x_vel = action[1] * const.BANKER_SPEED + + # Subtract math.pi / 2 because sprite starts off with math.pi / 2 rotated + angle_radians = math.atan2(y_vel, x_vel) - (math.pi / 2) - print(self.rect.topleft, self.rect.midtop, self.rect.midleft) - - for x in range(self.rect.topleft[0], self.rect.midtop[0]): - if rect.collidepoint(x, self.rect.topleft[1]): - self.collision[0] = True - for y in range(self.rect.topleft[1], self.rect.midleft[1]): - if rect.collidepoint(self.rect.topleft[0], y): - self.collision[0] = True - - for x in range(self.rect.bottomleft[1], self.rect.midbottom[1]): - if rect.collidepoint(x, self.rect.bottomleft[0]): - self.collision[2] = True - for y in range(self.rect.bottomleft[0], self.rect.midleft[0]): - if rect.collidepoint(self.rect.bottomleft[1], y): - self.collision[2] = True - print("collision:", self.collision) - return self.collision.index(True) - - def rotate(self, angle): - self.image = pygame.transform.rotate(self.image, angle) + # Angle is determined only by current trajectory. + self.body.angle = angle_radians + self.body.angular_velocity = 0 + + self.body.velocity = Vec2d(x_vel, y_vel) + + self.rect.center = utils.flipy(self.body.position) + self.image = pg.transform.rotozoom( + self.orig_image, math.degrees(self.body.angle), 1 + ) self.rect = self.image.get_rect(center=self.rect.center) - def move_single_axis(self, dx, dy, area, p1): - # returns done - - # move ball rect - self.rect.x += dx - self.rect.y += dy - - if not area.contains(self.rect): - # bottom wall - if dy > 0: - self.rect.bottom = area.bottom - self.vel[1] = -self.vel[1] - # top wall - elif dy < 0: - self.rect.top = area.top - self.vel[1] = -self.vel[1] - # right or left walls - else: - self.vel[0] = -self.vel[0] - return True + curr_vel = self.body.velocity + + if self.body.nugget is not None: + self.body.nugget.update(self.body.position, self.body.angle, False) + + self.body.velocity = curr_vel + def convert_img(self): + self.image = self.image.convert_alpha() + + def __str__(self): + return "banker_%s" % self.id + + def __repr__(self): + return self.__str__() + + +class Fence(pg.sprite.Sprite): + def __init__(self, w_type, sprite_pos, body_pos, verts, space, *sprite_groups): + super().__init__(sprite_groups) + + if w_type == "top": + # self.image = utils.load_image(["horiz-fence.png"]) + self.image = utils.load_image(["top-down-horiz-fence.png"]) + elif w_type in ["right", "left"]: + # self.image = utils.load_image(["vert-fence.png"]) + self.image = utils.load_image(["top-down-vert-fence.png"]) else: - # Do ball and bat collide? - # add some randomness - r_val = 0 - if self.bounce_randomness: - r_val = get_small_random_value() - - # ball in left half of screen - is_collision, self.rect, self.vel = p1.process_collision( - self.rect, dx, dy, self.vel - ) - if is_collision: - self.vel = vec( - 0, 0 - ) # self.speed[0] + np.sign(self.speed[0]) * r_val, self.speed[1] + np.sign(self.speed[1]) * r_val] - - return False - - def process_collision(self, b_rect, dx, dy, b_speed): - if self.rect.colliderect(b_rect): - is_collision = True - if dx < 0: - b_rect.left = self.rect.right - b_speed[0] = -b_speed[0] - # top or bottom edge - elif dy > 0: - b_rect.bottom = self.rect.top - b_speed[1] = -b_speed[1] - elif dy < 0: - b_rect.top = self.rect.bottom - b_speed[1] = -b_speed[1] - return is_collision, b_rect, b_speed - return False, b_rect, b_speed - - -class agent2(pygame.sprite.Sprite): - """ - This class represents the triangle - """ - - def __init__(self, _screen_width, speed=5): - """ Constructor. Pass in the color of the block, - and its x and y position. """ - # Call the parent class (Sprite) constructor - super().__init__() - - self.image = get_image( - "agent2.jpg" - ) # pygame.draw.polygon(screen, BLACK, [[0, 0], [0, 50], [50, 0]], 5)# - self.screen_width = _screen_width - self.rect = self.image.get_rect() - self.dim = self.rect.size - self.rect.y = 350 - self.mode = True - - self.speed_val = speed - self.vel = vec( - int(self.speed_val * np.cos(np.pi / 4)), - int(self.speed_val * np.sin(np.pi / 4)), + raise ValueError("Fence image not found! Check the spelling") + # elif w_type == "left": + # # self.image = utils.load_image(["vert-fence.png"]) + # self.image = utils.load_image(["top-down-vert-fence.png"]) + + self.rect = self.image.get_rect(topleft=sprite_pos) + + self.body = pm.Body(body_type=pm.Body.STATIC) + + # Transform pygame vertices to fit Pymunk body + invert_verts = utils.invert_y(verts) + self.shape = pm.Poly(self.body, invert_verts) + self.shape.elasticity = 0.0 + self.shape.collision_type = CollisionTypes.BOUNDARY + + self.body.position = utils.flipy(body_pos) + space.add(self.shape) + + def convert_img(self): + self.image = self.image.convert_alpha() + + +class Water(pg.sprite.Sprite): + def __init__(self, pos, verts, space, *sprite_groups): + super().__init__(*sprite_groups) + # Determine the width and height of the surface. + self.image = utils.load_image(["water.png"]) + self.image = pg.transform.scale( + self.image, (const.SCREEN_WIDTH, const.WATER_HEIGHT) ) - self.hit = False - self.bounce_randomness = 1 - - def reset_pos(self): - self.rect.y = random.randrange(350, 600) - self.rect.x = random.randrange(0, self.screen_width[0]) - # print(self.rect.y, self.rect.x) - - def update(self, pos): - """ Called each frame. """ - if self.mode: - # if self.rect.x > 1000: - # self.rect.x = 0 - # self.rect.x += 5 - if self.rect.y > 100: - self.rect.center = pos - # self.rect.x = pos[0] - # else: - # self.reset_pos() + self.rect = self.image.get_rect(topleft=pos) + + self.body = pm.Body(body_type=pm.Body.STATIC) + + # Transform pygame vertices to fit Pymunk body + invert_verts = utils.invert_y(verts) + self.shape = pm.Poly(self.body, invert_verts) + self.shape.collision_type = CollisionTypes.WATER + + # self.shape.friction = 1.0 + self.body.position = utils.flipy(pos) + self.space = space + self.space.add(self.shape) + + def convert_img(self): + self.image = self.image.convert_alpha() + + +class Bank(pg.sprite.Sprite): + def __init__(self, pos, verts, space, *sprite_groups): + super().__init__(sprite_groups) + + self.image = utils.load_image(["bank-2.png"]) + self.image = pg.transform.scale(self.image, (184, 100)) + self.rect = self.image.get_rect(topleft=pos) + + self.body = pm.Body(body_type=pm.Body.STATIC) + + invert_verts = utils.invert_y(verts) + self.shape = pm.Poly(self.body, invert_verts) + self.shape.collision_type = CollisionTypes.BANK + + self.body.position = utils.flipy(pos) + self.space = space + self.space.add(self.shape, self.body) + + def convert_img(self): + self.image = self.image.convert_alpha() + + +class Gold(pg.sprite.Sprite): + ids = it.count(0) + + def __init__(self, pos, body, space, *sprite_groups): + super().__init__(sprite_groups) + self.id = next(self.ids) + + self.image = utils.load_image(["gold", "6.png"]) + self.image = pg.transform.scale(self.image, (16, 16)) + self.orig_image = self.image + + self.rect = self.image.get_rect() + + self.moment = pm.moment_for_circle(1, 0, 8) + self.body = pm.Body(1, self.moment) + self.body.position = body.position + + self.shape = pm.Circle(self.body, 8) + self.shape.collision_type = CollisionTypes.GOLD + self.shape.id = self.id + + self.space = space + self.space.add(self.body, self.shape) + + self.initial_angle = body.angle - Vec2d(0, -1).angle + self.parent_body = body + + def update(self, pos, angle, banker: bool): + + if banker: + new_angle = angle else: - if self.rect.y < 100: - self.change_command() - self.rect.y = 350 - self.rect.y -= 5 - - def change_command(self): - self.mode = not self.mode - - def process_collision(self, b_rect, dx, dy, b_speed): - if self.rect.colliderect(b_rect): - is_collision = True - if dx < 0: - b_rect.left = self.rect.right - b_speed[0] = -b_speed[0] - # top or bottom edge - elif dy > 0: - b_rect.bottom = self.rect.top - b_speed[1] = -b_speed[1] - elif dy < 0: - b_rect.top = self.rect.bottom - b_speed[1] = -b_speed[1] - return is_collision, b_rect, b_speed - return False, b_rect, b_speed - - def move_single_axis(self, dx, dy, area, p1): - self.rect.x += dx - self.rect.y += dy - - if not area.contains(self.rect): - # bottom wall - if dy > 0: - self.rect.bottom = area.bottom - self.vel[1] = -self.vel[1] - # top wall - elif dy < 0: - self.rect.top = area.top - self.vel[1] = -self.vel[1] - # right or left walls - else: - self.vel[0] = -self.vel[0] + new_angle = angle - self.initial_angle + new_pos = pos + Vec2d(const.AGENT_RADIUS + 9, 0).rotated(new_angle) + + self.body.position = new_pos + self.body.angular_velocity = 0 + self.rect.center = utils.flipy(self.body.position) + self.image = pg.transform.rotozoom( + self.orig_image, math.degrees(self.body.angle), 1 + ) + self.rect = self.image.get_rect(center=self.rect.center) + + def convert_img(self): + self.image = self.image.convert_alpha() + + +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.ClipOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, [0, 0, 0], "setting action to 0") + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): + def __init__( + self, + ind_reward=0.8, + group_reward=0.1, + other_group_reward=0.1, + prospec_find_gold_reward=1, + prospec_handoff_gold_reward=1, + banker_receive_gold_reward=1, + banker_deposit_gold_reward=1, + max_frames=900, + seed=None, + ): + if ind_reward + group_reward + other_group_reward != 1.0: + raise ValueError( + "Individual reward, group reward, and other group reward should " + "add up to 1.0" + ) + + self.num_agents = const.NUM_AGENTS + # self.agents = list(range(0, self.num_agents)) + self.agents = [] + + self.sprite_list = [ + "bankers/1-big.png", + "bankers/2-big.png", + "bankers/3-big.png", + "prospector-pickaxe-big.png", + ] + self.rendering = False + self.max_frames = max_frames + self.frame = 0 + + pg.init() + self.rng, seed = seeding.np_random(seed) + self.screen = pg.Surface(const.SCREEN_SIZE) + self.clock = pg.time.Clock() + self.done = False + self.closed = False + + self.background = utils.load_image(["background-debris.png"]) + self.background_rect = pg.Rect(0, 0, *const.SCREEN_SIZE) + self.screen.blit(self.background, self.background_rect) + + self.space = pm.Space() + self.space.gravity = Vec2d(0.0, 0.0) + self.space.damping = 0.0 + + # self.all_sprites = pg.sprite.Group() + self.all_sprites = pg.sprite.RenderUpdates() + self.gold = [] + + # Generate random positions for each prospector agent + prospector_info = [ + (i, utils.rand_pos("prospector", self.rng)) for i in range(const.NUM_PROSPECTORS) + ] + self.prospectors = {} + for num, pos in prospector_info: + prospector = Prospector(pos, self.space, num, self.all_sprites) + identifier = f"prospector_{num}" + self.prospectors[identifier] = prospector + self.agents.append(identifier) + + banker_info = [(i, utils.rand_pos("banker", self.rng)) for i in range(const.NUM_BANKERS)] + self.bankers = {} + for num, pos in banker_info: + banker = Banker(pos, self.space, num, self.all_sprites) + identifier = f"banker_{num}" + self.bankers[identifier] = banker + self.agents.append(identifier) + + self.banks = [] + for pos, verts in const.BANK_INFO: + self.banks.append(Bank(pos, verts, self.space, self.all_sprites)) + + for w_type, s_pos, b_pos, verts in const.FENCE_INFO: + Fence(w_type, s_pos, b_pos, verts, self.space, self.all_sprites) + + Water(const.WATER_INFO[0], const.WATER_INFO[1], self.space, self.all_sprites) + + self.metadata = {"render.modes": ["human"]} + + self.action_spaces = {} + for p in self.prospectors: + self.action_spaces[p] = spaces.Box(low=np.float32(-1.), high=np.float32(1.), shape=(3,)) + + for b in self.bankers: + self.action_spaces[b] = spaces.Box(low=np.float32(-1.), high=np.float32(1.), shape=(3,)) + + self.observation_spaces = {} + self.last_observation = {} + for a in self.agents: + self.last_observation[a] = None + # low, high for RGB values + self.observation_spaces[a] = spaces.Box( + low=0, high=255, shape=const.OBSERVATION_SHAPE, dtype=np.uint8 + ) + + self.agent_order = self.agents[:] + self._agent_selector = agent_selector(self.agent_order) + self.agent_selection = self._agent_selector.next() + self.reset() + + # Collision Handler Functions -------------------------------------------- + # Water to Prospector + def add_gold(arbiter, space, data): + prospec_shape = arbiter.shapes[0] + prospec_body = prospec_shape.body + + position = arbiter.contact_point_set.points[0].point_a + normal = arbiter.contact_point_set.normal + + prospec_body.position = position - (24 * normal) + prospec_body.velocity = (0, 0) + + for k, v in self.prospectors.items(): + if v.body is prospec_body: + self.rewards[k] += ind_reward * prospec_find_gold_reward + else: + self.rewards[k] += group_reward * prospec_find_gold_reward + + for k in self.bankers: + self.rewards[k] += other_group_reward * prospec_find_gold_reward + + if prospec_body.nugget is None: + position = arbiter.contact_point_set.points[0].point_a + + gold = Gold(position, prospec_body, self.space, self.all_sprites) + self.gold.append(gold) + prospec_body.nugget = gold + + return True + + # Prospector to banker + def handoff_gold_handler(arbiter, space, data): + banker_shape, gold_shape = arbiter.shapes[0], arbiter.shapes[1] + + gold_sprite = None + for g in self.gold: + if g.id == gold_shape.id: + gold_sprite = g + + # This collision handler is only for prospector -> banker gold handoffs + if gold_sprite.parent_body.sprite_type != "prospector": return True + banker_body = banker_shape.body + prospec_body = gold_sprite.parent_body + + for k, v in self.prospectors.items(): + self.rewards[k] += other_group_reward * banker_receive_gold_reward + if v.body is prospec_body: + self.rewards[k] += prospec_handoff_gold_reward + else: + self.rewards[k] += group_reward * prospec_handoff_gold_reward + + for k, v in self.bankers.items(): + self.rewards[k] += other_group_reward * prospec_handoff_gold_reward + if v.body is banker_body: + self.rewards[k] += banker_receive_gold_reward + else: + self.rewards[k] += group_reward * banker_receive_gold_reward + + normal = arbiter.contact_point_set.normal + # Correct the angle because banker's head is rotated pi/2 + corrected = utils.normalize_angle(banker_body.angle + (math.pi / 2)) + if ( + corrected - const.BANKER_HANDOFF_TOLERANCE + <= normal.angle + <= corrected + const.BANKER_HANDOFF_TOLERANCE + ): + gold_sprite.parent_body.nugget = None + + gold_sprite.parent_body = banker_body + banker_body.nugget = gold_sprite + banker_body.nugget_offset = normal.angle + + return True + + # Banker to bank + def gold_score_handler(arbiter, space, data): + gold_shape, _ = arbiter.shapes[0], arbiter.shapes[1] + + for g in self.gold: + if g.id == gold_shape.id: + gold_class = g + + if gold_class.parent_body.sprite_type == "banker": + self.space.remove(gold_shape, gold_shape.body) + gold_class.parent_body.nugget = None + banker_body = gold_class.parent_body + + for k, v in self.bankers.items(): + if v.body is banker_body: + self.rewards[k] += banker_deposit_gold_reward + # banker_sprite = v + else: + self.rewards[k] += group_reward * banker_deposit_gold_reward + + for k in self.prospectors: + self.rewards[k] += other_group_reward * banker_deposit_gold_reward + + self.gold.remove(gold_class) + self.all_sprites.remove(gold_class) + + return False + + # Create the collision event generators + gold_dispenser = self.space.add_collision_handler( + CollisionTypes.PROSPECTOR, CollisionTypes.WATER + ) + + gold_dispenser.begin = add_gold + + handoff_gold = self.space.add_collision_handler( + CollisionTypes.BANKER, CollisionTypes.GOLD + ) + + handoff_gold.begin = handoff_gold_handler + + gold_score = self.space.add_collision_handler( + CollisionTypes.GOLD, CollisionTypes.BANK + ) + + gold_score.begin = gold_score_handler + + def observe(self, agent): + capture = pg.surfarray.pixels3d(self.screen) + if agent in self.prospectors: + ag = self.prospectors[agent] else: - r_val = 0 - if self.bounce_randomness: - r_val = get_small_random_value() + ag = self.bankers[agent] - # ball in left half of screen - is_collision, self.rect, self.vel = p1.process_collision( - self.rect, dx, dy, self.vel - ) - if is_collision: - self.vel = vec( - 0, 0 - ) # self.speed[0] + np.sign(self.speed[0]) * r_val, self.speed[1] + np.sign(self.speed[1]) * r_val] - - return False - - -class env(gym.Env): - def __init__(self): - super(env, self).__init__() - global agent2, agent1 - pygame.init() - - # Set the width and height of the screen [width, height] - size = (1002, 699) - self.screen = pygame.display.set_mode(size) - background = get_image("background.jpg") - pygame.display.set_caption("My Game") - self.screen.blit(background, (0, 0)) - self.area = self.screen.get_rect() - - # Loop until the user clicks the close button. - done = False - - # Used to manage how fast the screen updates - clock = pygame.time.Clock() - agent1 = agent1(size, x=50, y=50, speed=20) - agent2 = agent2(size) - - block_list, all_sprites_list = self.create_targets() - - vis = pygame.sprite.Group() # Visualize block that is being carried by agent 1 - vis2 = pygame.sprite.Group() # Visualize block that is being carried by agent 2 - block_picked = None - block_transfered = None - flag = 0 - blocks_hit_list = [] - # cropped = pygame.Surface((100,100)) - - # -------- Main Program Loop ----------- - while not done: - # --- Main event loop - for event in pygame.event.get(): - if event.type == pygame.QUIT: - done = True - - self.screen.blit(background, (0, 0)) - self.screen.blit(agent1.image, agent1.rect) - # cropped.blit(agent1.image, (agent1.rect.x,agent1.rect.y)) - self.screen.blit(agent2.image, agent2.rect) - pos = pygame.mouse.get_pos() - agent1.update(pos, self.area, agent2) - agent2.update(pos) # , self.area, agent1) - if flag == 0: - blocks_hit_list = pygame.sprite.spritecollide(agent1, block_list, True) - pygame.draw.circle(self.screen, RED, agent1.rect.topleft, 5) - while len(blocks_hit_list) > 1: - block_list.add(blocks_hit_list.pop()) - if blocks_hit_list: - # print(len(blocks_hit_list), len(block_list)) - vis.add(blocks_hit_list[0]) - block_picked = blocks_hit_list[0] - flag = 1 - if block_picked: - corner = agent1.check_collision(block_picked.rect) - - block_picked.update(agent1.rect, corner) - agent1.rotate_flag = True - # --- Go ahead and update the screen with what we've drawn. - # print(len(block_list)) - if agent1.rect.y < 355: - flag = 0 - block_picked = None - # agent1.rotate_flag = False - - blocks_transfer_list = pygame.sprite.spritecollide(agent2, vis, True) - if blocks_transfer_list: - block_transfered = blocks_transfer_list[0] - block_transfered.update(agent2.rect) - vis2.add(block_transfered) - agent2.change_command() - block_picked = None - flag = 0 - - if block_transfered: - block_transfered.update(agent2.rect) - if agent2.rect.y < 105: - block_transfered = None - - block_list.draw(self.screen) - vis.draw(self.screen) - vis2.draw(self.screen) - pygame.display.flip() - clock.tick(10) - - pygame.quit() - - def create_targets(self): - block_list = pygame.sprite.Group() - - # This is a list of every sprite. - # All blocks and the player block as well. - all_sprites_list = pygame.sprite.Group() - x = 20 - for i in range(18): - # This represents a block - block = Block() - # Set a random location for the block - block.rect.x = x - x += 75 - block.rect.y = 630 - # Add the block to the list of objects - block_list.add(block) - all_sprites_list.add(block) - - return block_list, all_sprites_list + assert ag is not None + + delta = const.OBSERVATION_SIDE_LENGTH // 2 + x, y = ag.center # Calculated property added to prospector and banker classes + sub_screen = np.array(capture[ + max(0, x - delta): min(const.SCREEN_WIDTH, x + delta), + max(0, y - delta): min(const.SCREEN_HEIGHT, y + delta), :], dtype=np.uint8) + + s_x, s_y, _ = sub_screen.shape + pad_x = const.OBSERVATION_SIDE_LENGTH - s_x + if x > const.SCREEN_WIDTH - delta: # Right side of the screen + sub_screen = np.pad(sub_screen, pad_width=((0, pad_x), (0, 0), (0, 0)), mode='constant') + elif x < 0 + delta: + sub_screen = np.pad(sub_screen, pad_width=((pad_x, 0), (0, 0), (0, 0)), mode='constant') + + pad_y = const.OBSERVATION_SIDE_LENGTH - s_y + if y > const.SCREEN_HEIGHT - delta: # Bottom of the screen + sub_screen = np.pad(sub_screen, pad_width=((0, 0), (0, pad_y), (0, 0)), mode='constant') + elif y < 0 + delta: + sub_screen = np.pad(sub_screen, pad_width=((0, 0), (pad_y, 0), (0, 0)), mode='constant') + + self.last_observation[agent] = sub_screen + + return sub_screen + + def step(self, action, observe=True): + agent_id = self.agent_selection + + if agent_id in self.prospectors: + agent = self.prospectors[agent_id] + else: + agent = self.bankers[agent_id] + + agent.update(action) + + all_agents_updated = self._agent_selector.is_last() + # Only take next step in game if all agents have received an action + if all_agents_updated: + if self.rendering: + self.clock.tick(const.FPS) + else: + self.clock.tick() + self.space.step(1 / const.FPS) + + self.draw() + + self.frame += 1 + # If we reached max frames, we're done + if self.frame == self.max_frames: + self.dones = dict(zip(self.agents, [True for _ in self.agents])) + + if self.rendering: + pg.event.pump() + + self.agent_selection = self._agent_selector.next() + + if observe: + return self.observe(self.agent_selection) + + def reward(self): + return self.rewards + + def reset(self, observe=True): + self.screen = pg.Surface(const.SCREEN_SIZE) + self.screen.blit(self.background, self.background_rect) + self.done = False + + for p in self.prospectors.values(): + p.reset(utils.rand_pos("prospector", self.rng)) + + for b in self.bankers.values(): + b.reset(utils.rand_pos("banker", self.rng)) + + self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) + self.dones = dict(zip(self.agents, [False for _ in self.agents])) + self.infos = dict(zip(self.agents, [{} for _ in self.agents])) + self.metadata = {"render.modes": ["human"]} + self.rendering = False + self.frame = 0 + + self.agent_order = self.agents[:] + self._agent_selector.reinit(self.agent_order) + self.agent_selection = self._agent_selector.next() + self.draw() + if observe: + return self.observe(self.agent_selection) + + def render(self, mode="human"): + if not self.rendering: + pg.display.init() + self.screen = pg.display.set_mode(const.SCREEN_SIZE) + self.background = self.background.convert_alpha() + self.screen.blit(self.background, self.background_rect) + for s in self.all_sprites.sprites(): + s.convert_img() + self.rendering = True + self.draw() + pg.display.flip() + + def draw(self): + self.screen.blit(self.background, self.background_rect) + self.all_sprites.draw(self.screen) + + def close(self): + if not self.closed: + self.closed = True + if self.rendering: + pg.event.pump() + pg.display.quit() + pg.quit() + + +# class env(gym.Env): +# def __init__(self): +# super().__init__() +# global agent2, agent1 +# pygame.init() + +# # Set the width and height of the screen [width, height] +# size = (1002, 699) +# self.screen = pygame.display.set_mode(size) +# background = get_image("background.jpg") +# pygame.display.set_caption("My Game") +# self.screen.blit(background, (0, 0)) +# self.area = self.screen.get_rect() + +# # Loop until the user clicks the close button. +# done = False + +# # Used to manage how fast the screen updates +# clock = pygame.time.Clock() +# agent1 = agent1(size, x=50, y=50, speed=20) +# agent2 = agent2(size) + +# block_list, all_sprites_list = self.create_targets() + +# vis = pygame.sprite.Group() # Visualize block that is being carried by agent 1 +# vis2 = pygame.sprite.Group() # Visualize block that is being carried by agent 2 +# block_picked = None +# block_transfered = None +# flag = 0 +# blocks_hit_list = [] +# # cropped = pygame.Surface((100,100)) + +# # -------- Main Program Loop ----------- +# while not done: +# # --- Main event loop +# for event in pygame.event.get(): +# if event.type == pygame.QUIT: +# done = True + +# self.screen.blit(background, (0, 0)) +# self.screen.blit(agent1.image, agent1.rect) +# # cropped.blit(agent1.image, (agent1.rect.x,agent1.rect.y)) +# self.screen.blit(agent2.image, agent2.rect) +# pos = pygame.mouse.get_pos() +# agent1.update(pos, self.area, agent2) +# agent2.update(pos) # , self.area, agent1) +# if flag == 0: +# blocks_hit_list = pygame.sprite.spritecollide(agent1, block_list, True) +# pygame.draw.circle(self.screen, RED, agent1.rect.topleft, 5) +# while len(blocks_hit_list) > 1: +# block_list.add(blocks_hit_list.pop()) +# if blocks_hit_list: +# # print(len(blocks_hit_list), len(block_list)) +# vis.add(blocks_hit_list[0]) +# block_picked = blocks_hit_list[0] +# flag = 1 +# if block_picked: +# corner = agent1.check_collision(block_picked.rect) + +# block_picked.update(agent1.rect, corner) +# agent1.rotate_flag = True +# # --- Go ahead and update the screen with what we've drawn. +# # print(len(block_list)) +# if agent1.rect.y < 355: +# flag = 0 +# block_picked = None +# # agent1.rotate_flag = False + +# blocks_transfer_list = pygame.sprite.spritecollide(agent2, vis, True) +# if blocks_transfer_list: +# block_transfered = blocks_transfer_list[0] +# block_transfered.update(agent2.rect) +# vis2.add(block_transfered) +# agent2.change_command() +# block_picked = None +# flag = 0 + +# if block_transfered: +# block_transfered.update(agent2.rect) +# if agent2.rect.y < 105: +# block_transfered = None + +# block_list.draw(self.screen) +# vis.draw(self.screen) +# vis2.draw(self.screen) +# pygame.display.flip() +# clock.tick(10) + +# pygame.quit() + +# def create_targets(self): +# block_list = pygame.sprite.Group() + +# # This is a list of every sprite. +# # All blocks and the player block as well. +# all_sprites_list = pygame.sprite.Group() +# x = 20 +# for i in range(18): +# # This represents a block +# block = Block() +# # Set a random location for the block +# block.rect.x = x +# x += 75 +# block.rect.y = 630 +# # Add the block to the list of objects +# block_list.add(block) +# all_sprites_list.add(block) + +# return block_list, all_sprites_list +# def draw(self): +# self.screen.blit(self.background, self.background_rect) +# self.all_sprites.draw(self.screen) + +# for p in self.prospectors: +# if p.body.nugget is not None: +# p.body.nugget.draw(self.screen) + +# for b in self.bankers: +# if b.body.nugget is not None: +# b.body.nugget.draw(self.screen) + +# def close(self): +# pg.event.pump() +# pg.display.quit() +# pg.quit() + +# Except for the gold png images, all other sprite art was created by Yashas Lokesh diff --git a/pettingzoo/gamma/prospector/pymunk-game.py b/pettingzoo/gamma/prospector/pymunk-game.py deleted file mode 100644 index 9c2d908cb..000000000 --- a/pettingzoo/gamma/prospector/pymunk-game.py +++ /dev/null @@ -1,919 +0,0 @@ -import pygame as pg -import pygame.locals as locals -import pymunk as pm -from pymunk import Vec2d -from gym import spaces -import numpy as np - -from pettingzoo import AECEnv -from pettingzoo.utils import agent_selector -import pettingzoo.gamma.prospector.constants as const -import pettingzoo.gamma.prospector.utils as utils - -import math -import os -from enum import IntEnum, auto -import itertools as it -from random import randint - - -class CollisionTypes(IntEnum): - PROSPECTOR = auto() - BOUNDARY = auto() - WATER = auto() - CHEST = auto() - GOLD = auto() - BANKER = auto() - - -class Prospector(pg.sprite.Sprite): - def __init__(self, pos, space, *sprite_groups): - super().__init__(sprite_groups) - # self.image = load_image(['prospec.png']) - self.image = utils.load_image(["prospector-pickaxe-big.png"]) - self.image = pg.transform.scale( - self.image, (int(const.AGENT_RADIUS * 2), int(const.AGENT_RADIUS * 2)) - ) - - self.rect = self.image.get_rect() - self.orig_image = self.image - - # Create the physics body and shape of this object. - # moment = pm.moment_for_poly(mass, vertices) - - moment = pm.moment_for_circle(1, 0, self.rect.width / 2) - - self.body = pm.Body(1, moment) - self.body.nugget = None - self.body.sprite_type = "prospector" - # self.shape = pm.Poly(self.body, vertices, radius=3) - - self.shape = pm.Circle(self.body, const.AGENT_RADIUS) - self.shape.elasticity = 0.0 - self.shape.collision_type = CollisionTypes.PROSPECTOR - - self.body.position = utils.flipy(pos) - # Add them to the Pymunk space. - self.space = space - self.space.add(self.body, self.shape) - - @property - def center(self): - return self.rect.x + const.AGENT_RADIUS, self.rect.y + const.AGENT_RADIUS - - def _update(self, action): - # forward/backward action - y_vel = action[0] * const.PROSPECTOR_SPEED - # left/right action - x_vel = action[1] * const.PROSPECTOR_SPEED - - delta_angle = action[2] * const.MAX_SPRITE_ROTATION - - self.body.angle += delta_angle - self.body.angular_velocity = 0 - - self.body.velocity = Vec2d(x_vel, y_vel).rotated(self.body.angle) - - self.rect.center = utils.flipy(self.body.position) - self.image = pg.transform.rotozoom( - self.orig_image, math.degrees(self.body.angle), 1 - ) - self.rect = self.image.get_rect(center=self.rect.center) - - curr_vel = self.body.velocity - - if self.body.nugget is not None: - self.body.nugget.update(self.body.position, self.body.angle, False) - - self.body.velocity = curr_vel - - def update(self, keys): - x = y = 0 - if keys[locals.K_w]: - y = 1 - if keys[locals.K_s]: - y = -1 - if keys[locals.K_d]: - x = 1 - if keys[locals.K_a]: - x = -1 - if keys[locals.K_q]: - self.body.angle += 0.1 - self.body.angular_velocity = 0 - if keys[locals.K_e]: - self.body.angle -= 0.1 - self.body.angular_velocity = 0 - - if x != 0 and y != 0: - self.body.velocity = Vec2d(x, y).rotated(self.body.angle) * ( - const.PROSPECTOR_SPEED / math.sqrt(2) - ) - else: - self.body.velocity = ( - Vec2d(x, y).rotated(self.body.angle) * const.PROSPECTOR_SPEED - ) - - # Rotate the image of the sprite. - self.rect.center = utils.flipy(self.body.position) - self.image = pg.transform.rotozoom( - self.orig_image, math.degrees(self.body.angle), 1 - ) - self.rect = self.image.get_rect(center=self.rect.center) - - curr_vel = self.body.velocity - - if self.body.nugget is not None: - self.body.nugget.update(self.body.position, self.body.angle, False) - - self.body.velocity = curr_vel - - -class Banker(pg.sprite.Sprite): - def __init__(self, pos, space, num, *sprite_groups): - super().__init__(sprite_groups) - self.image = utils.load_image(["bankers", "%s-big.png" % num]) - self.image = pg.transform.scale( - self.image, (int(const.AGENT_RADIUS * 2), int(const.AGENT_RADIUS * 2)) - ) - - self.rect = self.image.get_rect() - self.orig_image = self.image - - moment = pm.moment_for_circle(1, 0, self.rect.width / 2) - - self.body = pm.Body(1, moment) - self.body.nugget = None - self.body.sprite_type = "banker" - self.body.nugget_offset = None - - self.shape = pm.Circle(self.body, const.AGENT_RADIUS) - self.shape.collision_type = CollisionTypes.BANKER - - self.body.position = utils.flipy(pos) - # Add them to the Pymunk space. - self.space = space - self.space.add(self.body, self.shape) - - @property - def center(self): - return self.rect.x + const.AGENT_RADIUS, self.rect.y + const.AGENT_RADIUS - - def update(self, keys): - move = 0 - if any( - keys[key] - for key in (locals.K_UP, locals.K_DOWN, locals.K_RIGHT, locals.K_LEFT,) - ): - move = 1 - - if keys[locals.K_UP]: - self.body.angle = 0 - elif keys[locals.K_DOWN]: - self.body.angle = math.pi - elif keys[locals.K_RIGHT]: - self.body.angle = -math.pi / 2 - elif keys[locals.K_LEFT]: - self.body.angle = math.pi / 2 - - self.body.velocity = ( - Vec2d(0, move).rotated(self.body.angle) * const.BANKER_SPEED - ) - - # Rotate the image of the sprite. - self.rect.center = utils.flipy(self.body.position) - self.image = pg.transform.rotozoom( - self.orig_image, math.degrees(self.body.angle), 1 - ) - - self.rect = self.image.get_rect(center=self.rect.center) - - curr_vel = self.body.velocity - - if self.body.nugget is not None: - corrected = utils.normalize_angle(self.body.angle + (math.pi / 2)) - self.body.nugget.update(self.body.position, corrected, True) - - self.body.velocity = curr_vel - - def _update(self, action): - # left/right action - x_vel = action[0] * const.PROSPECTOR_SPEED - # up/down action - y_vel = action[1] * const.PROSPECTOR_SPEED - - # Add math.pi / 2 because sprite is facing upwards at the start - angle_radians = math.atan2(x_vel, y_vel) + (math.pi / 2) - - # Angle is determined only by current trajectory. - self.body.angle = angle_radians - self.body.angular_velocity = 0 - - self.body.velocity = Vec2d(x_vel, y_vel).rotated(self.body.angle) - - self.rect.center = utils.flipy(self.body.position) - self.image = pg.transform.rotozoom( - self.orig_image, math.degrees(self.body.angle), 1 - ) - self.rect = self.image.get_rect(center=self.rect.center) - - curr_vel = self.body.velocity - - if self.body.nugget is not None: - self.body.nugget.update(self.body.position, self.body.angle, False) - - self.body.velocity = curr_vel - - -class Fence(pg.sprite.Sprite): - def __init__(self, w_type, sprite_pos, body_pos, verts, space, *sprite_groups): - super().__init__(sprite_groups) - - if w_type == "top": - # self.image = utils.load_image(["horiz-fence.png"]) - self.image = utils.load_image(["top-down-horiz-fence.png"]) - elif w_type in ["right", "left"]: - # self.image = utils.load_image(["vert-fence.png"]) - self.image = utils.load_image(["top-down-vert-fence.png"]) - else: - raise ValueError("Fence image not found! Check the spelling") - # elif w_type == "left": - # # self.image = utils.load_image(["vert-fence.png"]) - # self.image = utils.load_image(["top-down-vert-fence.png"]) - - self.rect = self.image.get_rect(topleft=sprite_pos) - - self.body = pm.Body(body_type=pm.Body.STATIC) - - # Transform pygame vertices to fit Pymunk body - invert_verts = utils.invert_y(verts) - self.shape = pm.Poly(self.body, invert_verts) - self.shape.elasticity = 0.0 - self.shape.collision_type = CollisionTypes.BOUNDARY - - self.body.position = utils.flipy(body_pos) - space.add(self.shape) - - -class Water(pg.sprite.Sprite): - def __init__(self, pos, verts, space, *sprite_groups): - super().__init__(*sprite_groups) - # Determine the width and height of the surface. - self.image = utils.load_image(["water.png"]) - self.image = pg.transform.scale( - self.image, (const.SCREEN_WIDTH, const.WATER_HEIGHT) - ) - - self.rect = self.image.get_rect(topleft=pos) - - self.body = pm.Body(body_type=pm.Body.STATIC) - - # Transform pygame vertices to fit Pymunk body - invert_verts = utils.invert_y(verts) - self.shape = pm.Poly(self.body, invert_verts) - self.shape.collision_type = CollisionTypes.WATER - - # self.shape.friction = 1.0 - self.body.position = utils.flipy(pos) - self.space = space - self.space.add(self.shape) - - -class Bank(pg.sprite.Sprite): - def __init__(self, pos, verts, space, *sprite_groups): - super().__init__(sprite_groups) - - self.image = utils.load_image(["bank-2.png"]) - self.image = pg.transform.scale(self.image, (184, 100)) - self.rect = self.image.get_rect(topleft=pos) - - self.body = pm.Body(body_type=pm.Body.STATIC) - self.body.score = 0 - - invert_verts = utils.invert_y(verts) - self.shape = pm.Poly(self.body, invert_verts) - self.shape.collision_type = CollisionTypes.CHEST - - self.body.position = utils.flipy(pos) - self.space = space - self.space.add(self.shape, self.body) - - def __str__(self): - return str(self.body.score) - - -class Gold(pg.sprite.Sprite): - ids = it.count(0) - - def __init__(self, pos, body, space, *sprite_groups): - super().__init__(sprite_groups) - self.id = next(self.ids) - - self.image = utils.load_image(["gold", "6.png"]) - self.image = pg.transform.scale(self.image, (16, 16)) - self.orig_image = self.image - - self.rect = self.image.get_rect() - - self.moment = pm.moment_for_circle(1, 0, 8) - self.body = pm.Body(1, self.moment) - self.body.position = body.position - - self.shape = pm.Circle(self.body, 8) - self.shape.collision_type = CollisionTypes.GOLD - self.shape.id = self.id - - self.space = space - self.space.add(self.body, self.shape) - - self.initial_angle = body.angle - Vec2d(0, -1).angle - self.parent_body = body - - def update(self, pos, angle, banker: bool): - - if banker: - new_angle = angle - else: - new_angle = angle - self.initial_angle - new_pos = pos + Vec2d(const.AGENT_RADIUS + 9, 0).rotated(new_angle) - - self.body.position = new_pos - self.body.angular_velocity = 0 - self.rect.center = utils.flipy(self.body.position) - self.image = pg.transform.rotozoom( - self.orig_image, math.degrees(self.body.angle), 1 - ) - self.rect = self.image.get_rect(center=self.rect.center) - - def draw(self, surf): - surf.blit(self.image, self.rect) - - -class Game: - def __init__(self): - self.done = False - self.screen = pg.display.set_mode(const.SCREEN_SIZE) - self.clock = pg.time.Clock() - - self.background = utils.load_image(["background-debris.png"]) - # self.background_rect = pg.Rect(0, 0, *const.SCREEN_SIZE) - self.background_rect = self.background.get_rect(topleft=(0, 0)) - - self.space = pm.Space() - self.space.gravity = Vec2d(0.0, 0.0) - # self.space.damping = 0.5 - self.space.damping = 0.0 - - self.all_sprites = pg.sprite.Group() - self.gold = [] - - prospec_info = [utils.rand_pos("prospector") for _ in range(3)] - self.prospectors = [] - for pos in prospec_info: - self.prospectors.append(Prospector(pos, self.space, self.all_sprites)) - - banker_info = ( - (1, utils.rand_pos("banker")), - (2, utils.rand_pos("banker")), - (3, utils.rand_pos("banker")), - ) - self.bankers = [] - for num, pos in banker_info: - self.bankers.append(Banker(pos, self.space, num, self.all_sprites)) - - chest_verts = ( - (0, 0), - (184, 0), - (184, 100), - (0, 100), - ) - - chest_info = [ - ([184 * 1, 50], chest_verts), - ([184 * 3, 50], chest_verts), - ([184 * 5, 50], chest_verts), - ] - - self.chests = [] - for pos, verts in chest_info: - self.chests.append(Bank(pos, verts, self.space, self.all_sprites)) - - for w_type, s_pos, b_pos, verts in const.FENCE_INFO: - Fence(w_type, s_pos, b_pos, verts, self.space, self.all_sprites) - - water_info = { - "pos": (0, const.SCREEN_HEIGHT - const.WATER_HEIGHT), - "verts": ( - (0, 0), - (const.SCREEN_WIDTH, 0), - (const.SCREEN_WIDTH, const.WATER_HEIGHT), - (0, const.WATER_HEIGHT), - ), - } - - Water(water_info["pos"], water_info["verts"], self.space, self.all_sprites) - - def add_gold(arbiter, space, data): - prospec = arbiter.shapes[0] - prospec_body = prospec.body - - position = arbiter.contact_point_set.points[0].point_a - normal = arbiter.contact_point_set.normal - - prospec_body.position = position - (24 * normal) - prospec_body.velocity = (0, 0) - - if prospec_body.nugget is None: - - position = arbiter.contact_point_set.points[0].point_a - - gold = Gold(position, prospec_body, self.space) - self.gold.append(gold) - - prospec_body.nugget = gold - - return True - - gold_dispenser = self.space.add_collision_handler( - CollisionTypes.PROSPECTOR, CollisionTypes.WATER - ) - - gold_dispenser.begin = add_gold - - def handoff_gold_handler(arbiter, space, data): - banker, gold = arbiter.shapes[0], arbiter.shapes[1] - - gold_class = None - for g in self.gold: - if g.id == gold.id: - gold_class = g - - if gold_class.parent_body.sprite_type == "prospector": - - banker_body = banker.body - - normal = arbiter.contact_point_set.normal - - corrected = utils.normalize_angle(banker_body.angle + (math.pi / 2)) - - if ( - corrected - const.BANKER_HANDOFF_TOLERANCE - <= normal.angle - <= corrected + const.BANKER_HANDOFF_TOLERANCE - ): - - gold_class.parent_body.nugget = None - - gold_class.parent_body = banker_body - banker_body.nugget = gold_class - banker_body.nugget_offset = normal.angle - - return True - - handoff_gold = self.space.add_collision_handler( - CollisionTypes.BANKER, CollisionTypes.GOLD - ) - - handoff_gold.begin = handoff_gold_handler - - def gold_score_handler(arbiter, space, data): - gold, chest = arbiter.shapes[0], arbiter.shapes[1] - - gold_class = None - for g in self.gold: - if g.id == gold.id: - gold_class = g - - if gold_class.parent_body.sprite_type == "banker": - - chest.body.score += 1 - - self.space.remove(gold, gold.body) - - gold_class.parent_body.nugget = None - - # total_score = ", ".join( - # [ - # "Chest %d: %d" % (i, c.body.score) - # for i, c in enumerate(self.chests) - # ] - # ) - - # print(total_score) - - self.gold.remove(gold_class) - self.all_sprites.remove(gold_class) - - return False - - gold_score = self.space.add_collision_handler( - CollisionTypes.GOLD, CollisionTypes.CHEST - ) - - gold_score.begin = gold_score_handler - - def run(self): - while not self.done: - for event in pg.event.get(): - if event.type == locals.QUIT or ( - event.type == locals.KEYDOWN and event.key in [locals.K_ESCAPE] - ): - self.done = True - - self.dt = self.clock.tick(15) - - self.space.step(1 / 15) - self.all_sprites.update(pg.key.get_pressed()) - - self.draw() - - def draw(self): - # self.screen.fill(BACKGROUND_COLOR) - # self.background.blit(self.screen) - self.screen.blit(self.background, self.background_rect) - self.all_sprites.draw(self.screen) - - for p in self.prospectors: - if p.body.nugget is not None: - p.body.nugget.draw(self.screen) - - for b in self.bankers: - if b.body.nugget is not None: - b.body.nugget.draw(self.screen) - - pg.display.flip() - - -class env(AECEnv): - def __init__( - self, - ind_reward=0.8, - group_reward=0.1, - other_group_reward=0.1, - prospec_find_gold_reward=1, - prospec_handoff_gold_reward=1, - banker_receive_gold_reward=1, - banker_deposit_gold_reward=1, - max_frames=900, - ): - if ind_reward + group_reward + other_group_reward != 1.0: - raise ValueError( - "Individual reward, group reward, and other group reward should " - "add up to 1.0" - ) - - self.num_agents = 7 - # self.agents = list(range(0, self.num_agents)) - self.agents = [] - - self.sprite_list = [ - "bankers/1-big.png", - "bankers/2-big.png", - "bankers/3-big.png", - "prospector-pickaxe-big.png", - ] - self.rendering = False - self.max_frames = max_frames - self.frame = 0 - - # TODO: Setup game data here - pg.init() - self.clock = pg.time.Clock() - self.done = False - - self.background = utils.load_image(["test.png"]) - self.background_rect = pg.Rect(0, 0, *const.SCREEN_SIZE) - - self.space = pm.Space() - self.space.gravity = Vec2d(0.0, 0.0) - self.space.damping = 0.0 - - self.all_sprites = pg.sprite.Group() - self.gold = [] - - # Generate random positions for each prospector agent - prospector_info = [utils.rand_pos("prospector") for _ in range(4)] - self.prospectors = [] - for pos in prospector_info: - prospector = Prospector(pos, self.space, self.all_sprites) - self.prospectors.append(prospector) - self.agents.append(prospector) - - banker_info = [(i + 1, utils.rand_pos("banker")) for i in range(3)] - self.bankers = [] - for num, pos in banker_info: - banker = Banker(pos, self.space, num, self.all_sprites) - self.bankers.append(banker) - self.agents.append(banker) - - # Create these dictionaries after self.agents is populated - self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) - self.dones = dict(zip(self.agents, [False for _ in self.agents])) - self.infos = dict(zip(self.agents, [[] for _ in self.agents])) - self.metadata = {"render.modes": ["human"]} - - # TODO: Setup action spaces - self.action_spaces = {} - for a in self.agents: - num_actions = 0 - if type(a) is Prospector: - num_actions = 6 - else: - num_actions = 4 - self.action_spaces[a] = spaces.Discrete(num_actions + 1) - - # TODO: Setup observation spaces - self.observation_spaces = {} - self.last_observation = {} - for a in self.agents: - self.last_observation[a] = None - # low, high for RGB values - self.observation_spaces[a] = spaces.Box( - low=0, high=255, shape=const.OBSERVATION_SHAPE - ) - - """ Finish setting up environment agents """ - self.agent_order = self.agents[:] - self._agent_selector = agent_selector(self.agent_order) - self.agent_selection = self._agent_selector.next() - - self.banks = [] - for pos, verts in const.BANK_INFO: - self.banks.append(Bank(pos, verts, self.space, self.all_sprites)) - - for w_type, pos, verts in const.BOUNDARY_INFO: - Fence(w_type, pos, verts, self.space, self.all_sprites) - - Water(const.WATER_INFO[0], const.WATER_INFO[1], self.space, self.all_sprites) - - # Collision Handler Functions -------------------------------------------- - # Water to Prospector - def add_gold(arbiter, space, data): - prospec_shape = arbiter.shapes[0] - prospec_body = prospec_shape.body - - position = arbiter.contact_point_set.points[0].point_a - normal = arbiter.contact_point_set.normal - - prospec_body.position = position - (24 * normal) - prospec_body.velocity = (0, 0) - - prospec_sprite = None - for p in self.prospectors: - if p.body is prospec_body: - prospec_sprite = p - self.rewards[prospec_sprite] += ind_reward * prospec_find_gold_reward - - for a in self.agents: - if isinstance(a, Prospector) and a is not prospec_sprite: - self.rewards[a] += group_reward * prospec_find_gold_reward - elif isinstance(a, Banker): - self.rewards[a] += other_group_reward * prospec_find_gold_reward - - if prospec_body.nugget is None: - position = arbiter.contact_point_set.points[0].point_a - - gold = Gold(position, prospec_body, self.space) - self.gold.append(gold) - prospec_body.nugget = gold - - return True - - # Prospector to banker - def handoff_gold_handler(arbiter, space, data): - banker_shape, gold_shape = arbiter.shapes[0], arbiter.shapes[1] - - gold_sprite = None - for g in self.gold: - if g.id == gold_shape.id: - gold_sprite = g - - if gold_sprite.parent_body.sprite_type == "prospector": - banker_body = banker_shape.body - prospec_body = gold_sprite.parent_body - - prospec_sprite = None - for p in self.prospectors: - if p.body is prospec_body: - prospec_sprite = p - self.rewards[prospec_sprite] += prospec_handoff_gold_reward - - for a in self.agents: - if isinstance(a, Prospector) and a is not prospec_sprite: - self.rewards[a] += group_reward * prospec_handoff_gold_reward - elif isinstance(a, Banker): - self.rewards[a] += ( - other_group_reward * prospec_handoff_gold_reward - ) - - banker_sprite = None - for b in self.bankers: - if b.shape is banker_shape: - banker_sprite = b - self.rewards[banker_sprite] += banker_receive_gold_reward - - for a in self.agents: - if isinstance(a, Prospector): - self.rewards[a] += ( - other_group_reward * banker_receive_gold_reward - ) - elif isinstance(a, Banker) and a is not banker_sprite: - self.rewards[a] += group_reward * banker_receive_gold_reward - - normal = arbiter.contact_point_set.normal - # Correct the angle because banker's head is rotated pi/2 - corrected = utils.normalize_angle(banker_body.angle + (math.pi / 2)) - if ( - corrected - const.BANKER_HANDOFF_TOLERANCE - <= normal.angle - <= corrected + const.BANKER_HANDOFF_TOLERANCE - ): - gold_sprite.parent_body.nugget = None - - gold_sprite.parent_body = banker_body - banker_body.nugget = gold_sprite - banker_body.nugget_offset = normal.angle - - return True - - # Banker to chest - def gold_score_handler(arbiter, space, data): - gold_shape, chest = arbiter.shapes[0], arbiter.shapes[1] - - gold_class = None - for g in self.gold: - if g.id == gold_shape.id: - gold_class = g - - if gold_class.parent_body.sprite_type == "banker": - chest.body.score += 1 - self.space.remove(gold_shape, gold_shape.body) - gold_class.parent_body.nugget = None - banker_body = gold_class.parent_body - - banker_sprite = None - for b in self.bankers: - if b.body is banker_body: - banker_sprite = b - self.rewards[banker_sprite] += banker_deposit_gold_reward - - for a in self.agents: - if isinstance(a, Prospector): - self.rewards[a] += ( - other_group_reward * banker_deposit_gold_reward - ) - elif isinstance(a, Banker) and a is not banker_sprite: - self.rewards[a] += group_reward * banker_deposit_gold_reward - - # total_score = ", ".join( - # [ - # "Chest %d: %d" % (i, c.body.score) - # for i, c in enumerate(self.chests) - # ] - # ) - - # print(total_score) - self.gold.remove(gold_class) - self.all_sprites.remove(gold_class) - - return False - - # Create the collision event generators - gold_dispenser = self.space.add_collision_handler( - CollisionTypes.PROSPECTOR, CollisionTypes.WATER - ) - - gold_dispenser.begin = add_gold - - handoff_gold = self.space.add_collision_handler( - CollisionTypes.BANKER, CollisionTypes.GOLD - ) - - handoff_gold.begin = handoff_gold_handler - - gold_score = self.space.add_collision_handler( - CollisionTypes.GOLD, CollisionTypes.CHEST - ) - - gold_score.begin = gold_score_handler - - def observe(self, agent): - capture = pg.surfarray.array3d(self.screen) - a = self.agents[agent] - - delta = const.OBSERVATION_SIDE_LENGTH // 2 - center = a.center # Calculated property added to prospector and banker classes - x, y = center - sub_screen = capture[x - delta: x + delta, y - delta: y + delta, :] - - return sub_screen - - def step(self, action, observe=True): - agent = self.agent_selection - # TODO: Figure out rewards - if action is None: - print("Error: NoneType received as action") - else: - agent.update(action) - - all_agents_updated = self._agent_selector.is_last() - # Only take next step in game if all agents have received an action - if all_agents_updated: - self.frame += 1 - # If we reached max frames, we're done - if self.frame == self.max_frames: - self.dones = dict(zip(self.agents, [True for _ in self.agents])) - - if self.rendering: - self.clock.tick(const.FPS) - else: - self.clock.tick() - self.space.step(1 / 15) - - self.agent_selection = self._agent_selector.next() - observation = self.observe(self.agent_selection) - - pg.event.pump() - if observe: - return observation - - def reward(self): - return self.rewards - - def reset(self, observe=True): - self.done = False - self.agents = [] - - # Re-create all agents and Pymunk space - self.space = pm.Space() - self.space.gravity = Vec2d(0.0, 0.0) - self.space.damping = 0.0 - - self.all_sprites = pg.sprite.Group() - self.gold = [] - - prospector_info = [utils.rand_pos("prospector") for _ in range(4)] - self.prospectors = [] - for pos in prospector_info: - prospector = Prospector(pos, self.space, self.all_sprites) - self.prospectors.append(prospector) - self.agents.append(prospector) - - banker_info = [(i + 1, utils.rand_pos("banker")) for i in range(3)] - self.bankers = [] - for num, pos in banker_info: - banker = Banker(pos, self.space, num, self.all_sprites) - self.bankers.append(banker) - self.agents.append(banker) - - self.banks = [] - for pos, verts in const.BANK_INFO: - self.banks.append(Bank(pos, verts, self.space, self.all_sprites)) - - for w_type, pos, verts in const.BOUNDARY_INFO: - Fence(w_type, pos, verts, self.space, self.all_sprites) - - Water(const.WATER_INFO[0], const.WATER_INFO[1], self.space, self.all_sprites) - - self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) - self.dones = dict(zip(self.agents, [False for _ in self.agents])) - self.infos = dict(zip(self.agents, [[] for _ in self.agents])) - self.metadata = {"render.modes": ["human"]} - - self.rendering = False - self.frame = 0 - self._agent_selector.reinit(self.agent_order) - self.agent_selection = self._agent_selector.next() - - if observe: - return self.observe(self.agent_selection) - - def render(self, mode="human"): - if not self.rendering: - pg.display.init() - self.screen = pg.display.set_mode(const.SCREEN_SIZE) - self.rendering = True - self.draw() - pg.display.flip() - - def draw(self): - self.screen.blit(self.background, self.background_rect) - self.all_sprites.draw(self.screen) - - for p in self.prospectors: - if p.body.nugget is not None: - p.body.nugget.draw(self.screen) - - for b in self.bankers: - if b.body.nugget is not None: - b.body.nugget.draw(self.screen) - - def close(self): - pg.event.pump() - pg.display.quit() - pg.quit() - - -if __name__ == "__main__": - pg.init() - game = Game() - game.run() diff --git a/pettingzoo/gamma/prospector/utils.py b/pettingzoo/gamma/prospector/utils.py index f11d186eb..65295f95b 100644 --- a/pettingzoo/gamma/prospector/utils.py +++ b/pettingzoo/gamma/prospector/utils.py @@ -2,14 +2,15 @@ from pymunk import Vec2d import os -from random import randint import math from . import constants as const + def load_image(path: list) -> pg.Surface: # All images stored in data/ - img = pg.image.load(os.path.join("data", *path)) - img = img.convert_alpha() + cwd = os.path.dirname(__file__) + img = pg.image.load(os.path.join(cwd, "data", *path)) + # img = img.convert_alpha() return img @@ -22,12 +23,12 @@ def invert_y(points): return [(x, -y) for x, y in points] -def rand_pos(sprite): - x = randint(100, const.SCREEN_WIDTH - 100) +def rand_pos(sprite, rng): + x = rng.randint(100, const.SCREEN_WIDTH - 100) if sprite == "banker": - return x, randint(150, 300) + return x, rng.randint(150, 300) elif sprite == "prospector": - return x, randint(350, const.SCREEN_HEIGHT - (const.WATER_HEIGHT + 30)) + return x, rng.randint(350, const.SCREEN_HEIGHT - (const.WATER_HEIGHT + 30)) def normalize_angle(angle): diff --git a/pettingzoo/mpe/_mpe_utils/simple_env.py b/pettingzoo/mpe/_mpe_utils/simple_env.py index 0966bf563..048ed2d79 100644 --- a/pettingzoo/mpe/_mpe_utils/simple_env.py +++ b/pettingzoo/mpe/_mpe_utils/simple_env.py @@ -4,6 +4,18 @@ from pettingzoo.utils.agent_selector import agent_selector from pettingzoo.utils.env_logger import EnvLogger from gym.utils import seeding +from pettingzoo.utils import wrappers + + +def make_env(raw_env): + def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.AssertOutOfBoundsWrapper(env) + backup_policy = "taking zero action (no movement, communication 0)" + env = wrappers.NanNoOpWrapper(env, 0, backup_policy) + env = wrappers.OrderEnforcingWrapper(env) + return env + return env class SimpleEnv(AECEnv): diff --git a/pettingzoo/mpe/simple.py b/pettingzoo/mpe/simple.py index d8411d2f7..7becf9848 100644 --- a/pettingzoo/mpe/simple.py +++ b/pettingzoo/mpe/simple.py @@ -1,9 +1,12 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): def __init__(self, seed=None, max_frames=100): scenario = Scenario() world = scenario.make_world() - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_adversary.py b/pettingzoo/mpe/simple_adversary.py index 99fdba399..a368f4325 100644 --- a/pettingzoo/mpe/simple_adversary.py +++ b/pettingzoo/mpe/simple_adversary.py @@ -1,9 +1,12 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_adversary import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): def __init__(self, seed=None, N=2, max_frames=100): scenario = Scenario() world = scenario.make_world(N=2) - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_crypto.py b/pettingzoo/mpe/simple_crypto.py index 568573fe0..7a0e2f2a2 100644 --- a/pettingzoo/mpe/simple_crypto.py +++ b/pettingzoo/mpe/simple_crypto.py @@ -1,9 +1,13 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_crypto import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, max_frames=100): scenario = Scenario() world = scenario.make_world() - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_push.py b/pettingzoo/mpe/simple_push.py index 036d102d1..a88eb97b7 100644 --- a/pettingzoo/mpe/simple_push.py +++ b/pettingzoo/mpe/simple_push.py @@ -1,9 +1,13 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_push import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, max_frames=100): scenario = Scenario() world = scenario.make_world() - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_reference.py b/pettingzoo/mpe/simple_reference.py index 5c1af432b..2a86c779d 100644 --- a/pettingzoo/mpe/simple_reference.py +++ b/pettingzoo/mpe/simple_reference.py @@ -1,10 +1,14 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_reference import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, local_ratio=0.5, max_frames=100): assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world() - super(env, self).__init__(scenario, world, max_frames, seed, local_ratio) + super().__init__(scenario, world, max_frames, seed, local_ratio) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_speaker_listener.py b/pettingzoo/mpe/simple_speaker_listener.py index 4737fd9d3..8a69bf6b9 100644 --- a/pettingzoo/mpe/simple_speaker_listener.py +++ b/pettingzoo/mpe/simple_speaker_listener.py @@ -1,9 +1,13 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_speaker_listener import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, max_frames=100): scenario = Scenario() world = scenario.make_world() - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_spread.py b/pettingzoo/mpe/simple_spread.py index 32f365d99..36370fa49 100644 --- a/pettingzoo/mpe/simple_spread.py +++ b/pettingzoo/mpe/simple_spread.py @@ -1,10 +1,14 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_spread import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, N=3, local_ratio=0.5, max_frames=100): assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world(N) - super(env, self).__init__(scenario, world, max_frames, seed, local_ratio) + super().__init__(scenario, world, max_frames, seed, local_ratio) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_tag.py b/pettingzoo/mpe/simple_tag.py index bd4f49e11..2ad603eca 100644 --- a/pettingzoo/mpe/simple_tag.py +++ b/pettingzoo/mpe/simple_tag.py @@ -1,9 +1,13 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_tag import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): + def __init__(self, seed=None, num_good=1, num_adversaries=3, num_obstacles=2, max_frames=100): scenario = Scenario() world = scenario.make_world(num_good, num_adversaries, num_obstacles) - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_world_comm.py b/pettingzoo/mpe/simple_world_comm.py index 61825be00..c474b7ba3 100644 --- a/pettingzoo/mpe/simple_world_comm.py +++ b/pettingzoo/mpe/simple_world_comm.py @@ -1,9 +1,12 @@ -from ._mpe_utils.simple_env import SimpleEnv +from ._mpe_utils.simple_env import SimpleEnv, make_env from .scenarios.simple_world_comm import Scenario -class env(SimpleEnv): +class raw_env(SimpleEnv): def __init__(self, seed=None, num_good=2, num_adversaries=4, num_obstacles=1, num_food=2, num_forests=2, max_frames=100): scenario = Scenario() world = scenario.make_world(num_good, num_adversaries, num_obstacles, num_food, num_forests) - super(env, self).__init__(scenario, world, max_frames, seed) + super().__init__(scenario, world, max_frames, seed) + + +env = make_env(raw_env) diff --git a/pettingzoo/sisl/multiwalker/multiwalker.py b/pettingzoo/sisl/multiwalker/multiwalker.py index 14b01c306..7997b036c 100755 --- a/pettingzoo/sisl/multiwalker/multiwalker.py +++ b/pettingzoo/sisl/multiwalker/multiwalker.py @@ -1,16 +1,24 @@ from .multiwalker_base import MultiWalkerEnv as _env from pettingzoo import AECEnv from pettingzoo.utils import agent_selector -from pettingzoo.utils import EnvLogger import numpy as np +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.ClipOutOfBoundsWrapper(env) + env = wrappers.NanZerosWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self, seed=None, *args, **kwargs): - super(env, self).__init__() + super().__init__() self.env = _env(seed, *args, **kwargs) self.num_agents = self.env.num_agents @@ -48,26 +56,13 @@ def close(self): self.env.close() def render(self, mode="human"): - if not self.has_reset: - EnvLogger.error_render_before_reset() - else: - self.env.render() + self.env.render() def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() return self.env.observe(self.agent_name_mapping[agent]) def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() agent = self.agent_selection - if action is None or any(np.isnan(action)): - EnvLogger.warn_action_is_NaN(backup_policy="setting to zeros") - action = np.zeros_like(self.action_spaces[agent].sample()) - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting to zeros") - action = np.zeros_like(self.action_spaces[agent].sample()) action = np.array(action, dtype=np.float32) self.env.step(action, self.agent_name_mapping[agent], self._agent_selector.is_last()) for r in self.rewards: diff --git a/pettingzoo/sisl/multiwalker/multiwalker_base.py b/pettingzoo/sisl/multiwalker/multiwalker_base.py index ff4f67b91..16a27d088 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker_base.py +++ b/pettingzoo/sisl/multiwalker/multiwalker_base.py @@ -7,7 +7,6 @@ from Box2D.b2 import (circleShape, contactListener, edgeShape, fixtureDef, polygonShape, revoluteJointDef) from .. import Agent -from pettingzoo.utils import EnvLogger MAX_AGENTS = 40 @@ -354,8 +353,6 @@ def close(self): if self.viewer is not None: self.viewer.close() self.viewer = None - else: - EnvLogger.warn_close_unrendered_env() def reset(self): self._destroy() diff --git a/pettingzoo/sisl/pursuit/pursuit.py b/pettingzoo/sisl/pursuit/pursuit.py index 1194c5220..e0953105c 100755 --- a/pettingzoo/sisl/pursuit/pursuit.py +++ b/pettingzoo/sisl/pursuit/pursuit.py @@ -3,16 +3,25 @@ from pettingzoo import AECEnv from pettingzoo.utils import agent_selector import numpy as np -from pettingzoo.utils import EnvLogger import pygame +from pettingzoo.utils import wrappers -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + example_space = list(env.action_spaces.values())[0] + env = wrappers.AssertOutOfBoundsWrapper(env) + env = wrappers.NanNoOpWrapper(env, np.zeros(example_space.shape, dtype=example_space.dtype), "taking all zeros action") + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self, seed=None, *args, **kwargs): - super(env, self).__init__() + super().__init__() self.env = _env(*args, seed, **kwargs) pygame.init() self.num_agents = self.env.num_agents @@ -49,21 +58,11 @@ def close(self): self.env.close() def render(self, mode="human"): - if not self.has_reset: - EnvLogger.error_render_before_reset() - elif not self.closed: + if not self.closed: self.env.render() def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() agent = self.agent_selection - if action is None or np.isnan(action): - action = 0 - EnvLogger.warn_action_is_NaN(backup_policy="setting action to 0") - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting action to 0") - action = 0 self.env.step(action, self.agent_name_mapping[agent], self._agent_selector.is_last()) for k in self.dones: if self.env.frames >= self.env.max_frames: @@ -78,7 +77,5 @@ def step(self, action, observe=True): return self.observe(self.agent_selection) def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() o = np.array(self.env.safely_observe(self.agent_name_mapping[agent])) return o diff --git a/pettingzoo/sisl/pursuit/pursuit_base.py b/pettingzoo/sisl/pursuit/pursuit_base.py index f25c3e21f..59d1ce43a 100755 --- a/pettingzoo/sisl/pursuit/pursuit_base.py +++ b/pettingzoo/sisl/pursuit/pursuit_base.py @@ -6,7 +6,6 @@ import numpy as np from gym import spaces from gym.utils import seeding -from pettingzoo.utils import EnvLogger import pygame @@ -160,9 +159,7 @@ def __init__(self, seed=0, **kwargs): self.reset() def close(self): - if not self.renderOn: - EnvLogger.warn_close_unrendered_env() - else: + if self.renderOn: pygame.event.pump() pygame.display.quit() pygame.quit() diff --git a/pettingzoo/sisl/pursuit/utils/agent_utils.py b/pettingzoo/sisl/pursuit/utils/agent_utils.py index 8501c533f..bf581bac9 100644 --- a/pettingzoo/sisl/pursuit/utils/agent_utils.py +++ b/pettingzoo/sisl/pursuit/utils/agent_utils.py @@ -23,7 +23,8 @@ def create_agents(nagents, map_matrix, obs_range, randomizer, flatten=False, ran for i in xrange(nagents): xinit, yinit = (0, 0) if randinit: - xinit, yinit = feasible_position_exp(randomizer, + xinit, yinit = feasible_position_exp( + randomizer, map_matrix, expanded_mat, constraints=constraints) # fill expanded_mat expanded_mat[xinit + 1, yinit + 1] = -1 diff --git a/pettingzoo/sisl/waterworld/waterworld.py b/pettingzoo/sisl/waterworld/waterworld.py index 12ecad7a4..3b31b716a 100755 --- a/pettingzoo/sisl/waterworld/waterworld.py +++ b/pettingzoo/sisl/waterworld/waterworld.py @@ -1,16 +1,24 @@ from .waterworld_base import MAWaterWorld as _env from pettingzoo import AECEnv from pettingzoo.utils import agent_selector -from pettingzoo.utils import EnvLogger +from pettingzoo.utils import wrappers import numpy as np -class env(AECEnv): +def env(**kwargs): + env = raw_env(**kwargs) + env = wrappers.ClipOutOfBoundsWrapper(env) + env = wrappers.NanZerosWrapper(env) + env = wrappers.OrderEnforcingWrapper(env) + return env + + +class raw_env(AECEnv): metadata = {'render.modes': ['human']} def __init__(self, seed=None, *args, **kwargs): - super(env, self).__init__() + super().__init__() self.env = _env(seed, *args, **kwargs) self.num_agents = self.env.num_agents @@ -42,24 +50,14 @@ def reset(self, observe=True): return self.observe(self.agent_selection) def close(self): - self.env.close() + if self.has_reset: + self.env.close() def render(self, mode="human"): - if not self.has_reset: - EnvLogger.error_render_before_reset() - else: - self.env.render() + self.env.render() def step(self, action, observe=True): - if not self.has_reset: - EnvLogger.error_step_before_reset() agent = self.agent_selection - if action is None or any(np.isnan(action)): - EnvLogger.warn_action_is_NaN(backup_policy="setting to zeros") - action = np.zeros_like(self.action_spaces[agent].sample()) - elif not self.action_spaces[agent].contains(action): - EnvLogger.warn_action_out_of_bound(action=action, action_space=self.action_spaces[agent], backup_policy="setting to zeros") - action = np.zeros_like(self.action_spaces[agent].sample()) self.env.step(action, self.agent_name_mapping[agent], self._agent_selector.is_last()) for r in self.rewards: @@ -79,6 +77,4 @@ def step(self, action, observe=True): return self.observe(self.agent_selection) def observe(self, agent): - if not self.has_reset: - EnvLogger.error_observe_before_reset() return self.env.observe(self.agent_name_mapping[agent]) diff --git a/pettingzoo/sisl/waterworld/waterworld_base.py b/pettingzoo/sisl/waterworld/waterworld_base.py index 6b6ecf77f..f0a3396d3 100755 --- a/pettingzoo/sisl/waterworld/waterworld_base.py +++ b/pettingzoo/sisl/waterworld/waterworld_base.py @@ -4,7 +4,6 @@ from gym.utils import seeding from .. import Agent import cv2 -from pettingzoo.utils import EnvLogger class Archea(Agent): @@ -151,8 +150,6 @@ def close(self): if self.renderOn: cv2.destroyAllWindows() cv2.waitKey(1) - else: - EnvLogger.warn_close_unrendered_env() @property def reward_mech(self): diff --git a/pettingzoo/tests/all_modules.py b/pettingzoo/tests/all_modules.py new file mode 100644 index 000000000..11686aedc --- /dev/null +++ b/pettingzoo/tests/all_modules.py @@ -0,0 +1,69 @@ +from pettingzoo.classic import chess_v0 +from pettingzoo.classic import rps_v0 +from pettingzoo.classic import rpsls_v0 +from pettingzoo.classic import connect_four_v0 +from pettingzoo.classic import tictactoe_v0 +from pettingzoo.classic import leduc_holdem_v0 +from pettingzoo.classic import mahjong_v0 +from pettingzoo.classic import texas_holdem_v0 +from pettingzoo.classic import texas_holdem_no_limit_v0 +from pettingzoo.classic import uno_v0 +from pettingzoo.classic import dou_dizhu_v0 +from pettingzoo.classic import gin_rummy_v0 +from pettingzoo.classic import go_v0 + +from pettingzoo.gamma import knights_archers_zombies_v0 +from pettingzoo.gamma import pistonball_v0 +from pettingzoo.gamma import cooperative_pong_v0 +from pettingzoo.gamma import prison_v0 +from pettingzoo.gamma import prospector_v0 + +from pettingzoo.mpe import simple_adversary_v0 +from pettingzoo.mpe import simple_crypto_v0 +from pettingzoo.mpe import simple_push_v0 +from pettingzoo.mpe import simple_reference_v0 +from pettingzoo.mpe import simple_speaker_listener_v0 +from pettingzoo.mpe import simple_spread_v0 +from pettingzoo.mpe import simple_tag_v0 +from pettingzoo.mpe import simple_world_comm_v0 +from pettingzoo.mpe import simple_v0 + +from pettingzoo.sisl import pursuit_v0 +from pettingzoo.sisl import waterworld_v0 +from pettingzoo.sisl import multiwalker_v0 + +all_environments = { + "classic/chess": chess_v0, + "classic/rps": rps_v0, + "classic/rpsls": rpsls_v0, + "classic/connect_four": connect_four_v0, + "classic/tictactoe": tictactoe_v0, + "classic/leduc_holdem": leduc_holdem_v0, + "classic/mahjong": mahjong_v0, + "classic/texas_holdem": texas_holdem_v0, + "classic/texas_holdem_no_limit": texas_holdem_no_limit_v0, + "classic/uno": uno_v0, + "classic/dou_dizhu": dou_dizhu_v0, + "classic/gin_rummy": gin_rummy_v0, + "classic/go": go_v0, + + "gamma/knights_archers_zombies": knights_archers_zombies_v0, + "gamma/pistonball": pistonball_v0, + "gamma/cooperative_pong": cooperative_pong_v0, + "gamma/prison": prison_v0, + "gamma/prospector": prospector_v0, + + "mpe/simple_adversary": simple_adversary_v0, + "mpe/simple_crypto": simple_crypto_v0, + "mpe/simple_push": simple_push_v0, + "mpe/simple_reference": simple_reference_v0, + "mpe/simple_speaker_listener": simple_speaker_listener_v0, + "mpe/simple_spread": simple_spread_v0, + "mpe/simple_tag": simple_tag_v0, + "mpe/simple_world_comm": simple_world_comm_v0, + "mpe/simple": simple_v0, + + "sisl/multiwalker": multiwalker_v0, + "sisl/waterworld": waterworld_v0, + "sisl/pursuit": pursuit_v0, +} diff --git a/pettingzoo/tests/api_test.py b/pettingzoo/tests/api_test.py index 9c416e298..441a3441e 100644 --- a/pettingzoo/tests/api_test.py +++ b/pettingzoo/tests/api_test.py @@ -1,8 +1,6 @@ import pettingzoo from pettingzoo.utils import agent_selector -from pettingzoo.utils import save_observation import warnings -import inspect import numpy as np from copy import copy import gym @@ -110,6 +108,17 @@ def test_reward(reward): def test_rewards_dones(env, agent_0): for agent in env.agents: assert isinstance(env.dones[agent], bool), "Agent's values in dones must be True or False" + print() + print() + print() + print(agent_0 is env.agent_order[0]) + print('class') + # print(env.rewards[agent_0]) + print(env.rewards[agent_0].__class__) + print('class done') + print() + print() + print() assert isinstance(env.rewards[agent], env.rewards[agent_0].__class__), "Rewards for each agent must be of the same class" test_reward(env.rewards[agent]) @@ -149,40 +158,13 @@ def play_test(env, observation_0): assert observation is None, "step(observe=False) must not return anything" -def test_observe(env, observation_0, save_obs): - for agent in env.agent_order: - observation = env.observe(agent) - if save_obs: - save_observation(env=env, agent=agent, save_dir="saved_observations") - test_obervation(observation, observation_0) - - -def test_render(env): - render_modes = env.metadata.get('render.modes') - assert render_modes is not None, "Environment's that support rendering must define render modes in metadata" - env.reset(observe=False) - for mode in render_modes: - for _ in range(10): - for agent in env.agent_order: - if 'legal_moves' in env.infos[agent]: - action = random.choice(env.infos[agent]['legal_moves']) - else: - action = env.action_spaces[agent].sample() - env.step(action, observe=False) - env.render(mode=mode) - if all(env.dones.values()): - env.reset() - break - - -def test_agent_selector(env): +def test_agent_order(env): + env.reset() if not hasattr(env, "_agent_selector"): warnings.warn("Env has no object named _agent_selector. We recommend handling agent cycling with the agent_selector utility from utils/agent_selector.py.") - return - if not isinstance(env._agent_selector, agent_selector): + elif not isinstance(env._agent_selector, agent_selector): warnings.warn("You created your own agent_selector utility. You might want to use ours, in utils/agent_selector.py") - return assert hasattr(env, "agent_order"), "Env does not have agent_order" @@ -190,7 +172,10 @@ def test_agent_selector(env): agent_order = copy(env.agent_order) _agent_selector = agent_selector(agent_order) agent_selection = _agent_selector.next() - assert env._agent_selector == _agent_selector, "env._agent_selector is initialized incorrectly" + + if hasattr(env, "_agent_selector"): + assert env._agent_selector == _agent_selector, "env._agent_selector is initialized incorrectly" + assert env.agent_selection == agent_selection, "env.agent_selection is not the same as the first agent in agent_order" for _ in range(200): @@ -217,247 +202,15 @@ def test_agent_selector(env): assert env.agent_selection == agent_selection, "env.agent_selection ({}) is not the same as the next agent in agent_order {}".format(env.agent_selection, env.agent_order) -def test_bad_close(env): - from pettingzoo.utils import EnvLogger - EnvLogger.suppress_output() - e1 = copy(env) - # test that immediately closing the environment does not crash - try: - e1.close() - except Exception as e: - warnings.warn("Immediately closing a newly initialized environment should not crash with {}".format(e)) - - # test that closing twice does not crash - - e2 = copy(env) - if "render.modes" in e2.metadata and len(e2.metadata["render.modes"]) > 0: - e2.reset() - e2.render() - e2.close() - try: - e2.close() - except Exception as e: - warnings.warn("Closing an already closed environment should not crash with {}".format(e)) - EnvLogger.unsuppress_output() - - -def test_warnings(env): - from pettingzoo.utils import EnvLogger - EnvLogger.suppress_output() - EnvLogger.flush() - e1 = copy(env) - e1.reset() - e1.close() - # e1 should throw a close_unrendered_environment warning - if len(EnvLogger.mqueue) == 0: - warnings.warn("env does not warn when closing unrendered env. Should call EnvLogger.warn_close_unrendered_env") - EnvLogger.unsuppress_output() - - -def inp_handler(name): - from pynput.keyboard import Key, Controller as KeyboardController - import time - - keyboard = KeyboardController() - time.sleep(0.1) - choices = ['w', 'a', 's', 'd', 'j', 'k', Key.left, Key.right, Key.up, Key.down] - NUM_TESTS = 50 - for x in range(NUM_TESTS): - i = random.choice(choices) if x != NUM_TESTS - 1 else Key.esc - keyboard.press(i) - time.sleep(0.1) - keyboard.release(i) - - -def test_manual_control(manual_control): - import threading - manual_in_thread = threading.Thread(target=inp_handler, args=(1,)) - - manual_in_thread.start() - - try: - manual_control() - except Exception: - raise Exception("manual_control() has crashed. Please fix it.") - - manual_in_thread.join() - - -def check_asserts(fn, message=None): - try: - fn() - return False - except AssertionError as e: - if message is not None: - return message == str(e) - return True - except Exception as e: - raise e - - -def check_excepts(fn): - try: - fn() - return False - except Exception: - return True - - -# yields length of mqueue -def check_warns(fn, message=None): - from pettingzoo.utils import EnvLogger - EnvLogger.suppress_output() - EnvLogger.flush() - fn() - EnvLogger.unsuppress_output() - if message is None: - return EnvLogger.mqueue - else: - for item in EnvLogger.mqueue: - if message in item: - return True - return False - - -def test_requires_reset(env): - if not check_excepts(lambda: env.agent_selection): - warnings.warn("env.agent_selection should not be defined until reset is called") - if not check_excepts(lambda: env.dones): - warnings.warn("env.dones should not be defined until reset is called") - if not check_excepts(lambda: env.rewards): - warnings.warn("env.rewards should not be defined until reset is called") - - first_agent_name = env.agents[0] - - print(env.infos[first_agent_name].keys()) - - if not 'legal_moves' in env.infos[first_agent_name]: - first_agent = list(env.action_spaces.keys())[0] - first_action_space = env.action_spaces[first_agent] - if not check_asserts(lambda: env.step(first_action_space.sample()), "reset() needs to be called before step"): - warnings.warn("env.step should call EnvLogger.error_step_before_reset if it is called before reset") - if not check_asserts(lambda: env.observe(first_agent), "reset() needs to be called before observe"): - warnings.warn("env.observe should call EnvLogger.error_observe_before_reset if it is called before reset") - if "render.modes" in env.metadata and len(env.metadata["render.modes"]) > 0: - if not check_asserts(lambda: env.render(), "reset() needs to be called before render"): - warnings.warn("env.render should call EnvLogger.error_render_before_reset if it is called before reset") - - -def test_bad_actions(env): - env.reset() - first_action_space = env.action_spaces[env.agent_selection] - - if isinstance(first_action_space, gym.spaces.Box): - try: - if not check_warns(lambda: env.step(np.nan * np.ones_like(first_action_space.low)), "[WARNING]: Received an NaN"): - warnings.warn("NaN actions should call EnvLogger.warn_action_is_NaN") - except Exception: - warnings.warn("nan values should not raise an error, instead, they should call EnvLogger.warn_action_is_NaN and instead perform some reasonable action, (perhaps the all zeros action?)") - - env.reset() - if np.all(np.greater(first_action_space.low.flatten(), -1e10)): - small_value = first_action_space.low - 1e10 - try: - if not check_warns(lambda: env.step(small_value), "[WARNING]: Received an action"): - warnings.warn("out of bounds actions should call EnvLogger.warn_action_out_of_bound") - except Exception: - warnings.warn("out of bounds actions should not raise an error, instead, they should call EnvLogger.warn_action_out_of_bound and instead perform some reasonable action, (perhaps the all zeros action?)") - - if not check_excepts(lambda: env.step(np.ones((29, 67, 17)))): - warnings.warn("actions of a shape not equal to the box should fail with some useful error") - elif isinstance(first_action_space, gym.spaces.Discrete): - try: - if not check_warns(lambda: env.step(np.nan), "[WARNING]: Received an NaN"): - warnings.warn("nan actions should call EnvLogger.warn_action_is_NaN, and instead perform some reasonable action (perhaps the do nothing action? Or perhaps the same behavior as an illegal action?)") - except Exception: - warnings.warn("nan actions should not raise an error, instead, they should call EnvLogger.warn_action_is_NaN and instead perform some reasonable action (perhaps the do nothing action? Or perhaps the same behavior as an illegal action?)") - - env.reset() - try: - if not check_warns(lambda: env.step(first_action_space.n)): - warnings.warn("out of bounds actions should call EnvLogger.warn_discrete_out_of_bound") - except Exception: - warnings.warn("out of bounds actions should not raise an error, instead, they should call EnvLogger.warn_discrete_out_of_bound and instead perform some reasonable action (perhaps the do nothing action if your environment has one? Or perhaps the same behavior as an illegal action?)") - - env.reset() - - # test illegal actions - first_agent = env.agent_selection - info = env.infos[first_agent] - action_space = env.action_spaces[first_agent] - if 'legal_moves' in info: - legal_moves = info['legal_moves'] - illegal_moves = set(range(action_space.n)) - set(legal_moves) - - if len(illegal_moves) > 0: - illegal_move = list(illegal_moves)[0] - #if not check_warns(lambda: env.step(env.step(illegal_move)), "[WARNING]: Illegal"): - #warnings.warn("If an illegal move is made, warning should be generated by calling EnvLogger.warn_on_illegal_move") - if not env.dones[first_agent]: - warnings.warn("Environment should terminate after receiving an illegal move") - else: - warnings.warn("The legal moves were just all possible moves. This is very usual") - - env.reset() - - -def check_environment_args(env): - args = inspect.getfullargspec(env.__init__) - if len(args.args) < 2 or "seed" != args.args[1]: - warnings.warn("Environment does not take `seed` as its first argument. If it uses any randomness, it should") - else: - if args.defaults[0] != None: - warnings.warn("Environment's seed parameter should have a default value of None. This defaults to a nondeterministic seed, which is what you want as a default.") - def hash_obsevation(obs): - try: - val = hash(obs.tobytes()) - return val - except AttributeError: - try: - return hash(obs) - except TypeError: - warnings.warn("Observation not an int or an Numpy array") - return 0 - - # checks deterministic behavior if seed is set - base_seed = 192312 - actions = {agent: space.sample() for agent, space in env.action_spaces.items()} - hashes = [] - num_seeds = 5 - for x in range(num_seeds): - new_env = env.__class__(seed=base_seed) - cur_hashes = [] - obs = new_env.reset() - for i in range(x + 1): - random.randint(0, 1000) - np.random.normal(size=100) - cur_hashes.append(hash_obsevation(obs)) - for _ in range(50): - rew, done, info = new_env.last() - if done: - break - next_obs = new_env.step(actions[new_env.agent_selection]) - cur_hashes.append(hash_obsevation(next_obs)) - - hashes.append(hash(tuple(cur_hashes))) - new_env = env.__class__(seed=base_seed) - if not all(hashes[0] == h for h in hashes): - warnings.warn("seeded environment is not fully deterministic, depends on random or numpy.random's random state") - - -def api_test(env, render=False, manual_control=None, save_obs=False): +def api_test(env, render=False): print("Starting API test") env_agent_sel = copy(env) - env_warnings = copy(env) - env_bad_close = copy(env) + + env.reset() assert isinstance(env, pettingzoo.AECEnv), "Env must be an instance of pettingzoo.AECEnv" # do this before reset - test_requires_reset(env) - - check_environment_args(env) - observation = env.reset(observe=False) assert observation is None, "reset(observe=False) must not return anything" assert not any(env.dones.values()), "dones must all be False after reset" @@ -469,14 +222,6 @@ def api_test(env, render=False, manual_control=None, save_obs=False): observation_0 = env.reset() test_obervation(observation_0, observation_0) - if save_obs: - for agent in env.agents: - assert isinstance(env.observation_spaces[agent], gym.spaces.Box), "Observations must be Box to save observations as image" - assert np.all(np.equal(env.observation_spaces[agent].low, 0)) and np.all(np.equal(env.observation_spaces[agent].high, 255)), "Observations must be 0 to 255 to save as image" - assert len(env.observation_spaces[agent].shape) == 3 or len(env.observation_spaces[agent].shape) == 2, "Observations must be 2D or 3D to save as image" - if len(env.observation_spaces[agent].shape) == 3: - assert env.observation_spaces[agent].shape[2] == 1 or env.observation_spaces[agent].shape[2] == 3, "3D observations can only have 1 or 3 channels to save as an image" - assert isinstance(env.agent_order, list), "agent_order must be a list" agent_0 = env.agent_order[0] @@ -485,8 +230,6 @@ def api_test(env, render=False, manual_control=None, save_obs=False): play_test(env, observation_0) - test_bad_actions(env) - assert isinstance(env.rewards, dict), "rewards must be a dict" assert isinstance(env.dones, dict), "dones must be a dict" assert isinstance(env.infos, dict), "infos must be a dict" @@ -495,19 +238,7 @@ def api_test(env, render=False, manual_control=None, save_obs=False): test_rewards_dones(env, agent_0) - test_observe(env, observation_0, save_obs=save_obs) - - test_agent_selector(env_agent_sel) - - test_warnings(env_warnings) - - if render: - test_render(env) - - if manual_control is not None: - test_manual_control(manual_control) - else: - env.close() + test_agent_order(env_agent_sel) # test that if env has overridden render(), they must have overridden close() as well base_render = pettingzoo.utils.env.AECEnv.render @@ -517,6 +248,4 @@ def api_test(env, render=False, manual_control=None, save_obs=False): else: warnings.warn("Environment has not defined a render() method") - test_bad_close(env_bad_close) - print("Passed API test") diff --git a/pettingzoo/tests/ci_test.py b/pettingzoo/tests/ci_test.py index 701ee88cd..a345716fc 100644 --- a/pettingzoo/tests/ci_test.py +++ b/pettingzoo/tests/ci_test.py @@ -1,482 +1,49 @@ import pettingzoo.tests.api_test as api_test import pettingzoo.tests.bombardment_test as bombardment_test import pettingzoo.tests.performance_benchmark as performance_benchmark -import sys - -# classic -_manual_control = None - -render = False -if sys.argv[2] == 'True': - render = True -manual_control = False -if sys.argv[3] == 'True': - manual_control = True -bombardment = False -if sys.argv[4] == 'True': - bombardment = True -performance = False -if sys.argv[5] == 'True': - performance = True -save_obs = False -if sys.argv[6] == 'True': - save_obs = True - -if sys.argv[1] == 'classic/backgammon': - print('classic/backgammon') - from pettingzoo.classic import backgammon_v0 - _env = backgammon_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = backgammon_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = backgammon_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/checkers': - print('classic/checkers_') - from pettingzoo.classic import checkers_v0 - _env = checkers_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = checkers_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = checkers_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/chess': - print('classic/chess') - from pettingzoo.classic import chess_v0 - _env = chess_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = chess_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = chess_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/connect_four': - print('classic/connect_four') - from pettingzoo.classic import connect_four_v0 - _env = connect_four_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = connect_four_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = connect_four_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/dou_dizhu': - print('classic/dou_dizhu') - from pettingzoo.classic import dou_dizhu_v0 - _env = dou_dizhu_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = dou_dizhu_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = dou_dizhu_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/gin_rummy': - print('classic/gin_rummy') - from pettingzoo.classic import gin_rummy_v0 - _env = gin_rummy_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = gin_rummy_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = gin_rummy_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/go': - print('classic/go') - from pettingzoo.classic import go_v0 - _env = go_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = go_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = go_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/hanabi': - print('classic/hanabi') - from pettingzoo.classic import hanabi_v0 - _env = hanabi_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = hanabi_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = hanabi_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/leduc_holdem': - print('classic/leduc_holdem') - from pettingzoo.classic import leduc_holdem_v0 - _env = leduc_holdem_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = leduc_holdem_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = leduc_holdem_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/mahjong': - print('classic/mahjong') - from pettingzoo.classic import mahjong_v0 - _env = mahjong_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = mahjong_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = mahjong_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/rps': - print('classic/rps') - from pettingzoo.classic import rps_v0 - _env = rps_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = rps_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = rps_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/rpsls': - print('classic/rpsls') - from pettingzoo.classic import rpsls_v0 - _env = rpsls_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = rpsls_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = rpsls_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/texas_holdem': - print('classic/texas_holdem') - from pettingzoo.classic import texas_holdem_v0 - _env = texas_holdem_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = texas_holdem_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = texas_holdem_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/texas_holdem_no_limit': - print('classic/texas_holdem_no_limit') - from pettingzoo.classic import texas_holdem_no_limit_v0 - _env = texas_holdem_no_limit_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = texas_holdem_no_limit_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = texas_holdem_no_limit_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/tictactoe': - print('classic/tictactoe') - from pettingzoo.classic import tictactoe_v0 - _env = tictactoe_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = tictactoe_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = tictactoe_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'classic/uno': - print('classic/uno') - from pettingzoo.classic import uno_v0 - _env = uno_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = uno_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = uno_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -# gamma +import pettingzoo.tests.manual_control_test as test_manual_control -elif sys.argv[1] == 'gamma/cooperative_pong': - print('gamma/cooperative_pong') - from pettingzoo.gamma import cooperative_pong_v0 - _env = cooperative_pong_v0.env() - if manual_control: - _manual_control = cooperative_pong_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=save_obs) - if bombardment: - _env = cooperative_pong_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = cooperative_pong_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'gamma/knights_archers_zombies': - print('gamma/knights_archers_zombies') - from pettingzoo.gamma import knights_archers_zombies_v0 - _env = knights_archers_zombies_v0.env() - if manual_control: - _manual_control = knights_archers_zombies_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=save_obs) - if bombardment: - _env = knights_archers_zombies_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = knights_archers_zombies_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'gamma/pistonball': - print('gamma/pistonball') - from pettingzoo.gamma import pistonball_v0 - _env = pistonball_v0.env() - if manual_control: - _manual_control = pistonball_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=save_obs) - if bombardment: - _env = pistonball_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = pistonball_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'gamma/prison': - print('gamma/prison') - from pettingzoo.gamma import prison_v0 - _env = prison_v0.env() - if manual_control: - _manual_control = prison_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=save_obs) - if bombardment: - _env = prison_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = prison_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'gamma/prospector': - print('gamma/prospector') - from pettingzoo.gamma import prospector_v0 - _env = prospector_v0.env() - if manual_control: - _manual_control = prospector_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=save_obs) - if bombardment: - _env = prospector_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = prospector_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -# mpe - -elif sys.argv[1] == 'mpe/simple': - print('mpe/simple') - from pettingzoo.mpe import simple_v0 - _env = simple_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'mpe/simple_adversary': - print('mpe/simple_adversary') - from pettingzoo.mpe import simple_adversary_v0 - _env = simple_adversary_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_adversary_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_adversary_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'mpe/simple_crypto': - print('mpe/simple_crypto') - from pettingzoo.mpe import simple_crypto_v0 - _env = simple_crypto_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_crypto_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_crypto_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'mpe/simple_push': - print('mpe/simple_push') - from pettingzoo.mpe import simple_push_v0 - _env = simple_push_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_push_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_push_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') - -elif sys.argv[1] == 'mpe/simple_reference': - print('mpe/simple_reference') - from pettingzoo.mpe import simple_reference_v0 - _env = simple_reference_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_reference_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_reference_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') +import sys +from .all_modules import all_environments +from .render_test import test_render +from .error_tests import error_test +from .seed_test import seed_test +from .save_obs_test import test_save_obs -elif sys.argv[1] == 'mpe/simple_speak_listener': - print('mpe/simple_speak_listener') - from pettingzoo.mpe import simple_speak_listener_v0 - _env = simple_speak_listener_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_speak_listener_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_speak_listener_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') +render = sys.argv[2] == 'True' +manual_control = sys.argv[3] == 'True' +bombardment = sys.argv[4] == 'True' +performance = sys.argv[5] == 'True' +save_obs = sys.argv[6] == 'True' -elif sys.argv[1] == 'mpe/simple_spread': - print('mpe/simple_spread') - from pettingzoo.mpe import simple_spread_v0 - _env = simple_spread_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_spread_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_spread_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') -elif sys.argv[1] == 'mpe/simple_tag': - print('mpe/simple_tag') - from pettingzoo.mpe import simple_tag_v0 - _env = simple_tag_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_tag_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_tag_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') +env_id = sys.argv[1] +if env_id in all_environments: + print("running game {}".format(env_id)) + env_module = all_environments[env_id] + _env = env_module.raw_env() + api_test.api_test(_env, render=render) -elif sys.argv[1] == 'mpe/simple_world_comm': - print('mpe/simple_world_comm') - from pettingzoo.mpe import simple_world_comm_v0 - _env = simple_world_comm_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = simple_world_comm_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = simple_world_comm_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') + seed_test(env_module.env) + # error_test(env_module.env()) -# sisl + if save_obs: + test_save_obs(_env) -elif sys.argv[1] == 'sisl/multiwalker': - print('sisl/multiwalker') - from pettingzoo.sisl import multiwalker_v0 - _env = multiwalker_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = multiwalker_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = multiwalker_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') + if render: + test_render(_env) -elif sys.argv[1] == 'sisl/pursuit': - print('sisl/pursuit') - from pettingzoo.sisl import pursuit_v0 - _env = pursuit_v0.env() if manual_control: - _manual_control = pursuit_v0.manual_control - api_test.api_test(_env, render=render, manual_control=_manual_control, save_obs=False) - if bombardment: - _env = pursuit_v0.env() - bombardment_test.bombardment_test(_env) - if performance: - _env = pursuit_v0.env() - performance_benchmark.performance_benchmark(_env) - print('') + manual_control_fn = getattr(env_module, "manual_control", None) + if manual_control_fn is not None: + test_manual_control.test_manual_control(manual_control_fn) -elif sys.argv[1] == 'sisl/waterworld': - print('sisl/waterworld') - from pettingzoo.sisl import waterworld_v0 - _env = waterworld_v0.env() - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) - if bombardment: - _env = waterworld_v0.env() - bombardment_test.bombardment_test(_env) if performance: - _env = waterworld_v0.env() + _env = env_module.env() performance_benchmark.performance_benchmark(_env) - print('') -elif sys.argv[1] == 'magent': - print("magent") - from pettingzoo.magent import magent_env - _env = magent_env.env("pursuit", map_size=100) - api_test.api_test(_env, render=render, manual_control=None, save_obs=False) if bombardment: - _env = magent_env.env("pursuit", map_size=00) + _env = env_module.env() bombardment_test.bombardment_test(_env) - if performance: - _env = magent_env.env("pursuit", map_size=100) - performance_benchmark.performance_benchmark(_env) +else: + print("Environment: '{}' not in the 'all_environments' list".format(env_id)) diff --git a/pettingzoo/tests/error_tests.py b/pettingzoo/tests/error_tests.py new file mode 100644 index 000000000..e361d37e8 --- /dev/null +++ b/pettingzoo/tests/error_tests.py @@ -0,0 +1,177 @@ +import pettingzoo +from pettingzoo.utils import agent_selector +from pettingzoo.utils import save_observation +import warnings +import inspect +import numpy as np +from copy import copy +import gym +import random +import re +import os +from pettingzoo.utils import EnvLogger + + +def test_bad_close(env): + EnvLogger.suppress_output() + EnvLogger.flush() + e1 = copy(env) + # test that immediately closing the environment does not crash + try: + e1.close() + except Exception as e: + warnings.warn("Immediately closing a newly initialized environment should not crash with {}".format(e)) + + # test that closing twice does not crash + + e2 = copy(env) + if "render.modes" in e2.metadata and len(e2.metadata["render.modes"]) > 0: + e2.reset() + e2.render() + e2.close() + try: + e2.close() + except Exception as e: + warnings.warn("Closing an already closed environment should not crash with {}".format(e)) + EnvLogger.unsuppress_output() + + +def test_warnings(env): + EnvLogger.suppress_output() + EnvLogger.flush() + e1 = copy(env) + e1.reset() + e1.close() + # e1 should throw a close_unrendered_environment warning + if len(EnvLogger.mqueue) == 0: + warnings.warn("env does not warn when closing unrendered env. Should call EnvLogger.warn_close_unrendered_env") + EnvLogger.unsuppress_output() + + +def check_asserts(fn, message=None): + try: + fn() + return False + except AssertionError as e: + if message is not None: + return message == str(e) + return True + except Exception as e: + raise e + + +def check_excepts(fn): + try: + fn() + return False + except Exception: + return True + + +# yields length of mqueue +def check_warns(fn, message=None): + EnvLogger.suppress_output() + EnvLogger.flush() + fn() + EnvLogger.unsuppress_output() + if message is None: + return EnvLogger.mqueue + else: + for item in EnvLogger.mqueue: + if message in item: + return True + return False + + +def test_requires_reset(env): + if not check_excepts(lambda: env.agent_selection): + warnings.warn("env.agent_selection should not be defined until reset is called") + if not check_excepts(lambda: env.dones): + warnings.warn("env.dones should not be defined until reset is called") + if not check_excepts(lambda: env.rewards): + warnings.warn("env.rewards should not be defined until reset is called") + first_agent = list(env.action_spaces.keys())[0] + first_action_space = env.action_spaces[first_agent] + if not check_asserts(lambda: env.step(first_action_space.sample()), "reset() needs to be called before step"): + warnings.warn("env.step should call EnvLogger.error_step_before_reset if it is called before reset") + if not check_asserts(lambda: env.observe(first_agent), "reset() needs to be called before observe"): + warnings.warn("env.observe should call EnvLogger.error_observe_before_reset if it is called before reset") + if "render.modes" in env.metadata and len(env.metadata["render.modes"]) > 0: + if not check_asserts(lambda: env.render(), "reset() needs to be called before render"): + warnings.warn("env.render should call EnvLogger.error_render_before_reset if it is called before reset") + if not check_warns(lambda: env.close(), "reset() needs to be called before close"): + warnings.warn("env should warn_close_before_reset() if closing before reset()") + + +def test_bad_actions(env): + env.reset() + first_action_space = env.action_spaces[env.agent_selection] + + if isinstance(first_action_space, gym.spaces.Box): + try: + if not check_warns(lambda: env.step(np.nan * np.ones_like(first_action_space.low)), "[WARNING]: Received an NaN"): + warnings.warn("NaN actions should call EnvLogger.warn_action_is_NaN") + except Exception: + warnings.warn("nan values should not raise an error, instead, they should call EnvLogger.warn_action_is_NaN and instead perform some reasonable action, (perhaps the all zeros action?)") + + env.reset() + if np.all(np.greater(first_action_space.low.flatten(), -1e10)): + small_value = first_action_space.low - 1e10 + try: + if not check_warns(lambda: env.step(small_value), "[WARNING]: Received an action"): + warnings.warn("out of bounds actions should call EnvLogger.warn_action_out_of_bound") + except Exception: + warnings.warn("out of bounds actions should not raise an error, instead, they should call EnvLogger.warn_action_out_of_bound and instead perform some reasonable action, (perhaps the all zeros action?)") + + if not check_excepts(lambda: env.step(np.ones((29, 67, 17)))): + warnings.warn("actions of a shape not equal to the box should fail with some useful error") + elif isinstance(first_action_space, gym.spaces.Discrete): + try: + if not check_warns(lambda: env.step(np.nan), "[WARNING]: Received an NaN"): + warnings.warn("nan actions should call EnvLogger.warn_action_is_NaN, and instead perform some reasonable action (perhaps the do nothing action? Or perhaps the same behavior as an illegal action?)") + except Exception: + warnings.warn("nan actions should not raise an error, instead, they should call EnvLogger.warn_action_is_NaN and instead perform some reasonable action (perhaps the do nothing action? Or perhaps the same behavior as an illegal action?)") + + env.reset() + try: + if not check_asserts(lambda: env.step(first_action_space.n)): + warnings.warn("out of bounds actions should assert") + except Exception: + warnings.warn("out of bounds actions should assert") + + env.reset() + + # test illegal actions + first_agent = env.agent_selection + info = env.infos[first_agent] + action_space = env.action_spaces[first_agent] + if 'legal_moves' in info: + legal_moves = info['legal_moves'] + illegal_moves = set(range(action_space.n)) - set(legal_moves) + + if len(illegal_moves) > 0: + illegal_move = list(illegal_moves)[0] + if not check_warns(lambda: (env.step(illegal_move)), "[WARNING]: Illegal"): + warnings.warn("If an illegal move is made, warning should be generated by calling EnvLogger.warn_on_illegal_move") + if not env.dones[first_agent]: + warnings.warn("Environment should terminate after receiving an illegal move") + else: + warnings.warn("The legal moves were just all possible moves. This is very usual") + + env.reset() + + +def error_test(env): + print("Starting Error test") + env_warnings = copy(env) + env_bad_close = copy(env) + + test_warnings(env_warnings) + # do this before reset + test_requires_reset(env) + + test_bad_actions(env) + + test_bad_close(env_bad_close) + + print("Passed Error test") diff --git a/pettingzoo/tests/manual_control_test.py b/pettingzoo/tests/manual_control_test.py new file mode 100644 index 000000000..7774a4d47 --- /dev/null +++ b/pettingzoo/tests/manual_control_test.py @@ -0,0 +1,30 @@ +import random +import time +import threading + + +def inp_handler(name): + from pynput.keyboard import Key, Controller as KeyboardController + + keyboard = KeyboardController() + time.sleep(0.1) + choices = ['w', 'a', 's', 'd', 'j', 'k', Key.left, Key.right, Key.up, Key.down] + NUM_TESTS = 50 + for x in range(NUM_TESTS): + i = random.choice(choices) if x != NUM_TESTS - 1 else Key.esc + keyboard.press(i) + time.sleep(0.1) + keyboard.release(i) + + +def test_manual_control(manual_control): + manual_in_thread = threading.Thread(target=inp_handler, args=(1,)) + + manual_in_thread.start() + + try: + manual_control() + except Exception: + raise Exception("manual_control() has crashed. Please fix it.") + + manual_in_thread.join() diff --git a/pettingzoo/tests/pytest_runner.py b/pettingzoo/tests/pytest_runner.py new file mode 100644 index 000000000..d1475c251 --- /dev/null +++ b/pettingzoo/tests/pytest_runner.py @@ -0,0 +1,15 @@ +import pytest +from .all_modules import all_environments +import pettingzoo.tests.api_test as api_test + +from .error_tests import error_test +from .seed_test import seed_test + + +@pytest.mark.parametrize("env_module", list(all_environments.values())) +def test_module(env_module): + _env = env_module.env() + api_test.api_test(_env) + + seed_test(env_module.env) + # error_test(env_module.env()) diff --git a/pettingzoo/tests/render_test.py b/pettingzoo/tests/render_test.py new file mode 100644 index 000000000..80b34f7d0 --- /dev/null +++ b/pettingzoo/tests/render_test.py @@ -0,0 +1,19 @@ +import random + + +def test_render(env): + render_modes = env.metadata.get('render.modes') + assert render_modes is not None, "Environment's that support rendering must define render modes in metadata" + env.reset(observe=False) + for mode in render_modes: + for _ in range(10): + for agent in env.agent_order: + if 'legal_moves' in env.infos[agent]: + action = random.choice(env.infos[agent]['legal_moves']) + else: + action = env.action_spaces[agent].sample() + env.step(action, observe=False) + env.render(mode=mode) + if all(env.dones.values()): + env.reset() + break diff --git a/pettingzoo/tests/save_obs_test.py b/pettingzoo/tests/save_obs_test.py new file mode 100644 index 000000000..e39b966ab --- /dev/null +++ b/pettingzoo/tests/save_obs_test.py @@ -0,0 +1,22 @@ +from pettingzoo.utils import save_observation +import gym +import numpy as np + + +def check_save_obs(env): + for agent in env.agents: + assert isinstance(env.observation_spaces[agent], gym.spaces.Box), "Observations must be Box to save observations as image" + assert np.all(np.equal(env.observation_spaces[agent].low, 0)) and np.all(np.equal(env.observation_spaces[agent].high, 255)), "Observations must be 0 to 255 to save as image" + assert len(env.observation_spaces[agent].shape) == 3 or len(env.observation_spaces[agent].shape) == 2, "Observations must be 2D or 3D to save as image" + if len(env.observation_spaces[agent].shape) == 3: + assert env.observation_spaces[agent].shape[2] == 1 or env.observation_spaces[agent].shape[2] == 3, "3D observations can only have 1 or 3 channels to save as an image" + + +def test_save_obs(env): + try: + check_save_obs(env) + for agent in env.agent_order: + save_observation(env=env, agent=agent, save_dir="saved_observations") + + except AssertionError as ae: + print("did not save the observations: ", ae) diff --git a/pettingzoo/tests/seed_test.py b/pettingzoo/tests/seed_test.py new file mode 100644 index 000000000..b2569eb16 --- /dev/null +++ b/pettingzoo/tests/seed_test.py @@ -0,0 +1,62 @@ +import warnings +import random +import numpy as np + + +def check_environment_deterministic(env1, env2): + ''' + env1 and env2 should be seeded environments + + returns a bool: true if env1 and env2 execute the same way + ''' + + # checks deterministic behavior if seed is set + actions = {agent: space.sample() for agent, space in env1.action_spaces.items()} + hashes = [] + num_seeds = 2 + envs = [env1, env2] + for x in range(num_seeds): + new_env = envs[x] + cur_hashes = [] + obs = new_env.reset() + for i in range(x + 1): + random.randint(0, 1000) + np.random.normal(size=100) + cur_hashes.append(hash_obsevation(obs)) + for _ in range(50): + rew, done, info = new_env.last() + if done: + break + next_obs = new_env.step(actions[new_env.agent_selection]) + cur_hashes.append(hash_obsevation(next_obs)) + + hashes.append(hash(tuple(cur_hashes))) + + return all(hashes[0] == h for h in hashes) + + +def hash_obsevation(obs): + try: + val = hash(obs.tobytes()) + return val + except AttributeError: + try: + return hash(obs) + except TypeError: + warnings.warn("Observation not an int or an Numpy array") + return 0 + + +def seed_test(env_constructor): + try: + env_constructor(seed=None) + except Exception: + if not check_environment_deterministic(env_constructor(), env_constructor()): + warnings.warn("The environment gives different results on multiple runs and does not have a `seed` argument. Environments which use random values should take a seed as an argument.") + return + + base_seed = 42 + if not check_environment_deterministic(env_constructor(seed=base_seed), env_constructor(seed=base_seed)): + warnings.warn("The environment gives different results on multiple runs when intialized with the same seed. This is usually a sign that you are using np.random or random modules directly, which uses a global random state.") + if check_environment_deterministic(env_constructor(), env_constructor()): + warnings.warn("The environment gives same results on multiple runs when intialized by default. By default, environments that take a seed argument should be nondeterministic") diff --git a/pettingzoo/utils/env_logger.py b/pettingzoo/utils/env_logger.py index d99ea23fe..d93b2f341 100644 --- a/pettingzoo/utils/env_logger.py +++ b/pettingzoo/utils/env_logger.py @@ -17,6 +17,8 @@ def _generic_warning(msg): handler = EnvWarningHandler(mqueue=EnvLogger.mqueue) logger.addHandler(handler) logger.warning(msg) + # needed to get the pytest runner to work correctly, and doesn't seem to have serious issues + EnvLogger.mqueue.append(msg) @staticmethod def flush(): @@ -42,6 +44,10 @@ def warn_action_is_NaN(backup_policy): def warn_close_unrendered_env(): EnvLogger._generic_warning("[WARNING]: Called close on an unrendered environment.") + @staticmethod + def warn_close_before_reset(): + EnvLogger._generic_warning("[WARNING]: reset() needs to be called before close.") + @staticmethod def warn_on_illegal_move(): EnvLogger._generic_warning("[WARNING]: Illegal move made, game terminating with current player losing. \nenv.infos[player]['legal_moves'] contains a list of all legal moves that can be chosen.") @@ -55,8 +61,8 @@ def error_step_before_reset(): assert False, "reset() needs to be called before step" @staticmethod - def error_close_before_reset(): - assert False, "reset() needs to be called before close" + def warn_step_after_done(): + EnvLogger._generic_warning("[WARNING]: step() called after all agents are done. Should reset() first.") @staticmethod def error_render_before_reset(): diff --git a/pettingzoo/utils/frame_stack.py b/pettingzoo/utils/frame_stack.py deleted file mode 100755 index ed2f132c0..000000000 --- a/pettingzoo/utils/frame_stack.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -from gym.spaces import Box - - -def stack_obs_space(obs_space_dict, stack_size): - ''' - obs_space_dict: Dictionary of observations spaces of agents - stack_size: Number of frames in the observation stack - Returns: - New obs_space_dict - ''' - assert isinstance(obs_space_dict, dict), "obs_space_dict is not a dictionary." - - new_obs_space_dict = {} - - for agent_id in obs_space_dict.keys(): - obs_space = obs_space_dict[agent_id] - assert isinstance(obs_space, Box), "Stacking is currently only allowed for Box obs space. The given obs space is {}".format(obs_space) - dtype = obs_space.dtype - obs_dim = obs_space_dict[agent_id].low.ndim - # stack 1-D frames and 3-D frames - if obs_dim == 1 or obs_dim == 3: - new_shape = (stack_size,) - # stack 2-D frames - elif obs_dim == 2: - new_shape = (stack_size, 1, 1) - low = np.tile(obs_space.low, new_shape) - high = np.tile(obs_space.high, new_shape) - new_obs_space_dict[agent_id] = Box(low=low, high=high, dtype=dtype) - return new_obs_space_dict - - -def stack_reset_obs(obs, stack_size): - ''' - Input: 1 agent's observation only. - Reset observations are only 1 obs. Tile them. - ''' - # stack 1-D frames and 3-D frames - if obs.ndim == 1 or obs.ndim == 3: - new_shape = (stack_size,) - # stack 2-D frames - elif obs.ndim == 2: - new_shape = (stack_size, 1, 1) - frame_stack = np.tile(obs, new_shape) - return frame_stack - - -def stack_obs(frame_stack, agent, obs, stack_size): - ''' - Parameters - ---------- - frame_stack : if not None, it is the stack of frames - obs : new observation - Rearranges frame_stack. Appends the new observation at the end. - Throws away the oldest observation. - stack_size : needed for stacking reset observations - ''' - if frame_stack[agent] is None: - frame_stack[agent] = stack_reset_obs(obs, stack_size) - obs_shape = obs.shape - agent_fs = frame_stack[agent] - - if len(obs_shape) == 1: - agent_fs[:-obs_shape[-1]] = agent_fs[obs_shape[-1]:] - agent_fs[-obs_shape[-1]:] = obs - elif len(obs_shape) == 2: - agent_fs[:-1] = agent_fs[1:] - agent_fs[-1] = obs - elif len(obs_shape) == 3: - agent_fs[:, :, :-obs_shape[-1]] = agent_fs[:, :, obs_shape[-1]:] - agent_fs[:, :, -obs_shape[-1]:] = obs diff --git a/pettingzoo/utils/wrappers.py b/pettingzoo/utils/wrappers.py new file mode 100644 index 000000000..5199858e1 --- /dev/null +++ b/pettingzoo/utils/wrappers.py @@ -0,0 +1,252 @@ +import numpy as np +import copy +from gym.spaces import Box, Discrete +from gym import spaces +import warnings +from skimage import measure +from pettingzoo import AECEnv + +from .env_logger import EnvLogger + + +class BaseWrapper(AECEnv): + ''' + Creates a wrapper around `env` parameter. Extend this class + to create a useful wrapper. + ''' + metadata = {'render.modes': ['human']} + + def __init__(self, env): + super().__init__() + self.env = env + + self.num_agents = self.env.num_agents + self.agents = self.env.agents + self.observation_spaces = self.env.observation_spaces + self.action_spaces = self.env.action_spaces + + # we don't want these defined as we don't want them used before they are gotten + + # self.agent_selection = self.env.agent_selection + + # self.rewards = self.env.rewards + # self.dones = self.env.dones + + # we don't want to care one way or the other whether environments have an infos or not before reset + try: + self.infos = self.env.infos + except AttributeError: + pass + + # self.agent_order = self.env.agent_order + + def close(self): + self.env.close() + + def render(self, mode='human'): + self.env.render(mode) + + def reset(self, observe=True): + observation = self.env.reset(observe) + + self.agent_selection = self.env.agent_selection + self.agent_order = self.env.agent_order + self.rewards = self.env.rewards + self.dones = self.env.dones + self.infos = self.env.infos + + return observation + + def observe(self, agent): + return self.env.observe(agent) + + def step(self, action, observe=True): + next_obs = self.env.step(action, observe=observe) + + self.agent_selection = self.env.agent_selection + self.agent_order = self.env.agent_order + self.rewards = self.env.rewards + self.dones = self.env.dones + self.infos = self.env.infos + + return next_obs + + +class TerminateIllegalWrapper(BaseWrapper): + ''' + this wrapper terminates the game with the current player losing + in case of illegal values + + parameters: + - illegal_reward: number that is the value of the player making an illegal move. + ''' + def __init__(self, env, illegal_reward): + super().__init__(env) + self._illegal_value = illegal_reward + + def step(self, action, observe=True): + current_agent = self.agent_selection + assert 'legal_moves' in self.infos[current_agent], "Illegal moves must always be defined to use the TerminateIllegalWrapper" + if action not in self.infos[current_agent]['legal_moves']: + EnvLogger.warn_on_illegal_move() + self.dones = {d: True for d in self.dones} + for info in self.infos.values(): + info['legal_moves'] = [] + self.rewards = {d: 0 for d in self.dones} + self.rewards[current_agent] = self._illegal_value + else: + return super().step(action, observe) + + +class NanNoOpWrapper(BaseWrapper): + ''' + this wrapper expects there to be a no_op_action parameter which + is the action to take in cases when nothing should be done. + ''' + def __init__(self, env, no_op_action, no_op_policy): + super().__init__(env) + self._no_op_action = no_op_action + self._no_op_policy = no_op_policy + + def step(self, action, observe=True): + if np.isnan(action).any(): + EnvLogger.warn_action_is_NaN(self._no_op_policy) + action = self._no_op_action + return super().step(action, observe) + + +class NanZerosWrapper(BaseWrapper): + ''' + this wrapper warns and executes a zeros action when nothing should be done. + Only for Box action spaces. + ''' + def __init__(self, env): + super().__init__(env) + assert all(isinstance(space, Box) for space in self.action_spaces.values()), "should only use NanZerosWrapper for Box spaces. Use NanNoOpWrapper for discrete spaces" + + def step(self, action, observe=True): + if np.isnan(action).any(): + EnvLogger.warn_action_is_NaN("taking the all zeros action") + action = np.zeros_like(action) + return super().step(action, observe) + + +class NaNRandomWrapper(BaseWrapper): + ''' + this wrapper takes a random action + ''' + def __init__(self, env): + super().__init__(env) + assert all(isinstance(space, Discrete) for space in env.action_spaces.values()), "action space should be discrete for NaNRandomWrapper" + SEED = 0x33bb9cc9 + self.np_random = np.random.RandomState(SEED) + + def step(self, action, observe=True): + if np.isnan(action).any(): + cur_info = self.infos[self.agent_selection] + if 'legal_moves' in cur_info: + backup_policy = "taking a random legal action" + EnvLogger.warn_action_is_NaN(backup_policy) + action = self.np_random.choice(cur_info['legal_moves']) + else: + backup_policy = "taking a random action" + EnvLogger.warn_action_is_NaN(backup_policy) + act_space = self.action_spaces[self.agent_selection] + action = self.np_random.choice(act_space.n) + + return super().step(action, observe) + + +class AssertOutOfBoundsWrapper(BaseWrapper): + ''' + this wrapper crashes for out of bounds actions + Should be used for Discrete spaces + ''' + def __init__(self, env): + super().__init__(env) + assert all(isinstance(space, Discrete) for space in self.action_spaces.values()), "should only use AssertOutOfBoundsWrapper for Discrete spaces" + + def step(self, action, observe=True): + assert self.action_spaces[self.agent_selection].contains(action), "action is not in action space" + return super().step(action, observe) + + +class ClipOutOfBoundsWrapper(BaseWrapper): + ''' + this wrapper crops out of bounds actions for Box spaces + ''' + def __init__(self, env): + super().__init__(env) + assert all(isinstance(space, Box) for space in self.action_spaces.values()), "should only use ClipOutOfBoundsWrapper for Box spaces" + + def step(self, action, observe=True): + space = self.action_spaces[self.agent_selection] + if not space.contains(action): + assert space.shape == action.shape, "action should have shape {}, has shape {}".format(space.shape, action.shape) + + EnvLogger.warn_action_out_of_bound(action=action, action_space=space, backup_policy="clipping to space") + action = np.clip(action, space.low, space.high) + + return super().step(action, observe) + + +class OrderEnforcingWrapper(BaseWrapper): + ''' + check all orders: + + * error on getting rewards, dones, infos, agent_selection, agent_order before reset + * error on calling step, observe before reset + * warn on calling close before render or reset + * warn on calling step after environment is done + ''' + def __init__(self, env): + self._has_reset = False + self._has_rendered = False + super().__init__(env) + + def __getattr__(self, value): + ''' + raises an error message when data is gotten from the env + which should only be gotten after reset + ''' + if value in {"rewards", "dones", "infos", "agent_selection", "agent_order"}: + EnvLogger.error_field_before_reset(value) + return None + else: + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, value)) + + def render(self, mode='human'): + if not self._has_reset: + EnvLogger.error_render_before_reset() + self._has_rendered = True + super().render(mode) + + def close(self): + super().close() + if not self._has_rendered: + EnvLogger.warn_close_unrendered_env() + if not self._has_reset: + EnvLogger.warn_close_before_reset() + + self._has_rendered = False + self._has_reset = False + + def step(self, action, observe=True): + if not self._has_reset: + EnvLogger.error_step_before_reset() + elif self.dones[self.agent_selection]: + EnvLogger.warn_step_after_done() + self.dones = {agent: True for agent in self.dones} + self.rewards = {agent: 0 for agent in self.rewards} + return super().observe(action) if observe else None + else: + return super().step(action, observe) + + def observe(self, agent): + if not self._has_reset: + EnvLogger.error_observe_before_reset() + return super().observe(observe) + + def reset(self, observe=True): + self._has_reset = True + return super().reset(observe) diff --git a/requirements.txt b/requirements.txt index 57e3d6dcf..a899af852 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ -gym>=0.15.4 +gym>=0.17.2 pygame==2.0.0.dev6 scikit-image>=0.16.2 numpy>=1.18.0 -matplotlib>=3.1.2 pymunk>=5.6.0 gym[box2d]>=0.15.4 python-chess -rlcard >= 0.1.14 +rlcard >= 0.2.1 pynput opencv-python diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..5f464bca2 --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +from setuptools import find_packages, setup + +with open("README.md", "r") as fh: + long_description = "" + header_count = 0 + for line in fh: + if line.startswith("##"): + header_count += 1 + if header_count < 2: + long_description += line + else: + break + +setup( + name='PettingZoo', + version="0.1.5", + author='PettingZoo Team', + author_email="justinkterry@gmail.com", + description="Gym for multi-agent reinforcement learning", + url='https://github.com/PettingZoo-Team/PettingZoo', + long_description=long_description, + long_description_content_type="text/markdown", + keywords=["Reinforcement Learning", "game", "RL", "AI", "gym"], + python_requires=">=3.6", + data_files=[("", ["LICENSE.txt"])], + packages=find_packages(), + include_package_data=True, + install_requires=[ + "gym>=0.17.2", + "pygame==2.0.0.dev6", + "scikit-image>=0.16.2", + "numpy>=1.18.0", + "pymunk>=5.6.0", + "gym[box2d]>=0.15.4", + "python-chess", + "rlcard >= 0.1.14", + "pynput", + "opencv-python" + ], +) diff --git a/test.sh b/test.sh index 260aa374b..a433bc67d 100755 --- a/test.sh +++ b/test.sh @@ -30,7 +30,7 @@ python3 -m pettingzoo.tests.ci_test gamma/cooperative_pong $render $manual_contr python3 -m pettingzoo.tests.ci_test gamma/knights_archers_zombies $render $manual_control $bombardment $performance $save_obs python3 -m pettingzoo.tests.ci_test gamma/pistonball $render $manual_control $bombardment $performance $save_obs python3 -m pettingzoo.tests.ci_test gamma/prison $render $manual_control $bombardment $performance $save_obs -# python3 -m pettingzoo.tests.ci_test gamma/prospector $render $manual_control $bombardment $performance $save_obs +python3 -m pettingzoo.tests.ci_test gamma/prospector $render $manual_control $bombardment $performance $save_obs # MAgent flake8 pettingzoo/magent --ignore E501,E731,E741,E402,F401,W503