From 40e0b9d2c88efed9713cb3bca97c7c9893922e19 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 14 Apr 2023 13:13:59 +0200 Subject: [PATCH] Add Gymnasium support (#1327) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips Co-authored-by: tlpss Co-authored-by: Quentin GALLOUÉDEC --- .github/ISSUE_TEMPLATE/custom_env.yml | 9 +- .github/workflows/ci.yml | 2 + Dockerfile | 38 ++---- Makefile | 6 + docs/conda_env.yml | 6 +- docs/guide/callbacks.rst | 10 +- docs/guide/checking_nan.rst | 6 +- docs/guide/custom_env.rst | 8 +- docs/guide/custom_policy.rst | 8 +- docs/guide/examples.rst | 106 ++++++++------- docs/guide/install.rst | 19 +-- docs/guide/integrations.rst | 4 +- docs/guide/quickstart.rst | 8 +- docs/guide/tensorboard.rst | 2 +- docs/guide/vec_envs.rst | 52 ++++++++ docs/misc/changelog.rst | 36 ++++- docs/modules/a2c.rst | 2 +- docs/modules/ddpg.rst | 2 +- docs/modules/dqn.rst | 2 +- docs/modules/ppo.rst | 2 +- docs/modules/sac.rst | 2 +- docs/modules/td3.rst | 2 +- pyproject.toml | 18 +-- scripts/build_docker.sh | 6 +- scripts/run_docker_cpu.sh | 4 +- scripts/run_docker_gpu.sh | 4 +- setup.py | 17 ++- stable_baselines3/__init__.py | 6 - stable_baselines3/a2c/a2c.py | 2 +- stable_baselines3/common/atari_wrappers.py | 115 ++++++++-------- stable_baselines3/common/base_class.py | 86 +++++++----- stable_baselines3/common/buffers.py | 13 +- stable_baselines3/common/callbacks.py | 4 +- stable_baselines3/common/distributions.py | 2 +- stable_baselines3/common/env_checker.py | 121 +++++++++++------ stable_baselines3/common/env_util.py | 39 ++++-- .../common/envs/bit_flipping_env.py | 32 +++-- stable_baselines3/common/envs/identity_env.py | 35 +++-- .../common/envs/multi_input_envs.py | 18 ++- stable_baselines3/common/evaluation.py | 11 +- stable_baselines3/common/monitor.py | 27 ++-- .../common/off_policy_algorithm.py | 9 +- .../common/on_policy_algorithm.py | 36 +++-- stable_baselines3/common/policies.py | 11 +- stable_baselines3/common/preprocessing.py | 13 +- stable_baselines3/common/save_util.py | 2 +- stable_baselines3/common/torch_layers.py | 10 +- stable_baselines3/common/type_aliases.py | 11 +- stable_baselines3/common/utils.py | 24 +++- .../common/vec_env/base_vec_env.py | 91 ++++++++++--- .../common/vec_env/dummy_vec_env.py | 49 ++++--- stable_baselines3/common/vec_env/patch_gym.py | 100 ++++++++++++++ .../common/vec_env/stacked_observations.py | 44 ++---- .../common/vec_env/subproc_vec_env.py | 57 +++++--- stable_baselines3/common/vec_env/util.py | 2 +- .../common/vec_env/vec_check_nan.py | 2 +- .../common/vec_env/vec_frame_stack.py | 5 +- .../common/vec_env/vec_normalize.py | 2 +- .../common/vec_env/vec_transpose.py | 2 +- .../common/vec_env/vec_video_recorder.py | 5 +- stable_baselines3/dqn/dqn.py | 15 +-- stable_baselines3/dqn/policies.py | 21 +-- stable_baselines3/her/her_replay_buffer.py | 11 +- stable_baselines3/ppo/ppo.py | 6 +- stable_baselines3/sac/policies.py | 38 ++++-- stable_baselines3/sac/sac.py | 25 ++-- stable_baselines3/td3/policies.py | 29 ++-- stable_baselines3/td3/td3.py | 16 ++- stable_baselines3/version.txt | 2 +- tests/test_buffers.py | 30 +++-- tests/test_callbacks.py | 2 +- tests/test_cnn.py | 6 +- tests/test_dict_env.py | 39 ++++-- tests/test_distributions.py | 2 +- tests/test_env_checker.py | 24 ++-- tests/test_envs.py | 75 ++++++++--- tests/test_gae.py | 38 ++++-- tests/test_her.py | 2 +- tests/test_identity.py | 17 ++- tests/test_logger.py | 15 ++- tests/test_monitor.py | 20 +-- tests/test_predict.py | 21 ++- tests/test_preprocessing.py | 2 +- tests/test_run.py | 2 +- tests/test_save_load.py | 2 +- tests/test_spaces.py | 33 +++-- tests/test_train_eval_mode.py | 2 +- tests/test_utils.py | 21 +-- tests/test_vec_check_nan.py | 10 +- tests/test_vec_envs.py | 125 ++++++++++++++++-- tests/test_vec_extract_dict_obs.py | 4 +- tests/test_vec_monitor.py | 6 +- tests/test_vec_normalize.py | 36 +++-- tests/test_vec_stacked_obs.py | 2 +- 94 files changed, 1333 insertions(+), 733 deletions(-) create mode 100644 stable_baselines3/common/vec_env/patch_gym.py diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml index cf624c03b..f90210858 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.yml +++ b/.github/ISSUE_TEMPLATE/custom_env.yml @@ -49,15 +49,16 @@ body: self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) - def reset(self): - return self.observation_space.sample() + def reset(self, seed=None): + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 - done = False + terminated = False + truncated = False info = {} - return obs, reward, done, info + return obs, reward, terminated, truncated, info env = CustomEnv() check_env(env) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 52f5f3895..7c238cdd4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,6 +55,8 @@ jobs: - name: Type check run: | make type + # skip mypy type check for python3.7 (result is different to all other versions) + if: "!(matrix.python-version == '3.7')" - name: Test with pytest run: | make pytest diff --git a/Dockerfile b/Dockerfile index 8dfbbbf4c..421324dff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,41 +1,25 @@ ARG PARENT_IMAGE FROM $PARENT_IMAGE ARG PYTORCH_DEPS=cpuonly -ARG PYTHON_VERSION=3.7 +ARG PYTHON_VERSION=3.8 +ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - curl \ - ca-certificates \ - libjpeg-dev \ - libpng-dev \ - libglib2.0-0 && \ - rm -rf /var/lib/apt/lists/* +# Install micromamba env and dependencies +RUN micromamba install -n base -y python=$PYTHON_VERSION \ + pytorch $PYTORCH_DEPS -c conda-forge -c pytorch -c nvidia && \ + micromamba clean --all --yes -# Install Anaconda and dependencies -RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ - /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \ - /opt/conda/bin/conda clean -ya -ENV PATH /opt/conda/bin:$PATH - -ENV CODE_DIR /root/code +ENV CODE_DIR /home/$MAMBA_USER # Copy setup file only to install dependencies -COPY ./setup.py ${CODE_DIR}/stable-baselines3/setup.py -COPY ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt +COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py +COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt -RUN \ - cd ${CODE_DIR}/stable-baselines3 3&& \ +RUN cd ${CODE_DIR}/stable-baselines3 && \ pip install -e .[extra,tests,docs] && \ # Use headless version for docker pip uninstall -y opencv-python && \ pip install opencv-python-headless && \ - rm -rf $HOME/.cache/pip + pip cache purge CMD /bin/bash diff --git a/Makefile b/Makefile index 29ac5e70e..4f477d066 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,12 @@ pytype: mypy: mypy ${LINT_PATHS} +missing-annotations: + mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3 + +# missing docstrings +# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4 + type: pytype mypy lint: diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 98a550820..0545eef3c 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,11 +4,11 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=21.1 + - pip=22.3.1 - python=3.7 - - pytorch=1.11=py3.7_cpu_0 + - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.21 + - gymnasium - cloudpickle - opencv-python-headless - pandas diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 632743bbf..e08c00367 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -210,7 +210,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback @@ -260,7 +260,7 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback @@ -290,7 +290,7 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold @@ -322,7 +322,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps @@ -379,7 +379,7 @@ It must be used with the :ref:`EvalCallback` and use the event triggered after e .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst index ef3762c41..8d0de36e8 100644 --- a/docs/guide/checking_nan.rst +++ b/docs/guide/checking_nan.rst @@ -100,8 +100,8 @@ It will monitor the actions, observations, and rewards, indicating what action o .. code-block:: python - import gym - from gym import spaces + import gymnasium as gym + from gymnasium import spaces import numpy as np from stable_baselines3 import PPO @@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o def reset(self): return [0.0] - def render(self, mode="human", close=False): + def render(self, close=False): pass # Create environment diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index d2878c376..822c8215e 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -26,9 +26,9 @@ That is to say, your environment must implement the following methods (and inher .. code-block:: python - import gym + import gymnasium as gym import numpy as np - from gym import spaces + from gymnasium import spaces class CustomEnv(gym.Env): @@ -54,7 +54,7 @@ That is to say, your environment must implement the following methods (and inher ... return observation # reward, done, info can't be included - def render(self, mode="human"): + def render(self): ... def close(self): @@ -91,7 +91,7 @@ Optionally, you can also register the environment with gym, that will allow you .. code-block:: python - from gym.envs.registration import register + from gymnasium.envs.registration import register # Example for the CartPole environment register( # unique identifier for the env `name-version` diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 80f69396a..af7d9ef0a 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -101,7 +101,7 @@ using ``policy_kwargs`` parameter: .. code-block:: python - import gym + import gymnasium as gym import torch as th from stable_baselines3 import PPO @@ -143,7 +143,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t import torch as th import torch.nn as nn - from gym import spaces + from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.torch_layers import BaseFeaturesExtractor @@ -208,7 +208,7 @@ downsampling and "vector" with a single linear layer. .. code-block:: python - import gym + import gymnasium as gym import torch as th from torch import nn @@ -308,7 +308,7 @@ If your task requires even more granular control over the policy/value architect from typing import Callable, Dict, List, Optional, Tuple, Type, Union - from gym import spaces + from gymnasium import spaces import torch as th from torch import nn diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a3f1dc6f0..5935f504a 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -64,7 +64,7 @@ In the following example, we will train, save and load a DQN model on the Lunar .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import DQN from stable_baselines3.common.evaluation import evaluate_policy @@ -115,7 +115,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import PPO @@ -123,18 +123,18 @@ Multiprocessing: Unleashing the Power of Vectorized Environments from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.utils import set_random_seed - def make_env(env_id, rank, seed=0): + def make_env(env_id: str, rank: int, seed: int = 0): """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the inital seed for RNG - :param rank: (int) index of the subprocess + :param env_id: the environment ID + :param num_env: the number of environments you wish to have in subprocesses + :param seed: the inital seed for RNG + :param rank: index of the subprocess """ def _init(): - env = gym.make(env_id) - env.seed(seed + rank) + env = gym.make(env_id, render_mode="human") + env.reset(seed=seed + rank) return env set_random_seed(seed) return _init @@ -143,21 +143,21 @@ Multiprocessing: Unleashing the Power of Vectorized Environments env_id = "CartPole-v1" num_cpu = 4 # Number of processes to use # Create the vectorized environment - env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) + vec_env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) # Stable Baselines provides you with make_vec_env() helper # which does exactly the previous steps for you. # You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv` # env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv) - model = PPO("MlpPolicy", env, verbose=1) + model = PPO("MlpPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) - obs = env.reset() + obs = vec_env.reset() for _ in range(1000): action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() + obs, rewards, dones, info = vec_env.step(action) + vec_env.render() Multiprocessing with off-policy algorithms @@ -173,17 +173,17 @@ Multiprocessing with off-policy algorithms .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) + vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) # We collect 4 transitions per call to `ènv.step()` # and performs 2 gradient steps per call to `ènv.step()` # if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()` - model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1) + model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1) model.learn(total_timesteps=10_000) @@ -229,7 +229,7 @@ If your callback returns False, training is aborted early. import os - import gym + import gymnasium as gym import numpy as np import matplotlib.pyplot as plt @@ -337,18 +337,18 @@ and multiprocessing for you. To install the Atari environments, run the command # There already exists an environment generator # that will make and wrap atari environments correctly. # Here we are also multi-worker training (n_envs=4 => 4 environments) - env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0) + vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0) # Frame-stacking with 4 frames - env = VecFrameStack(env, n_stack=4) + vec_env = VecFrameStack(vec_env, n_stack=4) - model = A2C("CnnPolicy", env, verbose=1) + model = A2C("CnnPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) - obs = env.reset() + obs = vec_env.reset() while True: - action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() + action, _states = model.predict(obs, deterministic=False) + obs, rewards, dones, info = vec_env.step(action) + vec_env.render("human") PyBullet: Normalizing input features @@ -372,18 +372,22 @@ will compute a running average and standard deviation of input features (it can .. code-block:: python import os - import gym + import gymnasium as gym import pybullet_envs from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from stable_baselines3 import PPO - env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) + # Note: pybullet is not compatible yet with Gymnasium + # you might need to use `import rl_zoo3.gym_patches` + # and use gym (not Gymnasium) to instanciate the env + # Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4" + vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) # Automatically normalize the input features and reward - env = VecNormalize(env, norm_obs=True, norm_reward=True, + vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.) - model = PPO("MlpPolicy", env) + model = PPO("MlpPolicy", vec_env) model.learn(total_timesteps=2000) # Don't forget to save the VecNormalize statistics when saving the agent @@ -393,18 +397,18 @@ will compute a running average and standard deviation of input features (it can env.save(stats_path) # To demonstrate loading - del model, env + del model, vec_env # Load the saved statistics - env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) - env = VecNormalize.load(stats_path, env) + vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) + vec_env = VecNormalize.load(stats_path, vec_env) # do not update them at test time - env.training = False + vec_env.training = False # reward normalization is not needed at test time - env.norm_reward = False + vec_env.norm_reward = False # Load the agent - model = PPO.load(log_dir + "ppo_halfcheetah", env=env) + model = PPO.load(log_dir + "ppo_halfcheetah", env=vec_env) Hindsight Experience Replay (HER) @@ -430,7 +434,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi .. code-block:: python - import gym + import gymnasium as gym import highway_env import numpy as np @@ -467,19 +471,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi # HER must be loaded with the env model = SAC.load("her_sac_highway", env=env) - obs = env.reset() + obs, info = env.reset() # Evaluate the agent episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() episode_reward += reward - if done or info.get("is_success", False): + if terminated or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 - obs = env.reset() + obs, info = env.reset() Learning Rate Schedule @@ -621,7 +625,7 @@ A2C policy gradient updates on the model. from typing import Dict - import gym + import gymnasium as gym import numpy as np import torch as th @@ -662,7 +666,7 @@ A2C policy gradient updates on the model. # Keep top 10% n_elite = pop_size // 10 # Retrieve the environment - env = model.get_env() + vec_env = model.get_env() for iteration in range(10): # Create population of candidates and evaluate them @@ -674,7 +678,7 @@ A2C policy gradient updates on the model. # we give it (policy parameters) model.policy.load_state_dict(candidate, strict=False) # Evaluate the candidate - fitness, _ = evaluate_policy(model, env) + fitness, _ = evaluate_policy(model, vec_env) population.append((candidate, fitness)) # Take top 10% and use average over their parameters as next mean parameter top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite] @@ -738,28 +742,28 @@ Record a mp4 video (here using a random agent). .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv env_id = "CartPole-v1" video_folder = "logs/videos/" video_length = 100 - env = DummyVecEnv([lambda: gym.make(env_id)]) + vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")]) - obs = env.reset() + obs = vec_env.reset() # Record the video starting at the first step - env = VecVideoRecorder(env, video_folder, + vec_env = VecVideoRecorder(vec_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=video_length, name_prefix=f"random-agent-{env_id}") - env.reset() + vec_env.reset() for _ in range(video_length + 1): - action = [env.action_space.sample()] - obs, _, _, _ = env.step(action) + action = [vec_env.action_space.sample()] + obs, _, _, _ = vec_env.step(action) # Save the video - env.close() + vec_env.close() Bonus: Make a GIF of a Trained Agent diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 312b86fcb..68f8f764f 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -54,19 +54,6 @@ Bleeding-edge version pip install git+https://github.com/DLR-RM/stable-baselines3 -.. note:: - - If you want to use Gymnasium (or the latest Gym version 0.24+), you have to use - - .. code-block:: bash - - pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support - - - See `PR #1327 `_ for more information. - - Development version ------------------- @@ -131,7 +118,7 @@ Run the nvidia-docker GPU image .. code-block:: bash - docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' + docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' Or, with the shell file: @@ -143,7 +130,7 @@ Run the docker CPU image .. code-block:: bash - docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' + docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' Or, with the shell file: @@ -165,7 +152,7 @@ Explanation of the docker command: - ``--name test`` give explicitly the name ``test`` to the container, otherwise it will be assigned a random name - ``--mount src=...`` give access of the local directory (``pwd`` - command) to the container (it will be map to ``/root/code/stable-baselines``), so + command) to the container (it will be map to ``/home/mamba/stable-baselines``), so all the logs created in the container in this folder will be kept - ``bash -c '...'`` Run command inside the docker image, here run the tests (``pytest tests/``) diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 49bbdb248..14573cdec 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -13,7 +13,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati .. code-block:: python - import gym + import gymnasium as gym import wandb from wandb.integration.sb3 import WandbCallback @@ -86,7 +86,7 @@ For instance ``sb3/demo-hf-CartPole-v1``: .. code-block:: python - import gym + import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 5d1055ac9..b22ac54da 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -4,13 +4,19 @@ Getting Started =============== +.. note:: + + Stable-Baselines3 (SB3) uses :ref:`vectorized environments (VecEnv) ` internally. + Please read the associated section to learn more about its features and differences compared to a single Gym environment. + + Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. Here is a quick example of how to train and run A2C on a CartPole environment: .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import A2C diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index 2699d4a86..720c3ded2 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -190,7 +190,7 @@ Here is an example of how to render an episode and log the resulting video to Te from typing import Any, Dict - import gym + import gymnasium as gym import torch as th from stable_baselines3 import A2C diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index d84781122..ea99444d1 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -44,6 +44,58 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ For more information, see Python's `multiprocessing guidelines `_. +VecEnv API vs Gym API +--------------------- + +For consistency across Stable-Baselines3 (SB3) versions and because of its special requirements and features, +SB3 VecEnv API is not the same as Gym API. +SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: + +- the ``reset()`` method only returns the observation (``obs = vec_env.reset()``) and not a tuple, the info at reset are stored in ``vec_env.reset_infos``. + +- only the initial call to ``vec_env.reset()`` is required, environments are reset automatically afterward (and ``reset_infos`` is updated automatically). + +- the ``vec_env.step(actions)`` method expects an array as input + (with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): ``obs, rewards, dones, infos`` instead of ``obs, reward, terminated, truncated, info`` + where ``dones = terminated or truncated`` (for each env). + ``obs, rewards, dones`` are numpy arrays with shape ``(n_envs, shape_for_single_env)`` (so with a batch dimension). + Additional information is passed via the ``infos`` value which is a list of dictionaries. + +- at the end of an episode, ``infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated`` + tells the user if an episode was truncated or not: + you should bootstrap if ``infos[env_idx]["TimeLimit.truncated"] is True`` (episode over due to a timeout/truncation) + or ``dones[env_idx] is False`` (episode not finished). + Note: compared to Gym 0.26+ ``infos[env_idx]["TimeLimit.truncated"]`` and ``terminated`` `are mutually exclusive `_. + The conversion from SB3 to Gym API is + + .. code-block:: python + + # done is True at the end of an episode + # dones[env_idx] = terminated[env_idx] or truncated[env_idx] + # In SB3, truncated and terminated are mutually exclusive + # infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + # terminated[env_idx] tells you whether you should bootstrap or not: + # when the episode has not ended or when the termination was a timeout/truncation + terminated[env_idx] = dones[env_idx] and not infos[env_idx]["TimeLimit.truncated"] + should_bootstrap[env_idx] = not terminated[env_idx] + + +- at the end of an episode, because the environment resets automatically, + we provide ``infos[env_idx]["terminal_observation"]`` which contains the last observation + of an episode (and can be used when bootstrapping, see note in the previous section) + +- to overcome the current Gymnasium limitation (only one render mode allowed per env instance, see `issue #100 `_), + we recommend using ``render_mode="rgb_array"`` since we can both have the image as a numpy array and display it with OpenCV. + if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display. + Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). + +- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator, + you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward. + +- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, + ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. + + Vectorized Environments Wrappers -------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4ae5e471d..5f2d0b760 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,16 +3,29 @@ Changelog ========== -Release 1.8.1a0 (WIP) +Release 2.0.0a4 (WIP) -------------------------- +**Gymnasium support** + +.. warning:: + + Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). + We highly recommended you to upgrade to Python >= 3.8. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss) +- The deprecated ``online_sampling`` argument of ``HerReplayBuffer`` was removed +- Removed deprecated ``stack_observation_space`` method of ``StackedObservations`` - Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) +- Upgraded wrappers and custom environment to Gymnasium New Features: ^^^^^^^^^^^^^ + `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -28,9 +41,22 @@ Deprecations: Others: ^^^^^^^ +- Fixed ``stable_baselines3/a2c/*.py`` type hints +- Fixed ``stable_baselines3/ppo/*.py`` type hints +- Fixed ``stable_baselines3/sac/*.py`` type hints +- Fixed ``stable_baselines3/td3/*.py`` type hints +- Fixed ``stable_baselines3/common/base_class.py`` type hints +- Upgraded docker images to use mamba/micromamba and CUDA 11.7 +- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks +- Improve type annotation of wrappers +- Tests envs are now checked too +- Added render test for ``VecEnv`` Documentation: ^^^^^^^^^^^^^^ +- Added documentation about ``VecEnv`` API vs Gym API +- Upgraded tutorials to Gymnasium API +- Make it more explicit when using ``VecEnv`` vs Gym env Release 1.8.0 (2023-04-07) @@ -328,6 +354,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ + `SB3-Contrib`_ ^^^^^^^^^^^^^^ - Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53 @@ -373,7 +400,7 @@ Release 1.5.0 (2022-03-25) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. +- Switched minimum Gym version to 0.21.0 New Features: ^^^^^^^^^^^^^ @@ -1298,6 +1325,7 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi +@carlosluis @arjun-kg @tlpss +@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel \ No newline at end of file +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 670da617c..84a94eaab 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -53,7 +53,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index c484a1c93..4ac28ccb3 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import DDPG diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 8648606cc..0569aa528 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -56,7 +56,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import DQN diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index d0c425fb5..a822cb436 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -65,7 +65,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index e7f9057d5..0e9bb3f64 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -68,7 +68,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import SAC diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index d039ae71c..7c17e644d 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import TD3 diff --git a/pyproject.toml b/pyproject.toml index 2f98aa4f4..b44edf529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,7 @@ ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/a2c/a2c.py$ - | stable_baselines3/common/base_class.py$ - | stable_baselines3/common/buffers.py$ + stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ | stable_baselines3/common/distributions.py$ | stable_baselines3/common/envs/bit_flipping_env.py$ @@ -45,7 +43,6 @@ exclude = """(?x)( | stable_baselines3/common/envs/multi_input_envs.py$ | stable_baselines3/common/logger.py$ | stable_baselines3/common/off_policy_algorithm.py$ - | stable_baselines3/common/on_policy_algorithm.py$ | stable_baselines3/common/policies.py$ | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ @@ -62,11 +59,6 @@ exclude = """(?x)( | stable_baselines3/common/vec_env/vec_transpose.py$ | stable_baselines3/common/vec_env/vec_video_recorder.py$ | stable_baselines3/her/her_replay_buffer.py$ - | stable_baselines3/ppo/ppo.py$ - | stable_baselines3/sac/policies.py$ - | stable_baselines3/sac/sac.py$ - | stable_baselines3/td3/policies.py$ - | stable_baselines3/td3/td3.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ )""" @@ -80,12 +72,8 @@ env = [ filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", - # Gym warnings - "ignore:Parameters to load are deprecated.:DeprecationWarning", - "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", - "ignore::UserWarning:gym", - "ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning", - "ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning", + # Gymnasium warnings + "ignore::UserWarning:gymnasium", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 13ac86b17..c1a4a5608 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,14 +1,14 @@ #!/bin/bash -CPU_PARENT=ubuntu:18.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 +CPU_PARENT=mambaorg/micromamba:1.4-kinetic +GPU_PARENT=mambaorg/micromamba:1.4.1-focal-cuda-11.7.1 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} - PYTORCH_DEPS="cudatoolkit=10.1" + PYTORCH_DEPS="pytorch-cuda=11.7" else PARENT=${CPU_PARENT} PYTORCH_DEPS="cpuonly" diff --git a/scripts/run_docker_cpu.sh b/scripts/run_docker_cpu.sh index 6dfafd2b9..db6c6493b 100755 --- a/scripts/run_docker_cpu.sh +++ b/scripts/run_docker_cpu.sh @@ -7,5 +7,5 @@ echo "Executing in the docker (cpu image):" echo $cmd_line docker run -it --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" diff --git a/scripts/run_docker_gpu.sh b/scripts/run_docker_gpu.sh index 19e16067a..fa8aae9c4 100755 --- a/scripts/run_docker_gpu.sh +++ b/scripts/run_docker_gpu.sh @@ -15,5 +15,5 @@ else fi docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" diff --git a/setup.py b/setup.py index dd7b69637..a7a3fcc8b 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,11 @@ Here is a quick example of how to train and run PPO on a cartpole environment: ```python -import gym +import gymnasium from stable_baselines3 import PPO -env = gym.make("CartPole-v1") +env = gymnasium.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) @@ -60,7 +60,7 @@ ``` -Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): +Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO @@ -76,6 +76,9 @@ extra_no_roms = [ # For render "opencv-python", + 'pygame; python_version >= "3.8.0"', + # See https://github.com/pygame/pygame/issues/3572 + 'pygame>=2.0,<2.1.3; python_version < "3.8.0"', # Tensorboard support "tensorboard>=2.9.1", # Checking memory taken by replay buffer @@ -84,7 +87,7 @@ "tqdm", "rich", # For atari games, - "ale-py==0.7.4", + "shimmy[atari]~=0.2.1", "pillow", ] @@ -99,7 +102,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.21", # Fixed version due to breaking changes in 0.22 + "gymnasium==0.28.1", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', @@ -109,8 +112,6 @@ "pandas", # Plotting learning curves "matplotlib", - # gym not compatible with importlib-metadata>5.0 - "importlib-metadata~=4.13", ], extras_require={ "tests": [ @@ -128,8 +129,6 @@ "isort>=5.0", # Reformat "black", - # For toy text Gym envs - "scipy>=1.4.1", ], "docs": [ "sphinx", diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 680e25453..0775a8ec5 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,7 +1,5 @@ import os -import numpy as np - from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info from stable_baselines3.ddpg import DDPG @@ -11,10 +9,6 @@ from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 -# Small monkey patch so gym 0.21 is compatible with numpy >= 1.24 -# TODO: remove when upgrading to gym 0.26 -np.bool = bool # type: ignore[attr-defined] - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file) as file_handler: diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 9ecde5b85..996f17c19 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index ad29a3142..ea78c5b7c 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,6 +1,10 @@ -import gym +from typing import Dict, SupportsFloat + +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces + +from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn try: import cv2 # pytype:disable=import-error @@ -9,10 +13,8 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn - -class StickyActionEnv(gym.Wrapper): +class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sticky action. @@ -26,19 +28,19 @@ class StickyActionEnv(gym.Wrapper): def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: super().__init__(env) self.action_repeat_probability = action_repeat_probability - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> AtariResetReturn: self._sticky_action = 0 # NOOP return self.env.reset(**kwargs) - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> AtariStepReturn: if self.np_random.random() >= self.action_repeat_probability: self._sticky_action = action return self.env.step(self._sticky_action) -class NoopResetEnv(gym.Wrapper): +class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. @@ -52,24 +54,25 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) + info: Dict = {} for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) - if done: - obs = self.env.reset(**kwargs) - return obs + obs, _, terminated, truncated, info = self.env.step(self.noop_action) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info -class FireResetEnv(gym.Wrapper): +class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Take action on reset for environments that are fixed until firing. @@ -78,21 +81,21 @@ class FireResetEnv(gym.Wrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - assert env.unwrapped.get_action_meanings()[1] == "FIRE" - assert len(env.unwrapped.get_action_meanings()) >= 3 + assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] + assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, terminated, truncated, _ = self.env.step(2) + if terminated or truncated: self.env.reset(**kwargs) - return obs + return obs, {} -class EpisodicLifeEnv(gym.Wrapper): +class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. @@ -105,21 +108,21 @@ def __init__(self, env: gym.Env) -> None: self.lives = 0 self.was_real_done = True - def step(self, action: int) -> GymStepReturn: - obs, reward, done, info = self.env.step(action) - self.was_real_done = done + def step(self, action: int) -> AtariStepReturn: + obs, reward, terminated, truncated, info = self.env.step(action) + self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives - lives = self.env.unwrapped.ale.lives() + lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. - done = True + terminated = True self.lives = lives - return obs, reward, done, info + return obs, reward, terminated, truncated, info - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> AtariResetReturn: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, @@ -129,21 +132,21 @@ def reset(self, **kwargs) -> np.ndarray: :return: the first observation of the environment """ if self.was_real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, done, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) # The no-op step can lead to a game over, so we need to check it again # to see if we should reset the environment and avoid the # monitor.py `RuntimeError: Tried to step environment that needs reset` - if done: - obs = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return obs + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] + return obs, info -class MaxAndSkipEnv(gym.Wrapper): +class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames. @@ -156,33 +159,36 @@ class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) # most recent raw observations (for max pooling across time steps) + assert env.observation_space.dtype is not None, "No dtype specified for the observation space" + assert env.observation_space.shape is not None, "No shape defined for the observation space" self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) self._skip = skip - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> AtariStepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ total_reward = 0.0 - done = False + terminated = truncated = False for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs - total_reward += reward + total_reward += float(reward) if done: break # Note that the observation on the done=True frame # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, terminated, truncated, info class ClipRewardEnv(gym.RewardWrapper): @@ -195,17 +201,17 @@ class ClipRewardEnv(gym.RewardWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - def reward(self, reward: float) -> float: + def reward(self, reward: SupportsFloat) -> float: """ Bin reward to {+1, 0, -1} by its sign. :param reward: :return: """ - return np.sign(reward) + return np.sign(float(reward)) -class WarpFrame(gym.ObservationWrapper): +class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]): """ Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. @@ -219,8 +225,13 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: super().__init__(env) self.width = width self.height = height + assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}" + self.observation_space = spaces.Box( - low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype + low=0, + high=255, + shape=(self.height, self.width, 1), + dtype=env.observation_space.dtype, # type: ignore[arg-type] ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -235,7 +246,7 @@ def observation(self, frame: np.ndarray) -> np.ndarray: return frame[:, :, None] -class AtariWrapper(gym.Wrapper): +class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Atari 2600 preprocessings @@ -285,7 +296,7 @@ def __init__( env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) - if "FIRE" in env.unwrapped.get_action_meanings(): + if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined] env = FireResetEnv(env) env = WarpFrame(env, width=screen_size, height=screen_size) if clip_reward: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index bec075699..ece511a24 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -8,10 +8,10 @@ from collections import deque from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union -import gym +import gymnasium as gym import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback @@ -22,7 +22,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -39,11 +39,12 @@ is_vecenv_wrapped, unwrap_vec_normalize, ) +from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") -def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]: +def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. :param env: The environment to learn from. @@ -51,9 +52,14 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE :return A Gym (vector) environment. """ if isinstance(env, str): + env_id = env if verbose >= 1: - print(f"Creating environment from the given name '{env}'") - env = gym.make(env) + print(f"Creating environment from the given name '{env_id}'") + # Set render_mode to `rgb_array` as default, so we can record video + try: + env = gym.make(env_id, render_mode="rgb_array") + except TypeError: + env = gym.make(env_id) return env @@ -90,6 +96,11 @@ class BaseAlgorithm(ABC): # Policy aliases (see _get_policy_from_name()) policy_aliases: Dict[str, Type[BasePolicy]] = {} policy: BasePolicy + observation_space: spaces.Space + action_space: spaces.Space + n_envs: int + lr_schedule: Schedule + _logger: Logger def __init__( self, @@ -106,8 +117,8 @@ def __init__( seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, - supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, - ): + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + ) -> None: if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) else: @@ -117,14 +128,9 @@ def __init__( if verbose >= 1: print(f"Using {self.device} device") - self.env = None # type: Optional[GymEnv] - # get VecNormalize object if needed - self._vec_normalize_env = unwrap_vec_normalize(env) self.verbose = verbose self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs - self.observation_space: spaces.Space - self.action_space: spaces.Space - self.n_envs: int + self.num_timesteps = 0 # Used for updating schedules self._total_timesteps = 0 @@ -132,10 +138,9 @@ def __init__( self._num_timesteps_at_start = 0 self.seed = seed self.action_noise: Optional[ActionNoise] = None - self.start_time = None + self.start_time = 0.0 self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log - self.lr_schedule = None # type: Optional[Schedule] self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: @@ -146,17 +151,17 @@ def __init__( self.sde_sample_freq = sde_sample_freq # Track the training progress remaining (from 1 to 0) # this is used to update the learning rate - self._current_progress_remaining = 1 + self._current_progress_remaining = 1.0 # Buffers for logging self._stats_window_size = stats_window_size self.ep_info_buffer = None # type: Optional[deque] self.ep_success_buffer = None # type: Optional[deque] # For logging (and TD3 delayed updates) self._n_updates = 0 # type: int - # The logger object - self._logger = None # type: Logger # Whether the user passed a custom logger or not self._custom_logger = False + self.env: Optional[VecEnv] = None + self._vec_normalize_env: Optional[VecNormalize] = None # Create and wrap the env if needed if env is not None: @@ -168,6 +173,9 @@ def __init__( self.n_envs = env.num_envs self.env = env + # get VecNormalize object if needed + self._vec_normalize_env = unwrap_vec_normalize(env) + if supported_action_spaces is not None: assert isinstance(self.action_space, supported_action_spaces), ( f"The algorithm only supports {supported_action_spaces} as action spaces " @@ -204,13 +212,15 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve :return: The wrapped environment. """ if not isinstance(env, VecEnv): + # Patch to support gym 0.21/0.26 and gymnasium + env = _patch_env(env) if not is_wrapped(env, Monitor) and monitor_wrapper: if verbose >= 1: print("Wrapping the env with a `Monitor` wrapper") env = Monitor(env) if verbose >= 1: print("Wrapping the env in a DummyVecEnv.") - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] # Make sure that dict-spaces are not nested (not supported) check_for_nested_spaces(env.observation_space) @@ -223,11 +233,11 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve # the other channel last), VecTransposeImage will throw an error for space in env.observation_space.spaces.values(): wrap_with_vectranspose = wrap_with_vectranspose or ( - is_image_space(space) and not is_image_space_channels_first(space) + is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type] ) else: wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( - env.observation_space + env.observation_space # type: ignore[arg-type] ) if wrap_with_vectranspose: @@ -409,7 +419,10 @@ def _setup_learn( # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: - self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch + assert self.env is not None + # pytype: disable=annotation-type-mismatch + self._last_obs = self.env.reset() # type: ignore[assignment] + # pytype: enable=annotation-type-mismatch self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: @@ -432,6 +445,9 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd :param infos: List of additional information about the transition. :param dones: Termination signals """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + if dones is None: dones = np.array([False] * len(infos)) for idx, info in enumerate(infos): @@ -555,7 +571,7 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: def set_parameters( self, - load_path_or_dict: Union[str, Dict[str, Dict]], + load_path_or_dict: Union[str, TensorDict], exact_match: bool = True, device: Union[th.device, str] = "auto", ) -> None: @@ -571,7 +587,7 @@ def set_parameters( can be used to update only specific parameters. :param device: Device on which the code should run. """ - params = None + params = {} if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: @@ -609,7 +625,7 @@ def set_parameters( # # Solution: Just load the state-dict as is, and trust # the user has provided a sensible state dictionary. - attr.load_state_dict(params[name]) + attr.load_state_dict(params[name]) # type: ignore[arg-type] else: # Assume attr is th.nn.Module attr.load_state_dict(params[name], strict=exact_match) @@ -667,6 +683,9 @@ def load( # noqa: C901 print_system_info=print_system_info, ) + assert data is not None, "No data found in the saved file" + assert params is not None, "No params found in the saved file" + # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: @@ -686,6 +705,10 @@ def load( # noqa: C901 if "observation_space" not in data or "action_space" not in data: raise KeyError("The observation_space and action_space were not given, can't verify new environments") + # Gym -> Gymnasium space conversion + for key in {"observation_space", "action_space"}: + data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands + if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) @@ -703,13 +726,14 @@ def load( # noqa: C901 if "env" in data: env = data["env"] - # noinspection PyArgumentList - model = cls( # pytype: disable=not-instantiable,wrong-keyword-args + # pytype: disable=not-instantiable,wrong-keyword-args + model = cls( policy=data["policy_class"], env=env, device=device, - _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args + _init_setup_model=False, # type: ignore[call-arg] ) + # pytype: enable=not-instantiable,wrong-keyword-args # load parameters model.__dict__.update(data) @@ -747,12 +771,12 @@ def load( # noqa: C901 continue # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 - recursive_setattr(model, name + ".data", pytorch_variables[name].data) + recursive_setattr(model, f"{name}.data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: - model.policy.reset_noise() # pytype: disable=attribute-error + model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error return model def get_parameters(self) -> Dict[str, Dict]: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 95a9c1e86..e52f08f69 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( @@ -335,6 +335,15 @@ class RolloutBuffer(BaseBuffer): :param n_envs: Number of parallel environments """ + observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + advantages: np.ndarray + returns: np.ndarray + episode_starts: np.ndarray + log_probs: np.ndarray + values: np.ndarray + def __init__( self, buffer_size: int, @@ -348,8 +357,6 @@ def __init__( super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma - self.observations, self.actions, self.rewards, self.advantages = None, None, None, None - self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None self.generator_ready = False self.reset() diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 28d07e006..c9b8a3367 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common.logger import Logger @@ -313,7 +313,7 @@ class ConvertCallback(BaseCallback): :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): + def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0): super().__init__(verbose) self.callback = callback diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 2b942e8f8..8a8e9f903 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -5,7 +5,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from torch.distributions import Bernoulli, Categorical, Normal diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index b71454b1c..cc8be48ef 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,9 +1,9 @@ import warnings from typing import Any, Dict, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -60,9 +60,15 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Dict): nested_dict = False - for space in observation_space.spaces.values(): + for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True + if isinstance(space, spaces.Discrete) and space.start != 0: + warnings.warn( + f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + if nested_dict: warnings.warn( "Nested observation spaces are not supported by Stable Baselines3 " @@ -81,6 +87,18 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) + if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: + warnings.warn( + "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + + if isinstance(action_space, spaces.Discrete) and action_space.start != 0: + warnings.warn( + "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your action space." + ) + if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " @@ -101,9 +119,8 @@ def _is_goal_env(env: gym.Env) -> bool: """ Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) """ - if isinstance(env, gym.Wrapper): # We need to unwrap the env since gym.Wrapper has the compute_reward method - return _is_goal_env(env.unwrapped) - return hasattr(env, "compute_reward") + # We need to unwrap the env since gym.Wrapper has the compute_reward method + return hasattr(env.unwrapped, "compute_reward") def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: @@ -131,7 +148,7 @@ def _check_goal_env_compute_reward( env: gym.Env, reward: float, info: Dict[str, Any], -): +) -> None: """ Check that reward is computed with `compute_reward` and that the implementation is vectorized. @@ -165,7 +182,9 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # The check for a GoalEnv is done by the base class if isinstance(observation_space, spaces.Discrete): - assert isinstance(obs, int), f"The observation returned by `{method_name}()` method must be an int" + # Since https://github.com/Farama-Foundation/Gymnasium/pull/141, + # `sample()` will return a np.int64 instead of an int + assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" @@ -174,27 +193,32 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # check obs dimensions, dtype and bounds assert observation_space.shape == obs.shape, ( f"The observation returned by the `{method_name}()` method does not match the shape " - f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}" + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.shape}, actual shape: {obs.shape}" ) - assert observation_space.dtype == obs.dtype, ( - f"The observation returned by the `{method_name}()` method does not match the data type " - f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" + assert np.can_cast(obs.dtype, observation_space.dtype), ( + f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) " + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): assert np.all(obs >= observation_space.low), ( f"The observation returned by the `{method_name}()` method does not match the lower bound " - f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, " + f"of the given observation space {observation_space}." + f"Expected: obs >= {np.min(observation_space.low)}, " f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" ) assert np.all(obs <= observation_space.high), ( f"The observation returned by the `{method_name}()` method does not match the upper bound " - f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, " + f"of the given observation space {observation_space}. " + f"Expected: obs <= {np.max(observation_space.high)}, " f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" ) - assert observation_space.contains( - obs - ), f"The observation returned by the `{method_name}()` method does not match the given observation space" + assert observation_space.contains(obs), ( + f"The observation returned by the `{method_name}()` method " + f"does not match the given observation space {observation_space}" + ) def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: @@ -221,8 +245,12 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action """ Check the returned values by the env when calling `.reset()` or `.step()` methods. """ - # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists - obs = env.reset() + # because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists + reset_returns = env.reset() + assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" + assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" + obs, info = reset_returns + assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}" if _is_goal_env(env): # Make mypy happy, already checked @@ -249,19 +277,24 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action action = action_space.sample() data = env.step(action) - assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info" + assert len(data) == 5, ( + "The `step()` method must return five values: " + f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned." + ) # Unpack - obs, reward, done, info = data + obs, reward, terminated, truncated, info = data - if _is_goal_env(env): - # Make mypy happy, already checked - assert isinstance(observation_space, spaces.Dict) - _check_goal_env_obs(obs, observation_space, "step") - _check_goal_env_compute_reward(obs, env, reward, info) - elif isinstance(observation_space, spaces.Dict): + if isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" + # Additional checks for GoalEnvs + if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) + _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, float(reward), info) + if not obs.keys() == observation_space.spaces.keys(): raise AssertionError( "The observation keys returned by `step()` must match the observation " @@ -279,11 +312,14 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" + assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" # Goal conditioned env if _is_goal_env(env): + # for mypy, env.unwrapped was checked by _is_goal_env() + assert hasattr(env, "compute_reward") assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) @@ -299,8 +335,10 @@ def _check_spaces(env: gym.Env) -> None: assert hasattr(env, "observation_space"), "You must specify an observation space (cf gym.spaces)" + gym_spaces assert hasattr(env, "action_space"), "You must specify an action space (cf gym.spaces)" + gym_spaces - assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces - assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces + assert isinstance(env.observation_space, spaces.Space), ( + "The observation space must inherit from gymnasium.spaces" + gym_spaces + ) + assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" + gym_spaces if _is_goal_env(env): assert isinstance( @@ -309,9 +347,9 @@ def _check_spaces(env: gym.Env) -> None: # Check render cannot be covered by CI -def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover +def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover """ - Check the declared render modes and the `render()`/`close()` + Check the instantiated render mode (if any) by calling the `render()`/`close()` method of the environment. :param env: The environment to check @@ -319,24 +357,19 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No :param headless: Whether to disable render modes that require a graphical interface. False by default. """ - render_modes = env.metadata.get("render.modes") + render_modes = env.metadata.get("render_modes") if render_modes is None: if warn: warnings.warn( "No render modes was declared in the environment " - " (env.metadata['render.modes'] is None or not defined), " + "(env.metadata['render_modes'] is None or not defined), " "you may have trouble when calling `.render()`" ) - else: - # Don't check render mode that require a - # graphical interface (useful for CI) - if headless and "human" in render_modes: - render_modes.remove("human") - # Check all declared render modes - for render_mode in render_modes: - env.render(mode=render_mode) - env.close() + # Only check currrent render mode + if env.render_mode: + env.render() + env.close() def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: @@ -401,7 +434,7 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ==== Check the render method and the declared render modes ==== if not skip_render_check: - _check_render(env, warn=warn) # pragma: no cover + _check_render(env, warn) # pragma: no cover try: check_for_nested_spaces(env.observation_space) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index c85d1472b..c3b73909e 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -1,11 +1,13 @@ import os from typing import Any, Callable, Dict, Optional, Type, Union -import gym +import gymnasium as gym from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv +from stable_baselines3.common.vec_env.patch_gym import _patch_env def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: @@ -24,7 +26,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g return None -def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: +def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. @@ -72,25 +74,40 @@ def make_vec_env( :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. :return: The wrapped environment """ - env_kwargs = {} if env_kwargs is None else env_kwargs - vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs - monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs - wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs + env_kwargs = env_kwargs or {} + vec_env_kwargs = vec_env_kwargs or {} + monitor_kwargs = monitor_kwargs or {} + wrapper_kwargs = wrapper_kwargs or {} + assert vec_env_kwargs is not None # for mypy + + def make_env(rank: int) -> Callable[[], gym.Env]: + def _init() -> gym.Env: + # For type checker: + assert monitor_kwargs is not None + assert wrapper_kwargs is not None + assert env_kwargs is not None - def make_env(rank): - def _init(): if isinstance(env_id, str): - env = gym.make(env_id, **env_kwargs) + # if the render mode was not specified, we set it to `rgb_array` as default. + kwargs = {"render_mode": "rgb_array"} + kwargs.update(env_kwargs) + try: + env = gym.make(env_id, **kwargs) # type: ignore[arg-type] + except TypeError: + env = gym.make(env_id, **env_kwargs) else: env = env_id(**env_kwargs) + # Patch to support gym 0.21/0.26 and gymnasium + env = _patch_env(env) + if seed is not None: - env.seed(seed + rank) + compat_gym_seed(env, seed=seed + rank) env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None # Create the monitor folder if needed - if monitor_path is not None: + if monitor_path is not None and monitor_dir is not None: os.makedirs(monitor_dir, exist_ok=True) env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index d6724c9cc..ec0de2bf9 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,9 +1,9 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from gym import Env, spaces -from gym.envs.registration import EnvSpec +from gymnasium import Env, spaces +from gymnasium.envs.registration import EnvSpec from stable_baselines3.common.type_aliases import GymStepReturn @@ -25,7 +25,7 @@ class BitFlippingEnv(Env): :param channel_first: Whether to use channel-first or last image. """ - spec = EnvSpec("BitFlippingEnv-v0") + spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") def __init__( self, @@ -96,7 +96,7 @@ def __init__( self.discrete_obs_space = discrete_obs_space self.image_obs_space = image_obs_space self.state = None - self.desired_goal = np.ones((n_bits,)) + self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype) if max_steps is None: max_steps = n_bits self.max_steps = max_steps @@ -157,24 +157,34 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self) -> Dict[str, Union[int, np.ndarray]]: + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict] = None + ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + if seed is not None: + self.obs_space.seed(seed) self.current_step = 0 self.state = self.obs_space.sample() - return self._get_obs() + return self._get_obs(), {} def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + """ + Step into the env. + + :param action: + :return: + """ if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: self.state[action] = 1 - self.state[action] obs = self._get_obs() reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)) - done = reward == 0 + terminated = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps - info = {"is_success": done} - done = done or self.current_step >= self.max_steps - return obs, reward, done, info + info = {"is_success": terminated} + truncated = self.current_step >= self.max_steps + return obs, reward, terminated, truncated, info def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index a8bed175a..99a664999 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import GymStepReturn @@ -34,18 +34,21 @@ def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = No self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self) -> T: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() - return self.state + return self.state, {} - def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]: + def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() @@ -71,12 +74,13 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 @@ -138,15 +142,18 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self) -> np.ndarray: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: reward = 0.0 self.current_step += 1 - done = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.observation_space.sample(), reward, terminated, truncated, {} def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 166c6991a..3bb07106a 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,8 +1,8 @@ -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import GymStepReturn @@ -153,11 +153,12 @@ def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward - done = self.count > self.max_count or got_to_end + truncated = self.count > self.max_count + terminated = got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" - return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} + return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ @@ -167,15 +168,18 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self) -> Dict[str, np.ndarray]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]: """ Resets the environment state and step count and returns reset observation. + :param seed: :return: observation dict {'vec': ..., 'img': ...} """ + if seed is not None: + super().reset(seed=seed) self.count = 0 if not self.random_start: self.state = 0 else: self.state = np.random.randint(0, self.max_state) - return self.state_mapping[self.state] + return self.state_mapping[self.state], {} diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 3b634375b..c9253a899 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,7 +1,7 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common import type_aliases @@ -59,7 +59,7 @@ def evaluate_policy( from stable_baselines3.common.monitor import Monitor if not isinstance(env, VecEnv): - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] @@ -85,7 +85,12 @@ def evaluate_policy( states = None episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) + actions, states = model.predict( + observations, # type: ignore[arg-type] + state=states, + episode_start=episode_starts, + deterministic=deterministic, + ) new_observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index b8ebc2bac..9c3bd54a8 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -5,16 +5,14 @@ import os import time from glob import glob -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union -import gym -import numpy as np +import gymnasium as gym import pandas +from gymnasium.core import ActType, ObsType -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn - -class Monitor(gym.Wrapper): +class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. @@ -43,9 +41,10 @@ def __init__( self.t_start = time.time() self.results_writer = None if filename is not None: + env_id = env.spec.id if env.spec is not None else None self.results_writer = ResultsWriter( filename, - header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, + header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=reset_keywords + info_keywords, override_existing=override_existing, ) @@ -62,7 +61,7 @@ def __init__( # extra info about the current episode, that was passed in during reset() self.current_reset_info: Dict[str, Any] = {} - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -83,18 +82,18 @@ def reset(self, **kwargs) -> GymObs: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """ Step the environment with the given action :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") - observation, reward, done, info = self.env.step(action) - self.rewards.append(reward) - if done: + observation, reward, terminated, truncated, info = self.env.step(action) + self.rewards.append(float(reward)) + if terminated or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) @@ -109,7 +108,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 - return observation, reward, done, info + return observation, reward, terminated, truncated, info def close(self) -> None: """ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index a4524517d..e3e6c594a 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -8,7 +8,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer @@ -75,6 +75,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param supported_action_spaces: The action spaces supported by the algorithm. """ + actor: th.nn.Module + def __init__( self, policy: Union[str, Type[BasePolicy]], @@ -103,7 +105,7 @@ def __init__( sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -129,6 +131,7 @@ def __init__( self.gradient_steps = gradient_steps self.action_noise = action_noise self.optimize_memory_usage = optimize_memory_usage + self.replay_buffer: Optional[ReplayBuffer] = None self.replay_buffer_class = replay_buffer_class self.replay_buffer_kwargs = replay_buffer_kwargs or {} self._episode_storage = None @@ -136,8 +139,6 @@ def __init__( # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = train_freq - self.actor = None # type: Optional[th.nn.Module] - self.replay_buffer: Optional[ReplayBuffer] = None # Update policy keyword arguments if sde_support: self.policy_kwargs["use_sde"] = self.use_sde diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 0e6f31616..87e192990 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer @@ -52,6 +52,9 @@ class OnPolicyAlgorithm(BaseAlgorithm): :param supported_action_spaces: The action spaces supported by the algorithm. """ + rollout_buffer: RolloutBuffer + policy: ActorCriticPolicy + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], @@ -73,7 +76,7 @@ def __init__( seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -97,7 +100,6 @@ def __init__( self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm - self.rollout_buffer = None if _init_setup_model: self._setup_model() @@ -117,13 +119,11 @@ def _setup_model(self) -> None: gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) - self.policy = self.policy_class( # pytype:disable=not-instantiable - self.observation_space, - self.action_space, - self.lr_schedule, - use_sde=self.use_sde, - **self.policy_kwargs # pytype:disable=not-instantiable + # pytype:disable=not-instantiable + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) + # pytype:enable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts( @@ -201,16 +201,23 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): - terminal_value = self.policy.predict_values(terminal_obs)[0] + terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) - self._last_obs = new_obs + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + ) + self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -246,6 +253,8 @@ def learn( callback.on_training_start(locals(), globals()) + assert self.env is not None + while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) @@ -257,6 +266,7 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: + assert self.ep_info_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index c67c45cf2..21d2034d6 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -9,7 +9,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import ( @@ -58,6 +58,8 @@ class BaseModel(nn.Module): excluding the learning rate, to pass to the optimizer """ + optimizer: th.optim.Optimizer + def __init__( self, observation_space: spaces.Space, @@ -84,7 +86,6 @@ def __init__( self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs - self.optimizer: th.optim.Optimizer self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs @@ -279,6 +280,8 @@ class BasePolicy(BaseModel, ABC): or not using a ``tanh()`` function. """ + features_extractor: BaseFeaturesExtractor + def __init__(self, *args, squash_output: bool = False, **kwargs): super().__init__(*args, **kwargs) self._squash_output = squash_output @@ -898,9 +901,9 @@ class ContinuousCritic(BaseModel): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], - features_extractor: nn.Module, + features_extractor: BaseFeaturesExtractor, features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index e280ed731..bc0959480 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F @@ -158,10 +158,7 @@ def get_obs_shape( return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - if type(observation_space.n) in [tuple, list, np.ndarray]: - return tuple(observation_space.n) - else: - return (int(observation_space.n),) + return observation_space.shape elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] @@ -205,18 +202,20 @@ def get_action_dim(action_space: spaces.Space) -> int: return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): # Number of binary actions + assert isinstance( + action_space.n, int + ), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead." return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") -def check_for_nested_spaces(obs_space: spaces.Space): +def check_for_nested_spaces(obs_space: spaces.Space) -> None: """ Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). If so, raise an Exception informing that there is no support for this. :param obs_space: an observation space - :return: """ if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index e5aeb662b..3c01a3f26 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -367,7 +367,7 @@ def load_from_zip_file( device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, -) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]: +) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 44714d6fe..ad6c7eef1 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,8 +1,8 @@ from typing import Dict, List, Tuple, Type, Union -import gym +import gymnasium as gym import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space @@ -63,10 +63,14 @@ class NatureCNN(BaseFeaturesExtractor): def __init__( self, - observation_space: spaces.Box, + observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False, ) -> None: + assert isinstance(observation_space, spaces.Box), ( + "NatureCNN must be used with a gym.spaces.Box ", + f"observation space, not {observation_space}", + ) super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7227667a1..d38d7cf73 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -2,9 +2,9 @@ import sys from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, SupportsFloat, Tuple, Union -import gym +import gymnasium as gym import numpy as np import torch as th @@ -17,8 +17,11 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] -GymStepReturn = Tuple[GymObs, float, bool, Dict] -TensorDict = Dict[Union[str, int], th.Tensor] +GymResetReturn = Tuple[GymObs, Dict] +AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] +GymStepReturn = Tuple[GymObs, float, bool, bool, Dict] +AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] +TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 1234fba79..d20cc8529 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,13 +4,14 @@ import random import re from collections import deque +from inspect import signature from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces import stable_baselines3 as sb3 @@ -470,9 +471,7 @@ def polyak_update( th.add(target_param.data, param.data, alpha=tau, out=target_param.data) -def obs_as_tensor( - obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device -) -> Union[th.Tensor, TensorDict]: +def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: """ Moves the observation to the given device. @@ -541,3 +540,18 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def compat_gym_seed(env: GymEnv, seed: int) -> None: + """ + Compatibility helper to seed Gym envs. + + :param env: The Gym environment. + :param seed: The seed for the pseudo random generator + """ + if "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + env.reset(seed=seed) + else: + # VecEnv and backward compatibility + env.seed(seed) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 708c021aa..a15750360 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -4,9 +4,9 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces # Define type aliases here to avoid circular import # Used when we want to access one or more VecEnv @@ -54,12 +54,20 @@ class VecEnv(ABC): :param action_space: Action space """ - metadata = {"render.modes": ["human", "rgb_array"]} + metadata = {"render_modes": ["human", "rgb_array"]} - def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space): + def __init__( + self, + num_envs: int, + observation_space: spaces.Space, + action_space: spaces.Space, + render_mode: Optional[str] = None, + ): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.render_mode = render_mode + self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method @abstractmethod def reset(self) -> VecEnvObs: @@ -162,35 +170,72 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: self.step_async(actions) return self.step_wait() - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: """ - Return RGB images from each environment + Return RGB images from each environment when available """ raise NotImplementedError - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering :param mode: the rendering type """ - try: - imgs = self.get_images() - except NotImplementedError: - warnings.warn(f"Render not defined for {self}") + + if mode == "human" and self.render_mode != mode: + # Special case, if the render_mode="rgb_array" + # we can still display that image using opencv + if self.render_mode != "rgb_array": + warnings.warn( + f"You tried to render a VecEnv with mode='{mode}' " + "but the render mode defined when initializing the environment must be " + f"'human' or 'rgb_array', not '{self.render_mode}'." + ) + return + + elif mode and self.render_mode != mode: + warnings.warn( + f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment. + We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) + has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" + ) + return + + mode = mode or self.render_mode + + if mode is None: + warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") + return + + # mode == self.render_mode == "human" + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + if self.render_mode == "human": + self.env_method("render") return - # Create a big image by tiling images from subprocesses - bigimg = tile_images(imgs) - if mode == "human": - import cv2 # pytype:disable=import-error + if mode == "rgb_array" or mode == "human": + # call the render method of the environments + images = self.get_images() + # Create a big image by tiling images from subprocesses + bigimg = tile_images(images) + + if mode == "human": + # Display it using OpenCV + import cv2 # pytype:disable=import-error + + cv2.imshow("vecenv", bigimg[:, :, ::-1]) + cv2.waitKey(1) + else: + return bigimg - cv2.imshow("vecenv", bigimg[:, :, ::-1]) - cv2.waitKey(1) - elif mode == "rgb_array": - return bigimg else: - raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") + # Other render modes: + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + # and we don't return the values + self.env_method("render") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -251,6 +296,7 @@ def __init__( venv: VecEnv, observation_space: Optional[spaces.Space] = None, action_space: Optional[spaces.Space] = None, + render_mode: Optional[str] = None, ): self.venv = venv VecEnv.__init__( @@ -258,6 +304,7 @@ def __init__( num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space, + render_mode=render_mode, ) self.class_attributes = dict(inspect.getmembers(self.__class__)) @@ -278,10 +325,10 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def close(self) -> None: return self.venv.close() - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5b9fc8b40..7f092e040 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,18 +1,20 @@ +import warnings from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence, Type, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn +from stable_baselines3.common.vec_env.patch_gym import _patch_env from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): """ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current - Python process. This is useful for computationally simple environment such as ``cartpole-v1``, + Python process. This is useful for computationally simple environment such as ``Cartpole-v1``, as the overhead of multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. @@ -23,7 +25,7 @@ class DummyVecEnv(VecEnv): """ def __init__(self, env_fns: List[Callable[[], gym.Env]]): - self.envs = [fn() for fn in env_fns] + self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( "You tried to create multiple environments, but the function to create them returned the same instance " @@ -35,7 +37,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]): "Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information." ) env = self.envs[0] - VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode) obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) @@ -50,28 +52,38 @@ def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: + # Avoid circular imports for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( self.actions[env_idx] ) + # convert to SB3 VecEnv api + self.buf_dones[env_idx] = terminated or truncated + # See https://github.com/openai/gym/issues/3102 + # Gym 0.26 introduces a breaking change + self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + # Avoid circular import + from stable_baselines3.common.utils import compat_gym_seed + if seed is None: seed = np.random.randint(0, 2**32 - 1) seeds = [] for idx, env in enumerate(self.envs): - seeds.append(env.seed(seed + idx)) + seeds.append(compat_gym_seed(env, seed=seed + idx)) return seeds def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() @@ -79,25 +91,22 @@ def close(self) -> None: for env in self.envs: env.close() - def get_images(self) -> Sequence[np.ndarray]: - return [env.render(mode="rgb_array") for env in self.envs] + def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.render_mode != "rgb_array": + warnings.warn( + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) + return [None for _ in self.envs] + return [env.render() for env in self.envs] - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. - Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the - underlying environment. - - Therefore, some arguments such as ``mode`` will have values that are valid - only when ``num_envs == 1``. :param mode: The rendering type. """ - if self.num_envs == 1: - return self.envs[0].render(mode=mode) - else: - return super().render(mode=mode) + return super().render(mode=mode) def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: for key in self.keys: diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py new file mode 100644 index 000000000..b86c52236 --- /dev/null +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -0,0 +1,100 @@ +import warnings +from inspect import signature +from typing import Union + +import gymnasium + +try: + import gym # pytype: disable=import-error + + gym_installed = True +except ImportError: + gym_installed = False + + +def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover + """ + Adapted from https://github.com/thu-ml/tianshou. + + Takes an environment and patches it to return Gymnasium env. + This function takes the environment object and returns a patched + env, using shimmy wrapper to convert it to Gymnasium, + if necessary. + + :param env: A gym/gymnasium env + :return: Patched env (gymnasium env) + """ + + # Gymnasium env, no patching to be done + if isinstance(env, gymnasium.Env): + return env + + if not gym_installed or not isinstance(env, gym.Env): + raise ValueError( + f"The environment is of type {type(env)}, not a Gymnasium " + f"environment. In this case, we expect OpenAI Gym to be " + f"installed and the environment to be an OpenAI Gym environment." + ) + + try: + import shimmy # pytype: disable=import-error + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You an OpenAI Gym environment. " + "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym environments with SB3, you need to " + "install shimmy (`pip install 'shimmy>=0.2.1'`)." + ) from e + + warnings.warn( + "You provided an OpenAI Gym environment. " + "We strongly recommend transitioning to Gymnasium environments. " + "Stable-Baselines3 is automatically wrapping your environments in a compatibility " + "layer, which could potentially cause issues." + ) + + if "seed" in signature(env.unwrapped.reset).parameters: + # Gym 0.26+ env + return shimmy.GymV26CompatibilityV0(env=env) + # Gym 0.21 env + return shimmy.GymV21CompatibilityV0(env=env) + + +def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover + """ + Takes a space and patches it to return Gymnasium Space. + This function takes the space object and returns a patched + space, using shimmy wrapper to convert it to Gymnasium, + if necessary. + + :param env: A gym/gymnasium Space + :return: Patched space (gymnasium Space) + """ + + # Gymnasium space, no convertion to be done + if isinstance(space, gymnasium.Space): + return space + + if not gym_installed or not isinstance(space, gym.Space): + raise ValueError( + f"The space is of type {type(space)}, not a Gymnasium " + f"space. In this case, we expect OpenAI Gym to be " + f"installed and the space to be an OpenAI Gym space." + ) + + try: + import shimmy # pytype: disable=import-error + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You provided an OpenAI Gym space. " + "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym space with SB3, you need to " + "install shimmy (`pip install 'shimmy>=0.2.1'`)." + ) from e + + warnings.warn( + "You loaded a model that was trained using OpenAI Gym. " + "We strongly recommend transitioning to Gymnasium by saving that model again." + ) + + return shimmy.openai_gym_compatibility._convert_space(space) diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index e674ee08a..bf375e165 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first @@ -31,7 +31,7 @@ def __init__( self, num_envs: int, n_stack: int, - observation_space: Union[spaces.Box, spaces.Dict], # Replace by Space[TObs] in gym>=0.26 + observation_space: Union[spaces.Box, spaces.Dict], channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None, ) -> None: self.n_stack = n_stack @@ -40,12 +40,12 @@ def __init__( if not isinstance(channels_order, Mapping): channels_order = {key: channels_order for key in observation_space.spaces.keys()} self.sub_stacked_observations = { - key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) + key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type] for key, subspace in observation_space.spaces.items() } self.stacked_observation_space = spaces.Dict( {key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()} - ) # type: spaces.Dict # make mypy happy + ) # type: Union[spaces.Dict, spaces.Box] # make mypy happy elif isinstance(observation_space, spaces.Box): if isinstance(channels_order, Mapping): raise TypeError("When the observation space is Box, channels_order can't be a dict.") @@ -55,7 +55,11 @@ def __init__( ) low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis) - self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) + self.stacked_observation_space = spaces.Box( + low=low, + high=high, + dtype=observation_space.dtype, # type: ignore[arg-type] + ) self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype) else: raise TypeError( @@ -97,36 +101,6 @@ def compute_stacking( stacked_shape[repeat_axis] *= n_stack return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis - def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Dict]) -> Union[spaces.Box, spaces.Dict]: - """ - This function is deprecated. - - As an alternative, use - - .. code-block:: python - - low = np.repeat(observation_space.low, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) - high = np.repeat(observation_space.high, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) - stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) - - :return: New observation space with stacked dimensions - """ - warnings.warn( - "stack_observation_space is deprecated and will be removed in the next SB3 release. " - "Please refer to the docstring for a workaround.", - DeprecationWarning, - ) - if isinstance(observation_space, spaces.Dict): - return spaces.Dict( - { - key: sub_stacked_observation.stack_observation_space(sub_stacked_observation.observation_space) - for key, sub_stacked_observation in self.sub_stacked_observations.items() - } - ) - low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) - high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) - return spaces.Box(low=low, high=high, dtype=observation_space.dtype) - def reset(self, observation: TObs) -> TObs: """ Reset the stacked_obs, add the reset observation to the stack, and return the stack. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 7ff579d30..ccefd2078 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,10 +1,11 @@ import multiprocessing as mp +import warnings from collections import OrderedDict from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import ( CloudpickleWrapper, @@ -13,33 +14,41 @@ VecEnvObs, VecEnvStepReturn, ) +from stable_baselines3.common.vec_env.patch_gym import _patch_env def _worker( - remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper + remote: mp.connection.Connection, + parent_remote: mp.connection.Connection, + env_fn_wrapper: CloudpickleWrapper, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped + from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() - env = env_fn_wrapper.var() + env = _patch_env(env_fn_wrapper.var()) + reset_info = {} while True: try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, info = env.step(data) + observation, reward, terminated, truncated, info = env.step(data) + # convert to SB3 VecEnv api + done = terminated or truncated + info["TimeLimit.truncated"] = truncated and not terminated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation = env.reset() - remote.send((observation, reward, done, info)) + observation, reset_info = env.reset() + remote.send((observation, reward, done, info, reset_info)) elif cmd == "seed": - remote.send(env.seed(data)) + remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation = env.reset() - remote.send(observation) + observation, reset_info = env.reset() + remote.send((observation, reset_info)) elif cmd == "render": - remote.send(env.render(data)) + remote.send(env.render()) elif cmd == "close": env.close() remote.close() @@ -110,7 +119,10 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() - VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + self.remotes[0].send(("get_attr", "render_mode")) + render_mode = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode) def step_async(self, actions: np.ndarray) -> None: for remote, action in zip(self.remotes, actions): @@ -120,7 +132,7 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False - obs, rews, dones, infos = zip(*results) + obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -133,7 +145,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def reset(self) -> VecEnvObs: for remote in self.remotes: remote.send(("reset", None)) - obs = [remote.recv() for remote in self.remotes] + results = [remote.recv() for remote in self.remotes] + obs, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space) def close(self) -> None: @@ -148,13 +161,17 @@ def close(self) -> None: process.join() self.closed = True - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.render_mode != "rgb_array": + warnings.warn( + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) + return [None for _ in self.remotes] for pipe in self.remotes: - # gather images from subprocesses - # `mode` will be taken into account later - pipe.send(("render", "rgb_array")) - imgs = [pipe.recv() for pipe in self.remotes] - return imgs + # gather render return from subprocesses + pipe.send(("render", None)) + outputs = [pipe.recv() for pipe in self.remotes] + return outputs def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 7d318acff..6d55db817 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Tuple import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index 98ad217f6..170f36ec8 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -2,7 +2,7 @@ from typing import List, Tuple import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 8a020ddd6..200201f06 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.stacked_observations import StackedObservations @@ -35,6 +35,9 @@ def step_wait( return observations, rewards, dones, infos def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Reset all environments + """ observation = self.venv.reset() # pytype:disable=annotation-type-mismatch observation = self.stacked_obs.reset(observation) return observation diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index ffa813fa0..ebefa82d1 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.preprocessing import is_image_space diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index b6b0ad832..beb603961 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -2,7 +2,7 @@ from typing import Dict, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 83d058abc..6f670054b 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,7 +1,7 @@ import os from typing import Callable -from gym.wrappers.monitoring import video_recorder +from gymnasium.wrappers.monitoring import video_recorder from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -47,6 +47,7 @@ def __init__( metadata = temp_env.metadata self.env.metadata = metadata + assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" self.record_video_trigger = record_video_trigger self.video_recorder = None @@ -109,4 +110,4 @@ def close(self) -> None: self.close_video_recorder() def __del__(self): - self.close() + self.close_video_recorder() diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 59909105f..b85a30f80 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer @@ -11,7 +11,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update -from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy +from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") @@ -67,6 +67,11 @@ class DQN(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + q_net: QNetwork + q_net_target: QNetwork + policy: DQNPolicy def __init__( self, @@ -131,10 +136,6 @@ def __init__( self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 - # Linear schedule will be defined in `_setup_model()` - self.exploration_schedule: Schedule - self.q_net: th.nn.Module - self.q_net_target: th.nn.Module if _init_setup_model: self._setup_model() @@ -164,8 +165,6 @@ def _setup_model(self) -> None: self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) def _create_aliases(self) -> None: - # For type checker: - assert isinstance(self.policy, DQNPolicy) self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 6991b8a80..fcdb95890 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy @@ -27,10 +27,12 @@ class QNetwork(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Discrete + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, net_arch: Optional[List[int]] = None, @@ -50,7 +52,7 @@ def __init__( self.net_arch = net_arch self.activation_fn = activation_fn self.features_dim = features_dim - action_dim = self.action_space.n # number of actions + action_dim = int(self.action_space.n) # number of actions q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) self.q_net = nn.Sequential(*q_net) @@ -61,8 +63,6 @@ def forward(self, obs: th.Tensor) -> th.Tensor: :param obs: Observation :return: The estimated Q-Value for each action. """ - # For type checker: - assert isinstance(self.features_extractor, BaseFeaturesExtractor) return self.q_net(self.extract_features(obs, self.features_extractor)) def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: @@ -105,10 +105,13 @@ class DQNPolicy(BasePolicy): excluding the learning rate, to pass to the optimizer """ + q_net: QNetwork + q_net_target: QNetwork + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -145,8 +148,6 @@ def __init__( "normalize_images": normalize_images, } - self.q_net: QNetwork - self.q_net_target: QNetwork self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: @@ -234,7 +235,7 @@ class CnnPolicy(DQNPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -279,7 +280,7 @@ class MultiInputPolicy(DQNPolicy): def __init__( self, observation_space: spaces.Dict, - action_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 5a438b411..9e06d6444 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples, TensorDict @@ -58,7 +58,6 @@ def __init__( n_sampled_goal: int = 4, goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", copy_info_dict: bool = False, - online_sampling: Optional[bool] = None, ): super().__init__( buffer_size, @@ -72,14 +71,6 @@ def __init__( self.env = env self.copy_info_dict = copy_info_dict - if online_sampling is not None: - assert online_sampling is True, "Since v1.8.0, SB3 only supports online sampling with HerReplayBuffer." - warnings.warn( - "Since v1.8.0, the `online_sampling` argument is deprecated " - "as SB3 only supports online sampling with HerReplayBuffer. It will be removed in v2.0", - stacklevel=1, - ) - # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index e01c54ab8..0df51dc61 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm @@ -183,10 +183,10 @@ def train(self) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range - clip_range = self.clip_range(self._current_progress_remaining) + clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] # Optional: clip range for the value function if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] entropy_losses = [] pg_losses, value_losses = [], [] diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index e756097b1..8902629d4 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -45,10 +45,12 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Box + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -96,9 +98,9 @@ def __init__( if clip_mean > 0.0: self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) else: - self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment] self.mu = nn.Linear(last_layer_dim, action_dim) - self.log_std = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -157,7 +159,7 @@ def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, if self.use_sde: return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) - log_std = self.log_std(latent_pi) + log_std = self.log_std(latent_pi) # type: ignore[operator] # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} @@ -205,10 +207,14 @@ class SACPolicy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -267,15 +273,17 @@ def __init__( } ) - self.actor, self.actor_target = None, None - self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( + self.actor.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) @@ -286,13 +294,17 @@ def _build(self, lr_schedule: Schedule) -> None: # Create a separate features extractor for the critic # this requires more memory and computation self.critic = self.make_critic(features_extractor=None) - critic_parameters = self.critic.parameters() + critic_parameters = list(self.critic.parameters()) # Critic target should not share the features extractor with critic self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( + critic_parameters, + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) # Target networks should always be in eval mode self.critic_target.set_training_mode(False) @@ -386,7 +398,7 @@ class CnnPolicy(SACPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -452,7 +464,7 @@ class MultiInputPolicy(SACPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 9b989850a..de344a453 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -2,16 +2,16 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy +from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") @@ -82,6 +82,10 @@ class SAC(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: SACPolicy + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic def __init__( self, @@ -137,7 +141,7 @@ def __init__( sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -147,7 +151,7 @@ def __init__( # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval - self.ent_coef_optimizer = None + self.ent_coef_optimizer: Optional[th.optim.Adam] = None if _init_setup_model: self._setup_model() @@ -161,7 +165,7 @@ def _setup_model(self) -> None: # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed - self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore else: # Force conversion # this will also throw an error for unexpected string @@ -208,7 +212,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: for gradient_step in range(gradient_steps): # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: @@ -219,7 +223,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None - if self.ent_coef_optimizer is not None: + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 @@ -233,7 +237,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper - if ent_coef_loss is not None: + if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() @@ -255,7 +259,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute critic loss critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) - critic_losses.append(critic_loss.item()) + assert isinstance(critic_loss, th.Tensor) # for type checker + critic_losses.append(critic_loss.item()) # type: ignore[union-attr] # Optimize the critic self.critic.optimizer.zero_grad() diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 6c4a1e9c3..12117df89 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy, ContinuousCritic @@ -35,7 +35,7 @@ class Actor(BasePolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -106,10 +106,15 @@ class TD3Policy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + actor_target: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -160,8 +165,6 @@ def __init__( } ) - self.actor, self.actor_target = None, None - self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor self._build(lr_schedule) @@ -174,7 +177,11 @@ def _build(self, lr_schedule: Schedule) -> None: # Initialize the target to have the same weights as the actor self.actor_target.load_state_dict(self.actor.state_dict()) - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( + self.actor.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) @@ -190,7 +197,11 @@ def _build(self, lr_schedule: Schedule) -> None: self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( + self.critic.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) # Target networks should always be in eval mode self.actor_target.set_training_mode(False) @@ -272,7 +283,7 @@ class CnnPolicy(TD3Policy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -326,7 +337,7 @@ class MultiInputPolicy(TD3Policy): def __init__( self, observation_space: spaces.Dict, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 809d39e33..10cea8efa 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -2,16 +2,16 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy +from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy SelfTD3 = TypeVar("SelfTD3", bound="TD3") @@ -70,6 +70,11 @@ class TD3(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: TD3Policy + actor: Actor + actor_target: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic def __init__( self, @@ -120,7 +125,7 @@ def __init__( seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -157,7 +162,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: for _ in range(gradient_steps): self._n_updates += 1 # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] with th.no_grad(): # Select action according to policy and add clipped noise @@ -175,6 +180,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Compute critic loss critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) + assert isinstance(critic_loss, th.Tensor) critic_losses.append(critic_loss.item()) # Optimize the critics diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 9eba1a13a..997bba239 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.1a0 +2.0.0a4 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 9dc294c6a..825002c92 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -1,10 +1,11 @@ -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device @@ -19,7 +20,7 @@ class DummyEnv(gym.Env): def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Box(1, 5, (1,)) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -27,15 +28,16 @@ def __init__(self): def reset(self): self._t = 0 obs = self._observations[0] - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - done = self._t >= self._ep_length + terminated = False + truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, terminated, truncated, {} class DummyDictEnv(gym.Env): @@ -48,7 +50,7 @@ def __init__(self): self.action_space = spaces.Box(1, 5, shape=(10, 7)) space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space}) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -56,15 +58,23 @@ def __init__(self): def reset(self): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - done = self._t >= self._ep_length + terminated = False + truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, terminated, truncated, {} + + +@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) +def test_env(env_cls): + # Check the env used for testing + # Do not warn for assymetric space + check_env(env_cls(), warn=False, skip_render_check=True) @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index a9bdc431c..f8b0e5486 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,7 +1,7 @@ import os import shutil -import gym +import gymnasium as gym import numpy as np import pytest diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 55c55ca9c..e32438c27 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import FakeImageEnv @@ -45,7 +45,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor): # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() # Test stochastic predict with channel last input if model_class == DQN: @@ -248,7 +248,7 @@ def test_channel_first_env(tmp_path): assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs, deterministic=True) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 2c114f613..14777452e 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,9 +1,12 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy @@ -13,7 +16,7 @@ class DummyDictEnv(gym.Env): """Custom Environment for testing purposes only""" - metadata = {"render.modes": ["human"]} + metadata = {"render_modes": ["human"]} def __init__( self, @@ -66,19 +69,31 @@ def seed(self, seed=None): def step(self, action): reward = 0.0 - done = False - return self.observation_space.sample(), reward, done, {} - - def compute_reward(self, achieved_goal, desired_goal, info): - return np.zeros((len(achieved_goal),)) + terminated = truncated = False + return self.observation_space.sample(), reward, terminated, truncated, {} - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.observation_space.seed(seed) + return self.observation_space.sample(), {} - def render(self, mode="human"): + def render(self): pass +@pytest.mark.parametrize("use_discrete_actions", [True, False]) +@pytest.mark.parametrize("channel_last", [True, False]) +@pytest.mark.parametrize("nested_dict_obs", [True, False]) +@pytest.mark.parametrize("vec_only", [True, False]) +def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): + # Check the env used for testing + if nested_dict_obs: + with pytest.warns(UserWarning, match="Nested observation spaces are not supported"): + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + else: + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + + @pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) def test_policy_hint(policy): # Common mistake: using the wrong policy @@ -105,7 +120,7 @@ def test_consistency(model_class): dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) dict_env.seed(10) - obs = dict_env.reset() + obs, _ = dict_env.reset() kwargs = {} n_steps = 256 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e782182f4..48eae12d0 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,7 +1,7 @@ from copy import deepcopy from typing import Tuple -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 94aeb3c97..1050e866e 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,26 +1,32 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_checker import check_env class ActionDictTestEnv(gym.Env): + metadata = {"render_modes": ["human"]} + render_mode = None + action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)}) observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 - done = True + terminated = True + truncated = False info = {} - return observation, reward, done, info + return observation, reward, terminated, truncated, info def reset(self): - return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) + return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} - def render(self, mode="human"): + def render(self): pass @@ -94,12 +100,12 @@ def test_check_env_detailed_error(obs_tuple, method): class TestEnv(gym.Env): action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) - def reset(self): - return wrong_obs if method == "reset" else good_obs + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + return wrong_obs if method == "reset" else good_obs, {} def step(self, action): obs = wrong_obs if method == "step" else good_obs - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} TestEnv.observation_space = observation_space diff --git a/tests/test_envs.py b/tests/test_envs.py index 1281bb45f..aeb248fbb 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,10 +1,10 @@ import types import warnings -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( @@ -75,6 +75,17 @@ def test_bit_flipping(kwargs): # No warnings for custom envs assert len(record) == 0 + # Remove a key, must throw an error + obs_space = env.observation_space.spaces["observation"] + del env.observation_space.spaces["observation"] + with pytest.raises(AssertionError): + check_env(env) + + # Rename a key, must throw an error + env.observation_space.spaces["obs"] = obs_space + with pytest.raises(AssertionError): + check_env(env) + def test_high_dimension_action_space(): """ @@ -87,7 +98,7 @@ def test_high_dimension_action_space(): # Patch to avoid error def patched_step(_action): - return env.observation_space.sample(), 0.0, False, {} + return env.observation_space.sample(), 0.0, False, False, {} env.step = patched_step check_env(env) @@ -110,16 +121,20 @@ def patched_step(_action): spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), + # Non zero start index + spaces.Discrete(3, start=-1), + # Non zero start index inside a Dict + spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors - env.reset = new_obs_space.sample + env.reset = lambda: (new_obs_space.sample(), {}) def patched_step(_action): - return new_obs_space.sample(), 0.0, False, {} + return new_obs_space.sample(), 0.0, False, False, {} env.step = patched_step with pytest.warns(UserWarning): @@ -145,6 +160,8 @@ def patched_step(_action): spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), + # Non zero start index + spaces.Discrete(3, start=-1), ], ) def test_non_default_action_spaces(new_action_space): @@ -155,14 +172,26 @@ def test_non_default_action_spaces(new_action_space): # No warnings for custom envs assert len(record) == 0 + # Change the action space env.action_space = new_action_space + # Discrete action space + if isinstance(new_action_space, spaces.Discrete): + with pytest.warns(UserWarning): + check_env(env) + return + + low, high = new_action_space.low[0], new_action_space.high[0] # Unbounded action space throws an error, # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): check_env(env) + # numpy >= 1.21 raises a ValueError + elif int(np.__version__.split(".")[1]) >= 21 and (low > high): + with pytest.raises(ValueError), pytest.warns(UserWarning): + check_env(env) else: with pytest.warns(UserWarning): check_env(env) @@ -176,7 +205,7 @@ def check_reset_assert_error(env, new_reset_return): """ def wrong_reset(): - return new_reset_return + return new_reset_return, {} # Patch the reset method with a wrong one env.reset = wrong_reset @@ -194,6 +223,11 @@ def test_common_failures_reset(): # The observation is not a numpy array check_reset_assert_error(env, 1) + # Return only obs (gym < 0.26) + env.reset = env.observation_space.sample + with pytest.raises(AssertionError): + check_env(env) + # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) @@ -206,10 +240,10 @@ def test_common_failures_reset(): wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) - obs = env.reset() + obs, _ = env.reset() def wrong_reset(self): - return {"img": obs["img"], "vec": obs["img"]} + return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: @@ -242,33 +276,38 @@ def test_common_failures_step(): env = IdentityEnvBox() # Wrong shape for the observation - check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) + check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {})) # Obs is not a numpy array - check_step_assert_error(env, (1, 1.0, False, {})) + check_step_assert_error(env, (1, 1.0, False, False, {})) # Return a wrong reward - check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) + check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {})) # Info dict is not returned - check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False)) + + # Truncated is not returned (gym < 0.26) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {})) # Done is not a boolean - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {})) + # Truncated is not a boolean + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) - obs = env.reset() + obs, _ = env.reset() def wrong_step(self, action): - return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} + return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: diff --git a/tests/test_gae.py b/tests/test_gae.py index c90470f00..83b95a4c0 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,11 +1,14 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.policies import ActorCriticPolicy @@ -20,20 +23,26 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.observation_space.seed(seed) self.n_steps = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): self.n_steps += 1 - done = False + terminated = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 - done = True + terminated = True + # To simplify GAE computation checks, + # we do not consider truncation here. + # Truncations are checked in InfiniteHorizonEnv + truncated = False - return self.observation_space.sample(), reward, done, {} + return self.observation_space.sample(), reward, terminated, truncated, {} class InfiniteHorizonEnv(gym.Env): @@ -44,13 +53,16 @@ def __init__(self, n_states=4): self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + self.current_state = 0 - return self.current_state + return self.current_state, {} def step(self, action): self.current_state = (self.current_state + 1) % self.n_states - return self.current_state, 1.0, False, {} + return self.current_state, 1.0, False, False, {} class CheckGAECallback(BaseCallback): @@ -110,6 +122,12 @@ def forward(self, obs, deterministic=False): return actions, values, log_prob +@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) + + @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("gae_lambda", [1.0, 0.9]) @pytest.mark.parametrize("gamma", [1.0, 0.99]) diff --git a/tests/test_her.py b/tests/test_her.py index 57a8f0ff0..e17336bc5 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -251,7 +251,7 @@ def env_fn(): train_freq=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), - seed=1, + seed=0, ) model.learn(200) old_replay_buffer = deepcopy(model.replay_buffer) diff --git a/tests/test_identity.py b/tests/test_identity.py index cc7746bc7..3118a4d7a 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -15,21 +15,17 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 3000 + n_steps = 2500 if model_class == DQN: kwargs = dict(learning_starts=0) - n_steps = 4000 # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - elif model_class == A2C: - # slightly higher budget - n_steps = 3500 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) - obs = env.reset() + obs, _ = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -38,16 +34,19 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] + n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) + if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) kwargs["action_noise"] = action_noise elif model_class in [A2C]: kwargs["policy_kwargs"]["log_std_init"] = -0.5 + elif model_class == PPO: + kwargs = dict(n_steps=512, n_epochs=5) - model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) diff --git a/tests/test_logger.py b/tests/test_logger.py index 54fc864c8..c33bd4c03 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,15 +5,16 @@ from typing import Sequence from unittest import mock -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from matplotlib import pyplot as plt from pandas.errors import EmptyDataError from stable_baselines3 import A2C, DQN +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, INFO, @@ -352,12 +353,18 @@ def __init__(self, delay: float = 0.01): self.action_space = spaces.Discrete(2) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): time.sleep(self.delay) obs = self.observation_space.sample() - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} + + +@pytest.mark.parametrize("env_cls", [TimeDelayEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) class InMemoryLogger(Logger): diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 17002f39a..2428d2216 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -2,7 +2,7 @@ import os import uuid -import gym +import gymnasium as gym import pandas from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results @@ -13,7 +13,7 @@ def test_monitor(tmp_path): Test the monitor wrapper """ env = gym.make("CartPole-v1") - env.seed(0) + env.reset(seed=0) monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() @@ -22,10 +22,10 @@ def test_monitor(tmp_path): ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): - _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample()) + _, reward, terminated, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward - if done: + if terminated or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() @@ -64,7 +64,7 @@ def test_monitor_load_results(tmp_path): """ tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") - env1.seed(0) + env1.reset(seed=0) monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) @@ -75,8 +75,8 @@ def test_monitor_load_results(tmp_path): monitor_env1.reset() episode_count1 = 0 for _ in range(1000): - _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) - if done: + _, _, terminated, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) + if terminated or truncated: episode_count1 += 1 monitor_env1.reset() @@ -84,7 +84,7 @@ def test_monitor_load_results(tmp_path): assert results_size1 == episode_count1 env2 = gym.make("CartPole-v1") - env2.seed(0) + env2.reset(seed=0) monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) @@ -98,8 +98,8 @@ def test_monitor_load_results(tmp_path): monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): - _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) - if done: + _, _, terminated, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) + if terminated or truncated: episode_count2 += 1 monitor_env2.reset() diff --git a/tests/test_predict.py b/tests/test_predict.py index 579abff77..247fe9172 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,10 +1,11 @@ -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -30,10 +31,16 @@ def __init__(self): self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, {} + return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} + + +@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -70,7 +77,7 @@ def test_predict(model_class, env_id, device): env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs) assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape @@ -96,7 +103,7 @@ def test_dqn_epsilon_greedy(): env = IdentityEnv(2) model = DQN("MlpPolicy", env) model.exploration_rate = 1.0 - obs = env.reset() + obs, _ = env.reset() # is vectorized should not crash with discrete obs action, _ = model.predict(obs, deterministic=False) assert env.action_space.contains(action) @@ -107,5 +114,5 @@ def test_subclassed_space_env(model_class): env = CustomSubClassedSpaceEnv() model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32])) model.learn(300) - obs = env.reset() + obs, _ = env.reset() env.step(model.predict(obs)) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 89f869b45..b8a5891c7 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,5 +1,5 @@ import torch -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs diff --git a/tests/test_run.py b/tests/test_run.py index ca7548ff9..31c7b956e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import numpy as np import pytest diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9d3d537b7..2f227adf6 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -8,7 +8,7 @@ from collections import OrderedDict from copy import deepcopy -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 6dd6dc419..6d18fcef8 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,9 +1,12 @@ -import gym +from typing import Dict, Optional + +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy @@ -14,11 +17,13 @@ def __init__(self, nvec): self.observation_space = spaces.MultiDiscrete(nvec) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultiBinary(gym.Env): @@ -27,11 +32,13 @@ def __init__(self, n): self.observation_space = spaces.MultiBinary(n) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultidimensionalAction(gym.Env): @@ -41,10 +48,16 @@ def __init__(self): self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} + + +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) +def test_env(env): + # Check the env used for testing + check_env(env, skip_render_check=True) @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index f3a012fd4..dcbda74e1 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -1,6 +1,6 @@ from typing import Union -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_utils.py b/tests/test_utils.py index 88de942e5..02128edb7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,11 @@ import os import shutil -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces import stable_baselines3 as sb3 from stable_baselines3 import A2C @@ -28,7 +28,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv]) -@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit]) +@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.RecordEpisodeStatistics]) def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0) @@ -194,7 +194,7 @@ def dummy_callback(locals_, _globals): policy.n_callback_calls = 0 # type: ignore[assignment, attr-defined] _, episode_lengths = evaluate_policy( policy, # type: ignore[arg-type] - model.get_env(), + model.get_env(), # type: ignore[arg-type] n_eval_episodes, deterministic=True, render=False, @@ -213,7 +213,7 @@ def dummy_callback(locals_, _globals): episode_rewards, _ = evaluate_policy( policy, # type: ignore[arg-type] - model.get_env(), + model.get_env(), # type: ignore[arg-type] n_eval_episodes, return_episode_rewards=True, ) @@ -239,17 +239,18 @@ def __init__(self, env): self.needs_reset = True def step(self, action): - obs, reward, done, info = self.env.step(action) - self.needs_reset = done + obs, reward, terminated, truncated, info = self.env.step(action) + self.needs_reset = terminated or truncated self.last_obs = obs - return obs, reward, True, info + return obs, reward, True, truncated, info def reset(self, **kwargs): + info = {} if self.needs_reset: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) self.last_obs = obs self.needs_reset = False - return self.last_obs + return self.last_obs, info @pytest.mark.parametrize("n_envs", [1, 2, 5, 7]) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 962355782..1253be6e5 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -1,7 +1,7 @@ -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -24,13 +24,13 @@ def step(action): obs = float("inf") else: obs = 0 - return [obs], 0.0, False, {} + return [obs], 0.0, False, False, {} @staticmethod def reset(): - return [0.0] + return [0.0], {} - def render(self, mode="human", close=False): + def render(self): pass diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index ae05947c5..6bc7e74db 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,12 +2,16 @@ import functools import itertools import multiprocessing +import os +import warnings +from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize @@ -17,7 +21,7 @@ class CustomGymEnv(gym.Env): - def __init__(self, space): + def __init__(self, space, render_mode: str = "rgb_array"): """ Custom gym environment for testing purposes """ @@ -25,24 +29,28 @@ def __init__(self, space): self.observation_space = space self.current_step = 0 self.ep_length = 4 + self.render_mode = render_mode - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.seed(seed) self.current_step = 0 self._choose_next_state() - return self.state + return self.state, {} def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() - def render(self, mode="human"): - if mode == "rgb_array": + def render(self): + if self.render_mode == "rgb_array": return np.zeros((4, 4, 3)) def seed(self, seed=None): @@ -91,9 +99,20 @@ def make_env(): # Test seed method vec_env.seed(0) + # Test render method call - # vec_env.render() # we need a X server to test the "human" mode - vec_env.render(mode="rgb_array") + array_explicit_mode = vec_env.render(mode="rgb_array") + # test render without argument (new gym API style) + array_implicit_mode = vec_env.render() + assert np.array_equal(array_implicit_mode, array_explicit_mode) + + # test warning if you try different render mode + with pytest.warns(UserWarning): + vec_env.render(mode="something_else") + + # we need a X server to test the "human" mode (uses OpenCV) + # vec_env.render(mode="human") + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value @@ -155,13 +174,14 @@ def __init__(self, max_steps): def reset(self): self.current_step = 0 - return np.array([self.current_step], dtype="int") + return np.array([self.current_step], dtype="int"), {} def step(self, action): prev_step = self.current_step self.current_step += 1 - done = self.current_step >= self.max_steps - return np.array([prev_step], dtype="int"), 0.0, done, {} + terminated = False + truncated = self.current_step >= self.max_steps + return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @@ -455,6 +475,23 @@ def make_monitored_env(): assert vec_env.env_is_wrapped(Monitor) == [False, True] +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_backward_compat_seed(vec_env_class): + def make_env(): + env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + # Patch reset function to remove seed param + env.reset = lambda: (env.observation_space.sample(), {}) + env.seed = env.observation_space.seed + return env + + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + vec_env.seed(3) + obs = vec_env.reset() + vec_env.seed(3) + new_obs = vec_env.reset() + assert np.allclose(new_obs, obs) + + @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_seeding(vec_env_class): def make_env(): @@ -484,3 +521,63 @@ def make_env(): assert not np.allclose(rewards[1], rewards[2]) vec_env.close() + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_render(vec_env_class): + # Skip if no X-Server + if not os.environ.get("DISPLAY"): + pytest.skip("No X-Server") + + env_id = "Pendulum-v1" + # DummyVecEnv human render is currently + # buggy because of gym: + # https://github.com/carlosluis/stable-baselines3/pull/3#issuecomment-1356863808 + n_envs = 2 + # Human render + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="human"), + ) + + vec_env.reset() + vec_env.render() + + with pytest.warns(UserWarning): + vec_env.render("rgb_array") + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + + vec_env.close() + # rgb_array render, which allows human_render + # thanks to OpenCV + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="rgb_array"), + ) + + vec_env.reset() + with warnings.catch_warnings(record=True) as record: + vec_env.render() + vec_env.render("rgb_array") + vec_env.render(mode="human") + + # No warnings for using human mode + assert len(record) == 0 + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + vec_env.close() diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 17728bbd3..8c8dccd47 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -1,5 +1,5 @@ import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor @@ -41,7 +41,7 @@ def reset(self): self.n_steps = 0 return {"rgb": np.zeros((self.num_envs, 86, 86))} - def render(self, mode="human", close=False): + def render(self, close=False): pass diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 0a146a057..1a0e94d90 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,8 +2,9 @@ import json import os import uuid +import warnings -import gym +import gymnasium as gym import pandas import pytest @@ -132,8 +133,9 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) - env.seed(0) + env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 27bba9aad..ae5904795 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,10 +1,10 @@ import operator -from typing import Any, Dict +from typing import Any, Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.envs import FakeImageEnv @@ -35,11 +35,15 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {} + terminated = False + truncated = self.t == len(self.returned_rewards) + return np.array([returned_value]), returned_value, terminated, truncated, {} - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) self.t = 0 - return np.array([self.returned_rewards[self.return_reward_idx]]) + return np.array([self.returned_rewards[self.return_reward_idx]]), {} class DummyDictEnv(gym.Env): @@ -58,14 +62,16 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) - done = np.random.rand() > 0.8 - return obs, reward, done, {} + terminated = np.random.rand() > 0.8 + return obs, reward, terminated, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) @@ -88,13 +94,15 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() - done = np.random.rand() > 0.8 - return obs, 0.0, done, {} + terminated = np.random.rand() > 0.8 + return obs, 0.0, terminated, False, {} def allclose(obs_1, obs_2): diff --git a/tests/test_vec_stacked_obs.py b/tests/test_vec_stacked_obs.py index 0a7aa39f1..4b2c61444 100644 --- a/tests/test_vec_stacked_obs.py +++ b/tests/test_vec_stacked_obs.py @@ -1,5 +1,5 @@ import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.stacked_observations import StackedObservations