Skip to content

Commit

Permalink
Add support for gymnasium v1.0 (#475)
Browse files Browse the repository at this point in the history
* Add support for gymnasium v1.0

* Update versions

* Fix requirements

* Ignore mypy for gym 0.29

* Add explicit shimmy dep

* Patch obs space and update trained agents

* Comment out auto-fix obs space

* Fix vecnormalize stats
  • Loading branch information
araffin authored Nov 5, 2024
1 parent b1288ed commit 208b6fd
Show file tree
Hide file tree
Showing 21 changed files with 91 additions and 40 deletions.
21 changes: 13 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
with:
Expand All @@ -32,22 +37,22 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1

- name: Lint with ruff
run: |
make lint
Expand Down
21 changes: 14 additions & 7 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
with:
Expand All @@ -36,19 +41,21 @@ jobs:
# Use uv for faster downloads
pip install uv
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1

- name: Check trained agents
run: |
make check-trained-agents
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
## Release 2.4.0a10 (WIP)
## Release 2.4.0a11 (WIP)

**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env**
**New algorithm: CrossQ, Gymnasium v1.0 support, and better defaults for SAC/TQC on Swimmer-v4 env**

### Breaking Changes
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2)
- Upgraded to SB3 >= 2.4.0
- Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters

### New Features
- Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen)
- Added Gymnasium v1.0 support
- `--custom-objects` in `enjoy.py` now also patches obs space (when bounds are changed) to solve "Observation spaces do not match" errors

### Bug fixes
- Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz)
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/a2c.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Pendulum-v1:
policy_kwargs: "dict(log_std_init=-2, ortho_init=False)"

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
normalize: true
n_envs: 4
n_timesteps: !!float 5e6
Expand Down
3 changes: 1 addition & 2 deletions hyperparams/ars.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ LunarLander-v2:
n_timesteps: !!float 2e6

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
<<: *pendulum-params
n_timesteps: !!float 2e6

Expand Down Expand Up @@ -215,4 +215,3 @@ A1Jumping-v0:
# alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"
# policy_kwargs: "dict(net_arch=[16])"

2 changes: 1 addition & 1 deletion hyperparams/crossq.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Pendulum-v1:
policy_kwargs: "dict(net_arch=[256, 256])"


LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 1000000
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ddpg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3
policy_kwargs: "dict(net_arch=[400, 300])"

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ LunarLander-v2:
n_epochs: 4
ent_coef: 0.01

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_envs: 16
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/sac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3


LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/td3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3
policy_kwargs: "dict(net_arch=[400, 300])"

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/tqc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Pendulum-v1:
policy: 'MlpPolicy'
learning_rate: !!float 1e-3

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/trpo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LunarLander-v2:
n_critic_updates: 15

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
normalize: true
n_envs: 2
n_timesteps: !!float 1e5
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0
stable-baselines3[extra,tests,docs]>=2.4.0a11,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.4.0
pybullet_envs_gymnasium>=0.5.0
# minigrid
cloudpickle>=2.2.1
# optuna plots:
plotly
# need to upgrade to gymnasium:
# panda-gym~=3.0.1
wandb
moviepy
moviepy>=1.0.0
11 changes: 11 additions & 0 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,23 @@ def enjoy() -> None: # noqa: C901
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
# load models with different obs bounds
# Note: doesn't work with channel last envs
# "observation_space": env.observation_space,
}

if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""):
kwargs["env"] = env

model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs)
# Uncomment to save patched file (for instance gym -> gymnasium)
# model.save(model_path)
# Patch VecNormalize (gym -> gymnasium)
# from pathlib import Path
# env.observation_space = model.observation_space
# env.action_space = model.action_space
# env.save(Path(model_path).parent / env_name / "vecnormalize.pkl")

obs = env.reset()

# Deterministic by default except for atari games
Expand Down
5 changes: 4 additions & 1 deletion rl_zoo3/gym_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,8 @@ def step(self, action):

# Patch Gymnasium TimeLimit
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
try:
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
except AttributeError:
gymnasium.wrappers.common.TimeLimit = PatchedTimeLimit # type: ignore
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined]
10 changes: 9 additions & 1 deletion rl_zoo3/import_envs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Optional

import gymnasium as gym
from gymnasium.envs.registration import register
from gymnasium.envs.registration import register, register_envs

from rl_zoo3.wrappers import MaskVelocityWrapper

Expand All @@ -10,6 +10,14 @@
except ImportError:
pass

try:
import ale_py

# no-op
gym.register_envs(ale_py)
except ImportError:
pass

try:
import highway_env
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a10
2.4.0a11
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.4.0a10,<3.0",
"gymnasium~=0.29.1",
"sb3_contrib>=2.4.0a11,<3.0",
"gymnasium>=0.29.1,<1.1.0",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
"optuna>=3.0",
"pyyaml>=5.1",
"pytablewriter~=1.2",
"shimmy~=2.0",
]
plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"]
test_requires = [
# for MuJoCo envs v4:
"mujoco~=2.3",
"mujoco>=2.3,<4",
# install parking-env to test HER
"highway-env==1.8.2",
"highway-env>=1.10.1,<1.11.0",
]

setup(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shlex
import subprocess
from importlib.metadata import version

import pytest

Expand Down Expand Up @@ -40,6 +41,11 @@ def test_trained_agents(trained_model):
if "Panda" in env_id:
return

# TODO: rename trained agents once we drop support for gymnasium v0.29
if "Lander" in env_id and version("gymnasium") > "0.29.1":
# LunarLander-v2 is now LunarLander-v3
return

# Skip mujoco envs
if "Fetch" in trained_model or "-v3" in trained_model:
return
Expand Down
12 changes: 10 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shlex
import subprocess
from importlib.metadata import version

import pytest

Expand Down Expand Up @@ -36,10 +37,17 @@ def test_train(tmp_path, experiment):


def test_continue_training(tmp_path):
algo, env_id = "a2c", "CartPole-v1"
algo = "a2c"
if version("gymnasium") > "0.29.1":
# See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341
# obs bounds have changed...
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"

cmd = (
f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} "
"-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip"
f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip"
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)
Expand Down

0 comments on commit 208b6fd

Please sign in to comment.