From cf2bef7e55de60b35657b109a03c2679ea79506a Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 15:51:58 -0400 Subject: [PATCH 01/38] Update SB3 tutorial to have __main__ (error on macOS) --- tutorials/SB3/requirements.txt | 6 +-- tutorials/SB3/sb3_pistonball.py | 84 ++++++++++++++++----------------- tutorials/SB3/sb3_rps.py | 61 +++++++++++++----------- 3 files changed, 78 insertions(+), 73 deletions(-) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index 72605e8cc..db26f6ece 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,3 +1,3 @@ -stable-baselines3 >= 2.0.0 -pettingzoo >= 1.23.1 -supersuit >= 3.8.1 +stable-baselines3>=2.0.0 +pettingzoo>=1.23.1 +supersuit>=3.8.1 diff --git a/tutorials/SB3/sb3_pistonball.py b/tutorials/SB3/sb3_pistonball.py index a00f88d08..d645468f5 100644 --- a/tutorials/SB3/sb3_pistonball.py +++ b/tutorials/SB3/sb3_pistonball.py @@ -12,45 +12,45 @@ from pettingzoo.butterfly import pistonball_v6 -env = pistonball_v6.parallel_env( - n_pistons=20, - time_penalty=-0.1, - continuous=True, - random_drop=True, - random_rotate=True, - ball_mass=0.75, - ball_friction=0.3, - ball_elasticity=1.5, - max_cycles=125, -) - -env = ss.color_reduction_v0(env, mode="B") -env = ss.resize_v1(env, x_size=84, y_size=84) -env = ss.frame_stack_v1(env, 3) - - -env = ss.pettingzoo_env_to_vec_env_v1(env) -env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") - -model = PPO( - CnnPolicy, - env, - verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, - batch_size=256, -) - -model.learn(total_timesteps=2_000_000) - -model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") - - -print("Model has been saved.") +if __name__ == "__main__": + env = pistonball_v6.parallel_env( + n_pistons=20, + time_penalty=-0.1, + continuous=True, + random_drop=True, + random_rotate=True, + ball_mass=0.75, + ball_friction=0.3, + ball_elasticity=1.5, + max_cycles=125, + ) + + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + + model = PPO( + CnnPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, + ) + + model.learn(total_timesteps=2_000_000) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + env.close() diff --git a/tutorials/SB3/sb3_rps.py b/tutorials/SB3/sb3_rps.py index 0439d698a..410951cb2 100644 --- a/tutorials/SB3/sb3_rps.py +++ b/tutorials/SB3/sb3_rps.py @@ -13,31 +13,36 @@ from pettingzoo.classic import rps_v2 from pettingzoo.utils import turn_based_aec_to_parallel -env = rps_v2.env() -env = turn_based_aec_to_parallel(env) - -env = ss.pettingzoo_env_to_vec_env_v1(env) -env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") - -# TODO: find hyperparameters that make the model actually learn -model = PPO( - MlpPolicy, - env, - verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, - batch_size=256, -) - -model.learn(total_timesteps=2_000_000) - -model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") - -print("Model has been saved.") +if __name__ == "__main__": + env = rps_v2.env() + env = turn_based_aec_to_parallel(env) + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + + # TODO: find hyperparameters that make the model actually learn + model = PPO( + MlpPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, + ) + + model.learn(total_timesteps=10_000) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + env.close() + import sys + + sys.exit(0) From 86916e6dcde6e32d6853982fcff3a1ca896ffe99 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 15:54:41 -0400 Subject: [PATCH 02/38] Add SB3 tests for tutorials --- .github/workflows/linux-tutorials-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index cc47d31c7..4e46ad714 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10'] # '3.11' - broken due to numba - tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL'] # TODO: add back 'CleanRL' after SuperSuit is fixed + tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3'] # TODO: add back RLlib once it is fixed steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} From f3e0cc924bd92792dc61c9994b558c3b548a95cc Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 19:45:20 -0400 Subject: [PATCH 03/38] Add action masking tutorial, fix typos/update documentation --- docs/api/parallel.md | 2 +- docs/index.md | 1 + docs/tutorials/rllib/holdem.md | 2 +- docs/tutorials/rllib/index.md | 4 +- docs/tutorials/rllib/pistonball.md | 2 +- docs/tutorials/sb3/chess.md | 38 ++++++++ docs/tutorials/sb3/index.md | 28 ++++-- docs/tutorials/sb3/pistonball.md | 4 +- docs/tutorials/sb3/rps.md | 8 +- pettingzoo/classic/chess/chess.py | 3 + tutorials/Ray/render_rllib_leduc_holdem.py | 2 +- tutorials/Ray/render_rllib_pistonball.py | 2 +- tutorials/Ray/rllib_leduc_holdem.py | 2 +- tutorials/Ray/rllib_pistonball.py | 2 +- tutorials/SB3/render_sb3_chess_action_mask.py | 36 ++++++++ tutorials/SB3/render_sb3_pistonball.py | 40 +++++---- tutorials/SB3/render_sb3_rps.py | 25 +++--- tutorials/SB3/requirements.txt | 1 + tutorials/SB3/sb3_chess_action_mask.py | 86 +++++++++++++++++++ ...pistonball.py => sb3_pistonball_vector.py} | 7 +- .../SB3/{sb3_rps.py => sb3_rps_vector.py} | 9 +- tutorials/SB3/test_sb3_action_mask.py | 34 ++++++++ 22 files changed, 277 insertions(+), 61 deletions(-) create mode 100644 docs/tutorials/sb3/chess.md create mode 100644 tutorials/SB3/render_sb3_chess_action_mask.py create mode 100644 tutorials/SB3/sb3_chess_action_mask.py rename tutorials/SB3/{sb3_pistonball.py => sb3_pistonball_vector.py} (81%) rename tutorials/SB3/{sb3_rps.py => sb3_rps_vector.py} (80%) create mode 100644 tutorials/SB3/test_sb3_action_mask.py diff --git a/docs/api/parallel.md b/docs/api/parallel.md index 7352ea7fe..27d0c2a1e 100644 --- a/docs/api/parallel.md +++ b/docs/api/parallel.md @@ -5,7 +5,7 @@ title: Parallel # Parallel API -In addition to the main API, we have a secondary parallel API for environments where all agents have simultaneous actions and observations. An environment with parallel API support can be created via `.parallel_env()`. This API is based around the paradigm of *Partially Observable Stochastic Games* (POSGs) and the details are similar to [RLLib's MultiAgent environment specification](https://docs.ray.io/en/latest/rllib-env.html#multi-agent-and-hierarchical), except we allow for different observation and action spaces between the agents. +In addition to the main API, we have a secondary parallel API for environments where all agents have simultaneous actions and observations. An environment with parallel API support can be created via `.parallel_env()`. This API is based around the paradigm of *Partially Observable Stochastic Games* (POSGs) and the details are similar to [RLlib's MultiAgent environment specification](https://docs.ray.io/en/latest/rllib-env.html#multi-agent-and-hierarchical), except we allow for different observation and action spaces between the agents. All parallel environments can be converted into AEC environments by splitting a simultaneous turn into sequential turns, with observations only from the previous cycle. diff --git a/docs/index.md b/docs/index.md index 6cabc4718..fe6c1fff5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,6 +44,7 @@ tutorials/cleanrl/index tutorials/tianshou/index tutorials/rllib/index tutorials/langchain/index +tutorials/sb3/index ``` ```{toctree} diff --git a/docs/tutorials/rllib/holdem.md b/docs/tutorials/rllib/holdem.md index c102032be..5f333ee2c 100644 --- a/docs/tutorials/rllib/holdem.md +++ b/docs/tutorials/rllib/holdem.md @@ -16,7 +16,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLlib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/docs/tutorials/rllib/index.md b/docs/tutorials/rllib/index.md index dc95d2b52..ad06a050b 100644 --- a/docs/tutorials/rllib/index.md +++ b/docs/tutorials/rllib/index.md @@ -2,9 +2,9 @@ title: "RLlib" --- -# RLlib Tutorial +# Ray RLlib Tutorial -These tutorials show you how to use [RLlib](https://docs.ray.io/en/latest/rllib/index.html) to train agents in PettingZoo environments. +These tutorials show you how to use [Ray](https://docs.ray.io/en/latest/index.html)'s [RLlib](https://docs.ray.io/en/latest/rllib/index.html) library to train agents in PettingZoo environments. * [PPO for Pistonball](/tutorials/rllib/pistonball/): _Train a PPO model in a parallel environment_ diff --git a/docs/tutorials/rllib/pistonball.md b/docs/tutorials/rllib/pistonball.md index 7e24cf991..53a5fcd31 100644 --- a/docs/tutorials/rllib/pistonball.md +++ b/docs/tutorials/rllib/pistonball.md @@ -17,7 +17,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLlib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/docs/tutorials/sb3/chess.md b/docs/tutorials/sb3/chess.md new file mode 100644 index 000000000..ce5fd496d --- /dev/null +++ b/docs/tutorials/sb3/chess.md @@ -0,0 +1,38 @@ +--- +title: "SB3: Action Masking for Chess (AEC)" +--- + +# SB3: Action Masking for Chess (AEC) + +This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Chess](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +It creates a custom Wrapper to convert to a Gymnasium-like environment which is compatible with SB3's action masking format. + +Note: This assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments. + +After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training the RL agent + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/sb3_chess_action_mask.py + :language: python +``` + +### Watching the trained RL agent play + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/render_sb3_chess_action_mask.py + :language: python +``` diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index dfd2ebb8f..3c9585a9b 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -2,24 +2,36 @@ title: "Stable-Baselines3" --- -# SB3 Tutorial +# Stable-Baselines3 Tutorial -These tutorials show you how to use [SB3](https://stable-baselines3.readthedocs.io/en/master/) to train agents in PettingZoo environments. +These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments. -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a parallel environment_ +* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in a vectorized AEC environment_ -* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in an AEC environment_ +* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a vectorized Parallel environment_ +* [Action Masking for Chess](/tutorials/sb3/chess/): _Train an action masked PPO model in an AEC environment_ -```{figure} https://docs.ray.io/en/latest/_images/rllib-stack.svg - :alt: RLlib stack + +## Stable-Baselines Overview + +[Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) is a library providing reliable implementations of reinforcement learning algorithms in [PyTorch](https://pytorch.org/). It provides a clean and simple interface, giving you access to off-the-shelf state-of-the-art model-free RL algorithms. It allows training of RL agents with only a few lines of code. + +For more information, see the [Stable-Baselines3 v1.0 Blog Post](https://araffin.github.io/post/sb3/) + +Note: SB3 does not officially support PettingZoo, as it is designed for single-agent RL. These tutorials demonstrate how to adapt SB3 to work in multi-agent settings, but we cannot guarantee training convergence. + + +```{figure} https://raw.githubusercontent.com/DLR-RM/stable-baselines3/master/docs/_static/img/logo.png + :alt: SB3 Logo :width: 80% ``` ```{toctree} :hidden: -:caption: RLlib +:caption: SB3 pistonball -holdem +rps +chess ``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md index 8e86e13e2..3e83f8087 100644 --- a/docs/tutorials/sb3/pistonball.md +++ b/docs/tutorials/sb3/pistonball.md @@ -2,7 +2,7 @@ title: "SB3: PPO for Pistonball (Parallel)" --- -# RLlib: PPO for Pistonball +# SB3: PPO for Pistonball This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). @@ -17,7 +17,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/docs/tutorials/sb3/rps.md b/docs/tutorials/sb3/rps.md index fa70d3c55..c5d1bd2ed 100644 --- a/docs/tutorials/sb3/rps.md +++ b/docs/tutorials/sb3/rps.md @@ -2,9 +2,11 @@ title: "SB3: PPO for Rock-Paper-Scissors (AEC)" --- -# RLlib: PPO for Rock-Paper-Scissors +# SB3: PPO for Rock-Paper-Scissors -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Rock-Paper-Scissors](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +It converts the environment into a Parallel environment and uses SuperSuit to create vectorized environments, leveraging multithreading to speed up training. After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. @@ -17,7 +19,7 @@ To follow this tutorial, you will need to install the dependencies shown below. ``` ## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). ### Training the RL agent diff --git a/pettingzoo/classic/chess/chess.py b/pettingzoo/classic/chess/chess.py index 378eef0b4..5100f8fc3 100644 --- a/pettingzoo/classic/chess/chess.py +++ b/pettingzoo/classic/chess/chess.py @@ -268,6 +268,9 @@ def step(self, action): current_agent = self.agent_selection current_index = self.agents.index(current_agent) + # Cast action into int + action = int(action) + chosen_move = chess_utils.action_to_move(self.board, action, current_index) assert chosen_move in self.board.legal_moves self.board.push(chosen_move) diff --git a/tutorials/Ray/render_rllib_leduc_holdem.py b/tutorials/Ray/render_rllib_leduc_holdem.py index c622d1da8..b514872ff 100644 --- a/tutorials/Ray/render_rllib_leduc_holdem.py +++ b/tutorials/Ray/render_rllib_leduc_holdem.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to view trained agents playing Leduoc Holdem. +"""Uses Ray's RLlib to view trained agents playing Leduoc Holdem. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/render_rllib_pistonball.py b/tutorials/Ray/render_rllib_pistonball.py index 5469f4943..a15edd3ea 100644 --- a/tutorials/Ray/render_rllib_pistonball.py +++ b/tutorials/Ray/render_rllib_pistonball.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to view trained agents playing Pistonball. +"""Uses Ray's RLlib to view trained agents playing Pistonball. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/rllib_leduc_holdem.py b/tutorials/Ray/rllib_leduc_holdem.py index e46da3f49..9f4ee5c4e 100644 --- a/tutorials/Ray/rllib_leduc_holdem.py +++ b/tutorials/Ray/rllib_leduc_holdem.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to train agents to play Leduc Holdem. +"""Uses Ray's RLlib to train agents to play Leduc Holdem. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/Ray/rllib_pistonball.py b/tutorials/Ray/rllib_pistonball.py index f3bc4d713..92fbd1cd9 100644 --- a/tutorials/Ray/rllib_pistonball.py +++ b/tutorials/Ray/rllib_pistonball.py @@ -1,4 +1,4 @@ -"""Uses Ray's RLLib to train agents to play Pistonball. +"""Uses Ray's RLlib to train agents to play Pistonball. Author: Rohan (https://github.com/Rohan138) """ diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py new file mode 100644 index 000000000..343cbf1cd --- /dev/null +++ b/tutorials/SB3/render_sb3_chess_action_mask.py @@ -0,0 +1,36 @@ +import glob +import os + +from sb3_contrib import MaskablePPO + +from pettingzoo.classic import chess_v6 + + +def watch_action_mask(env_fn): + # Watch a game between two trained agents + env = env_fn.env(render_mode="human") + env.reset() + + latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + model = MaskablePPO.load(latest_policy) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + # Separate observation and action mask + observation, action_mask = obs.values() + + if termination or truncation: + act = None + else: + # Note that use of masks is manual and optional outside of learning, + # so masking can be "removed" at testing time + act = int( + model.predict(observation, action_masks=action_mask)[0] + ) # PettingZoo expects integer actions + env.step(act) + env.close() + + +if __name__ == "__main__": + watch_action_mask(chess_v6) diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py index 794cf027a..93cfd8f82 100644 --- a/tutorials/SB3/render_sb3_pistonball.py +++ b/tutorials/SB3/render_sb3_pistonball.py @@ -2,7 +2,7 @@ Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +Author: Elliot (https://github.com/elliottower) """ import glob import os @@ -12,21 +12,23 @@ from pettingzoo.butterfly import pistonball_v6 -env = pistonball_v6.env(render_mode="human") - -env = ss.color_reduction_v0(env, mode="B") -env = ss.resize_v1(env, x_size=84, y_size=84) -env = ss.frame_stack_v1(env, 3) - -latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) -model = PPO.load(latest_policy) - -env.reset() -for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - act = ( - model.predict(obs, deterministic=True)[0] - if not termination or truncation - else None - ) - env.step(act) +if __name__ == "__main__": + # Watch a game between two trained agents + env = pistonball_v6.env(render_mode="human") + + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + model = PPO.load(latest_policy) + + env.reset() + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + act = ( + model.predict(obs, deterministic=True)[0] + if not termination or truncation + else None + ) + env.step(act) diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py index c07c15567..c4cdbd34d 100644 --- a/tutorials/SB3/render_sb3_rps.py +++ b/tutorials/SB3/render_sb3_rps.py @@ -2,7 +2,7 @@ Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +Author: Elliot (https://github.com/elliottower) """ import glob @@ -12,16 +12,17 @@ from pettingzoo.classic import rps_v2 -env = rps_v2.env(render_mode="human") +if __name__ == "__main__": + env = rps_v2.env(render_mode="human") -latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) -model = PPO.load(latest_policy) + latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + model = PPO.load(latest_policy) -env.reset() -for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - if termination or truncation: - act = None - else: - act = model.predict(obs, deterministic=True)[0] - env.step(act) + env.reset() + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + if termination or truncation: + act = None + else: + act = model.predict(obs, deterministic=True)[0] + env.step(act) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index db26f6ece..2ca60adc3 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,3 +1,4 @@ stable-baselines3>=2.0.0 pettingzoo>=1.23.1 supersuit>=3.8.1 +sb3-contrib>=2.0.0 diff --git a/tutorials/SB3/sb3_chess_action_mask.py b/tutorials/SB3/sb3_chess_action_mask.py new file mode 100644 index 000000000..4bd47eed7 --- /dev/null +++ b/tutorials/SB3/sb3_chess_action_mask.py @@ -0,0 +1,86 @@ +"""Uses Stable-Baselines3 to train agents to play Connect Four using invalid action masking. + +For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking +For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html + +Author: Elliot (https://github.com/elliottower) +""" +import time + +from sb3_contrib import MaskablePPO +from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy +from sb3_contrib.common.wrappers import ActionMasker + +import pettingzoo.utils +from pettingzoo.classic import chess_v6 + + +class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper): + """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking.""" + + def reset(self, seed=None, options=None): + """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent. + + This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions + """ + super().reset(seed, options) + + # Strip the action mask out from the observation space + self.observation_space = super().observation_space(self.possible_agents[0])[ + "observation" + ] + self.action_space = super().action_space(self.possible_agents[0]) + + # Return initial observation, info (PettingZoo AEC envs do not by default) + return self.observe(self.agent_selection), {} + + def step(self, action): + """Gymnasium-like step function, returning observation, reward, termination, truncation, info.""" + super().step(action) + return super().last() + + def observe(self, agent): + """Return only raw observation, removing action mask.""" + return super().observe(agent)["observation"] + + def action_mask(self): + """Separate function used in order to access the action mask.""" + return super().observe(self.agent_selection)["action_mask"] + + +def mask_fn(env): + # Do whatever you'd like in this function to return the action mask + # for the current env. In this example, we assume the env has a + # helpful method we can rely on. + return env.action_mask() + + +def train_action_mask(env_fn, steps=10_000): + """Train a single agent to play both sides in a PettingZoo environment using invalid action masking.""" + env = env_fn.env() + + print(f"Starting training on {str(env.metadata['name'])}.") + + # Custom wrapper to convert PettingZoo envs to work with SB3 action masking + env = SB3ActionMaskWrapper(env) + + env.reset() # Must call reset() in order to re-define the spaces + + env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function) + # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped + # with ActionMasker. If the wrapper is detected, the masks are automatically + # retrieved and used when learning. Note that MaskablePPO does not accept + # a new action_mask_fn kwarg, as it did in an earlier draft. + model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1) + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + env.close() + + print(f"Finished training on {str(env.metadata['name'])}.\n") + + +if __name__ == "__main__": + train_action_mask(chess_v6, steps=20_000) diff --git a/tutorials/SB3/sb3_pistonball.py b/tutorials/SB3/sb3_pistonball_vector.py similarity index 81% rename from tutorials/SB3/sb3_pistonball.py rename to tutorials/SB3/sb3_pistonball_vector.py index d645468f5..39470dbc0 100644 --- a/tutorials/SB3/sb3_pistonball.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -1,8 +1,10 @@ -"""Uses Stable-Baselines3 to train agents to play Pistonball. +"""Uses Stable-Baselines3 to train agents to play Pistonball using SuperSuit vector envs. Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) """ import time @@ -13,6 +15,7 @@ from pettingzoo.butterfly import pistonball_v6 if __name__ == "__main__": + # Train a single agent to play both sides in a PettingZoo Pistonball environment env = pistonball_v6.parallel_env( n_pistons=20, time_penalty=-0.1, diff --git a/tutorials/SB3/sb3_rps.py b/tutorials/SB3/sb3_rps_vector.py similarity index 80% rename from tutorials/SB3/sb3_rps.py rename to tutorials/SB3/sb3_rps_vector.py index 410951cb2..eae5b8dcd 100644 --- a/tutorials/SB3/sb3_rps.py +++ b/tutorials/SB3/sb3_rps_vector.py @@ -1,8 +1,8 @@ -"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors. +"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors using SuperSuit vector envs. -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html -Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +Author: Elliot (https://github.com/elliottower) """ import time @@ -43,6 +43,3 @@ print("Model has been saved.") env.close() - import sys - - sys.exit(0) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py new file mode 100644 index 000000000..c427ab70c --- /dev/null +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -0,0 +1,34 @@ +"""Test file to ensure that action masking code works for all PettingZoo classic environments (except rps, which has no action mask).""" + +import pytest +from tutorials.SB3.render_sb3_chess_action_mask import watch_action_mask +from tutorials.SB3.sb3_chess_action_mask import train_action_mask + +from pettingzoo.classic import ( + chess_v6, + connect_four_v3, + gin_rummy_v4, + hanabi_v4, + leduc_holdem_v4, + texas_holdem_no_limit_v6, + texas_holdem_v4, + tictactoe_v3, +) + +WORKING_ENVS = [ + tictactoe_v3, + connect_four_v3, + chess_v6, + leduc_holdem_v4, + gin_rummy_v4, + hanabi_v4, + # texas holdem likely broken, game ends instantly, but with random actions it works fine + texas_holdem_no_limit_v6, + texas_holdem_v4, +] + + +@pytest.mark.parametrize("env_fn", WORKING_ENVS) +def test_action_mask(env_fn): + train_action_mask(env_fn, steps=4096) + watch_action_mask(env_fn) From d0784beedab97c31866742f9251abcb19d0d7112 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 19:57:17 -0400 Subject: [PATCH 04/38] Add try catch for test sb3 action mask (pytest -v shouldn't require sb3) --- tutorials/SB3/test_sb3_action_mask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index c427ab70c..745150399 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -1,8 +1,11 @@ """Test file to ensure that action masking code works for all PettingZoo classic environments (except rps, which has no action mask).""" -import pytest -from tutorials.SB3.render_sb3_chess_action_mask import watch_action_mask -from tutorials.SB3.sb3_chess_action_mask import train_action_mask +try: + import pytest + from tutorials.SB3.render_sb3_chess_action_mask import watch_action_mask + from tutorials.SB3.sb3_chess_action_mask import train_action_mask +except ModuleNotFoundError: + pass from pettingzoo.classic import ( chess_v6, From 87c46ef3c5bd9a89894809c5226d721aef36c210 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:07:09 -0400 Subject: [PATCH 05/38] Clean up documentation --- docs/tutorials/sb3/chess.md | 4 ++-- docs/tutorials/sb3/index.md | 6 +++--- docs/tutorials/sb3/rps.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/sb3/chess.md b/docs/tutorials/sb3/chess.md index ce5fd496d..574e4d8f3 100644 --- a/docs/tutorials/sb3/chess.md +++ b/docs/tutorials/sb3/chess.md @@ -1,8 +1,8 @@ --- -title: "SB3: Action Masking for Chess (AEC)" +title: "SB3: Action Masked PPO for Chess" --- -# SB3: Action Masking for Chess (AEC) +# SB3: Action Masked PPO for Chess This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Chess](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index 3c9585a9b..f400ef5cc 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -6,11 +6,11 @@ title: "Stable-Baselines3" These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments. -* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in a vectorized AEC environment_ +* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in vectorized Parallel environments_ -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a vectorized Parallel environment_ +* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in vectorized AEC environments_ -* [Action Masking for Chess](/tutorials/sb3/chess/): _Train an action masked PPO model in an AEC environment_ +* [Action Masked PPO for Chess](/tutorials/sb3/chess/): _Train an action masked PPO model in an AEC environment_ ## Stable-Baselines Overview diff --git a/docs/tutorials/sb3/rps.md b/docs/tutorials/sb3/rps.md index c5d1bd2ed..7f585d022 100644 --- a/docs/tutorials/sb3/rps.md +++ b/docs/tutorials/sb3/rps.md @@ -1,5 +1,5 @@ --- -title: "SB3: PPO for Rock-Paper-Scissors (AEC)" +title: "SB3: PPO for Rock-Paper-Scissors" --- # SB3: PPO for Rock-Paper-Scissors From baf59fa11a4ff09d441e24a16dd58d398f93d765 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:08:24 -0400 Subject: [PATCH 06/38] Fix requirements.txt to specify pettingzoo[classic] --- tutorials/SB3/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index 2ca60adc3..f4edfff8e 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,4 +1,4 @@ stable-baselines3>=2.0.0 -pettingzoo>=1.23.1 +pettingzoo[classic]>=1.23.1 supersuit>=3.8.1 sb3-contrib>=2.0.0 From db8331a5124219fc1b7cda8511187bfc29b46b61 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:21:08 -0400 Subject: [PATCH 07/38] Add try catch for render action mask --- tutorials/SB3/render_sb3_chess_action_mask.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py index 343cbf1cd..916b8a2b2 100644 --- a/tutorials/SB3/render_sb3_chess_action_mask.py +++ b/tutorials/SB3/render_sb3_chess_action_mask.py @@ -11,7 +11,20 @@ def watch_action_mask(env_fn): env = env_fn.env(render_mode="human") env.reset() - latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + # If training script has not been run, run it now + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + from tutorials.SB3.sb3_chess_action_mask import train_action_mask + + train_action_mask(env_fn) + + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + model = MaskablePPO.load(latest_policy) for agent in env.agent_iter(): @@ -23,11 +36,8 @@ def watch_action_mask(env_fn): if termination or truncation: act = None else: - # Note that use of masks is manual and optional outside of learning, - # so masking can be "removed" at testing time - act = int( - model.predict(observation, action_masks=action_mask)[0] - ) # PettingZoo expects integer actions + # Note: PettingZoo expects integer actions + act = int(model.predict(observation, action_masks=action_mask)[0]) env.step(act) env.close() From 05b3dcc1365ae0746863a2f706a47716a63d60b3 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:22:35 -0400 Subject: [PATCH 08/38] Add try catch for render action mask --- tutorials/SB3/render_sb3_chess_action_mask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py index 916b8a2b2..174f27d21 100644 --- a/tutorials/SB3/render_sb3_chess_action_mask.py +++ b/tutorials/SB3/render_sb3_chess_action_mask.py @@ -17,6 +17,8 @@ def watch_action_mask(env_fn): glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime ) except ValueError: + print("Policy not found. Running training to generate new policy.") + from tutorials.SB3.sb3_chess_action_mask import train_action_mask train_action_mask(env_fn) From ecb96bf1ab389c5321cd354cc60d099dd1ffd548 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:25:08 -0400 Subject: [PATCH 09/38] Add try catch for other render files --- tutorials/SB3/render_sb3_pistonball.py | 9 ++++++++- tutorials/SB3/render_sb3_rps.py | 9 ++++++++- tutorials/SB3/sb3_rps_vector.py | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py index 93cfd8f82..de879a856 100644 --- a/tutorials/SB3/render_sb3_pistonball.py +++ b/tutorials/SB3/render_sb3_pistonball.py @@ -20,7 +20,14 @@ env = ss.resize_v1(env, x_size=84, y_size=84) env = ss.frame_stack_v1(env, 3) - latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + raise UserWarning("Policy not found.") + pass + model = PPO.load(latest_policy) env.reset() diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py index c4cdbd34d..b48650933 100644 --- a/tutorials/SB3/render_sb3_rps.py +++ b/tutorials/SB3/render_sb3_rps.py @@ -15,7 +15,14 @@ if __name__ == "__main__": env = rps_v2.env(render_mode="human") - latest_policy = max(glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime) + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + raise UserWarning("Policy not found.") + pass + model = PPO.load(latest_policy) env.reset() diff --git a/tutorials/SB3/sb3_rps_vector.py b/tutorials/SB3/sb3_rps_vector.py index eae5b8dcd..bc4392007 100644 --- a/tutorials/SB3/sb3_rps_vector.py +++ b/tutorials/SB3/sb3_rps_vector.py @@ -20,7 +20,7 @@ env = ss.pettingzoo_env_to_vec_env_v1(env) env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") - # TODO: find hyperparameters that make the model actually learn + # TODO: test different hyperparameters model = PPO( MlpPolicy, env, From d887db27985796fbedb2597ed8a87bed5e129662 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:31:14 -0400 Subject: [PATCH 10/38] Fix code which doesn't work due to modules (tutorials not included) --- tutorials/SB3/render_sb3_chess_action_mask.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py index 174f27d21..4150b36a2 100644 --- a/tutorials/SB3/render_sb3_chess_action_mask.py +++ b/tutorials/SB3/render_sb3_chess_action_mask.py @@ -17,15 +17,7 @@ def watch_action_mask(env_fn): glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime ) except ValueError: - print("Policy not found. Running training to generate new policy.") - - from tutorials.SB3.sb3_chess_action_mask import train_action_mask - - train_action_mask(env_fn) - - latest_policy = max( - glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime - ) + raise UserWarning("Policy not found.") model = MaskablePPO.load(latest_policy) From 085ed0a344f758105f3794e89a38b76b36dcf420 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:43:02 -0400 Subject: [PATCH 11/38] Switch userwarnings to print statements and exit (so it doesn't fail) --- tutorials/SB3/render_sb3_chess_action_mask.py | 3 ++- tutorials/SB3/render_sb3_pistonball.py | 4 ++-- tutorials/SB3/render_sb3_rps.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py index 4150b36a2..800b79261 100644 --- a/tutorials/SB3/render_sb3_chess_action_mask.py +++ b/tutorials/SB3/render_sb3_chess_action_mask.py @@ -17,7 +17,8 @@ def watch_action_mask(env_fn): glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime ) except ValueError: - raise UserWarning("Policy not found.") + print("Policy not found.") + exit(0) model = MaskablePPO.load(latest_policy) diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py index de879a856..e42a375a1 100644 --- a/tutorials/SB3/render_sb3_pistonball.py +++ b/tutorials/SB3/render_sb3_pistonball.py @@ -25,8 +25,8 @@ glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime ) except ValueError: - raise UserWarning("Policy not found.") - pass + print("Policy not found.") + exit(0) model = PPO.load(latest_policy) diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py index b48650933..20d280aa8 100644 --- a/tutorials/SB3/render_sb3_rps.py +++ b/tutorials/SB3/render_sb3_rps.py @@ -20,8 +20,8 @@ glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime ) except ValueError: - raise UserWarning("Policy not found.") - pass + print("Policy not found.") + exit(0) model = PPO.load(latest_policy) From 18eca559bc90c2081f64a014ffa9efecca0004e8 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:46:28 -0400 Subject: [PATCH 12/38] Add butterfly requirement to sb3 tutorial --- tutorials/SB3/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index f4edfff8e..9010c7b29 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,4 +1,4 @@ stable-baselines3>=2.0.0 -pettingzoo[classic]>=1.23.1 +pettingzoo[classic,butterfly]>=1.23.1 supersuit>=3.8.1 sb3-contrib>=2.0.0 From 429cbd82ee51ebfca0ad09a76201beafbf97bc41 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 20:58:17 -0400 Subject: [PATCH 13/38] Switch default timesteps to be more reasonable (10,000) --- tutorials/SB3/sb3_chess_action_mask.py | 2 +- tutorials/SB3/sb3_pistonball_vector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/SB3/sb3_chess_action_mask.py b/tutorials/SB3/sb3_chess_action_mask.py index 4bd47eed7..3c99d41de 100644 --- a/tutorials/SB3/sb3_chess_action_mask.py +++ b/tutorials/SB3/sb3_chess_action_mask.py @@ -83,4 +83,4 @@ def train_action_mask(env_fn, steps=10_000): if __name__ == "__main__": - train_action_mask(chess_v6, steps=20_000) + train_action_mask(chess_v6, steps=10_000) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index 39470dbc0..296f83e55 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -51,7 +51,7 @@ batch_size=256, ) - model.learn(total_timesteps=2_000_000) + model.learn(total_timesteps=10_000) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") From c9f0024c8145ca2384615ea98d12ff641694a412 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 21:05:46 -0400 Subject: [PATCH 14/38] Switch default timesteps to be lower (2048), just so CI runs faster --- tutorials/SB3/sb3_chess_action_mask.py | 2 +- tutorials/SB3/sb3_pistonball_vector.py | 2 +- tutorials/SB3/sb3_rps_vector.py | 2 +- tutorials/SB3/test_sb3_action_mask.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorials/SB3/sb3_chess_action_mask.py b/tutorials/SB3/sb3_chess_action_mask.py index 3c99d41de..f3ab46c9f 100644 --- a/tutorials/SB3/sb3_chess_action_mask.py +++ b/tutorials/SB3/sb3_chess_action_mask.py @@ -83,4 +83,4 @@ def train_action_mask(env_fn, steps=10_000): if __name__ == "__main__": - train_action_mask(chess_v6, steps=10_000) + train_action_mask(chess_v6, steps=2048) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index 296f83e55..c76a00a76 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -51,7 +51,7 @@ batch_size=256, ) - model.learn(total_timesteps=10_000) + model.learn(total_timesteps=2048) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") diff --git a/tutorials/SB3/sb3_rps_vector.py b/tutorials/SB3/sb3_rps_vector.py index bc4392007..bf936af1a 100644 --- a/tutorials/SB3/sb3_rps_vector.py +++ b/tutorials/SB3/sb3_rps_vector.py @@ -37,7 +37,7 @@ batch_size=256, ) - model.learn(total_timesteps=10_000) + model.learn(total_timesteps=2048) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 745150399..05c00a9f1 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -33,5 +33,5 @@ @pytest.mark.parametrize("env_fn", WORKING_ENVS) def test_action_mask(env_fn): - train_action_mask(env_fn, steps=4096) + train_action_mask(env_fn, steps=2048) watch_action_mask(env_fn) From a64022d547b38d0a76e7fe5665ea63bf748ef29c Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 21:14:40 -0400 Subject: [PATCH 15/38] Switch num cpus to 2 by default (github ations only get 2 cores) --- tutorials/SB3/sb3_pistonball_vector.py | 2 +- tutorials/SB3/sb3_rps_vector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index c76a00a76..c94c20f62 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -33,7 +33,7 @@ env = ss.frame_stack_v1(env, 3) env = ss.pettingzoo_env_to_vec_env_v1(env) - env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") model = PPO( CnnPolicy, diff --git a/tutorials/SB3/sb3_rps_vector.py b/tutorials/SB3/sb3_rps_vector.py index bf936af1a..96fe3ce9d 100644 --- a/tutorials/SB3/sb3_rps_vector.py +++ b/tutorials/SB3/sb3_rps_vector.py @@ -18,7 +18,7 @@ env = turn_based_aec_to_parallel(env) env = ss.pettingzoo_env_to_vec_env_v1(env) - env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") # TODO: test different hyperparameters model = PPO( From ee083173299dbe6e8810534af931aaeb90133d9d Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 7 Jul 2023 21:46:23 -0400 Subject: [PATCH 16/38] Fix print statements logic --- tutorials/SB3/sb3_chess_action_mask.py | 5 +++-- tutorials/SB3/sb3_pistonball_vector.py | 7 ++++++- tutorials/SB3/sb3_rps_vector.py | 5 +++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tutorials/SB3/sb3_chess_action_mask.py b/tutorials/SB3/sb3_chess_action_mask.py index f3ab46c9f..2f754adf3 100644 --- a/tutorials/SB3/sb3_chess_action_mask.py +++ b/tutorials/SB3/sb3_chess_action_mask.py @@ -77,9 +77,10 @@ def train_action_mask(env_fn, steps=10_000): model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") print("Model has been saved.") - env.close() - print(f"Finished training on {str(env.metadata['name'])}.\n") + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") + + env.close() if __name__ == "__main__": diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index c94c20f62..7c2797119 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -28,6 +28,8 @@ max_cycles=125, ) + print(f"Starting training on {str(env.metadata['name'])}.") + env = ss.color_reduction_v0(env, mode="B") env = ss.resize_v1(env, x_size=84, y_size=84) env = ss.frame_stack_v1(env, 3) @@ -51,9 +53,12 @@ batch_size=256, ) - model.learn(total_timesteps=2048) + model.learn(total_timesteps=4096) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") + env.close() diff --git a/tutorials/SB3/sb3_rps_vector.py b/tutorials/SB3/sb3_rps_vector.py index 96fe3ce9d..0065a8c84 100644 --- a/tutorials/SB3/sb3_rps_vector.py +++ b/tutorials/SB3/sb3_rps_vector.py @@ -17,6 +17,8 @@ env = rps_v2.env() env = turn_based_aec_to_parallel(env) + print(f"Starting training on {str(env.metadata['name'])}.") + env = ss.pettingzoo_env_to_vec_env_v1(env) env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") @@ -42,4 +44,7 @@ model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") + env.close() From bd83f300626976ac4008b406271b914f3fd6a126 Mon Sep 17 00:00:00 2001 From: elliottower Date: Sun, 9 Jul 2023 19:04:37 -0400 Subject: [PATCH 17/38] Update tutorials to evaluate, add KAZ example, test hyperparameters --- .../sb3/{chess.md => connect_four.md} | 19 +- docs/tutorials/sb3/index.md | 10 +- docs/tutorials/sb3/kaz.md | 30 +++ docs/tutorials/sb3/pistonball.md | 13 +- docs/tutorials/sb3/rps.md | 36 ---- tutorials/SB3/render_sb3_chess_action_mask.py | 41 ----- ...ask.py => sb3_connect_four_action_mask.py} | 75 +++++++- tutorials/SB3/sb3_kaz_vector.py | 172 ++++++++++++++++++ tutorials/SB3/sb3_pistonball_vector.py | 111 +++++++++-- tutorials/SB3/sb3_rps_vector.py | 50 ----- tutorials/SB3/test_sb3_action_mask.py | 2 +- tutorials/Tianshou/2_training_agents.py | 4 +- tutorials/Tianshou/3_cli_and_logging.py | 4 +- 13 files changed, 383 insertions(+), 184 deletions(-) rename docs/tutorials/sb3/{chess.md => connect_four.md} (58%) create mode 100644 docs/tutorials/sb3/kaz.md delete mode 100644 docs/tutorials/sb3/rps.md delete mode 100644 tutorials/SB3/render_sb3_chess_action_mask.py rename tutorials/SB3/{sb3_chess_action_mask.py => sb3_connect_four_action_mask.py} (56%) create mode 100644 tutorials/SB3/sb3_kaz_vector.py delete mode 100644 tutorials/SB3/sb3_rps_vector.py diff --git a/docs/tutorials/sb3/chess.md b/docs/tutorials/sb3/connect_four.md similarity index 58% rename from docs/tutorials/sb3/chess.md rename to docs/tutorials/sb3/connect_four.md index 574e4d8f3..d6c674780 100644 --- a/docs/tutorials/sb3/chess.md +++ b/docs/tutorials/sb3/connect_four.md @@ -1,16 +1,16 @@ --- -title: "SB3: Action Masked PPO for Chess" +title: "SB3: Action Masked PPO for Connect Four" --- -# SB3: Action Masked PPO for Chess +# SB3: Action Masked PPO for Connect Four -This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Chess](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). +This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Connect Four](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). It creates a custom Wrapper to convert to a Gymnasium-like environment which is compatible with SB3's action masking format. Note: This assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments. -After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. +After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). ## Environment Setup @@ -23,16 +23,9 @@ To follow this tutorial, you will need to install the dependencies shown below. ## Code The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). -### Training the RL agent +### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_chess_action_mask.py - :language: python -``` - -### Watching the trained RL agent play - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/render_sb3_chess_action_mask.py +.. literalinclude:: ../../../tutorials/SB3/sb3_connect_four_action_mask.py :language: python ``` diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index f400ef5cc..7d8f84b84 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -6,11 +6,11 @@ title: "Stable-Baselines3" These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments. -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in vectorized Parallel environments_ +* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in vectorized Parallel environment_ -* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in vectorized AEC environments_ +* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train a PPO model in a vectorized AEC environment_ -* [Action Masked PPO for Chess](/tutorials/sb3/chess/): _Train an action masked PPO model in an AEC environment_ +* [Action Masked PPO for Chess](/tutorials/sb3/connect_four/): _Train an action masked PPO model in an AEC environment_ ## Stable-Baselines Overview @@ -32,6 +32,6 @@ Note: SB3 does not officially support PettingZoo, as it is designed for single-a :caption: SB3 pistonball -rps -chess +kaz +connect_four ``` diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md new file mode 100644 index 000000000..74d2965c1 --- /dev/null +++ b/docs/tutorials/sb3/kaz.md @@ -0,0 +1,30 @@ +--- +title: "SB3: PPO for Knights-Archers-Zombies" +--- + +# SB3: PPO for Knights-Archers-Zombies + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Knights-Archers-Zombies](https://pettingzoo.farama.org/environments/butterfly/knights_archers_zombies/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +It converts the environment into a Parallel environment and uses SuperSuit to create vectorized environments, leveraging multithreading to speed up training. + +After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + +Note: because this environment allows agents to spawn and die, it requires using SuperSuit's [Black Death](https://pettingzoo.farama.org/api/wrappers/supersuit_wrappers/#black_death_v2) wrapper, which provides blank observations to dead agents, rather than removing them from the environment. + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training and Evaluation + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/sb3_kaz_vector.py + :language: python +``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md index 3e83f8087..b102b8874 100644 --- a/docs/tutorials/sb3/pistonball.md +++ b/docs/tutorials/sb3/pistonball.md @@ -6,7 +6,7 @@ title: "SB3: PPO for Pistonball (Parallel)" This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). -After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. +After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). ## Environment Setup @@ -19,16 +19,9 @@ To follow this tutorial, you will need to install the dependencies shown below. ## Code The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). -### Training the RL agent +### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball.py - :language: python -``` - -### Watching the trained RL agent play - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/render_sb3_pistonball.py +.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball_vector.py :language: python ``` diff --git a/docs/tutorials/sb3/rps.md b/docs/tutorials/sb3/rps.md deleted file mode 100644 index 7f585d022..000000000 --- a/docs/tutorials/sb3/rps.md +++ /dev/null @@ -1,36 +0,0 @@ ---- -title: "SB3: PPO for Rock-Paper-Scissors" ---- - -# SB3: PPO for Rock-Paper-Scissors - -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Rock-Paper-Scissors](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). - -It converts the environment into a Parallel environment and uses SuperSuit to create vectorized environments, leveraging multithreading to speed up training. - -After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. - - -## Environment Setup -To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt - :language: text -``` - -## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). - -### Training the RL agent - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_rps.py - :language: python -``` - -### Watching the trained RL agent play - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/render_sb3_rps.py - :language: python -``` diff --git a/tutorials/SB3/render_sb3_chess_action_mask.py b/tutorials/SB3/render_sb3_chess_action_mask.py deleted file mode 100644 index 800b79261..000000000 --- a/tutorials/SB3/render_sb3_chess_action_mask.py +++ /dev/null @@ -1,41 +0,0 @@ -import glob -import os - -from sb3_contrib import MaskablePPO - -from pettingzoo.classic import chess_v6 - - -def watch_action_mask(env_fn): - # Watch a game between two trained agents - env = env_fn.env(render_mode="human") - env.reset() - - # If training script has not been run, run it now - try: - latest_policy = max( - glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime - ) - except ValueError: - print("Policy not found.") - exit(0) - - model = MaskablePPO.load(latest_policy) - - for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - - # Separate observation and action mask - observation, action_mask = obs.values() - - if termination or truncation: - act = None - else: - # Note: PettingZoo expects integer actions - act = int(model.predict(observation, action_masks=action_mask)[0]) - env.step(act) - env.close() - - -if __name__ == "__main__": - watch_action_mask(chess_v6) diff --git a/tutorials/SB3/sb3_chess_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py similarity index 56% rename from tutorials/SB3/sb3_chess_action_mask.py rename to tutorials/SB3/sb3_connect_four_action_mask.py index 2f754adf3..da015731c 100644 --- a/tutorials/SB3/sb3_chess_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -5,6 +5,8 @@ Author: Elliot (https://github.com/elliottower) """ +import glob +import os import time from sb3_contrib import MaskablePPO @@ -12,7 +14,7 @@ from sb3_contrib.common.wrappers import ActionMasker import pettingzoo.utils -from pettingzoo.classic import chess_v6 +from pettingzoo.classic import connect_four_v3 class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper): @@ -55,16 +57,16 @@ def mask_fn(env): return env.action_mask() -def train_action_mask(env_fn, steps=10_000): +def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs): """Train a single agent to play both sides in a PettingZoo environment using invalid action masking.""" - env = env_fn.env() + env = env_fn.env(**env_kwargs) print(f"Starting training on {str(env.metadata['name'])}.") # Custom wrapper to convert PettingZoo envs to work with SB3 action masking env = SB3ActionMaskWrapper(env) - env.reset() # Must call reset() in order to re-define the spaces + env.reset(seed=seed) # Must call reset() in order to re-define the spaces env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function) # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped @@ -72,6 +74,7 @@ def train_action_mask(env_fn, steps=10_000): # retrieved and used when learning. Note that MaskablePPO does not accept # a new action_mask_fn kwarg, as it did in an earlier draft. model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1) + model.set_random_seed(seed) model.learn(total_timesteps=steps) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") @@ -83,5 +86,67 @@ def train_action_mask(env_fn, steps=10_000): env.close() +def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + print( + f"Starting evaluation on {str(env.metadata['name'])} vs a random agent." + f"Trained agent will play as {env.possible_agents[1]}" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = MaskablePPO.load(latest_policy) + + scores = {agent: 0 for agent in env.possible_agents} + for i in range(num_games): + env.reset(seed=i) + env.action_space(env.possible_agents[0]).seed(i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + # Separate observation and action mask + observation, action_mask = obs.values() + + if termination or truncation: + scores[agent] += reward # winning agent gets +1 + break + else: + if agent == env.possible_agents[0]: + act = env.action_space(agent).sample(action_mask) + else: + # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int? + act = int( + model.predict( + observation, action_masks=action_mask, deterministic=True + )[0] + ) + env.step(act) + env.close() + + winrate = scores[env.possible_agents[1]] / sum(scores.values()) + print("Winrate: ", winrate) + print("Final scores: ", scores) + return winrate + + if __name__ == "__main__": - train_action_mask(chess_v6, steps=2048) + env_fn = connect_four_v3 + env_kwargs = {} + + # Train a model against itself (takes ~20 seconds on a laptop CPU) + train_action_mask(env_fn, steps=20480, seed=0, **env_kwargs) + + # Evaluate 100 games against a random agent (winrate should be ~80%) + eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) + + # Watch two games vs a random agent + eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/sb3_kaz_vector.py b/tutorials/SB3/sb3_kaz_vector.py new file mode 100644 index 000000000..42ef8fb4f --- /dev/null +++ b/tutorials/SB3/sb3_kaz_vector.py @@ -0,0 +1,172 @@ +"""Uses Stable-Baselines3 to train agents to play Knights-Archers-Zombies using SuperSuit vector envs. + +This environment requires using SuperSuit's Black Death wrapper, to handle agent death. + +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) +""" +from __future__ import annotations + +import glob +import os +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy + +from pettingzoo.butterfly import knights_archers_zombies_v10 + + +def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs): + # Train a single agent to play both sides in an AEC environment + env = env_fn.parallel_env(**env_kwargs) + + env = ss.black_death_v3(env) + + # Convert into a Parallel environment in order to vectorize it (SuperSuit does not currently support vectorized AEC envs) + # env = turn_based_aec_to_parallel(env) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + # env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + # Add black death wrapper so the number of agents stays constant + # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True + + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + + # TODO: test different hyperparameters + model = PPO( + MlpPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, + ) + + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") + + env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # TODO: figure out why Parallel performs differently at test time (my guess is maybe the way it counts num_cycles is different?) + # # It seems to make the rewards worse, the same policy scores 2/3 points per archer vs 6/7 with AEC. n + + # from pettingzoo.utils.wrappers import RecordEpisodeStatistics + # + # env = env_fn.parallel_env(render_mode=render_mode, **env_kwargs) + # + # # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + # env = ss.resize_v1(env, x_size=84, y_size=84) + # env = ss.frame_stack_v1(env, 3) + # env = RecordEpisodeStatistics(env) + # + # stats = [] + # for i in range(num_games): + # observations, infos = env.reset(seed=i) + # done = False + # while not done: + # actions = {agent: model.predict(observations[agent], deterministic=True)[0] for agent in env.agents} + # obss, rews, terms, truncs, infos = env.step(actions) + # + # for agent in env.possible_agents: + # rewards[agent] += rews[agent] + # done = any(terms.values()) or any(truncs.values()) + # stats.append(infos["episode"]) + + # Note: we evaluate here using an AEC environments, to allow for easy A/B testing against random policies + # For example, we can see here that using a random agent for archer_0 results in less points than the trained agent + for i in range(num_games): + env.reset(seed=i) + env.action_space(env.possible_agents[0]).seed(i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + for agent in env.agents: + rewards[agent] += env.rewards[agent] + + if termination or truncation: + break + else: + if agent == env.possible_agents[0]: + act = env.action_space(agent).sample() + else: + act = model.predict(obs, deterministic=True)[0] + env.step(act) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + avg_reward_per_agent = { + agent: rewards[agent] / num_games for agent in env.possible_agents + } + print(f"Avg reward: {avg_reward}") + print("Avg reward per agent, per game: ", avg_reward_per_agent) + print("Full rewards: ", rewards) + return avg_reward + + +if __name__ == "__main__": + env_fn = knights_archers_zombies_v10 + + # TODO: test out more hyperparameter combos + # max_cycles 100, max zombies 4, 8192 * 10 works decently, but sometimes fails due to agents dying + # black death wrapper, max cycles 100, max zombies 4, 8192*10, seems to work well (13 points over 10 games) + # black death wrapper, max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points) + + env_kwargs = dict(max_cycles=100, max_zombies=4) + + # Train a model (takes ~5 minutes on a laptop CPU) + # train(env_fn, steps=8192*10, seed=0, **env_kwargs) + + # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index 7c2797119..bc3ad4da3 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -1,11 +1,13 @@ -"""Uses Stable-Baselines3 to train agents to play Pistonball using SuperSuit vector envs. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b +"""Uses Stable-Baselines3 to train agents to play PettingZoo Butterfly (cooprative) environments using SuperSuit vector envs. For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html Author: Elliot (https://github.com/elliottower) """ +from __future__ import annotations + +import glob +import os import time import supersuit as ss @@ -14,26 +16,22 @@ from pettingzoo.butterfly import pistonball_v6 -if __name__ == "__main__": - # Train a single agent to play both sides in a PettingZoo Pistonball environment - env = pistonball_v6.parallel_env( - n_pistons=20, - time_penalty=-0.1, - continuous=True, - random_drop=True, - random_rotate=True, - ball_mass=0.75, - ball_friction=0.3, - ball_elasticity=1.5, - max_cycles=125, - ) - print(f"Starting training on {str(env.metadata['name'])}.") +def train_butterfly_supersuit( + env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs +): + # Train a single agent to play both sides in a Parallel environment, + env = env_fn.parallel_env(**env_kwargs) + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) env = ss.color_reduction_v0(env, mode="B") env = ss.resize_v1(env, x_size=84, y_size=84) env = ss.frame_stack_v1(env, 3) + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + env = ss.pettingzoo_env_to_vec_env_v1(env) env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") @@ -53,12 +51,87 @@ batch_size=256, ) - model.learn(total_timesteps=4096) + model.learn(total_timesteps=steps) model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") print("Model has been saved.") - print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + # Pre-process using SuperSuit (color reduction, resizing and frame stacking) + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # Note: We train using the Parallel API but evaluate using the AEC API + # SB3 models are designed for single-agent settings, we get around this by using he same model for every agent + for i in range(num_games): + env.reset(seed=i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + if termination or truncation: + for agent in env.agents: + rewards[agent] += env.rewards[agent] + break + else: + act = model.predict(obs, deterministic=True)[0] + + env.step(act) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + print(f"Avg reward: {avg_reward}") + return avg_reward + + +if __name__ == "__main__": + env_fn = pistonball_v6 + env_kwargs = dict( + n_pistons=20, + time_penalty=-0.1, + continuous=True, + random_drop=True, + random_rotate=True, + ball_mass=0.75, + ball_friction=0.3, + ball_elasticity=1.5, + max_cycles=25, + ) + + # TODO: figure out why pistonball takes so long to train and seems to not save the model? + # 296 seconds for 81_920 steps (n_updates = 5) vs 46 seconds for 40_960 + + # Train a model (takes ~3 minutes on a laptop CPU) + train_butterfly_supersuit(env_fn, steps=40_960, seed=0, **env_kwargs) + + # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/sb3_rps_vector.py b/tutorials/SB3/sb3_rps_vector.py deleted file mode 100644 index 0065a8c84..000000000 --- a/tutorials/SB3/sb3_rps_vector.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors using SuperSuit vector envs. - -For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html - -Author: Elliot (https://github.com/elliottower) -""" -import time - -import supersuit as ss -from stable_baselines3 import PPO -from stable_baselines3.ppo import MlpPolicy - -from pettingzoo.classic import rps_v2 -from pettingzoo.utils import turn_based_aec_to_parallel - -if __name__ == "__main__": - env = rps_v2.env() - env = turn_based_aec_to_parallel(env) - - print(f"Starting training on {str(env.metadata['name'])}.") - - env = ss.pettingzoo_env_to_vec_env_v1(env) - env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") - - # TODO: test different hyperparameters - model = PPO( - MlpPolicy, - env, - verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, - batch_size=256, - ) - - model.learn(total_timesteps=2048) - - model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") - - print("Model has been saved.") - - print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n") - - env.close() diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 05c00a9f1..70687b03c 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -3,7 +3,7 @@ try: import pytest from tutorials.SB3.render_sb3_chess_action_mask import watch_action_mask - from tutorials.SB3.sb3_chess_action_mask import train_action_mask + from tutorials.SB3.sb3_connect_four_action_mask import train_action_mask except ModuleNotFoundError: pass diff --git a/tutorials/Tianshou/2_training_agents.py b/tutorials/Tianshou/2_training_agents.py index 3f4a1e8fd..30ae1b159 100644 --- a/tutorials/Tianshou/2_training_agents.py +++ b/tutorials/Tianshou/2_training_agents.py @@ -12,7 +12,7 @@ import os from typing import Optional, Tuple -import gymnasium as gym +import gymnasium import numpy as np import torch from tianshou.data import Collector, VectorReplayBuffer @@ -33,7 +33,7 @@ def _get_agents( env = _get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gym.spaces.Dict) + if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) if agent_learn is None: diff --git a/tutorials/Tianshou/3_cli_and_logging.py b/tutorials/Tianshou/3_cli_and_logging.py index e78774dd3..b11abe574 100644 --- a/tutorials/Tianshou/3_cli_and_logging.py +++ b/tutorials/Tianshou/3_cli_and_logging.py @@ -14,7 +14,7 @@ from copy import deepcopy from typing import Optional, Tuple -import gymnasium as gym +import gymnasium import numpy as np import torch from tianshou.data import Collector, VectorReplayBuffer @@ -105,7 +105,7 @@ def get_agents( env = get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gym.spaces.Dict) + if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) args.state_shape = observation_space.shape or observation_space.n From 0185af8419b4cc645089160c445b2b35f4417214 Mon Sep 17 00:00:00 2001 From: elliottower Date: Sun, 9 Jul 2023 21:09:24 -0400 Subject: [PATCH 18/38] Update code to check more in depth statistics like winrate and total rewards --- tutorials/SB3/sb3_connect_four_action_mask.py | 26 ++++-- tutorials/SB3/test_sb3_action_mask.py | 83 ++++++++++++++++--- 2 files changed, 93 insertions(+), 16 deletions(-) diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py index da015731c..40ea7442a 100644 --- a/tutorials/SB3/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -91,8 +91,7 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): env = env_fn.env(render_mode=render_mode, **env_kwargs) print( - f"Starting evaluation on {str(env.metadata['name'])} vs a random agent." - f"Trained agent will play as {env.possible_agents[1]}" + f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}." ) try: @@ -106,6 +105,9 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): model = MaskablePPO.load(latest_policy) scores = {agent: 0 for agent in env.possible_agents} + total_rewards = {agent: 0 for agent in env.possible_agents} + round_rewards = [] + for i in range(num_games): env.reset(seed=i) env.action_space(env.possible_agents[0]).seed(i) @@ -117,7 +119,15 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): observation, action_mask = obs.values() if termination or truncation: - scores[agent] += reward # winning agent gets +1 + # If there is a winner, keep track, otherwise don't change the scores (tie) + if env.rewards[env.possible_agents[0]] != env.rewards[env.possible_agents[1]]: + winner = max(env.rewards, key=env.rewards.get) + scores[winner] += env.rewards[winner] # only tracks the largest reward (winner of game) + # Also track negative and positive rewards (penalizes illegal moves) + for a in env.possible_agents: + total_rewards[a] += env.rewards[a] + # List of rewards by round, for reference + round_rewards.append(env.rewards) break else: if agent == env.possible_agents[0]: @@ -132,10 +142,16 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): env.step(act) env.close() - winrate = scores[env.possible_agents[1]] / sum(scores.values()) + # Avoid dividing by zero + if sum(scores.values()) == 0: + winrate = 0 + else: + winrate = scores[env.possible_agents[1]] / sum(scores.values()) + print("Rewards by round: ", round_rewards) + print("Total rewards (incl. negative rewards): ", total_rewards) print("Winrate: ", winrate) print("Final scores: ", scores) - return winrate + return round_rewards, total_rewards, winrate, scores if __name__ == "__main__": diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 70687b03c..f1e68ac61 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -2,8 +2,7 @@ try: import pytest - from tutorials.SB3.render_sb3_chess_action_mask import watch_action_mask - from tutorials.SB3.sb3_connect_four_action_mask import train_action_mask + from tutorials.SB3.sb3_connect_four_action_mask import train_action_mask, eval_action_mask except ModuleNotFoundError: pass @@ -11,6 +10,7 @@ chess_v6, connect_four_v3, gin_rummy_v4, + go_v5, hanabi_v4, leduc_holdem_v4, texas_holdem_no_limit_v6, @@ -18,20 +18,81 @@ tictactoe_v3, ) -WORKING_ENVS = [ - tictactoe_v3, +# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself + +# These environments do better than random even after the minimum number of timesteps +EASY_ENVS = [ connect_four_v3, - chess_v6, - leduc_holdem_v4, gin_rummy_v4, - hanabi_v4, # texas holdem likely broken, game ends instantly, but with random actions it works fine texas_holdem_no_limit_v6, texas_holdem_v4, ] +# More difficult environments which will likely take more training time +MEDIUM_ENVS = [ + leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores + hanabi_v4, # even with 10x as many steps, total score seems to always be tied between the two agents + tictactoe_v3, # even with 10x as many steps, agent still loses every time (most likely an error somewhere) +] + +# Most difficult environments to train agents for (and longest games +# TODO: test board_size to see if smaller go board is more easily solvable +HARD_ENVS = [ + chess_v6, # difficult to train because games take so long, 0.28 winrate even after 10x + go_v5, # difficult to train because games take so long, +] + + + + +@pytest.mark.parametrize("env_fn", EASY_ENVS) +def test_action_mask_easy(env_fn): + env_kwargs = {} + + # Train a model against itself + train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) + + assert total_rewards[env_fn.env().possible_agents[0]] > total_rewards[env_fn.env().possible_agents[1]], "Trained policy should outperform random actions" + + + # Watch two games + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) + +# @pytest.mark.skip(reason="training for these environments can be compute intensive") +@pytest.mark.parametrize("env_fn", MEDIUM_ENVS) +def test_action_mask_medium(env_fn): + env_kwargs = {} + + # Train a model against itself + train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) + + assert winrate < 0.75, "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe + + # Watch two games + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) + + +# @pytest.mark.skip(reason="training for these environments can be compute intensive") +@pytest.mark.parametrize("env_fn", HARD_ENVS) +def test_action_mask_hard(env_fn): + env_kwargs = {} + + # Train a model against itself + train_action_mask(env_fn, steps=2048 * 10, seed=0, **env_kwargs) + + # Evaluate 2 games against a random agent + round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, + **env_kwargs) + + assert winrate < 0.5, "Policy should not perform better than 50% winrate" # 28% for chess, 0% for go + -@pytest.mark.parametrize("env_fn", WORKING_ENVS) -def test_action_mask(env_fn): - train_action_mask(env_fn, steps=2048) - watch_action_mask(env_fn) + # Watch two games + # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) From c977d892e62e5f10d61a419afdf9882f614fa8f2 Mon Sep 17 00:00:00 2001 From: elliottower Date: Sun, 9 Jul 2023 21:10:12 -0400 Subject: [PATCH 19/38] Pre-commit --- tutorials/SB3/sb3_connect_four_action_mask.py | 9 ++++- tutorials/SB3/test_sb3_action_mask.py | 38 ++++++++++++------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py index 40ea7442a..e6cac2860 100644 --- a/tutorials/SB3/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -120,9 +120,14 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): if termination or truncation: # If there is a winner, keep track, otherwise don't change the scores (tie) - if env.rewards[env.possible_agents[0]] != env.rewards[env.possible_agents[1]]: + if ( + env.rewards[env.possible_agents[0]] + != env.rewards[env.possible_agents[1]] + ): winner = max(env.rewards, key=env.rewards.get) - scores[winner] += env.rewards[winner] # only tracks the largest reward (winner of game) + scores[winner] += env.rewards[ + winner + ] # only tracks the largest reward (winner of game) # Also track negative and positive rewards (penalizes illegal moves) for a in env.possible_agents: total_rewards[a] += env.rewards[a] diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index f1e68ac61..083b1aa43 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -2,7 +2,10 @@ try: import pytest - from tutorials.SB3.sb3_connect_four_action_mask import train_action_mask, eval_action_mask + from tutorials.SB3.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) except ModuleNotFoundError: pass @@ -31,7 +34,7 @@ # More difficult environments which will likely take more training time MEDIUM_ENVS = [ - leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores + leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores hanabi_v4, # even with 10x as many steps, total score seems to always be tied between the two agents tictactoe_v3, # even with 10x as many steps, agent still loses every time (most likely an error somewhere) ] @@ -44,8 +47,6 @@ ] - - @pytest.mark.parametrize("env_fn", EASY_ENVS) def test_action_mask_easy(env_fn): env_kwargs = {} @@ -54,14 +55,19 @@ def test_action_mask_easy(env_fn): train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) # Evaluate 2 games against a random agent - round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) - - assert total_rewards[env_fn.env().possible_agents[0]] > total_rewards[env_fn.env().possible_agents[1]], "Trained policy should outperform random actions" + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) + assert ( + total_rewards[env_fn.env().possible_agents[0]] + > total_rewards[env_fn.env().possible_agents[1]] + ), "Trained policy should outperform random actions" # Watch two games # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) + # @pytest.mark.skip(reason="training for these environments can be compute intensive") @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): @@ -71,9 +77,13 @@ def test_action_mask_medium(env_fn): train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) # Evaluate 2 games against a random agent - round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) - assert winrate < 0.75, "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe + assert ( + winrate < 0.75 + ), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe # Watch two games # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) @@ -88,11 +98,13 @@ def test_action_mask_hard(env_fn): train_action_mask(env_fn, steps=2048 * 10, seed=0, **env_kwargs) # Evaluate 2 games against a random agent - round_rewards, total_rewards, winrate, scores = eval_action_mask(env_fn, num_games=100, render_mode=None, - **env_kwargs) - - assert winrate < 0.5, "Policy should not perform better than 50% winrate" # 28% for chess, 0% for go + round_rewards, total_rewards, winrate, scores = eval_action_mask( + env_fn, num_games=100, render_mode=None, **env_kwargs + ) + assert ( + winrate < 0.5 + ), "Policy should not perform better than 50% winrate" # 28% for chess, 0% for go # Watch two games # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) From 062529606e3ae5b74cc69313c7775c93450a2ed2 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 00:19:05 -0400 Subject: [PATCH 20/38] Un-comment training code for KAZ --- tutorials/SB3/sb3_kaz_vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/SB3/sb3_kaz_vector.py b/tutorials/SB3/sb3_kaz_vector.py index 42ef8fb4f..461259e23 100644 --- a/tutorials/SB3/sb3_kaz_vector.py +++ b/tutorials/SB3/sb3_kaz_vector.py @@ -163,7 +163,7 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa env_kwargs = dict(max_cycles=100, max_zombies=4) # Train a model (takes ~5 minutes on a laptop CPU) - # train(env_fn, steps=8192*10, seed=0, **env_kwargs) + train(env_fn, steps=8192*10, seed=0, **env_kwargs) # Evaluate 10 games (takes ~10 seconds on a laptop CPU) eval(env_fn, num_games=10, render_mode=None, **env_kwargs) From 459cc8699d9294c9a76b9cebf0b1913b013160c2 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 02:13:44 -0400 Subject: [PATCH 21/38] Update hyperparameters and fix pistonball crashing issue --- tutorials/SB3/sb3_connect_four_action_mask.py | 2 +- tutorials/SB3/sb3_kaz_vector.py | 10 ++++------ tutorials/SB3/sb3_pistonball_vector.py | 6 ++++-- tutorials/SB3/test_sb3_action_mask.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py index e6cac2860..af5beabdc 100644 --- a/tutorials/SB3/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -164,7 +164,7 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): env_kwargs = {} # Train a model against itself (takes ~20 seconds on a laptop CPU) - train_action_mask(env_fn, steps=20480, seed=0, **env_kwargs) + train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs) # Evaluate 100 games against a random agent (winrate should be ~80%) eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs) diff --git a/tutorials/SB3/sb3_kaz_vector.py b/tutorials/SB3/sb3_kaz_vector.py index 461259e23..747b1e744 100644 --- a/tutorials/SB3/sb3_kaz_vector.py +++ b/tutorials/SB3/sb3_kaz_vector.py @@ -155,15 +155,13 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa if __name__ == "__main__": env_fn = knights_archers_zombies_v10 - # TODO: test out more hyperparameter combos - # max_cycles 100, max zombies 4, 8192 * 10 works decently, but sometimes fails due to agents dying - # black death wrapper, max cycles 100, max zombies 4, 8192*10, seems to work well (13 points over 10 games) - # black death wrapper, max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points) - + # Notes on environment configuration: + # max_cycles 100, max_zombies 4, seems to work well (13 points over 10 games) + # max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points) env_kwargs = dict(max_cycles=100, max_zombies=4) # Train a model (takes ~5 minutes on a laptop CPU) - train(env_fn, steps=8192*10, seed=0, **env_kwargs) + train(env_fn, steps=81_920, seed=0, **env_kwargs) # Evaluate 10 games (takes ~10 seconds on a laptop CPU) eval(env_fn, num_games=10, render_mode=None, **env_kwargs) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index bc3ad4da3..a63a5f62d 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -33,7 +33,7 @@ def train_butterfly_supersuit( print(f"Starting training on {str(env.metadata['name'])}.") env = ss.pettingzoo_env_to_vec_env_v1(env) - env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + env = ss.concat_vec_envs_v1(env, 4, num_cpus=2, base_class="stable_baselines3") model = PPO( CnnPolicy, @@ -59,7 +59,8 @@ def train_butterfly_supersuit( print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") - env.close() + # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) + # env.close() def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): @@ -128,6 +129,7 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa # 296 seconds for 81_920 steps (n_updates = 5) vs 46 seconds for 40_960 # Train a model (takes ~3 minutes on a laptop CPU) + # Note: stochastic environment makes training difficult, for better results try order of 2 million (~2 hours on GPU) train_butterfly_supersuit(env_fn, steps=40_960, seed=0, **env_kwargs) # Evaluate 10 games (takes ~10 seconds on a laptop CPU) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 083b1aa43..c5e2a7a7e 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -95,7 +95,7 @@ def test_action_mask_hard(env_fn): env_kwargs = {} # Train a model against itself - train_action_mask(env_fn, steps=2048 * 10, seed=0, **env_kwargs) + train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs) # Evaluate 2 games against a random agent round_rewards, total_rewards, winrate, scores = eval_action_mask( From 9546c9cc3860e08e239264e2b66a5fee64c7230c Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 02:21:31 -0400 Subject: [PATCH 22/38] Add hyperparameter notes --- tutorials/SB3/sb3_connect_four_action_mask.py | 6 ++++++ tutorials/SB3/sb3_pistonball_vector.py | 4 +--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py index af5beabdc..5baa00d3c 100644 --- a/tutorials/SB3/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -161,8 +161,14 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): if __name__ == "__main__": env_fn = connect_four_v3 + env_kwargs = {} + # Evaluation/training hyperparameter notes: + # 10k steps: Winrate: 0.76, loss order of 1e-03 + # 20k steps: Winrate: 0.86, loss order of 1e-04 + # 40k steps: Winrate: 0.86, loss order of 7e-06 + # Train a model against itself (takes ~20 seconds on a laptop CPU) train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index a63a5f62d..9c6ae2550 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -113,6 +113,7 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa if __name__ == "__main__": env_fn = pistonball_v6 + env_kwargs = dict( n_pistons=20, time_penalty=-0.1, @@ -125,9 +126,6 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa max_cycles=25, ) - # TODO: figure out why pistonball takes so long to train and seems to not save the model? - # 296 seconds for 81_920 steps (n_updates = 5) vs 46 seconds for 40_960 - # Train a model (takes ~3 minutes on a laptop CPU) # Note: stochastic environment makes training difficult, for better results try order of 2 million (~2 hours on GPU) train_butterfly_supersuit(env_fn, steps=40_960, seed=0, **env_kwargs) From 8cfd867292d4ec629e029869220b0ed7919ba754 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 02:42:02 -0400 Subject: [PATCH 23/38] Add multiwalker tutorial for MLP example --- docs/tutorials/sb3/index.md | 13 ++- docs/tutorials/sb3/multiwalker.md | 29 +++++ docs/tutorials/sb3/pistonball.md | 2 +- tutorials/SB3/sb3_connect_four_action_mask.py | 2 +- tutorials/SB3/sb3_kaz_vector.py | 2 +- tutorials/SB3/sb3_multiwalker_vector.py | 109 ++++++++++++++++++ tutorials/SB3/sb3_pistonball_vector.py | 2 +- tutorials/SB3/test_sb3_action_mask.py | 2 +- 8 files changed, 153 insertions(+), 8 deletions(-) create mode 100644 docs/tutorials/sb3/multiwalker.md create mode 100644 tutorials/SB3/sb3_multiwalker_vector.py diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index 7d8f84b84..daeeeb0ad 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -6,11 +6,17 @@ title: "Stable-Baselines3" These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments. -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in vectorized Parallel environment_ +For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/) -* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train a PPO model in a vectorized AEC environment_ +* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train agents using PPO in vectorized Parallel environment_ -* [Action Masked PPO for Chess](/tutorials/sb3/connect_four/): _Train an action masked PPO model in an AEC environment_ +* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized AEC environment_ + +For non-visual environments, we use [Actor Critic](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) or [Maskable Actor Critic](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html#maskableppo-policies) policies and do not perform any pre-processing steps. + +* [PPO for Multiwalker](/tutorials/sb3/multiwalker/): _Train agents using PPO in a vectorized AEC environment_ + +* [Action Masked PPO for Connect Four](/tutorials/sb3/connect_four/): _Train an agent using Action Masked PPO in an AEC environment_ ## Stable-Baselines Overview @@ -33,5 +39,6 @@ Note: SB3 does not officially support PettingZoo, as it is designed for single-a pistonball kaz +multiwalker connect_four ``` diff --git a/docs/tutorials/sb3/multiwalker.md b/docs/tutorials/sb3/multiwalker.md new file mode 100644 index 000000000..ea8095efb --- /dev/null +++ b/docs/tutorials/sb3/multiwalker.md @@ -0,0 +1,29 @@ +--- +title: "SB3: PPO for Pistonball (Parallel)" +--- + +# SB3: PPO for Pistonball + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Multiwalker](https://pettingzoo.farama.org/environments/sisl/multiwalker/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). + +Note: this environment uses a discrete 1-dimensional observation space, so we use an MLP extractor rather than CNN + +After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training and Evaluation + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/sb3_multiwalker_vector.py + :language: python +``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md index b102b8874..f2e369b9c 100644 --- a/docs/tutorials/sb3/pistonball.md +++ b/docs/tutorials/sb3/pistonball.md @@ -4,7 +4,7 @@ title: "SB3: PPO for Pistonball (Parallel)" # SB3: PPO for Pistonball -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/sb3_connect_four_action_mask.py index 5baa00d3c..789794bae 100644 --- a/tutorials/SB3/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/sb3_connect_four_action_mask.py @@ -1,4 +1,4 @@ -"""Uses Stable-Baselines3 to train agents to play Connect Four using invalid action masking. +"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking. For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html diff --git a/tutorials/SB3/sb3_kaz_vector.py b/tutorials/SB3/sb3_kaz_vector.py index 747b1e744..b02e72f10 100644 --- a/tutorials/SB3/sb3_kaz_vector.py +++ b/tutorials/SB3/sb3_kaz_vector.py @@ -1,4 +1,4 @@ -"""Uses Stable-Baselines3 to train agents to play Knights-Archers-Zombies using SuperSuit vector envs. +"""Uses Stable-Baselines3 to train agents in the Knights-Archers-Zombies environment using SuperSuit vector envs. This environment requires using SuperSuit's Black Death wrapper, to handle agent death. diff --git a/tutorials/SB3/sb3_multiwalker_vector.py b/tutorials/SB3/sb3_multiwalker_vector.py new file mode 100644 index 000000000..8ca27c44a --- /dev/null +++ b/tutorials/SB3/sb3_multiwalker_vector.py @@ -0,0 +1,109 @@ +"""Uses Stable-Baselines3 to train agents to play the Multiwalker environment using SuperSuit vector envs. + +For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html + +Author: Elliot (https://github.com/elliottower) +""" +from __future__ import annotations + +import glob +import os +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy + +from pettingzoo.sisl import multiwalker_v9 + + +def train_butterfly_supersuit( + env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs +): + # Train a single agent to play both sides in a Parallel environment, + env = env_fn.parallel_env(**env_kwargs) + + env.reset(seed=seed) + + print(f"Starting training on {str(env.metadata['name'])}.") + + env = ss.pettingzoo_env_to_vec_env_v1(env) + env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + + # Note: Multiwalker's observation space is discrete, therefore we use an MLP policy rather than CNN + model = PPO( + MlpPolicy, + env, + verbose=3, + learning_rate=1e-3, + batch_size=256, + ) + + model.learn(total_timesteps=steps) + + model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + print("Model has been saved.") + + print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") + + env.close() + + +def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): + # Evaluate a trained agent vs a random agent + env = env_fn.env(render_mode=render_mode, **env_kwargs) + + print( + f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" + ) + + try: + latest_policy = max( + glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime + ) + except ValueError: + print("Policy not found.") + exit(0) + + model = PPO.load(latest_policy) + + rewards = {agent: 0 for agent in env.possible_agents} + + # Note: We train using the Parallel API but evaluate using the AEC API + # SB3 models are designed for single-agent settings, we get around this by using he same model for every agent + for i in range(num_games): + env.reset(seed=i) + + for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + + if termination or truncation: + for agent in env.agents: + rewards[agent] += env.rewards[agent] + break + else: + act = model.predict(obs, deterministic=True)[0] + + env.step(act) + env.close() + + avg_reward = sum(rewards.values()) / len(rewards.values()) + print(f"Avg reward: {avg_reward}") + return avg_reward + + +if __name__ == "__main__": + env_fn = multiwalker_v9 + + env_kwargs = {} + + # Train a model (takes ~3 minutes on a laptop CPU) + # Note: stochastic environment makes training difficult, hyperparameters have not been fully tuned for this example + train_butterfly_supersuit(env_fn, steps=49_152 * 4, seed=0, **env_kwargs) + + # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=10, render_mode=None, **env_kwargs) + + # Watch 2 games (takes ~10 seconds on a laptop CPU) + eval(env_fn, num_games=2, render_mode="human", **env_kwargs) diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/sb3_pistonball_vector.py index 9c6ae2550..ec102edc9 100644 --- a/tutorials/SB3/sb3_pistonball_vector.py +++ b/tutorials/SB3/sb3_pistonball_vector.py @@ -1,4 +1,4 @@ -"""Uses Stable-Baselines3 to train agents to play PettingZoo Butterfly (cooprative) environments using SuperSuit vector envs. +"""Uses Stable-Baselines3 to train agents in the Pistonball environment using SuperSuit vector envs. For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index c5e2a7a7e..ad8ec5dca 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -1,4 +1,4 @@ -"""Test file to ensure that action masking code works for all PettingZoo classic environments (except rps, which has no action mask).""" +"""Tests that action masking code works properly with all PettingZoo classic environments.""" try: import pytest From 41e26fcc3ea6479c1fe8adce97e09cee97b97d3d Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 02:45:01 -0400 Subject: [PATCH 24/38] Fix typo in docs --- docs/tutorials/sb3/multiwalker.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/sb3/multiwalker.md b/docs/tutorials/sb3/multiwalker.md index ea8095efb..d8d82766c 100644 --- a/docs/tutorials/sb3/multiwalker.md +++ b/docs/tutorials/sb3/multiwalker.md @@ -1,8 +1,8 @@ --- -title: "SB3: PPO for Pistonball (Parallel)" +title: "SB3: PPO for Multiwalker (Parallel)" --- -# SB3: PPO for Pistonball +# SB3: PPO for Multiwalker This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Multiwalker](https://pettingzoo.farama.org/environments/sisl/multiwalker/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). From 6af9e1822a03198a84a2ec56edd1a02a4913335d Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 03:17:45 -0400 Subject: [PATCH 25/38] Polish up documentation and add sphinx warnings/notes --- docs/tutorials/sb3/connect_four.md | 29 ++++++++++++++++++++++++--- docs/tutorials/sb3/index.md | 7 ++++++- docs/tutorials/sb3/kaz.md | 15 ++++++++++++-- docs/tutorials/sb3/multiwalker.md | 8 ++++++-- docs/tutorials/sb3/pistonball.md | 7 ++++++- tutorials/SB3/test_sb3_action_mask.py | 2 +- 6 files changed, 58 insertions(+), 10 deletions(-) diff --git a/docs/tutorials/sb3/connect_four.md b/docs/tutorials/sb3/connect_four.md index d6c674780..6a228ffdd 100644 --- a/docs/tutorials/sb3/connect_four.md +++ b/docs/tutorials/sb3/connect_four.md @@ -6,11 +6,22 @@ title: "SB3: Action Masked PPO for Connect Four" This tutorial shows how to train a Maskable [Proximal Policy Optimization](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html) (PPO) model on the [Connect Four](https://pettingzoo.farama.org/environments/classic/chess/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). -It creates a custom Wrapper to convert to a Gymnasium-like environment which is compatible with SB3's action masking format. +It creates a custom Wrapper to convert to a [Gymnasium](https://gymnasium.farama.org/)-like environment which is compatible with [SB3 action masking](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html). -Note: This assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments. -After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). +```{eval-rst} +.. note:: + + This environment has a discrete (1-dimensional) observation space with an illegal action mask, so we use a masked MLP feature extractor. +``` + +```{eval-rst} +.. warning:: + + This wrapper assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments. +``` + +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). ## Environment Setup @@ -29,3 +40,15 @@ The following code should run without any issues. The comments are designed to h .. literalinclude:: ../../../tutorials/SB3/sb3_connect_four_action_mask.py :language: python ``` + +### Testing other PettingZoo Classic environments + +The following script uses [pytest](https://docs.pytest.org/en/latest/) to test all other PettingZoo environments which support action masking. + +This code yields good results on simpler environments like [Gin Rummy](/environments/classic/gin_rummy/) and [Texas Hold’em No Limit](/environments/classic/texas_holdem_no_limit/), while failing to perform better than random in more difficult environments such as [Chess](/environments/classic/chess/) or [Hanabi](/environments/classic/hanabi/). + + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/test_sb3_action_mask.py + :language: python +``` diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index daeeeb0ad..6c1c123f2 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -25,7 +25,12 @@ For non-visual environments, we use [Actor Critic](https://stable-baselines3.rea For more information, see the [Stable-Baselines3 v1.0 Blog Post](https://araffin.github.io/post/sb3/) -Note: SB3 does not officially support PettingZoo, as it is designed for single-agent RL. These tutorials demonstrate how to adapt SB3 to work in multi-agent settings, but we cannot guarantee training convergence. + +```{eval-rst} +.. warning:: + + Note: SB3 is designed for single-agent RL and does not plan on natively supporting multi-agent PettingZoo environments. These tutorials are only intended for demonstration purposes, to show how SB3 can be adapted to work in multi-agent settings. +``` ```{figure} https://raw.githubusercontent.com/DLR-RM/stable-baselines3/master/docs/_static/img/logo.png diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md index 74d2965c1..247d43c8c 100644 --- a/docs/tutorials/sb3/kaz.md +++ b/docs/tutorials/sb3/kaz.md @@ -8,9 +8,20 @@ This tutorial shows how to train a [Proximal Policy Optimization](https://stable It converts the environment into a Parallel environment and uses SuperSuit to create vectorized environments, leveraging multithreading to speed up training. -After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). + +```{eval-rst} +.. note:: + + This environment has a visual (3-dimensional) observation space, so we use a CNN feature extractor. +``` + +```{eval-rst} +.. warning:: + + Because this environment allows agents to spawn and die, it requires using SuperSuit's Black Death wrapper, which provides blank observations to dead agents, rather than removing them from the environment. +``` -Note: because this environment allows agents to spawn and die, it requires using SuperSuit's [Black Death](https://pettingzoo.farama.org/api/wrappers/supersuit_wrappers/#black_death_v2) wrapper, which provides blank observations to dead agents, rather than removing them from the environment. ## Environment Setup To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. diff --git a/docs/tutorials/sb3/multiwalker.md b/docs/tutorials/sb3/multiwalker.md index d8d82766c..3393573ea 100644 --- a/docs/tutorials/sb3/multiwalker.md +++ b/docs/tutorials/sb3/multiwalker.md @@ -6,9 +6,13 @@ title: "SB3: PPO for Multiwalker (Parallel)" This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Multiwalker](https://pettingzoo.farama.org/environments/sisl/multiwalker/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). -Note: this environment uses a discrete 1-dimensional observation space, so we use an MLP extractor rather than CNN +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). -After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). +```{eval-rst} +.. note:: + + This environment has a discrete (1-dimensional) observation space, so we use an MLP feature extractor. +``` ## Environment Setup diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md index f2e369b9c..5b75404f3 100644 --- a/docs/tutorials/sb3/pistonball.md +++ b/docs/tutorials/sb3/pistonball.md @@ -6,8 +6,13 @@ title: "SB3: PPO for Pistonball (Parallel)" This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). -After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). +After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). +```{eval-rst} +.. note:: + + This environment has a visual (3-dimensional) observation space, so we use a CNN feature extractor. +``` ## Environment Setup To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index ad8ec5dca..fc550f784 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -43,7 +43,7 @@ # TODO: test board_size to see if smaller go board is more easily solvable HARD_ENVS = [ chess_v6, # difficult to train because games take so long, 0.28 winrate even after 10x - go_v5, # difficult to train because games take so long, + go_v5, # difficult to train because games take so long ] From a45436235289c9ae69854602dd94bcf8a530b6b8 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 03:22:42 -0400 Subject: [PATCH 26/38] Try to fix missing module error from test file --- tutorials/SB3/requirements.txt | 2 +- tutorials/SB3/test_sb3_action_mask.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index 9010c7b29..53d42ef2c 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -1,4 +1,4 @@ stable-baselines3>=2.0.0 -pettingzoo[classic,butterfly]>=1.23.1 +pettingzoo[classic,butterfly,sisl]>=1.23.1 supersuit>=3.8.1 sb3-contrib>=2.0.0 diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index fc550f784..2c6cf1bd0 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -2,12 +2,14 @@ try: import pytest - from tutorials.SB3.sb3_connect_four_action_mask import ( - eval_action_mask, - train_action_mask, - ) -except ModuleNotFoundError: - pass +except ModuleNotFoundError as e: + print(e) + exit() + +from tutorials.SB3.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, +) from pettingzoo.classic import ( chess_v6, From 5c75d4c88f6edeef938ebe8dfea450882de79eba Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 10:54:21 -0400 Subject: [PATCH 27/38] Update test_sb3_action_mask.py --- tutorials/SB3/test_sb3_action_mask.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 2c6cf1bd0..a3ba1e8c1 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -2,9 +2,8 @@ try: import pytest -except ModuleNotFoundError as e: - print(e) - exit() +except ModuleNotFoundError: + return from tutorials.SB3.sb3_connect_four_action_mask import ( eval_action_mask, From fd23175716ef162cb9de1ed554d5299044a5dd59 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 11:58:46 -0400 Subject: [PATCH 28/38] Add importorskip to each test, choose better hyperparameters --- tutorials/SB3/render_sb3_pistonball.py | 41 -------------------------- tutorials/SB3/render_sb3_rps.py | 35 ---------------------- tutorials/SB3/requirements.txt | 1 + tutorials/SB3/test_sb3_action_mask.py | 40 +++++++++++++------------ 4 files changed, 22 insertions(+), 95 deletions(-) delete mode 100644 tutorials/SB3/render_sb3_pistonball.py delete mode 100644 tutorials/SB3/render_sb3_rps.py diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py deleted file mode 100644 index e42a375a1..000000000 --- a/tutorials/SB3/render_sb3_pistonball.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Uses Stable-Baselines3 to view trained agents playing Pistonball. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Author: Elliot (https://github.com/elliottower) -""" -import glob -import os - -import supersuit as ss -from stable_baselines3 import PPO - -from pettingzoo.butterfly import pistonball_v6 - -if __name__ == "__main__": - # Watch a game between two trained agents - env = pistonball_v6.env(render_mode="human") - - env = ss.color_reduction_v0(env, mode="B") - env = ss.resize_v1(env, x_size=84, y_size=84) - env = ss.frame_stack_v1(env, 3) - - try: - latest_policy = max( - glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime - ) - except ValueError: - print("Policy not found.") - exit(0) - - model = PPO.load(latest_policy) - - env.reset() - for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - act = ( - model.predict(obs, deterministic=True)[0] - if not termination or truncation - else None - ) - env.step(act) diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py deleted file mode 100644 index 20d280aa8..000000000 --- a/tutorials/SB3/render_sb3_rps.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Uses Stable-Baselines3 to view trained agents playing Rock-Paper-Scissors. - -Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b - -Author: Elliot (https://github.com/elliottower) -""" - -import glob -import os - -from stable_baselines3 import PPO - -from pettingzoo.classic import rps_v2 - -if __name__ == "__main__": - env = rps_v2.env(render_mode="human") - - try: - latest_policy = max( - glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime - ) - except ValueError: - print("Policy not found.") - exit(0) - - model = PPO.load(latest_policy) - - env.reset() - for agent in env.agent_iter(): - obs, reward, termination, truncation, info = env.last() - if termination or truncation: - act = None - else: - act = model.predict(obs, deterministic=True)[0] - env.step(act) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt index 53d42ef2c..60ca7deb2 100644 --- a/tutorials/SB3/requirements.txt +++ b/tutorials/SB3/requirements.txt @@ -2,3 +2,4 @@ stable-baselines3>=2.0.0 pettingzoo[classic,butterfly,sisl]>=1.23.1 supersuit>=3.8.1 sb3-contrib>=2.0.0 +pytest diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index a3ba1e8c1..f3fbf2d83 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -1,10 +1,6 @@ """Tests that action masking code works properly with all PettingZoo classic environments.""" -try: - import pytest -except ModuleNotFoundError: - return - +import pytest from tutorials.SB3.sb3_connect_four_action_mask import ( eval_action_mask, train_action_mask, @@ -28,8 +24,7 @@ EASY_ENVS = [ connect_four_v3, gin_rummy_v4, - # texas holdem likely broken, game ends instantly, but with random actions it works fine - texas_holdem_no_limit_v6, + texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine texas_holdem_v4, ] @@ -50,32 +45,38 @@ @pytest.mark.parametrize("env_fn", EASY_ENVS) def test_action_mask_easy(env_fn): + pytest.importorskip("stable_baselines3") + env_kwargs = {} - # Train a model against itself - train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) + # Leduc Hold`em takes slightly longer to outperform random + steps = 8192 if env_fn != leduc_holdem_v4 else 8192 * 4 + + # Train a model against itself (takes ~2 minutes on GPU) + train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs) # Evaluate 2 games against a random agent round_rewards, total_rewards, winrate, scores = eval_action_mask( env_fn, num_games=100, render_mode=None, **env_kwargs ) - assert ( - total_rewards[env_fn.env().possible_agents[0]] - > total_rewards[env_fn.env().possible_agents[1]] + assert winrate > 0.5 or ( + total_rewards[env_fn.env().possible_agents[1]] + > total_rewards[env_fn.env().possible_agents[0]] ), "Trained policy should outperform random actions" - # Watch two games + # Watch two games (disabled by default) # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) -# @pytest.mark.skip(reason="training for these environments can be compute intensive") @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): + pytest.importorskip("stable_baselines3") + env_kwargs = {} # Train a model against itself - train_action_mask(env_fn, steps=2048, seed=0, **env_kwargs) + train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs) # Evaluate 2 games against a random agent round_rewards, total_rewards, winrate, scores = eval_action_mask( @@ -86,17 +87,18 @@ def test_action_mask_medium(env_fn): winrate < 0.75 ), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe - # Watch two games + # Watch two games (disabled by default) # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) -# @pytest.mark.skip(reason="training for these environments can be compute intensive") @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): + pytest.importorskip("stable_baselines3") + env_kwargs = {} # Train a model against itself - train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs) + train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs) # Evaluate 2 games against a random agent round_rewards, total_rewards, winrate, scores = eval_action_mask( @@ -107,5 +109,5 @@ def test_action_mask_hard(env_fn): winrate < 0.5 ), "Policy should not perform better than 50% winrate" # 28% for chess, 0% for go - # Watch two games + # Watch two games (disabled by default) # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) From cefc86dc742d9f1f9157f31039635df52e73e808 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 12:46:10 -0400 Subject: [PATCH 29/38] Move pytest importorskip calls --- tutorials/SB3/test_sb3_action_mask.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index f3fbf2d83..715b82bbe 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -1,10 +1,6 @@ """Tests that action masking code works properly with all PettingZoo classic environments.""" import pytest -from tutorials.SB3.sb3_connect_four_action_mask import ( - eval_action_mask, - train_action_mask, -) from pettingzoo.classic import ( chess_v6, @@ -18,6 +14,9 @@ tictactoe_v3, ) +pytest.importorskip("stable_baselines3") +pytest.importorskip("sb3_contrib") + # Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself # These environments do better than random even after the minimum number of timesteps @@ -45,7 +44,10 @@ @pytest.mark.parametrize("env_fn", EASY_ENVS) def test_action_mask_easy(env_fn): - pytest.importorskip("stable_baselines3") + from tutorials.SB3.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) env_kwargs = {} @@ -71,7 +73,10 @@ def test_action_mask_easy(env_fn): @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): - pytest.importorskip("stable_baselines3") + from tutorials.SB3.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) env_kwargs = {} @@ -93,7 +98,10 @@ def test_action_mask_medium(env_fn): @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): - pytest.importorskip("stable_baselines3") + from tutorials.SB3.sb3_connect_four_action_mask import ( + eval_action_mask, + train_action_mask, + ) env_kwargs = {} From 142b155558f220648b288f09c98adaec70a165ff Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 14:30:06 -0400 Subject: [PATCH 30/38] Disable most of the tests on test_sb3_action_mask.py --- tutorials/SB3/test_sb3_action_mask.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test_sb3_action_mask.py index 715b82bbe..87087ca59 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test_sb3_action_mask.py @@ -4,7 +4,6 @@ from pettingzoo.classic import ( chess_v6, - connect_four_v3, gin_rummy_v4, go_v5, hanabi_v4, @@ -17,11 +16,11 @@ pytest.importorskip("stable_baselines3") pytest.importorskip("sb3_contrib") +# Note: Connect Four is tested in sb3_connect_four_action_mask.py # Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself # These environments do better than random even after the minimum number of timesteps EASY_ENVS = [ - connect_four_v3, gin_rummy_v4, texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine texas_holdem_v4, @@ -71,6 +70,9 @@ def test_action_mask_easy(env_fn): # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) +@pytest.mark.skip( + reason="training can be compute intensive and hyperparameters have not been tuned, disabled for CI" +) @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): from tutorials.SB3.sb3_connect_four_action_mask import ( @@ -96,6 +98,9 @@ def test_action_mask_medium(env_fn): # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) +@pytest.mark.skip( + reason="training can be compute intensive and hyperparameters have not been tuned, disabled for CI" +) @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): from tutorials.SB3.sb3_connect_four_action_mask import ( From 1a2d2ef65f3a26bb9b91d36143eee4a9f9123f6a Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 14:44:33 -0400 Subject: [PATCH 31/38] Split CI tests into separate actions (so they don't take 2 hours) --- .github/workflows/linux-tutorials-test.yml | 2 +- docs/tutorials/sb3/connect_four.md | 2 +- docs/tutorials/sb3/kaz.md | 2 +- docs/tutorials/sb3/multiwalker.md | 2 +- docs/tutorials/sb3/pistonball.md | 2 +- .../{ => connect_four}/sb3_connect_four_action_mask.py | 0 tutorials/SB3/{ => kaz}/sb3_kaz_vector.py | 9 ++------- .../SB3/{ => multiwalker}/sb3_multiwalker_vector.py | 0 tutorials/SB3/{ => pistonball}/sb3_pistonball_vector.py | 0 tutorials/SB3/{ => test}/test_sb3_action_mask.py | 6 +++--- 10 files changed, 10 insertions(+), 15 deletions(-) rename tutorials/SB3/{ => connect_four}/sb3_connect_four_action_mask.py (100%) rename tutorials/SB3/{ => kaz}/sb3_kaz_vector.py (96%) rename tutorials/SB3/{ => multiwalker}/sb3_multiwalker_vector.py (100%) rename tutorials/SB3/{ => pistonball}/sb3_pistonball_vector.py (100%) rename tutorials/SB3/{ => test}/test_sb3_action_mask.py (94%) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index 4e46ad714..22d838e0e 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10'] # '3.11' - broken due to numba - tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3'] # TODO: add back RLlib once it is fixed + tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3/connect_four', 'SB3/pistonball', 'SB3/kaz', 'SB3/multiwalker', 'SB3/test'] # TODO: add back RLlib once it is fixed steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/docs/tutorials/sb3/connect_four.md b/docs/tutorials/sb3/connect_four.md index 6a228ffdd..31afbbadc 100644 --- a/docs/tutorials/sb3/connect_four.md +++ b/docs/tutorials/sb3/connect_four.md @@ -37,7 +37,7 @@ The following code should run without any issues. The comments are designed to h ### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_connect_four_action_mask.py +.. literalinclude:: ../../../tutorials/SB3/connect_four/sb3_connect_four_action_mask.py :language: python ``` diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md index 247d43c8c..9cb8b78e0 100644 --- a/docs/tutorials/sb3/kaz.md +++ b/docs/tutorials/sb3/kaz.md @@ -36,6 +36,6 @@ The following code should run without any issues. The comments are designed to h ### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_kaz_vector.py +.. literalinclude:: ../../../tutorials/SB3/kaz/sb3_kaz_vector.py :language: python ``` diff --git a/docs/tutorials/sb3/multiwalker.md b/docs/tutorials/sb3/multiwalker.md index 3393573ea..ad2fc3d26 100644 --- a/docs/tutorials/sb3/multiwalker.md +++ b/docs/tutorials/sb3/multiwalker.md @@ -28,6 +28,6 @@ The following code should run without any issues. The comments are designed to h ### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_multiwalker_vector.py +.. literalinclude:: ../../../tutorials/SB3/multiwalker/sb3_multiwalker_vector.py :language: python ``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md index 5b75404f3..ee52dc8ac 100644 --- a/docs/tutorials/sb3/pistonball.md +++ b/docs/tutorials/sb3/pistonball.md @@ -27,6 +27,6 @@ The following code should run without any issues. The comments are designed to h ### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball_vector.py +.. literalinclude:: ../../../tutorials/SB3/pistonball/sb3_pistonball_vector.py :language: python ``` diff --git a/tutorials/SB3/sb3_connect_four_action_mask.py b/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py similarity index 100% rename from tutorials/SB3/sb3_connect_four_action_mask.py rename to tutorials/SB3/connect_four/sb3_connect_four_action_mask.py diff --git a/tutorials/SB3/sb3_kaz_vector.py b/tutorials/SB3/kaz/sb3_kaz_vector.py similarity index 96% rename from tutorials/SB3/sb3_kaz_vector.py rename to tutorials/SB3/kaz/sb3_kaz_vector.py index b02e72f10..fda1bd59a 100644 --- a/tutorials/SB3/sb3_kaz_vector.py +++ b/tutorials/SB3/kaz/sb3_kaz_vector.py @@ -23,19 +23,14 @@ def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs): # Train a single agent to play both sides in an AEC environment env = env_fn.parallel_env(**env_kwargs) + # Add black death wrapper so the number of agents stays constant + # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True env = ss.black_death_v3(env) - # Convert into a Parallel environment in order to vectorize it (SuperSuit does not currently support vectorized AEC envs) - # env = turn_based_aec_to_parallel(env) - # Pre-process using SuperSuit (color reduction, resizing and frame stacking) - # env = ss.color_reduction_v0(env, mode="B") env = ss.resize_v1(env, x_size=84, y_size=84) env = ss.frame_stack_v1(env, 3) - # Add black death wrapper so the number of agents stays constant - # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True - env.reset(seed=seed) print(f"Starting training on {str(env.metadata['name'])}.") diff --git a/tutorials/SB3/sb3_multiwalker_vector.py b/tutorials/SB3/multiwalker/sb3_multiwalker_vector.py similarity index 100% rename from tutorials/SB3/sb3_multiwalker_vector.py rename to tutorials/SB3/multiwalker/sb3_multiwalker_vector.py diff --git a/tutorials/SB3/sb3_pistonball_vector.py b/tutorials/SB3/pistonball/sb3_pistonball_vector.py similarity index 100% rename from tutorials/SB3/sb3_pistonball_vector.py rename to tutorials/SB3/pistonball/sb3_pistonball_vector.py diff --git a/tutorials/SB3/test_sb3_action_mask.py b/tutorials/SB3/test/test_sb3_action_mask.py similarity index 94% rename from tutorials/SB3/test_sb3_action_mask.py rename to tutorials/SB3/test/test_sb3_action_mask.py index 87087ca59..8d52462c9 100644 --- a/tutorials/SB3/test_sb3_action_mask.py +++ b/tutorials/SB3/test/test_sb3_action_mask.py @@ -43,7 +43,7 @@ @pytest.mark.parametrize("env_fn", EASY_ENVS) def test_action_mask_easy(env_fn): - from tutorials.SB3.sb3_connect_four_action_mask import ( + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( eval_action_mask, train_action_mask, ) @@ -75,7 +75,7 @@ def test_action_mask_easy(env_fn): ) @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): - from tutorials.SB3.sb3_connect_four_action_mask import ( + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( eval_action_mask, train_action_mask, ) @@ -103,7 +103,7 @@ def test_action_mask_medium(env_fn): ) @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): - from tutorials.SB3.sb3_connect_four_action_mask import ( + from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( eval_action_mask, train_action_mask, ) From 35addaa8cd16e1940e05dcf63c9756bace5e9dbd Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 15:43:18 -0400 Subject: [PATCH 32/38] Add separate requirements files for different sb3 tutorials --- tutorials/SB3/connect_four/requirements.txt | 3 +++ tutorials/SB3/kaz/requirements.txt | 3 +++ tutorials/SB3/multiwalker/requirements.txt | 3 +++ tutorials/SB3/pistonball/requirements.txt | 3 +++ tutorials/SB3/test/requirements.txt | 4 ++++ tutorials/SB3/test/test_sb3_action_mask.py | 4 ++-- 6 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 tutorials/SB3/connect_four/requirements.txt create mode 100644 tutorials/SB3/kaz/requirements.txt create mode 100644 tutorials/SB3/multiwalker/requirements.txt create mode 100644 tutorials/SB3/pistonball/requirements.txt create mode 100644 tutorials/SB3/test/requirements.txt diff --git a/tutorials/SB3/connect_four/requirements.txt b/tutorials/SB3/connect_four/requirements.txt new file mode 100644 index 000000000..30917f7b2 --- /dev/null +++ b/tutorials/SB3/connect_four/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[classic]>=1.23.1 +stable-baselines3>=2.0.0 +sb3-contrib>=2.0.0 diff --git a/tutorials/SB3/kaz/requirements.txt b/tutorials/SB3/kaz/requirements.txt new file mode 100644 index 000000000..6199e7131 --- /dev/null +++ b/tutorials/SB3/kaz/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[butterfly]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 diff --git a/tutorials/SB3/multiwalker/requirements.txt b/tutorials/SB3/multiwalker/requirements.txt new file mode 100644 index 000000000..4baeed307 --- /dev/null +++ b/tutorials/SB3/multiwalker/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[sisl]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 diff --git a/tutorials/SB3/pistonball/requirements.txt b/tutorials/SB3/pistonball/requirements.txt new file mode 100644 index 000000000..6199e7131 --- /dev/null +++ b/tutorials/SB3/pistonball/requirements.txt @@ -0,0 +1,3 @@ +pettingzoo[butterfly]>=1.23.1 +stable-baselines3>=2.0.0 +supersuit>=3.8.1 diff --git a/tutorials/SB3/test/requirements.txt b/tutorials/SB3/test/requirements.txt new file mode 100644 index 000000000..838ea192b --- /dev/null +++ b/tutorials/SB3/test/requirements.txt @@ -0,0 +1,4 @@ +pettingzoo[classic]>=1.23.1 +stable-baselines3>=2.0.0 +sb3-contrib>=2.0.0 +pytest diff --git a/tutorials/SB3/test/test_sb3_action_mask.py b/tutorials/SB3/test/test_sb3_action_mask.py index 8d52462c9..0bd54d067 100644 --- a/tutorials/SB3/test/test_sb3_action_mask.py +++ b/tutorials/SB3/test/test_sb3_action_mask.py @@ -71,7 +71,7 @@ def test_action_mask_easy(env_fn): @pytest.mark.skip( - reason="training can be compute intensive and hyperparameters have not been tuned, disabled for CI" + reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" ) @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): @@ -99,7 +99,7 @@ def test_action_mask_medium(env_fn): @pytest.mark.skip( - reason="training can be compute intensive and hyperparameters have not been tuned, disabled for CI" + reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" ) @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): From 996274e33c0b3366d5558db8803e60b58df360f3 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 16:03:47 -0400 Subject: [PATCH 33/38] Fix workflow for tutorials to always install from root dir --- .github/workflows/linux-tutorials-test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index 22d838e0e..e58439bf5 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -29,8 +29,9 @@ jobs: - name: Install dependencies and run tutorials run: | sudo apt-get install python3-opengl xvfb + export root_dir=$(pwd) cd tutorials/${{ matrix.tutorial }} pip install -r requirements.txt pip uninstall -y pettingzoo - pip install -e ../.. + pip install -e $root_dir for f in *.py; do xvfb-run -a -s "-screen 0 1024x768x24" python "$f"; done From c4834b58f14e5ae5e44c473d1cfb8cf8dc8daba4 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 16:23:46 -0400 Subject: [PATCH 34/38] Un-skip the rest of the action mask tests, as the longest one is pistonball (11 minutes) vs 2.5 mins for the easy tests --- tutorials/SB3/test/test_sb3_action_mask.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tutorials/SB3/test/test_sb3_action_mask.py b/tutorials/SB3/test/test_sb3_action_mask.py index 0bd54d067..f8f803766 100644 --- a/tutorials/SB3/test/test_sb3_action_mask.py +++ b/tutorials/SB3/test/test_sb3_action_mask.py @@ -70,9 +70,9 @@ def test_action_mask_easy(env_fn): # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) -@pytest.mark.skip( - reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" -) +# @pytest.mark.skip( +# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" +# ) @pytest.mark.parametrize("env_fn", MEDIUM_ENVS) def test_action_mask_medium(env_fn): from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( @@ -98,9 +98,9 @@ def test_action_mask_medium(env_fn): # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs) -@pytest.mark.skip( - reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" -) +# @pytest.mark.skip( +# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI" +# ) @pytest.mark.parametrize("env_fn", HARD_ENVS) def test_action_mask_hard(env_fn): from tutorials.SB3.connect_four.sb3_connect_four_action_mask import ( From 1dfe96bb05e5a40db76a4ae4a0508bf0dafed5e8 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 16:38:35 -0400 Subject: [PATCH 35/38] Remove pistonball env.close() line to avoid SuperSuit issue --- tutorials/SB3/pistonball/sb3_pistonball_vector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tutorials/SB3/pistonball/sb3_pistonball_vector.py b/tutorials/SB3/pistonball/sb3_pistonball_vector.py index ec102edc9..135761abe 100644 --- a/tutorials/SB3/pistonball/sb3_pistonball_vector.py +++ b/tutorials/SB3/pistonball/sb3_pistonball_vector.py @@ -104,7 +104,9 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa act = model.predict(obs, deterministic=True)[0] env.step(act) - env.close() + + # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) + # env.close() avg_reward = sum(rewards.values()) / len(rewards.values()) print(f"Avg reward: {avg_reward}") From 5f65af0479986c9537bcb264ae6b75c4f10d75fc Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 21:32:02 -0400 Subject: [PATCH 36/38] Change multiwalker to waterworld (actually trains), remove pistonball (bugged) --- .github/workflows/linux-tutorials-test.yml | 2 +- docs/tutorials/sb3/connect_four.md | 2 +- docs/tutorials/sb3/index.md | 17 ++++------ docs/tutorials/sb3/kaz.md | 2 +- docs/tutorials/sb3/pistonball.md | 32 ------------------- .../sb3/{multiwalker.md => waterworld.md} | 4 +-- tutorials/SB3/kaz/sb3_kaz_vector.py | 25 --------------- .../SB3/pistonball/sb3_pistonball_vector.py | 6 ++-- .../requirements.txt | 0 .../sb3_waterworld_vector.py} | 25 +++++++-------- 10 files changed, 27 insertions(+), 88 deletions(-) delete mode 100644 docs/tutorials/sb3/pistonball.md rename docs/tutorials/sb3/{multiwalker.md => waterworld.md} (89%) rename tutorials/SB3/{multiwalker => waterworld}/requirements.txt (100%) rename tutorials/SB3/{multiwalker/sb3_multiwalker_vector.py => waterworld/sb3_waterworld_vector.py} (77%) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index e58439bf5..569e87ee8 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10'] # '3.11' - broken due to numba - tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3/connect_four', 'SB3/pistonball', 'SB3/kaz', 'SB3/multiwalker', 'SB3/test'] # TODO: add back RLlib once it is fixed + tutorial: ['Tianshou', 'EnvironmentCreation', 'CleanRL', 'SB3/kaz', 'SB3/waterworld', 'SB3/connect_four', 'SB3/test'] # TODO: add back RLlib once it is fixed steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/docs/tutorials/sb3/connect_four.md b/docs/tutorials/sb3/connect_four.md index 31afbbadc..fc96af2cf 100644 --- a/docs/tutorials/sb3/connect_four.md +++ b/docs/tutorials/sb3/connect_four.md @@ -27,7 +27,7 @@ After training and evaluation, this script will launch a demo game using human r ## Environment Setup To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt +.. literalinclude:: ../../../tutorials/SB3/connect_four/requirements.txt :language: text ``` diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index 6c1c123f2..617d74a2f 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -4,19 +4,17 @@ title: "Stable-Baselines3" # Stable-Baselines3 Tutorial -These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments. +These tutorials show you how to use the [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) library to train agents in PettingZoo environments. -For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/) +For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/). -* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train agents using PPO in vectorized Parallel environment_ +* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized environment with visual observations_ -* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized AEC environment_ +For non-visual environments, we use [MLP](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.MlpPolicy) policies and do not perform any pre-processing steps. -For non-visual environments, we use [Actor Critic](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) or [Maskable Actor Critic](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html#maskableppo-policies) policies and do not perform any pre-processing steps. +* [PPO for Waterworld](/tutorials/sb3/waterworld/): _Train agents using PPO in a vectorized environment with discrete observations_ -* [PPO for Multiwalker](/tutorials/sb3/multiwalker/): _Train agents using PPO in a vectorized AEC environment_ - -* [Action Masked PPO for Connect Four](/tutorials/sb3/connect_four/): _Train an agent using Action Masked PPO in an AEC environment_ +* [Action Masked PPO for Connect Four](/tutorials/sb3/connect_four/): _Train agents using Action Masked PPO in an AEC environment_ ## Stable-Baselines Overview @@ -42,8 +40,7 @@ For more information, see the [Stable-Baselines3 v1.0 Blog Post](https://araffin :hidden: :caption: SB3 -pistonball kaz -multiwalker +waterworld connect_four ``` diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md index 9cb8b78e0..1714d66b2 100644 --- a/docs/tutorials/sb3/kaz.md +++ b/docs/tutorials/sb3/kaz.md @@ -26,7 +26,7 @@ After training and evaluation, this script will launch a demo game using human r ## Environment Setup To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt +.. literalinclude:: ../../../tutorials/SB3/kaz/requirements.txt :language: text ``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md deleted file mode 100644 index ee52dc8ac..000000000 --- a/docs/tutorials/sb3/pistonball.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -title: "SB3: PPO for Pistonball (Parallel)" ---- - -# SB3: PPO for Pistonball - -This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)). - -After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information). - -```{eval-rst} -.. note:: - - This environment has a visual (3-dimensional) observation space, so we use a CNN feature extractor. -``` - -## Environment Setup -To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt - :language: text -``` - -## Code -The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). - -### Training and Evaluation - -```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/pistonball/sb3_pistonball_vector.py - :language: python -``` diff --git a/docs/tutorials/sb3/multiwalker.md b/docs/tutorials/sb3/waterworld.md similarity index 89% rename from docs/tutorials/sb3/multiwalker.md rename to docs/tutorials/sb3/waterworld.md index ad2fc3d26..519079a5f 100644 --- a/docs/tutorials/sb3/multiwalker.md +++ b/docs/tutorials/sb3/waterworld.md @@ -18,7 +18,7 @@ After training and evaluation, this script will launch a demo game using human r ## Environment Setup To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/requirements.txt +.. literalinclude:: ../../../tutorials/SB3/waterworld/requirements.txt :language: text ``` @@ -28,6 +28,6 @@ The following code should run without any issues. The comments are designed to h ### Training and Evaluation ```{eval-rst} -.. literalinclude:: ../../../tutorials/SB3/multiwalker/sb3_multiwalker_vector.py +.. literalinclude:: ../../../tutorials/SB3/multiwalker/sb3_waterworld_vector.py :language: python ``` diff --git a/tutorials/SB3/kaz/sb3_kaz_vector.py b/tutorials/SB3/kaz/sb3_kaz_vector.py index fda1bd59a..b65df30e0 100644 --- a/tutorials/SB3/kaz/sb3_kaz_vector.py +++ b/tutorials/SB3/kaz/sb3_kaz_vector.py @@ -90,31 +90,6 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa rewards = {agent: 0 for agent in env.possible_agents} - # TODO: figure out why Parallel performs differently at test time (my guess is maybe the way it counts num_cycles is different?) - # # It seems to make the rewards worse, the same policy scores 2/3 points per archer vs 6/7 with AEC. n - - # from pettingzoo.utils.wrappers import RecordEpisodeStatistics - # - # env = env_fn.parallel_env(render_mode=render_mode, **env_kwargs) - # - # # Pre-process using SuperSuit (color reduction, resizing and frame stacking) - # env = ss.resize_v1(env, x_size=84, y_size=84) - # env = ss.frame_stack_v1(env, 3) - # env = RecordEpisodeStatistics(env) - # - # stats = [] - # for i in range(num_games): - # observations, infos = env.reset(seed=i) - # done = False - # while not done: - # actions = {agent: model.predict(observations[agent], deterministic=True)[0] for agent in env.agents} - # obss, rews, terms, truncs, infos = env.step(actions) - # - # for agent in env.possible_agents: - # rewards[agent] += rews[agent] - # done = any(terms.values()) or any(truncs.values()) - # stats.append(infos["episode"]) - # Note: we evaluate here using an AEC environments, to allow for easy A/B testing against random policies # For example, we can see here that using a random agent for archer_0 results in less points than the trained agent for i in range(num_games): diff --git a/tutorials/SB3/pistonball/sb3_pistonball_vector.py b/tutorials/SB3/pistonball/sb3_pistonball_vector.py index 135761abe..a6a0f723e 100644 --- a/tutorials/SB3/pistonball/sb3_pistonball_vector.py +++ b/tutorials/SB3/pistonball/sb3_pistonball_vector.py @@ -60,7 +60,7 @@ def train_butterfly_supersuit( print(f"Finished training on {str(env.unwrapped.metadata['name'])}.") # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) - # env.close() + env.close() def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs): @@ -106,7 +106,7 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa env.step(act) # TODO: fix SuperSuit bug where closing the vector env can sometimes crash (disabled for CI) - # env.close() + env.close() avg_reward = sum(rewards.values()) / len(rewards.values()) print(f"Avg reward: {avg_reward}") @@ -130,7 +130,7 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa # Train a model (takes ~3 minutes on a laptop CPU) # Note: stochastic environment makes training difficult, for better results try order of 2 million (~2 hours on GPU) - train_butterfly_supersuit(env_fn, steps=40_960, seed=0, **env_kwargs) + train_butterfly_supersuit(env_fn, steps=40_960 * 2, seed=0, **env_kwargs) # Evaluate 10 games (takes ~10 seconds on a laptop CPU) eval(env_fn, num_games=10, render_mode=None, **env_kwargs) diff --git a/tutorials/SB3/multiwalker/requirements.txt b/tutorials/SB3/waterworld/requirements.txt similarity index 100% rename from tutorials/SB3/multiwalker/requirements.txt rename to tutorials/SB3/waterworld/requirements.txt diff --git a/tutorials/SB3/multiwalker/sb3_multiwalker_vector.py b/tutorials/SB3/waterworld/sb3_waterworld_vector.py similarity index 77% rename from tutorials/SB3/multiwalker/sb3_multiwalker_vector.py rename to tutorials/SB3/waterworld/sb3_waterworld_vector.py index 8ca27c44a..edaa81afd 100644 --- a/tutorials/SB3/multiwalker/sb3_multiwalker_vector.py +++ b/tutorials/SB3/waterworld/sb3_waterworld_vector.py @@ -1,4 +1,4 @@ -"""Uses Stable-Baselines3 to train agents to play the Multiwalker environment using SuperSuit vector envs. +"""Uses Stable-Baselines3 to train agents to play the Waterworld environment using SuperSuit vector envs. For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html @@ -14,13 +14,13 @@ from stable_baselines3 import PPO from stable_baselines3.ppo import MlpPolicy -from pettingzoo.sisl import multiwalker_v9 +from pettingzoo.sisl import waterworld_v4 def train_butterfly_supersuit( env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs ): - # Train a single agent to play both sides in a Parallel environment, + # Train a single a model to play as each agent in a cooperative Parallel environment, env = env_fn.parallel_env(**env_kwargs) env.reset(seed=seed) @@ -30,7 +30,7 @@ def train_butterfly_supersuit( env = ss.pettingzoo_env_to_vec_env_v1(env) env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") - # Note: Multiwalker's observation space is discrete, therefore we use an MLP policy rather than CNN + # Note: Waterworld's observation space is discrete (242,) so we use an MLP policy rather than CNN model = PPO( MlpPolicy, env, @@ -78,9 +78,9 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa for agent in env.agent_iter(): obs, reward, termination, truncation, info = env.last() + for agent in env.agents: + rewards[agent] += env.rewards[agent] if termination or truncation: - for agent in env.agents: - rewards[agent] += env.rewards[agent] break else: act = model.predict(obs, deterministic=True)[0] @@ -89,21 +89,20 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa env.close() avg_reward = sum(rewards.values()) / len(rewards.values()) + print("Rewards: ", rewards) print(f"Avg reward: {avg_reward}") return avg_reward if __name__ == "__main__": - env_fn = multiwalker_v9 - + env_fn = waterworld_v4 env_kwargs = {} - # Train a model (takes ~3 minutes on a laptop CPU) - # Note: stochastic environment makes training difficult, hyperparameters have not been fully tuned for this example - train_butterfly_supersuit(env_fn, steps=49_152 * 4, seed=0, **env_kwargs) + # Train a model (takes ~3 minutes on GPU) + train_butterfly_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs) - # Evaluate 10 games (takes ~10 seconds on a laptop CPU) + # Evaluate 10 games (average reward should be positive but can vary significantly) eval(env_fn, num_games=10, render_mode=None, **env_kwargs) - # Watch 2 games (takes ~10 seconds on a laptop CPU) + # Watch 2 games eval(env_fn, num_games=2, render_mode="human", **env_kwargs) From 7403e029fa45a3ca04e8ea5c67cc17c837d1285c Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 21:59:42 -0400 Subject: [PATCH 37/38] Add pymunk dependency to sisl waterworld (modulenotfound error) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a11409697..e139a5d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classic = [ ] butterfly = ["pygame==2.3.0", "pymunk==6.2.0"] mpe = ["pygame==2.3.0"] -sisl = ["pygame==2.3.0", "box2d-py==2.3.5", "scipy>=1.4.1"] +sisl = ["pygame==2.3.0", "pymunk==6.2.0", "box2d-py==2.3.5", "scipy>=1.4.1"] other = ["pillow>=8.0.1"] testing = [ "pynput", From a92b2a3cecd6772a4b604976b36542eaced53a94 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 10 Jul 2023 22:07:23 -0400 Subject: [PATCH 38/38] Add pymunk req --- tutorials/SB3/waterworld/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/SB3/waterworld/requirements.txt b/tutorials/SB3/waterworld/requirements.txt index 4baeed307..cb7dac213 100644 --- a/tutorials/SB3/waterworld/requirements.txt +++ b/tutorials/SB3/waterworld/requirements.txt @@ -1,3 +1,4 @@ pettingzoo[sisl]>=1.23.1 stable-baselines3>=2.0.0 supersuit>=3.8.1 +pymunk