diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index cc47d31c7..569e87ee8 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10'] # '3.11' - broken due to numba - tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL'] # TODO: add back 'CleanRL' after SuperSuit is fixed + tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3/kaz', 'SB3/waterworld', 'SB3/connect_four', 'SB3/test'] # TODO: add back RLlib once it is fixed steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -29,8 +29,9 @@ jobs: - name: Install dependencies and run tutorials run: | sudo apt-get install python3-opengl xvfb + export root_dir=$(pwd) cd tutorials/${{ matrix.tutorial }} pip install -r requirements.txt pip uninstall -y pettingzoo - pip install -e ../.. + pip install -e $root_dir for f in *.py; do xvfb-run -a -s "-screen 0 1024x768x24" python "$f"; done diff --git a/docs/api/parallel.md b/docs/api/parallel.md index 7352ea7fe..27d0c2a1e 100644 --- a/docs/api/parallel.md +++ b/docs/api/parallel.md @@ -5,7 +5,7 @@ title: Parallel # Parallel API -In addition to the main API, we have a secondary parallel API for environments where all agents have simultaneous actions and observations. An environment with parallel API support can be created via `.parallel_env()`. This API is based around the paradigm of *Partially Observable Stochastic Games* (POSGs) and the details are similar to [RLLib's MultiAgent environment specification](https://docs.ray.io/en/latest/rllib-env.html#multi-agent-and-hierarchical), except we allow for different observation and action spaces between the agents. +In addition to the main API, we have a secondary parallel API for environments where all agents have simultaneous actions and observations. An environment with parallel API support can be created via `.parallel_env()`. This API is based around the paradigm of *Partially Observable Stochastic Games* (POSGs) and the details are similar to [RLlib's MultiAgent environment specification](https://docs.ray.io/en/latest/rllib-env.html#multi-agent-and-hierarchical), except we allow for different observation and action spaces between the agents. All parallel environments can be converted into AEC environments by splitting a simultaneous turn into sequential turns, with observations only from the previous cycle. diff --git a/docs/index.md b/docs/index.md index 6cabc4718..fe6c1fff5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,6 +44,7 @@ tutorials/cleanrl/index tutorials/tianshou/index tutorials/rllib/index tutorials/langchain/index +tutorials/sb3/index ``` ```{toctree} diff --git a/docs/tutorials/rllib/holdem.md b/docs/tutorials/rllib/holdem.md index c102032be..5f333ee2c 100644 --- a/docs/tutorials/rllib/holdem.md +++ b/docs/tutorials/rllib/holdem.md @@ -16,7 +16,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLlib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/docs/tutorials/rllib/index.md b/docs/tutorials/rllib/index.md index dc95d2b52..ad06a050b 100644 --- a/docs/tutorials/rllib/index.md +++ b/docs/tutorials/rllib/index.md @@ -2,9 +2,9 @@ title: "RLlib" --- -# RLlib Tutorial +# Ray RLlib Tutorial -These tutorials show you how to use [RLlib](https://docs.ray.io/en/latest/rllib/index.html) to train agents in PettingZoo environments. +These tutorials show you how to use [Ray](https://docs.ray.io/en/latest/index.html)'s [RLlib](https://docs.ray.io/en/latest/rllib/index.html) library to train agents in PettingZoo environments. * [PPO for Pistonball](/tutorials/rllib/pistonball/): _Train a PPO model in a parallel environment_ diff --git a/docs/tutorials/rllib/pistonball.md b/docs/tutorials/rllib/pistonball.md index 7e24cf991..53a5fcd31 100644 --- a/docs/tutorials/rllib/pistonball.md +++ b/docs/tutorials/rllib/pistonball.md @@ -17,7 +17,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLlib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/docs/tutorials/sb3/connect_four.md b/docs/tutorials/sb3/connect_four.md new file mode 100644 index 000000000..fc96af2cf --- /dev/null +++ b/docs/tutorials/sb3/connect_four.md @@ -0,0 +1,54 @@ +--- +title: "SB3: Action Masked PPO for Connect Four" +--- + +# SB3: Action Masked PPO for Connect Four + +This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Connect Four](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +It creates a custom Wrapper to convert to a [Gymnasium](https://gymnasium.farama.org/)-like environment which is compatible with [SB3 action masking](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html). + + +```{eval-rst} +.. note:: + + This environment has a discrete (1-dimensional) observation space with an illegal action mask, so we use a masked MLP feature extractor. +``` + +```{eval-rst} +.. warning:: + + This wrapper assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments. +``` + +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/connect_four/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training and Evaluation + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/connect_four/sb3_connect_four_action_mask.py + :language: python +``` + +### Testing other PettingZoo Classic environments + +The following script uses [pytest](https://docs.pytest.org/en/latest/) to test all other PettingZoo environments which support action masking. + +This code yields good results on simpler environments like [Gin Rummy](/environments/classic/gin_rummy/) and [Texas Hold’em No Limit](/environments/classic/texas_holdem_no_limit/), while failing to perform better than random in more difficult environments such as [Chess](/environments/classic/chess/) or [Hanabi](/environments/classic/hanabi/). + + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/test_sb3_action_mask.py + :language: python +``` diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index dfd2ebb8f..617d74a2f 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -2,24 +2,45 @@ title: "Stable-Baselines3" --- -# SB3 Tutorial +# Stable-Baselines3 Tutorial -These tutorials show you how to use [SB3](https://stable-baselines3.readthedocs.io/en/master/) to train agents in PettingZoo environments. +These tutorials show you how to use the [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) library to train agents in PettingZoo environments. -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a parallel environment_ +For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/). -* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in an AEC environment_ +* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized environment with visual observations_ +For non-visual environments, we use [MLP](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.MlpPolicy) policies and do not perform any pre-processing steps. -```{figure} https://docs.ray.io/en/latest/_images/rllib-stack.svg - :alt: RLlib stack +* [PPO for Waterworld](/tutorials/sb3/waterworld/): _Train agents using PPO in a vectorized environment with discrete observations_ + +* [Action Masked PPO for Connect Four](/tutorials/sb3/connect_four/): _Train agents using Action Masked PPO in an AEC environment_ + + +## Stable-Baselines Overview + +[Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) is a library providing reliable implementations of reinforcement learning algorithms in [PyTorch](https://pytorch.org/). It provides a clean and simple interface, giving you access to off-the-shelf state-of-the-art model-free RL algorithms. It allows training of RL agents with only a few lines of code. + +For more information, see the [Stable-Baselines3 v1.0 Blog Post](https://araffin.github.io/post/sb3/) + + +```{eval-rst} +.. warning:: + + Note: SB3 is designed for single-agent RL and does not plan on natively supporting multi-agent PettingZoo environments. These tutorials are only intended for demonstration purposes, to show how SB3 can be adapted to work in multi-agent settings. +``` + + +```{figure} https://raw.githubusercontent.com/DLR-RM/stable-baselines3/master/docs/_static/img/logo.png + :alt: SB3 Logo :width: 80% ``` ```{toctree} :hidden: -:caption: RLlib +:caption: SB3 -pistonball -holdem +kaz +waterworld +connect_four ``` diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md new file mode 100644 index 000000000..1714d66b2 --- /dev/null +++ b/docs/tutorials/sb3/kaz.md @@ -0,0 +1,41 @@ +--- +title: "SB3: PPO for Knights-Archers-Zombies" +--- + +# SB3: PPO for Knights-Archers-Zombies + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Knights-Archers-Zombies](https://pettingzoo.farama.org/environments/butterfly/knights_archers_zombies/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +It converts the environment into a Parallel environment and uses SuperSuit to create vectorized environments, leveraging multithreading to speed up training. + +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + +```{eval-rst} +.. note:: + + This environment has a visual (3-dimensional) observation space, so we use a CNN feature extractor. +``` + +```{eval-rst} +.. warning:: + + Because this environment allows agents to spawn and die, it requires using SuperSuit's Black Death wrapper, which provides blank observations to dead agents, rather than removing them from the environment. +``` + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/kaz/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training and Evaluation + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/kaz/sb3_kaz_vector.py + :language: python +``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md deleted file mode 100644 index 8e86e13e2..000000000 --- a/docs/tutorials/sb3/pistonball.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -title: "SB3: PPO for Pistonball (Parallel)" ---- - -# RLlib: PPO for Pistonball - -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). - -After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. - - -## Environment Setup -To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt - :language: text -``` - -## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). - -### Training the RL agent - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball.py - :language: python -``` - -### Watching the trained RL agent play - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/render_sb3_pistonball.py - :language: python -``` diff --git a/docs/tutorials/sb3/rps.md b/docs/tutorials/sb3/rps.md deleted file mode 100644 index fa70d3c55..000000000 --- a/docs/tutorials/sb3/rps.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -title: "SB3: PPO for Rock-Paper-Scissors (AEC)" ---- - -# RLlib: PPO for Rock-Paper-Scissors - -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). - -After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. - - -## Environment Setup -To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt - :language: text -``` - -## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). - -### Training the RL agent - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_rps.py - :language: python -``` - -### Watching the trained RL agent play - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/render_sb3_rps.py - :language: python -``` diff --git a/docs/tutorials/sb3/waterworld.md b/docs/tutorials/sb3/waterworld.md new file mode 100644 index 000000000..519079a5f --- /dev/null +++ b/docs/tutorials/sb3/waterworld.md @@ -0,0 +1,33 @@ +--- +title: "SB3: PPO for Multiwalker (Parallel)" +--- + +# SB3: PPO for Multiwalker + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Multiwalker](https://pettingzoo.farama.org/environments/sisl/multiwalker/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). + +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + +```{eval-rst} +.. note:: + + This environment has a discrete (1-dimensional) observation space, so we use an MLP feature extractor. +``` + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/waterworld/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training and Evaluation + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/multiwalker/sb3_waterworld_vector.py + :language: python +``` diff --git a/pettingzoo/classic/chess/chess.py b/pettingzoo/classic/chess/chess.py index 378eef0b4..5100f8fc3 100644 --- a/pettingzoo/classic/chess/chess.py +++ b/pettingzoo/classic/chess/chess.py @@ -268,6 +268,9 @@ def step(self, action): current_agent = self.agent_selection current_index = self.agents.index(current_agent) + # Cast action into int + action = int(action) + chosen_move = chess_utils.action_to_move(self.board, action, current_index) assert chosen_move in self.board.legal_moves self.board.push(chosen_move) diff --git a/pyproject.toml b/pyproject.toml index a11409697..e139a5d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classic = [ ] butterfly = ["pygame==2.3.0", "pymunk==6.2.0"] mpe = ["pygame==2.3.0"] -sisl = ["pygame==2.3.0", "box2d-py==2.3.5", "scipy>=1.4.1"] +sisl = ["pygame==2.3.0", "pymunk==6.2.0", "box2d-py==2.3.5", "scipy>=1.4.1"] other = ["pillow>=8.0.1"] testing = [ "pynput", diff --git a/tutorials/Ray/render_rllib_leduc_holdem.py b/tutorials/Ray/render_rllib_leduc_holdem.py index c622d1da8..b514872ff 100644 --- a/tutorials/Ray/render_rllib_leduc_holdem.py +++ b/tutorials/Ray/render_rllib_leduc_holdem.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to view trained agents playing Leduoc Holdem. +"""Uses Ray's RLlib to view trained agents playing Leduoc Holdem. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/render_rllib_pistonball.py b/tutorials/Ray/render_rllib_pistonball.py index 5469f4943..a15edd3ea 100644 --- a/tutorials/Ray/render_rllib_pistonball.py +++ b/tutorials/Ray/render_rllib_pistonball.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to view trained agents playing Pistonball. +"""Uses Ray's RLlib to view trained agents playing Pistonball. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/rllib_leduc_holdem.py b/tutorials/Ray/rllib_leduc_holdem.py index e46da3f49..9f4ee5c4e 100644 --- a/tutorials/Ray/rllib_leduc_holdem.py +++ b/tutorials/Ray/rllib_leduc_holdem.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to train agents to play Leduc Holdem. +"""Uses Ray's RLlib to train agents to play Leduc Holdem. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/rllib_pistonball.py b/tutorials/Ray/rllib_pistonball.py index f3bc4d713..92fbd1cd9 100644 --- a/tutorials/Ray/rllib_pistonball.py +++ b/tutorials/Ray/rllib_pistonball.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to train agents to play Pistonball. +"""Uses Ray's RLlib to train agents to play Pistonball. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/SB3/connect_four/requirements.txt b/tutorials/SB3/connect_four/requirements.txt new file mode 100644 index 000000000..30917f7b2 --- /dev/null +++ b/tutorials/SB3/connect_four/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[classic]>=1.23.1 +stable-baselines3>=2.0.0 +sb3-contrib>=2.0.0 diff --git a/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py b/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py new file mode 100644 index 000000000..789794bae --- /dev/null +++ b/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py @@ -0,0 +1,179 @@ +"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking. + +For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking +For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html + +Author: Elliot (https://github.com/elliottower) +""" +import glob +import os +import time + +from sb3_contrib import MaskablePPO +from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy +from sb3_contrib.common.wrappers import ActionMasker + +import pettingzoo.utils +from pettingzoo.classic import connect_four_v3 + + +class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper): + """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking.""" + + def reset(self, seed=None, options=None): + """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent. + + This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions + """ + super().reset(seed, options) + + # Strip the action mask out from the observation space + self.observation_space = super().observation_space(self.possible_agents[0])[ + "observation" + ] + self.action_space = super().action_space(self.possible_agents[0]) + + # Return initial observation, info (PettingZoo AEC envs do not by default) + return self.observe(self.agent_selection), {} + + def step(self, action): + """Gymnasium-like step function, returning observation, reward, termination, truncation, info.""" + super().step(action) + return super().last() + + def observe(self, agent): + """Return only raw observation, removing action mask.""" + return super().observe(agent)["observation"] + + def action_mask(self): + """Separate function used in order to access the action mask.""" + return super().observe(self.agent_selection)["action_mask"] + + +def mask_fn(env): + # Do whatever you'd like in this function to return the action mask + # for the current env. In this example, we assume the env has a + # helpful method we can rely on. + return env.action_mask() + + +def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs): + """Train a single agent to play both sides in a PettingZoo environment using invalid action masking.""" + env = env_fn.env(**env_kwargs) + + print(f"Starting training on {str(env.metadata['name'])}.") + + # Custom wrapper to convert PettingZoo envs to work with SB3 action masking + env = SB3ActionMaskWrapper(env) + + env.reset(seed=seed) # Must call reset() in order to re-define the spaces + + env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function) + # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped + # with ActionMasker. If the wrapper is detected, the masks are automatically + # retrieved and used when learning. Note that MaskablePPO does not accept + # a new action_mask_fn kwarg, as it did in an earlier draft. + model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1) + model.set_random_seed(seed) + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") + + env.close() + + +def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + print( + f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}." + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = MaskablePPO.load(latest_policy) + + scores = {agent: 0 for agent in env.possible_agents} + total_rewards = {agent: 0 for agent in env.possible_agents} + round_rewards = [] + + for i in range(num_games): + env.reset(seed=i) + env.action_space(env.possible_agents[0]).seed(i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + # Separate observation and action mask + observation, action_mask = obs.values() + + if termination or truncation: + # If there is a winner, keep track, otherwise don't change the scores (tie) + if ( + env.rewards[env.possible_agents[0]] + != env.rewards[env.possible_agents[1]] + ): + winner = max(env.rewards, key=env.rewards.get) + scores[winner] += env.rewards[ + winner + ] # only tracks the largest reward (winner of game) + # Also track negative and positive rewards (penalizes illegal moves) + for a in env.possible_agents: + total_rewards[a] += env.rewards[a] + # List of rewards by round, for reference + round_rewards.append(env.rewards) + break + else: + if agent == env.possible_agents[0]: + act = env.action_space(agent).sample(action_mask) + else: + # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int? + act = int( + model.predict( + observation, action_masks=action_mask, deterministic=True + )[0] + ) + env.step(act) + env.close() + + # Avoid dividing by zero + if sum(scores.values()) == 0: + winrate = 0 + else: + winrate = scores[env.possible_agents[1]] / sum(scores.values()) + print("Rewards by round: ", round_rewards) + print("Total rewards (incl. negative rewards): ", total_rewards) + print("Winrate: ", winrate) + print("Final scores: ", scores) + return round_rewards, total_rewards, winrate, scores + + +if __name__ == "__main__": + env_fn = connect_four_v3 + + env_kwargs = {} + + # Evaluation/training hyperparameter notes: + # 10k steps: Winrate: 0.76, loss order of 1e-03 + # 20k steps: Winrate: 0.86, loss order of 1e-04 + # 40k steps: Winrate: 0.86, loss order of 7e-06 + + # Train a model against itself (takes ~20 seconds on a laptop CPU) + train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs) + + # Evaluate 100 games against a random agent (winrate should be ~80%) + eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) + + # Watch two games vs a random agent + eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/kaz/requirements.txt b/tutorials/SB3/kaz/requirements.txt new file mode 100644 index 000000000..6199e7131 --- /dev/null +++ b/tutorials/SB3/kaz/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[butterfly]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 diff --git a/tutorials/SB3/kaz/sb3_kaz_vector.py b/tutorials/SB3/kaz/sb3_kaz_vector.py new file mode 100644 index 000000000..b65df30e0 --- /dev/null +++ b/tutorials/SB3/kaz/sb3_kaz_vector.py @@ -0,0 +1,140 @@ +"""Uses Stable-Baselines3 to train agents in the Knights-Archers-Zombies environment using SuperSuit vector envs. + +This environment requires using SuperSuit's Black Death wrapper, to handle agent death. + +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) +""" +from __future__ import annotations + +import glob +import os +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy + +from pettingzoo.butterfly import knights_archers_zombies_v10 + + +def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs): + # Train a single agent to play both sides in an AEC environment + env = env_fn.parallel_env(**env_kwargs) + + # Add black death wrapper so the number of agents stays constant + # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True + env = ss.black_death_v3(env) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + + # TODO: test different hyperparameters + model = PPO( + MlpPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, + ) + + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") + + env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # Note: we evaluate here using an AEC environments, to allow for easy A/B testing against random policies + # For example, we can see here that using a random agent for archer_0 results in less points than the trained agent + for i in range(num_games): + env.reset(seed=i) + env.action_space(env.possible_agents[0]).seed(i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + for agent in env.agents: + rewards[agent] += env.rewards[agent] + + if termination or truncation: + break + else: + if agent == env.possible_agents[0]: + act = env.action_space(agent).sample() + else: + act = model.predict(obs, deterministic=True)[0] + env.step(act) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + avg_reward_per_agent = { + agent: rewards[agent] / num_games for agent in env.possible_agents + } + print(f"Avg reward: {avg_reward}") + print("Avg reward per agent, per game: ", avg_reward_per_agent) + print("Full rewards: ", rewards) + return avg_reward + + +if __name__ == "__main__": + env_fn = knights_archers_zombies_v10 + + # Notes on environment configuration: + # max_cycles 100, max_zombies 4, seems to work well (13 points over 10 games) + # max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points) + env_kwargs = dict(max_cycles=100, max_zombies=4) + + # Train a model (takes ~5 minutes on a laptop CPU) + train(env_fn, steps=81_920, seed=0, **env_kwargs) + + # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/pistonball/requirements.txt b/tutorials/SB3/pistonball/requirements.txt new file mode 100644 index 000000000..6199e7131 --- /dev/null +++ b/tutorials/SB3/pistonball/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[butterfly]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 diff --git a/tutorials/SB3/pistonball/sb3_pistonball_vector.py b/tutorials/SB3/pistonball/sb3_pistonball_vector.py new file mode 100644 index 000000000..a6a0f723e --- /dev/null +++ b/tutorials/SB3/pistonball/sb3_pistonball_vector.py @@ -0,0 +1,139 @@ +"""Uses Stable-Baselines3 to train agents in the Pistonball environment using SuperSuit vector envs. + +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) +""" +from __future__ import annotations + +import glob +import os +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import CnnPolicy + +from pettingzoo.butterfly import pistonball_v6 + + +def train_butterfly_supersuit( + env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs +): + # Train a single agent to play both sides in a Parallel environment, + env = env_fn.parallel_env(**env_kwargs) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 4, num_cpus=2, base_class="stable_baselines3") + + model = PPO( + CnnPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, + ) + + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") + + # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) + env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # Note: We train using the Parallel API but evaluate using the AEC API + # SB3 models are designed for single-agent settings, we get around this by using he same model for every agent + for i in range(num_games): + env.reset(seed=i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + if termination or truncation: + for agent in env.agents: + rewards[agent] += env.rewards[agent] + break + else: + act = model.predict(obs, deterministic=True)[0] + + env.step(act) + + # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + print(f"Avg reward: {avg_reward}") + return avg_reward + + +if __name__ == "__main__": + env_fn = pistonball_v6 + + env_kwargs = dict( + n_pistons=20, + time_penalty=-0.1, + continuous=True, + random_drop=True, + random_rotate=True, + ball_mass=0.75, + ball_friction=0.3, + ball_elasticity=1.5, + max_cycles=25, + ) + + # Train a model (takes ~3 minutes on a laptop CPU) + # Note: stochastic environment makes training difficult, for better results try order of 2 million (~2 hours on GPU) + train_butterfly_supersuit(env_fn, steps=40_960 * 2, seed=0, **env_kwargs) + + # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py deleted file mode 100644 index 794cf027a..000000000 --- a/tutorials/SB3/render_sb3_pistonball.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Uses Stable-Baselines3 to view trained agents playing Pistonball. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) -""" -import glob -import os - -import supersuit as ss -from stable_baselines3 import PPO - -from pettingzoo.butterfly import pistonball_v6 - -env = pistonball_v6.env(render_mode="human") - -env = ss.color_reduction_v0(env, mode="B") -env = ss.resize_v1(env, x_size=84, y_size=84) -env = ss.frame_stack_v1(env, 3) - -latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) -model = PPO.load(latest_policy) - -env.reset() -for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - act = ( - model.predict(obs, deterministic=True)[0] - if not termination or truncation - else None - ) - env.step(act) diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py deleted file mode 100644 index c07c15567..000000000 --- a/tutorials/SB3/render_sb3_rps.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Uses Stable-Baselines3 to view trained agents playing Rock-Paper-Scissors. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) -""" - -import glob -import os - -from stable_baselines3 import PPO - -from pettingzoo.classic import rps_v2 - -env = rps_v2.env(render_mode="human") - -latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) -model = PPO.load(latest_policy) - -env.reset() -for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - if termination or truncation: - act = None - else: - act = model.predict(obs, deterministic=True)[0] - env.step(act) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index 72605e8cc..60ca7deb2 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,3 +1,5 @@ -stable-baselines3 >= 2.0.0 -pettingzoo >= 1.23.1 -supersuit >= 3.8.1 +stable-baselines3>=2.0.0 +pettingzoo[classic,butterfly,sisl]>=1.23.1 +supersuit>=3.8.1 +sb3-contrib>=2.0.0 +pytest diff --git a/tutorials/SB3/sb3_pistonball.py b/tutorials/SB3/sb3_pistonball.py deleted file mode 100644 index a00f88d08..000000000 --- a/tutorials/SB3/sb3_pistonball.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Uses Stable-Baselines3 to train agents to play Pistonball. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) -""" -import time - -import supersuit as ss -from stable_baselines3 import PPO -from stable_baselines3.ppo import CnnPolicy - -from pettingzoo.butterfly import pistonball_v6 - -env = pistonball_v6.parallel_env( - n_pistons=20, - time_penalty=-0.1, - continuous=True, - random_drop=True, - random_rotate=True, - ball_mass=0.75, - ball_friction=0.3, - ball_elasticity=1.5, - max_cycles=125, -) - -env = ss.color_reduction_v0(env, mode="B") -env = ss.resize_v1(env, x_size=84, y_size=84) -env = ss.frame_stack_v1(env, 3) - - -env = ss.pettingzoo_env_to_vec_env_v1(env) -env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") - -model = PPO( - CnnPolicy, - env, - verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, - batch_size=256, -) - -model.learn(total_timesteps=2_000_000) - -model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") - - -print("Model has been saved.") diff --git a/tutorials/SB3/sb3_rps.py b/tutorials/SB3/sb3_rps.py deleted file mode 100644 index 0439d698a..000000000 --- a/tutorials/SB3/sb3_rps.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) -""" -import time - -import supersuit as ss -from stable_baselines3 import PPO -from stable_baselines3.ppo import MlpPolicy - -from pettingzoo.classic import rps_v2 -from pettingzoo.utils import turn_based_aec_to_parallel - -env = rps_v2.env() -env = turn_based_aec_to_parallel(env) - -env = ss.pettingzoo_env_to_vec_env_v1(env) -env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") - -# TODO: find hyperparameters that make the model actually learn -model = PPO( - MlpPolicy, - env, - verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, - batch_size=256, -) - -model.learn(total_timesteps=2_000_000) - -model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") - -print("Model has been saved.") diff --git a/tutorials/SB3/test/requirements.txt b/tutorials/SB3/test/requirements.txt new file mode 100644 index 000000000..838ea192b --- /dev/null +++ b/tutorials/SB3/test/requirements.txt @@ -0,0 +1,4 @@ +pettingzoo[classic]>=1.23.1 +stable-baselines3>=2.0.0 +sb3-contrib>=2.0.0 +pytest diff --git a/tutorials/SB3/test/test_sb3_action_mask.py b/tutorials/SB3/test/test_sb3_action_mask.py new file mode 100644 index 000000000..f8f803766 --- /dev/null +++ b/tutorials/SB3/test/test_sb3_action_mask.py @@ -0,0 +1,126 @@ +"""Tests that action masking code works properly with all PettingZoo classic environments.""" + +import pytest + +from pettingzoo.classic import ( + chess_v6, + gin_rummy_v4, + go_v5, + hanabi_v4, + leduc_holdem_v4, + texas_holdem_no_limit_v6, + texas_holdem_v4, + tictactoe_v3, +) + +pytest.importorskip("stable_baselines3") +pytest.importorskip("sb3_contrib") + +# Note: Connect Four is tested in sb3_connect_four_action_mask.py +# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself + +# These environments do better than random even after the minimum number of timesteps +EASY_ENVS = [ + gin_rummy_v4, + texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine + texas_holdem_v4, +] + +# More difficult environments which will likely take more training time +MEDIUM_ENVS = [ + leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores + hanabi_v4, # even with 10x as many steps, total score seems to always be tied between the two agents + tictactoe_v3, # even with 10x as many steps, agent still loses every time (most likely an error somewhere) +] + +# Most difficult environments to train agents for (and longest games +# TODO: test board_size to see if smaller go board is more easily solvable +HARD_ENVS = [ + chess_v6, # difficult to train because games take so long, 0.28 winrate even after 10x + go_v5, # difficult to train because games take so long +] + + +@pytest.mark.parametrize("env_fn", EASY_ENVS) +def test_action_mask_easy(env_fn): + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) + + env_kwargs = {} + + # Leduc Hold`em takes slightly longer to outperform random + steps = 8192 if env_fn != leduc_holdem_v4 else 8192 * 4 + + # Train a model against itself (takes ~2 minutes on GPU) + train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) + + assert winrate > 0.5 or ( + total_rewards[env_fn.env().possible_agents[1]] + > total_rewards[env_fn.env().possible_agents[0]] + ), "Trained policy should outperform random actions" + + # Watch two games (disabled by default) + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) + + +# @pytest.mark.skip( +# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" +# ) +@pytest.mark.parametrize("env_fn", MEDIUM_ENVS) +def test_action_mask_medium(env_fn): + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) + + env_kwargs = {} + + # Train a model against itself + train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) + + assert ( + winrate < 0.75 + ), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe + + # Watch two games (disabled by default) + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) + + +# @pytest.mark.skip( +# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" +# ) +@pytest.mark.parametrize("env_fn", HARD_ENVS) +def test_action_mask_hard(env_fn): + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) + + env_kwargs = {} + + # Train a model against itself + train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) + + assert ( + winrate < 0.5 + ), "Policy should not perform better than 50% winrate" # 28% for chess, 0% for go + + # Watch two games (disabled by default) + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/waterworld/requirements.txt b/tutorials/SB3/waterworld/requirements.txt new file mode 100644 index 000000000..cb7dac213 --- /dev/null +++ b/tutorials/SB3/waterworld/requirements.txt @@ -0,0 +1,4 @@ +pettingzoo[sisl]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 +pymunk diff --git a/tutorials/SB3/waterworld/sb3_waterworld_vector.py b/tutorials/SB3/waterworld/sb3_waterworld_vector.py new file mode 100644 index 000000000..edaa81afd --- /dev/null +++ b/tutorials/SB3/waterworld/sb3_waterworld_vector.py @@ -0,0 +1,108 @@ +"""Uses Stable-Baselines3 to train agents to play the Waterworld environment using SuperSuit vector envs. + +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) +""" +from __future__ import annotations + +import glob +import os +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy + +from pettingzoo.sisl import waterworld_v4 + + +def train_butterfly_supersuit( + env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs +): + # Train a single a model to play as each agent in a cooperative Parallel environment, + env = env_fn.parallel_env(**env_kwargs) + + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + + # Note: Waterworld's observation space is discrete (242,) so we use an MLP policy rather than CNN + model = PPO( + MlpPolicy, + env, + verbose=3, + learning_rate=1e-3, + batch_size=256, + ) + + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") + + env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # Note: We train using the Parallel API but evaluate using the AEC API + # SB3 models are designed for single-agent settings, we get around this by using he same model for every agent + for i in range(num_games): + env.reset(seed=i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + for agent in env.agents: + rewards[agent] += env.rewards[agent] + if termination or truncation: + break + else: + act = model.predict(obs, deterministic=True)[0] + + env.step(act) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + print("Rewards: ", rewards) + print(f"Avg reward: {avg_reward}") + return avg_reward + + +if __name__ == "__main__": + env_fn = waterworld_v4 + env_kwargs = {} + + # Train a model (takes ~3 minutes on GPU) + train_butterfly_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs) + + # Evaluate 10 games (average reward should be positive but can vary significantly) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/Tianshou/2_training_agents.py b/tutorials/Tianshou/2_training_agents.py index 3f4a1e8fd..30ae1b159 100644 --- a/tutorials/Tianshou/2_training_agents.py +++ b/tutorials/Tianshou/2_training_agents.py @@ -12,7 +12,7 @@ import os from typing import Optional, Tuple -import gymnasium as gym +import gymnasium import numpy as np import torch from tianshou.data import Collector, VectorReplayBuffer @@ -33,7 +33,7 @@ def _get_agents( env = _get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gym.spaces.Dict) + if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) if agent_learn is None: diff --git a/tutorials/Tianshou/3_cli_and_logging.py b/tutorials/Tianshou/3_cli_and_logging.py index e78774dd3..b11abe574 100644 --- a/tutorials/Tianshou/3_cli_and_logging.py +++ b/tutorials/Tianshou/3_cli_and_logging.py @@ -14,7 +14,7 @@ from copy import deepcopy from typing import Optional, Tuple -import gymnasium as gym +import gymnasium import numpy as np import torch from tianshou.data import Collector, VectorReplayBuffer @@ -105,7 +105,7 @@ def get_agents( env = get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gym.spaces.Dict) + if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) args.state_shape = observation_space.shape or observation_space.n