Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SB3 tutorial (action masking, tests) #1017

Merged
merged 38 commits into from
Jul 11, 2023
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cf2bef7
Update SB3 tutorial to have __main__ (error on macOS)
elliottower Jul 7, 2023
86916e6
Add SB3 tests for tutorials
elliottower Jul 7, 2023
f3e0cc9
Add action masking tutorial, fix typos/update documentation
elliottower Jul 7, 2023
d0784be
Add try catch for test sb3 action mask (pytest -v shouldn't require sb3)
elliottower Jul 7, 2023
87c46ef
Clean up documentation
elliottower Jul 8, 2023
baf59fa
Fix requirements.txt to specify pettingzoo[classic]
elliottower Jul 8, 2023
db8331a
Add try catch for render action mask
elliottower Jul 8, 2023
05b3dcc
Add try catch for render action mask
elliottower Jul 8, 2023
ecb96bf
Add try catch for other render files
elliottower Jul 8, 2023
d887db2
Fix code which doesn't work due to modules (tutorials not included)
elliottower Jul 8, 2023
085ed0a
Switch userwarnings to print statements and exit (so it doesn't fail)
elliottower Jul 8, 2023
18eca55
Add butterfly requirement to sb3 tutorial
elliottower Jul 8, 2023
429cbd8
Switch default timesteps to be more reasonable (10,000)
elliottower Jul 8, 2023
c9f0024
Switch default timesteps to be lower (2048), just so CI runs faster
elliottower Jul 8, 2023
a64022d
Switch num cpus to 2 by default (github ations only get 2 cores)
elliottower Jul 8, 2023
ee08317
Fix print statements logic
elliottower Jul 8, 2023
bd83f30
Update tutorials to evaluate, add KAZ example, test hyperparameters
elliottower Jul 9, 2023
0185af8
Update code to check more in depth statistics like winrate and total …
elliottower Jul 10, 2023
c977d89
Pre-commit
elliottower Jul 10, 2023
0625296
Un-comment training code for KAZ
elliottower Jul 10, 2023
459cc86
Update hyperparameters and fix pistonball crashing issue
elliottower Jul 10, 2023
9546c9c
Add hyperparameter notes
elliottower Jul 10, 2023
8cfd867
Add multiwalker tutorial for MLP example
elliottower Jul 10, 2023
41e26fc
Fix typo in docs
elliottower Jul 10, 2023
6af9e18
Polish up documentation and add sphinx warnings/notes
elliottower Jul 10, 2023
a454362
Try to fix missing module error from test file
elliottower Jul 10, 2023
5c75d4c
Update test_sb3_action_mask.py
elliottower Jul 10, 2023
fd23175
Add importorskip to each test, choose better hyperparameters
elliottower Jul 10, 2023
cefc86d
Move pytest importorskip calls
elliottower Jul 10, 2023
142b155
Disable most of the tests on test_sb3_action_mask.py
elliottower Jul 10, 2023
1a2d2ef
Split CI tests into separate actions (so they don't take 2 hours)
elliottower Jul 10, 2023
35addaa
Add separate requirements files for different sb3 tutorials
elliottower Jul 10, 2023
996274e
Fix workflow for tutorials to always install from root dir
elliottower Jul 10, 2023
c4834b5
Un-skip the rest of the action mask tests, as the longest one is pist…
elliottower Jul 10, 2023
1dfe96b
Remove pistonball env.close() line to avoid SuperSuit issue
elliottower Jul 10, 2023
5f65af0
Change multiwalker to waterworld (actually trains), remove pistonball…
elliottower Jul 11, 2023
7403e02
Add pymunk dependency to sisl waterworld (modulenotfound error)
elliottower Jul 11, 2023
a92b2a3
Add pymunk req
elliottower Jul 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add multiwalker tutorial for MLP example
elliottower committed Jul 10, 2023

Verified

This commit was signed with the committer’s verified signature.
elliottower Elliot Tower
commit 8cfd867292d4ec629e029869220b0ed7919ba754
13 changes: 10 additions & 3 deletions docs/tutorials/sb3/index.md
Original file line number Diff line number Diff line change
@@ -6,11 +6,17 @@ title: "Stable-Baselines3"

These tutorials show you how to use the [SB3](https://stable-baselines3.readthedocs.io/en/master/) library to train agents in PettingZoo environments.

* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in vectorized Parallel environment_
For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/)

* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train a PPO model in a vectorized AEC environment_
* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train agents using PPO in vectorized Parallel environment_

* [Action Masked PPO for Chess](/tutorials/sb3/connect_four/): _Train an action masked PPO model in an AEC environment_
* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized AEC environment_

For non-visual environments, we use [Actor Critic](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) or [Maskable Actor Critic](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html#maskableppo-policies) policies and do not perform any pre-processing steps.

* [PPO for Multiwalker](/tutorials/sb3/multiwalker/): _Train agents using PPO in a vectorized AEC environment_

* [Action Masked PPO for Connect Four](/tutorials/sb3/connect_four/): _Train an agent using Action Masked PPO in an AEC environment_


## Stable-Baselines Overview
@@ -33,5 +39,6 @@ Note: SB3 does not officially support PettingZoo, as it is designed for single-a

pistonball
kaz
multiwalker
connect_four
```
29 changes: 29 additions & 0 deletions docs/tutorials/sb3/multiwalker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
title: "SB3: PPO for Pistonball (Parallel)"
---

# SB3: PPO for Pistonball

This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Multiwalker](https://pettingzoo.farama.org/environments/sisl/multiwalker/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)).

Note: this environment uses a discrete 1-dimensional observation space, so we use an MLP extractor rather than CNN

After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information).


## Environment Setup
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/requirements.txt
:language: text
```

## Code
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX).

### Training and Evaluation

```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/sb3_multiwalker_vector.py
:language: python
```
2 changes: 1 addition & 1 deletion docs/tutorials/sb3/pistonball.md
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ title: "SB3: PPO for Pistonball (Parallel)"

# SB3: PPO for Pistonball

This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)).
This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([Parallel](https://pettingzoo.farama.org/api/parallel/)).

After training and evaluation, this script will launch a demo game human rendering. Trained models are saved and loaded from disk (see SB3's [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information).

2 changes: 1 addition & 1 deletion tutorials/SB3/sb3_connect_four_action_mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Uses Stable-Baselines3 to train agents to play Connect Four using invalid action masking.
"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.

For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking
For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html
2 changes: 1 addition & 1 deletion tutorials/SB3/sb3_kaz_vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Uses Stable-Baselines3 to train agents to play Knights-Archers-Zombies using SuperSuit vector envs.
"""Uses Stable-Baselines3 to train agents in the Knights-Archers-Zombies environment using SuperSuit vector envs.

This environment requires using SuperSuit's Black Death wrapper, to handle agent death.

109 changes: 109 additions & 0 deletions tutorials/SB3/sb3_multiwalker_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Uses Stable-Baselines3 to train agents to play the Multiwalker environment using SuperSuit vector envs.

For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

Author: Elliot (https://github.com/elliottower)
"""
from __future__ import annotations

import glob
import os
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

from pettingzoo.sisl import multiwalker_v9


def train_butterfly_supersuit(
env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
# Train a single agent to play both sides in a Parallel environment,
env = env_fn.parallel_env(**env_kwargs)

env.reset(seed=seed)

print(f"Starting training on {str(env.metadata['name'])}.")

env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")

# Note: Multiwalker's observation space is discrete, therefore we use an MLP policy rather than CNN
model = PPO(
MlpPolicy,
env,
verbose=3,
learning_rate=1e-3,
batch_size=256,
)

model.learn(total_timesteps=steps)

model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

print("Model has been saved.")

print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")

env.close()


def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)

print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
)

try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)

model = PPO.load(latest_policy)

rewards = {agent: 0 for agent in env.possible_agents}

# Note: We train using the Parallel API but evaluate using the AEC API
# SB3 models are designed for single-agent settings, we get around this by using he same model for every agent
for i in range(num_games):
env.reset(seed=i)

for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()

if termination or truncation:
for agent in env.agents:
rewards[agent] += env.rewards[agent]
break
else:
act = model.predict(obs, deterministic=True)[0]

env.step(act)
env.close()

avg_reward = sum(rewards.values()) / len(rewards.values())
print(f"Avg reward: {avg_reward}")
return avg_reward


if __name__ == "__main__":
env_fn = multiwalker_v9

env_kwargs = {}

# Train a model (takes ~3 minutes on a laptop CPU)
# Note: stochastic environment makes training difficult, hyperparameters have not been fully tuned for this example
train_butterfly_supersuit(env_fn, steps=49_152 * 4, seed=0, **env_kwargs)

# Evaluate 10 games (takes ~10 seconds on a laptop CPU)
eval(env_fn, num_games=10, render_mode=None, **env_kwargs)

# Watch 2 games (takes ~10 seconds on a laptop CPU)
eval(env_fn, num_games=2, render_mode="human", **env_kwargs)
2 changes: 1 addition & 1 deletion tutorials/SB3/sb3_pistonball_vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Uses Stable-Baselines3 to train agents to play PettingZoo Butterfly (cooprative) environments using SuperSuit vector envs.
"""Uses Stable-Baselines3 to train agents in the Pistonball environment using SuperSuit vector envs.

For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

2 changes: 1 addition & 1 deletion tutorials/SB3/test_sb3_action_mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Test file to ensure that action masking code works for all PettingZoo classic environments (except rps, which has no action mask)."""
"""Tests that action masking code works properly with all PettingZoo classic environments."""

try:
import pytest