From 208b6fdf3a9aa24de44b2e068c6e5f6a07d97fcc Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 5 Nov 2024 21:47:24 +0100 Subject: [PATCH] Add support for gymnasium v1.0 (#475) * Add support for gymnasium v1.0 * Update versions * Fix requirements * Ignore mypy for gym 0.29 * Add explicit shimmy dep * Patch obs space and update trained agents * Comment out auto-fix obs space * Fix vecnormalize stats --- .github/workflows/ci.yml | 21 +++++++++++++-------- .github/workflows/trained_agents.yml | 21 ++++++++++++++------- CHANGELOG.md | 7 +++++-- hyperparams/a2c.yml | 2 +- hyperparams/ars.yml | 3 +-- hyperparams/crossq.yml | 2 +- hyperparams/ddpg.yml | 2 +- hyperparams/ppo.yml | 2 +- hyperparams/sac.yml | 2 +- hyperparams/td3.yml | 2 +- hyperparams/tqc.yml | 2 +- hyperparams/trpo.yml | 2 +- requirements.txt | 6 +++--- rl-trained-agents | 2 +- rl_zoo3/enjoy.py | 11 +++++++++++ rl_zoo3/gym_patches.py | 5 ++++- rl_zoo3/import_envs.py | 10 +++++++++- rl_zoo3/version.txt | 2 +- setup.py | 9 +++++---- tests/test_enjoy.py | 6 ++++++ tests/test_train.py | 12 ++++++++++-- 21 files changed, 91 insertions(+), 40 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0f55b3749..528e7fa2c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,12 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 with: @@ -32,15 +37,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - # Use uv for faster downloads pip install uv - # Install Atari Roms - uv pip install --system autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - + # cpu version of pytorch # See https://github.com/astral-sh/uv/issues/1497 uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu # Install full requirements (for additional envs and test tools) @@ -48,6 +47,12 @@ jobs: # Use headless version uv pip install --system opencv-python-headless uv pip install --system -e .[plots,tests] + + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + # Only run for python 3.10, downgrade gym to 0.29.1 + - name: Lint with ruff run: | make lint diff --git a/.github/workflows/trained_agents.yml b/.github/workflows/trained_agents.yml index f73c3b342..8199ca671 100644 --- a/.github/workflows/trained_agents.yml +++ b/.github/workflows/trained_agents.yml @@ -21,7 +21,12 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 with: @@ -36,19 +41,21 @@ jobs: # Use uv for faster downloads pip install uv - # Install Atari Roms - uv pip install --system autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - + # cpu version of pytorch # See https://github.com/astral-sh/uv/issues/1497 uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu # Install full requirements (for additional envs and test tools) + # Install full requirements (for additional envs and test tools) uv pip install --system -r requirements.txt # Use headless version uv pip install --system opencv-python-headless uv pip install --system -e .[plots,tests] + + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + # Only run for python 3.10, downgrade gym to 0.29.1 + - name: Check trained agents run: | make check-trained-agents diff --git a/CHANGELOG.md b/CHANGELOG.md index 88f078b36..95f56c118 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,16 @@ -## Release 2.4.0a10 (WIP) +## Release 2.4.0a11 (WIP) -**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env** +**New algorithm: CrossQ, Gymnasium v1.0 support, and better defaults for SAC/TQC on Swimmer-v4 env** ### Breaking Changes - Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2) - Upgraded to SB3 >= 2.4.0 +- Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters ### New Features - Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen) +- Added Gymnasium v1.0 support +- `--custom-objects` in `enjoy.py` now also patches obs space (when bounds are changed) to solve "Observation spaces do not match" errors ### Bug fixes - Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz) diff --git a/hyperparams/a2c.yml b/hyperparams/a2c.yml index 02e8c780b..ea6fb71f7 100644 --- a/hyperparams/a2c.yml +++ b/hyperparams/a2c.yml @@ -61,7 +61,7 @@ Pendulum-v1: policy_kwargs: "dict(log_std_init=-2, ortho_init=False)" # Tuned -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: normalize: true n_envs: 4 n_timesteps: !!float 5e6 diff --git a/hyperparams/ars.yml b/hyperparams/ars.yml index 9e365cf35..d404d8c3a 100644 --- a/hyperparams/ars.yml +++ b/hyperparams/ars.yml @@ -26,7 +26,7 @@ LunarLander-v2: n_timesteps: !!float 2e6 # Tuned -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: <<: *pendulum-params n_timesteps: !!float 2e6 @@ -215,4 +215,3 @@ A1Jumping-v0: # alive_bonus_offset: -1 normalize: "dict(norm_obs=True, norm_reward=False)" # policy_kwargs: "dict(net_arch=[16])" - diff --git a/hyperparams/crossq.yml b/hyperparams/crossq.yml index 12284efb2..93a14f7c3 100644 --- a/hyperparams/crossq.yml +++ b/hyperparams/crossq.yml @@ -18,7 +18,7 @@ Pendulum-v1: policy_kwargs: "dict(net_arch=[256, 256])" -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_timesteps: !!float 2e5 policy: 'MlpPolicy' buffer_size: 1000000 diff --git a/hyperparams/ddpg.yml b/hyperparams/ddpg.yml index bb78fdae1..38862e23d 100644 --- a/hyperparams/ddpg.yml +++ b/hyperparams/ddpg.yml @@ -23,7 +23,7 @@ Pendulum-v1: learning_rate: !!float 1e-3 policy_kwargs: "dict(net_arch=[400, 300])" -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_timesteps: !!float 3e5 policy: 'MlpPolicy' gamma: 0.98 diff --git a/hyperparams/ppo.yml b/hyperparams/ppo.yml index 9339eea8e..138fd4fd8 100644 --- a/hyperparams/ppo.yml +++ b/hyperparams/ppo.yml @@ -122,7 +122,7 @@ LunarLander-v2: n_epochs: 4 ent_coef: 0.01 -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_envs: 16 n_timesteps: !!float 1e6 policy: 'MlpPolicy' diff --git a/hyperparams/sac.yml b/hyperparams/sac.yml index 88cf7d1f9..d6c235e9b 100644 --- a/hyperparams/sac.yml +++ b/hyperparams/sac.yml @@ -22,7 +22,7 @@ Pendulum-v1: learning_rate: !!float 1e-3 -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_timesteps: !!float 5e5 policy: 'MlpPolicy' batch_size: 256 diff --git a/hyperparams/td3.yml b/hyperparams/td3.yml index 068a6ce4c..a27b465da 100644 --- a/hyperparams/td3.yml +++ b/hyperparams/td3.yml @@ -23,7 +23,7 @@ Pendulum-v1: learning_rate: !!float 1e-3 policy_kwargs: "dict(net_arch=[400, 300])" -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_timesteps: !!float 3e5 policy: 'MlpPolicy' gamma: 0.98 diff --git a/hyperparams/tqc.yml b/hyperparams/tqc.yml index 64f55915a..055e29f8a 100644 --- a/hyperparams/tqc.yml +++ b/hyperparams/tqc.yml @@ -19,7 +19,7 @@ Pendulum-v1: policy: 'MlpPolicy' learning_rate: !!float 1e-3 -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: n_timesteps: !!float 5e5 policy: 'MlpPolicy' learning_rate: lin_7.3e-4 diff --git a/hyperparams/trpo.yml b/hyperparams/trpo.yml index dbf49b89b..d75598a0d 100644 --- a/hyperparams/trpo.yml +++ b/hyperparams/trpo.yml @@ -35,7 +35,7 @@ LunarLander-v2: n_critic_updates: 15 # Tuned -LunarLanderContinuous-v2: +LunarLanderContinuous-v3: normalize: true n_envs: 2 n_timesteps: !!float 1e5 diff --git a/requirements.txt b/requirements.txt index 6d4a36009..cda9d4521 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ gym==0.26.2 -stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0 +stable-baselines3[extra,tests,docs]>=2.4.0a11,<3.0 box2d-py==2.3.8 -pybullet_envs_gymnasium>=0.4.0 +pybullet_envs_gymnasium>=0.5.0 # minigrid cloudpickle>=2.2.1 # optuna plots: @@ -9,4 +9,4 @@ plotly # need to upgrade to gymnasium: # panda-gym~=3.0.1 wandb -moviepy +moviepy>=1.0.0 diff --git a/rl-trained-agents b/rl-trained-agents index ca4371d8e..cd35bde61 160000 --- a/rl-trained-agents +++ b/rl-trained-agents @@ -1 +1 @@ -Subproject commit ca4371d8eef7c2eece81461f3d138d23743b2296 +Subproject commit cd35bde610f4045bf2e0731c8f4c88d22df8fc85 diff --git a/rl_zoo3/enjoy.py b/rl_zoo3/enjoy.py index 4cb717a7d..86225650a 100644 --- a/rl_zoo3/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -184,12 +184,23 @@ def enjoy() -> None: # noqa: C901 "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, + # load models with different obs bounds + # Note: doesn't work with channel last envs + # "observation_space": env.observation_space, } if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""): kwargs["env"] = env model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs) + # Uncomment to save patched file (for instance gym -> gymnasium) + # model.save(model_path) + # Patch VecNormalize (gym -> gymnasium) + # from pathlib import Path + # env.observation_space = model.observation_space + # env.action_space = model.action_space + # env.save(Path(model_path).parent / env_name / "vecnormalize.pkl") + obs = env.reset() # Deterministic by default except for atari games diff --git a/rl_zoo3/gym_patches.py b/rl_zoo3/gym_patches.py index 5011d1d29..a95a3a0d9 100644 --- a/rl_zoo3/gym_patches.py +++ b/rl_zoo3/gym_patches.py @@ -39,5 +39,8 @@ def step(self, action): # Patch Gymnasium TimeLimit gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc] -gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc] +try: + gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc] +except AttributeError: + gymnasium.wrappers.common.TimeLimit = PatchedTimeLimit # type: ignore gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined] diff --git a/rl_zoo3/import_envs.py b/rl_zoo3/import_envs.py index 7d2b40447..f8a3599b6 100644 --- a/rl_zoo3/import_envs.py +++ b/rl_zoo3/import_envs.py @@ -1,7 +1,7 @@ from typing import Callable, Optional import gymnasium as gym -from gymnasium.envs.registration import register +from gymnasium.envs.registration import register, register_envs from rl_zoo3.wrappers import MaskVelocityWrapper @@ -10,6 +10,14 @@ except ImportError: pass +try: + import ale_py + + # no-op + gym.register_envs(ale_py) +except ImportError: + pass + try: import highway_env except ImportError: diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index 852a32b3f..d5cafdb5a 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -2.4.0a10 +2.4.0a11 diff --git a/setup.py b/setup.py index 336043552..6699e0e04 100644 --- a/setup.py +++ b/setup.py @@ -15,21 +15,22 @@ See https://github.com/DLR-RM/rl-baselines3-zoo """ install_requires = [ - "sb3_contrib>=2.4.0a10,<3.0", - "gymnasium~=0.29.1", + "sb3_contrib>=2.4.0a11,<3.0", + "gymnasium>=0.29.1,<1.1.0", "huggingface_sb3>=3.0,<4.0", "tqdm", "rich", "optuna>=3.0", "pyyaml>=5.1", "pytablewriter~=1.2", + "shimmy~=2.0", ] plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"] test_requires = [ # for MuJoCo envs v4: - "mujoco~=2.3", + "mujoco>=2.3,<4", # install parking-env to test HER - "highway-env==1.8.2", + "highway-env>=1.10.1,<1.11.0", ] setup( diff --git a/tests/test_enjoy.py b/tests/test_enjoy.py index 629f6e031..d3f0e7244 100644 --- a/tests/test_enjoy.py +++ b/tests/test_enjoy.py @@ -1,6 +1,7 @@ import os import shlex import subprocess +from importlib.metadata import version import pytest @@ -40,6 +41,11 @@ def test_trained_agents(trained_model): if "Panda" in env_id: return + # TODO: rename trained agents once we drop support for gymnasium v0.29 + if "Lander" in env_id and version("gymnasium") > "0.29.1": + # LunarLander-v2 is now LunarLander-v3 + return + # Skip mujoco envs if "Fetch" in trained_model or "-v3" in trained_model: return diff --git a/tests/test_train.py b/tests/test_train.py index d0780acc1..0894cf669 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,6 +1,7 @@ import os import shlex import subprocess +from importlib.metadata import version import pytest @@ -36,10 +37,17 @@ def test_train(tmp_path, experiment): def test_continue_training(tmp_path): - algo, env_id = "a2c", "CartPole-v1" + algo = "a2c" + if version("gymnasium") > "0.29.1": + # See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341 + # obs bounds have changed... + env_id = "CartPole-v1" + else: + env_id = "Pendulum-v1" + cmd = ( f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " - "-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip" + f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip" ) return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0)