Skip to content

Commit

Permalink
Merge pull request #5 from DLR-RM/check_test_env
Browse files Browse the repository at this point in the history
Check test env in tests
  • Loading branch information
araffin authored Jan 25, 2023
2 parents 669ef02 + 7460782 commit 0431c7a
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 5 deletions.
12 changes: 10 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gym import spaces

from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
Expand All @@ -19,7 +20,7 @@ class DummyEnv(gym.Env):
def __init__(self):
self.action_space = spaces.Box(1, 5, (1,))
self.observation_space = spaces.Box(1, 5, (1,))
self._observations = [1, 2, 3, 4, 5]
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
self._rewards = [1, 2, 3, 4, 5]
self._t = 0
self._ep_length = 100
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self):
self.action_space = spaces.Box(1, 5, (1,))
space = spaces.Box(1, 5, (1,))
self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space})
self._observations = [1, 2, 3, 4, 5]
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
self._rewards = [1, 2, 3, 4, 5]
self._t = 0
self._ep_length = 100
Expand All @@ -66,6 +67,13 @@ def step(self, action):
return obs, reward, done, truncated, {}


@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv])
def test_env(env_cls):
# Check the env used for testing
# Do not warn for assymetric space
check_env(env_cls(), warn=False, skip_render_check=True)


@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
def test_replay_buffer_normalization(replay_buffer_cls):
env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]
Expand Down
17 changes: 14 additions & 3 deletions tests/test_dict_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gym import spaces

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
Expand Down Expand Up @@ -71,9 +72,6 @@ def step(self, action):
done = truncated = False
return self.observation_space.sample(), reward, done, truncated, {}

def compute_reward(self, achieved_goal, desired_goal, info):
return np.zeros((len(achieved_goal),))

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
self.observation_space.seed(seed)
Expand All @@ -83,6 +81,19 @@ def render(self):
pass


@pytest.mark.parametrize("use_discrete_actions", [True, False])
@pytest.mark.parametrize("channel_last", [True, False])
@pytest.mark.parametrize("nested_dict_obs", [True, False])
@pytest.mark.parametrize("vec_only", [True, False])
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
# Check the env used for testing
if nested_dict_obs:
with pytest.warns(UserWarning, match="Nested observation spaces are not supported"):
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
else:
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))


@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"])
def test_policy_hint(policy):
# Common mistake: using the wrong policy
Expand Down
7 changes: 7 additions & 0 deletions tests/test_gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.policies import ActorCriticPolicy


Expand Down Expand Up @@ -121,6 +122,12 @@ def forward(self, obs, deterministic=False):
return actions, values, log_prob


@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("gae_lambda", [1.0, 0.9])
@pytest.mark.parametrize("gamma", [1.0, 0.99])
Expand Down
7 changes: 7 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pandas.errors import EmptyDataError

from stable_baselines3 import A2C, DQN
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import (
DEBUG,
INFO,
Expand Down Expand Up @@ -363,6 +364,12 @@ def step(self, action):
return obs, 0.0, True, False, {}


@pytest.mark.parametrize("env_cls", [TimeDelayEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)


class InMemoryLogger(Logger):
"""
Logger that keeps key/value pairs in memory without any writers.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gym import spaces

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import IdentityEnv
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
Expand Down Expand Up @@ -36,6 +37,12 @@ def step(self, action):
return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {}


@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)


@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_auto_wrap(model_class):
"""Test auto wrapping of env into a VecEnv."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gym import spaces

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

Expand Down Expand Up @@ -53,6 +54,12 @@ def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}


@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
def test_env(env):
# Check the env used for testing
check_env(env, skip_render_check=True)


@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
def test_identity_spaces(model_class, env):
Expand Down

0 comments on commit 0431c7a

Please sign in to comment.