diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c68653934..ebbb9f262 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a2 (WIP) +Release 1.8.0a3 (WIP) -------------------------- @@ -14,6 +14,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``repeat_action_probability`` argument in ``AtariWrapper``. +- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper`` `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 32c1bda63..1e06a7fb7 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -12,13 +12,39 @@ from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +class StickyActionEnv(gym.Wrapper): + """ + Sticky action. + + Paper: https://arxiv.org/abs/1709.06009 + Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment + + :param env: Environment to wrap + :param action_repeat_probability: Probability of repeating the last action + """ + + def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: + super().__init__(env) + self.action_repeat_probability = action_repeat_probability + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + def reset(self, **kwargs) -> GymObs: + self._sticky_action = 0 # NOOP + return self.env.reset(**kwargs) + + def step(self, action: int) -> GymStepReturn: + if self.np_random.random() >= self.action_repeat_probability: + self._sticky_action = action + return self.env.step(self._sticky_action) + + class NoopResetEnv(gym.Wrapper): """ Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. - :param env: the environment to wrap - :param noop_max: the maximum value of no-ops to run + :param env: Environment to wrap + :param noop_max: Maximum value of no-ops to run """ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: @@ -47,7 +73,7 @@ class FireResetEnv(gym.Wrapper): """ Take action on reset for environments that are fixed until firing. - :param env: the environment to wrap + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -71,7 +97,7 @@ class EpisodicLifeEnv(gym.Wrapper): Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. - :param env: the environment to wrap + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -120,9 +146,11 @@ def reset(self, **kwargs) -> np.ndarray: class MaxAndSkipEnv(gym.Wrapper): """ Return only every ``skip``-th frame (frameskipping) + and return the max between the two last frames. - :param env: the environment - :param skip: number of ``skip``-th frame + :param env: Environment to wrap + :param skip: Number of ``skip``-th frame + The same action will be taken ``skip`` times. """ def __init__(self, env: gym.Env, skip: int = 4) -> None: @@ -159,9 +187,9 @@ def step(self, action: int) -> GymStepReturn: class ClipRewardEnv(gym.RewardWrapper): """ - Clips the reward to {+1, 0, -1} by its sign. + Clip the reward to {+1, 0, -1} by its sign. - :param env: the environment + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -182,9 +210,9 @@ class WarpFrame(gym.ObservationWrapper): Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. - :param env: the environment - :param width: - :param height: + :param env: Environment to wrap + :param width: New frame width + :param height: New frame height """ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: @@ -213,20 +241,29 @@ class AtariWrapper(gym.Wrapper): Specifically: - * NoopReset: obtain initial state by taking random number of no-ops on reset. + * Noop reset: obtain initial state by taking random number of no-ops on reset. * Frame skipping: 4 by default * Max-pooling: most recent two observations * Termination signal when a life is lost. * Resize to a square image: 84x84 by default * Grayscale observation * Clip reward to {-1, 0, 1} + * Sticky actions: disabled by default + + See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/ + for a visual explanation. + + .. warning:: + Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. - :param env: gym environment - :param noop_max: max number of no-ops - :param frame_skip: the frequency at which the agent experiences the game. - :param screen_size: resize Atari frame - :param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost. + :param env: Environment to wrap + :param noop_max: Max number of no-ops + :param frame_skip: Frequency at which the agent experiences the game. + This correspond to repeating the action ``frame_skip`` times. + :param screen_size: Resize Atari frame + :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. + :param action_repeat_probability: Probability of repeating the last action """ def __init__( @@ -237,9 +274,15 @@ def __init__( screen_size: int = 84, terminal_on_life_loss: bool = True, clip_reward: bool = True, + action_repeat_probability: float = 0.0, ) -> None: - env = NoopResetEnv(env, noop_max=noop_max) - env = MaxAndSkipEnv(env, skip=frame_skip) + if action_repeat_probability > 0.0: + env = StickyActionEnv(env, action_repeat_probability) + if noop_max > 0: + env = NoopResetEnv(env, noop_max=noop_max) + # frame_skip=1 is the same as no frame-skip (action repeat) + if frame_skip > 1: + env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c3d22c01c..f5e92647d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a2 +1.8.0a3 diff --git a/tests/test_utils.py b/tests/test_utils.py index 83d695afd..e5236e8d7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ import stable_baselines3 as sb3 from stable_baselines3 import A2C -from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv +from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor @@ -55,30 +55,54 @@ def test_make_vec_env_func_checker(): env.close() -@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"]) -@pytest.mark.parametrize("n_envs", [1, 2]) -@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)]) -def test_make_atari_env(env_id, n_envs, wrapper_kwargs): - env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0) +# Use Asterix as it does not requires fire reset +@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"]) +@pytest.mark.parametrize("noop_max", [0, 10]) +@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25]) +@pytest.mark.parametrize("frame_skip", [1, 4]) +@pytest.mark.parametrize("screen_size", [60]) +@pytest.mark.parametrize("terminal_on_life_loss", [True, False]) +@pytest.mark.parametrize("clip_reward", [True]) +def test_make_atari_env( + env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward +): + n_envs = 2 + wrapper_kwargs = { + "noop_max": noop_max, + "action_repeat_probability": action_repeat_probability, + "frame_skip": frame_skip, + "screen_size": screen_size, + "terminal_on_life_loss": terminal_on_life_loss, + "clip_reward": clip_reward, + } + venv = make_atari_env( + env_id, + n_envs=2, + wrapper_kwargs=wrapper_kwargs, + monitor_dir=None, + seed=0, + ) - assert env.num_envs == n_envs + assert venv.num_envs == n_envs - obs = env.reset() + needs_fire_reset = env_id == "BreakoutNoFrameskip-v4" + expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset + expected_frame_number_high = expected_frame_number_low + noop_max + expected_shape = (n_envs, screen_size, screen_size, 1) - new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)]) + obs = venv.reset() + frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] + for frame_number in frame_numbers: + assert expected_frame_number_low <= frame_number <= expected_frame_number_high + assert obs.shape == expected_shape - assert obs.shape == new_obs.shape + new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)]) - # Wrapped into DummyVecEnv - wrapped_atari_env = env.envs[0] - if wrapper_kwargs is not None: - assert obs.shape == (n_envs, 60, 60, 1) - assert wrapped_atari_env.observation_space.shape == (60, 60, 1) - assert not isinstance(wrapped_atari_env.env, ClipRewardEnv) - else: - assert obs.shape == (n_envs, 84, 84, 1) - assert wrapped_atari_env.observation_space.shape == (84, 84, 1) - assert isinstance(wrapped_atari_env.env, ClipRewardEnv) + new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] + for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers): + assert new_frame_number - frame_number == frame_skip + assert new_obs.shape == expected_shape + if clip_reward: assert np.max(np.abs(reward)) < 1.0