From d702f422dee68bdc6a6fbdc15ff6f7e43a88fffa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 18 Jan 2023 16:47:46 +0100 Subject: [PATCH 01/16] repeat_action_probability --- docs/misc/changelog.rst | 1 + stable_baselines3/common/atari_wrappers.py | 4 ++++ stable_baselines3/common/env_util.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index db9d6c627..bd4d31f5d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -13,6 +13,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``repeat_action_probability`` argument in ``make_atari_env``. `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 785d911f0..10eca15cf 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -218,6 +218,10 @@ class AtariWrapper(gym.Wrapper): * Grayscale observation * Clip reward to {-1, 0, 1} + This wrapper does not implement sticky actions (repeat action probability) as this is handled + at the simulator level. If you want to change the default value of `repeat_action_probability` + please refer to the documentation of `common.env_util.make_atari_env`. + :param env: gym environment :param noop_max: max number of no-ops :param frame_skip: the frequency at which the agent experiences the game. diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index c85d1472b..63a0d7c48 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -119,6 +119,7 @@ def make_atari_env( vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, monitor_kwargs: Optional[Dict[str, Any]] = None, + repeat_action_probability: float = 0.0, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. @@ -136,8 +137,11 @@ def make_atari_env( :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. + :param repeat_action_probability: Sticky actions, i.e. action repeat probability in ALE configuration :return: The wrapped environment """ + env_kwargs = {} if env_kwargs is None else env_kwargs + env_kwargs["repeat_action_probability"] = repeat_action_probability return make_vec_env( env_id, n_envs=n_envs, From ffa0818cc27a96233fa6048577f057a0447f18e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 18 Jan 2023 17:45:01 +0100 Subject: [PATCH 02/16] Add test --- tests/test_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 83d695afd..b4f75d6d4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -58,8 +58,16 @@ def test_make_vec_env_func_checker(): @pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)]) -def test_make_atari_env(env_id, n_envs, wrapper_kwargs): - env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0) +@pytest.mark.parametrize("repeat_action_probability", [0.0, 0.25]) +def test_make_atari_env(env_id, n_envs, wrapper_kwargs, repeat_action_probability): + env = make_atari_env( + env_id, + n_envs, + wrapper_kwargs=wrapper_kwargs, + monitor_dir=None, + seed=0, + repeat_action_probability=repeat_action_probability, + ) assert env.num_envs == n_envs From 8668077e1bf8e2b2717fff1470db8be410b2fb6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 18 Jan 2023 18:01:26 +0100 Subject: [PATCH 03/16] Undo atari wrapper doc change since CI fails --- stable_baselines3/common/atari_wrappers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 10eca15cf..785d911f0 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -218,10 +218,6 @@ class AtariWrapper(gym.Wrapper): * Grayscale observation * Clip reward to {-1, 0, 1} - This wrapper does not implement sticky actions (repeat action probability) as this is handled - at the simulator level. If you want to change the default value of `repeat_action_probability` - please refer to the documentation of `common.env_util.make_atari_env`. - :param env: gym environment :param noop_max: max number of no-ops :param frame_skip: the frequency at which the agent experiences the game. From a87cb7d8ca1c838bdbc18da699fd37a624852769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 11:06:24 +0100 Subject: [PATCH 04/16] remove action_repeat_probability from make_atari_env --- stable_baselines3/common/env_util.py | 4 ---- tests/test_utils.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 63a0d7c48..c85d1472b 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -119,7 +119,6 @@ def make_atari_env( vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, monitor_kwargs: Optional[Dict[str, Any]] = None, - repeat_action_probability: float = 0.0, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. @@ -137,11 +136,8 @@ def make_atari_env( :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. - :param repeat_action_probability: Sticky actions, i.e. action repeat probability in ALE configuration :return: The wrapped environment """ - env_kwargs = {} if env_kwargs is None else env_kwargs - env_kwargs["repeat_action_probability"] = repeat_action_probability return make_vec_env( env_id, n_envs=n_envs, diff --git a/tests/test_utils.py b/tests/test_utils.py index b4f75d6d4..cb01d7f71 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -58,15 +58,13 @@ def test_make_vec_env_func_checker(): @pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)]) -@pytest.mark.parametrize("repeat_action_probability", [0.0, 0.25]) -def test_make_atari_env(env_id, n_envs, wrapper_kwargs, repeat_action_probability): +def test_make_atari_env(env_id, n_envs, wrapper_kwargs): env = make_atari_env( env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0, - repeat_action_probability=repeat_action_probability, ) assert env.num_envs == n_envs From 3592a344ed8d428d49e9c83719e9cad2fc1f46f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 11:06:50 +0100 Subject: [PATCH 05/16] Add sticky action wrapper and improve documentation --- stable_baselines3/common/atari_wrappers.py | 62 ++++++++++++++++------ 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 785d911f0..739b3f407 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -17,8 +17,8 @@ class NoopResetEnv(gym.Wrapper): Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. - :param env: the environment to wrap - :param noop_max: the maximum value of no-ops to run + :param env: Environment to wrap + :param noop_max: Maximum value of no-ops to run """ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: @@ -47,7 +47,7 @@ class FireResetEnv(gym.Wrapper): """ Take action on reset for environments that are fixed until firing. - :param env: the environment to wrap + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -66,12 +66,36 @@ def reset(self, **kwargs) -> np.ndarray: return obs +class StickyActionEnv(gym.Wrapper): + """ + Sticky action. + + :param env: Environment to wrap + :action_repeat_probability: Probability of repeating the last action + """ + + def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: + super().__init__(env) + self.action_repeat_probability = action_repeat_probability + + def reset(self, **kwargs) -> GymObs: + self._last_action = None + return self.env.reset(**kwargs) + + def step(self, action: int) -> GymStepReturn: + if self._last_action is not None: # _last_action is set to None when reset + if self.np_random.random() < self.action_repeat_probability: + action = self._last_action + self._last_action = action + return self.env.step(action) + + class EpisodicLifeEnv(gym.Wrapper): """ Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. - :param env: the environment to wrap + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -115,8 +139,8 @@ class MaxAndSkipEnv(gym.Wrapper): """ Return only every ``skip``-th frame (frameskipping) - :param env: the environment - :param skip: number of ``skip``-th frame + :param env: Environment to wrap + :param skip: Number of ``skip``-th frame """ def __init__(self, env: gym.Env, skip: int = 4) -> None: @@ -156,9 +180,9 @@ def reset(self, **kwargs) -> GymObs: class ClipRewardEnv(gym.RewardWrapper): """ - Clips the reward to {+1, 0, -1} by its sign. + Clip the reward to {+1, 0, -1} by its sign. - :param env: the environment + :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: @@ -179,9 +203,9 @@ class WarpFrame(gym.ObservationWrapper): Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. - :param env: the environment - :param width: - :param height: + :param env: Environment to wrap + :param width: New frame width + :param height: New frame height """ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: @@ -210,19 +234,21 @@ class AtariWrapper(gym.Wrapper): Specifically: - * NoopReset: obtain initial state by taking random number of no-ops on reset. + * Noop reset: obtain initial state by taking random number of no-ops on reset. * Frame skipping: 4 by default + * Sticky actions: disabled by default * Max-pooling: most recent two observations * Termination signal when a life is lost. * Resize to a square image: 84x84 by default * Grayscale observation * Clip reward to {-1, 0, 1} - :param env: gym environment - :param noop_max: max number of no-ops - :param frame_skip: the frequency at which the agent experiences the game. - :param screen_size: resize Atari frame - :param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost. + :param env: Environment to wrap + :param noop_max: Max number of no-ops + :param action_repeat_probability: + :param frame_skip: Frequency at which the agent experiences the game. + :param screen_size: Resize Atari frame + :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. """ @@ -230,12 +256,14 @@ def __init__( self, env: gym.Env, noop_max: int = 30, + action_repeat_probability: float = 0.0, frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = True, clip_reward: bool = True, ) -> None: env = NoopResetEnv(env, noop_max=noop_max) + env = StickyActionEnv(env, action_repeat_probability) env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) From 47e4fb671fa5df9352d2ab05a571a094b3111eab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 11:07:01 +0100 Subject: [PATCH 06/16] Update changelog --- docs/misc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bd4d31f5d..fb47d0879 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -13,7 +13,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -- Added ``repeat_action_probability`` argument in ``make_atari_env``. +- Added ``repeat_action_probability`` argument in ``AtariWrapper``. `SB3-Contrib`_ ^^^^^^^^^^^^^^ From 970c7539c0322d3ddeef038d73f755826061abd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 13:32:14 +0100 Subject: [PATCH 07/16] handle the case noop_max=0 --- stable_baselines3/common/atari_wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 739b3f407..b50462664 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -262,7 +262,8 @@ def __init__( terminal_on_life_loss: bool = True, clip_reward: bool = True, ) -> None: - env = NoopResetEnv(env, noop_max=noop_max) + if noop_max > 0: + env = NoopResetEnv(env, noop_max=noop_max) env = StickyActionEnv(env, action_repeat_probability) env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: From 8fc77a0b8dbfd4a5bd980dd348ed42719c77d730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 13:32:24 +0100 Subject: [PATCH 08/16] Update tests --- tests/test_utils.py | 58 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index cb01d7f71..cb8dddf46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -55,36 +55,54 @@ def test_make_vec_env_func_checker(): env.close() -@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"]) -@pytest.mark.parametrize("n_envs", [1, 2]) -@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)]) -def test_make_atari_env(env_id, n_envs, wrapper_kwargs): - env = make_atari_env( +# Use Asterix as it does not requires fire reset +@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"]) +@pytest.mark.parametrize("noop_max", [0, 10]) +@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25]) +@pytest.mark.parametrize("frame_skip", [1, 4]) +@pytest.mark.parametrize("screen_size", [60]) +@pytest.mark.parametrize("terminal_on_life_loss", [True, False]) +@pytest.mark.parametrize("clip_reward", [True]) +def test_make_atari_env( + env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward +): + n_envs = 2 + wrapper_kwargs = { + "noop_max": noop_max, + "action_repeat_probability": action_repeat_probability, + "frame_skip": frame_skip, + "screen_size": screen_size, + "terminal_on_life_loss": terminal_on_life_loss, + "clip_reward": clip_reward, + } + venv = make_atari_env( env_id, - n_envs, + n_envs=2, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0, ) - assert env.num_envs == n_envs + assert venv.num_envs == n_envs - obs = env.reset() + needs_fire_reset = env_id == "BreakoutNoFrameskip-v4" + expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset + expected_frame_number_high = expected_frame_number_low + noop_max + expected_shape = (n_envs, screen_size, screen_size, 1) - new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)]) + obs = venv.reset() + frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] + for frame_number in frame_numbers: + assert expected_frame_number_low <= frame_number <= expected_frame_number_high + assert obs.shape == expected_shape - assert obs.shape == new_obs.shape + new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)]) - # Wrapped into DummyVecEnv - wrapped_atari_env = env.envs[0] - if wrapper_kwargs is not None: - assert obs.shape == (n_envs, 60, 60, 1) - assert wrapped_atari_env.observation_space.shape == (60, 60, 1) - assert not isinstance(wrapped_atari_env.env, ClipRewardEnv) - else: - assert obs.shape == (n_envs, 84, 84, 1) - assert wrapped_atari_env.observation_space.shape == (84, 84, 1) - assert isinstance(wrapped_atari_env.env, ClipRewardEnv) + new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] + for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers): + assert new_frame_number - frame_number == frame_skip + assert new_obs.shape == expected_shape + if clip_reward: assert np.max(np.abs(reward)) < 1.0 From c588a37aecdca4c286136ad81e922644d87d59c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 15:18:40 +0100 Subject: [PATCH 09/16] Comply to ALE implementation --- stable_baselines3/common/atari_wrappers.py | 49 +++++++++++----------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index b50462664..38475839e 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -12,6 +12,29 @@ from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +class StickyActionEnv(gym.Wrapper): + """ + Sticky action. + + :param env: Environment to wrap + :action_repeat_probability: Probability of repeating the last action + """ + + def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: + super().__init__(env) + self.action_repeat_probability = action_repeat_probability + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + def reset(self, **kwargs) -> GymObs: + self._sticky_action = 0 # NOOP + return self.env.reset(**kwargs) + + def step(self, action: int) -> GymStepReturn: + if self.np_random.random() >= self.action_repeat_probability: + self._sticky_action = action + return self.env.step(self._sticky_action) + + class NoopResetEnv(gym.Wrapper): """ Sample initial states by taking random number of no-ops on reset. @@ -66,30 +89,6 @@ def reset(self, **kwargs) -> np.ndarray: return obs -class StickyActionEnv(gym.Wrapper): - """ - Sticky action. - - :param env: Environment to wrap - :action_repeat_probability: Probability of repeating the last action - """ - - def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: - super().__init__(env) - self.action_repeat_probability = action_repeat_probability - - def reset(self, **kwargs) -> GymObs: - self._last_action = None - return self.env.reset(**kwargs) - - def step(self, action: int) -> GymStepReturn: - if self._last_action is not None: # _last_action is set to None when reset - if self.np_random.random() < self.action_repeat_probability: - action = self._last_action - self._last_action = action - return self.env.step(action) - - class EpisodicLifeEnv(gym.Wrapper): """ Make end-of-life == end-of-episode, but only reset on true game over. @@ -262,9 +261,9 @@ def __init__( terminal_on_life_loss: bool = True, clip_reward: bool = True, ) -> None: + env = StickyActionEnv(env, action_repeat_probability) if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) - env = StickyActionEnv(env, action_repeat_probability) env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) From 7a14fec72b48e7f9e8aa52802d11249ccfa2edfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 15:23:06 +0100 Subject: [PATCH 10/16] Reorder doc --- stable_baselines3/common/atari_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 38475839e..7c1ff5a7a 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -233,9 +233,9 @@ class AtariWrapper(gym.Wrapper): Specifically: + * Sticky actions: disabled by default * Noop reset: obtain initial state by taking random number of no-ops on reset. * Frame skipping: 4 by default - * Sticky actions: disabled by default * Max-pooling: most recent two observations * Termination signal when a life is lost. * Resize to a square image: 84x84 by default From b287bb4ccd045f09038f565113cae78858c2b957 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 15:41:36 +0100 Subject: [PATCH 11/16] Add doc warning and don't wrap with sticky action when not needed --- stable_baselines3/common/atari_wrappers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 7c1ff5a7a..2e8812210 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -242,6 +242,9 @@ class AtariWrapper(gym.Wrapper): * Grayscale observation * Clip reward to {-1, 0, 1} + .. warning:: + Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. + :param env: Environment to wrap :param noop_max: Max number of no-ops :param action_repeat_probability: @@ -261,7 +264,8 @@ def __init__( terminal_on_life_loss: bool = True, clip_reward: bool = True, ) -> None: - env = StickyActionEnv(env, action_repeat_probability) + if action_repeat_probability > 0.0: + env = StickyActionEnv(env, action_repeat_probability) if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) env = MaxAndSkipEnv(env, skip=frame_skip) From 35d4eb20febf293a99d597597abe0efe0c72c2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Fri, 20 Jan 2023 15:59:44 +0100 Subject: [PATCH 12/16] fix docstring and reorder --- stable_baselines3/common/atari_wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 2e8812210..7e3e3f602 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -17,7 +17,7 @@ class StickyActionEnv(gym.Wrapper): Sticky action. :param env: Environment to wrap - :action_repeat_probability: Probability of repeating the last action + :param action_repeat_probability: Probability of repeating the last action """ def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: @@ -246,8 +246,8 @@ class AtariWrapper(gym.Wrapper): Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. :param env: Environment to wrap + :param action_repeat_probability: Probability of repeating the last action :param noop_max: Max number of no-ops - :param action_repeat_probability: :param frame_skip: Frequency at which the agent experiences the game. :param screen_size: Resize Atari frame :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. @@ -257,8 +257,8 @@ class AtariWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - noop_max: int = 30, action_repeat_probability: float = 0.0, + noop_max: int = 30, frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = True, From 819d7dba5cd95c249e05cf61228a4bdfe865d063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 23 Jan 2023 18:23:40 +0100 Subject: [PATCH 13/16] Move `action_repeat_probability` args at the last position --- stable_baselines3/common/atari_wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 7e3e3f602..42f7501f7 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -246,23 +246,23 @@ class AtariWrapper(gym.Wrapper): Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. :param env: Environment to wrap - :param action_repeat_probability: Probability of repeating the last action :param noop_max: Max number of no-ops :param frame_skip: Frequency at which the agent experiences the game. :param screen_size: Resize Atari frame :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. + :param action_repeat_probability: Probability of repeating the last action """ def __init__( self, env: gym.Env, - action_repeat_probability: float = 0.0, noop_max: int = 30, frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = True, clip_reward: bool = True, + action_repeat_probability: float = 0.0, ) -> None: if action_repeat_probability > 0.0: env = StickyActionEnv(env, action_repeat_probability) From 5351c0eb9c6e161b404b41f1e34a17b9b44a9d2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 25 Jan 2023 16:12:14 +0100 Subject: [PATCH 14/16] Add ref --- stable_baselines3/common/atari_wrappers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 42f7501f7..c3a75497d 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -16,6 +16,9 @@ class StickyActionEnv(gym.Wrapper): """ Sticky action. + Paper: https://arxiv.org/abs/1709.06009 + Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment + :param env: Environment to wrap :param action_repeat_probability: Probability of repeating the last action """ From f504eb6149abc18fd59372ac4074f29c98fb2659 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 26 Jan 2023 00:16:09 +0100 Subject: [PATCH 15/16] Update doc and wrap with frameskip only if needed --- stable_baselines3/common/atari_wrappers.py | 12 ++++++++++-- tests/test_utils.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index c3a75497d..016003b3b 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -140,9 +140,11 @@ def reset(self, **kwargs) -> np.ndarray: class MaxAndSkipEnv(gym.Wrapper): """ Return only every ``skip``-th frame (frameskipping) + and return the max between the two last frames. :param env: Environment to wrap :param skip: Number of ``skip``-th frame + The same action will be taken ``skip`` times. """ def __init__(self, env: gym.Env, skip: int = 4) -> None: @@ -236,7 +238,6 @@ class AtariWrapper(gym.Wrapper): Specifically: - * Sticky actions: disabled by default * Noop reset: obtain initial state by taking random number of no-ops on reset. * Frame skipping: 4 by default * Max-pooling: most recent two observations @@ -244,6 +245,10 @@ class AtariWrapper(gym.Wrapper): * Resize to a square image: 84x84 by default * Grayscale observation * Clip reward to {-1, 0, 1} + * Sticky actions: disabled by default + + See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/ + for a visual explanation. .. warning:: Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. @@ -251,6 +256,7 @@ class AtariWrapper(gym.Wrapper): :param env: Environment to wrap :param noop_max: Max number of no-ops :param frame_skip: Frequency at which the agent experiences the game. + This correspond to repeating the action ``frame_skip`` times. :param screen_size: Resize Atari frame :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. @@ -271,7 +277,9 @@ def __init__( env = StickyActionEnv(env, action_repeat_probability) if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) - env = MaxAndSkipEnv(env, skip=frame_skip) + # frame_skip=1 is the same as no frame-skip (action repeat) + if frame_skip > 1: + env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): diff --git a/tests/test_utils.py b/tests/test_utils.py index cb8dddf46..e5236e8d7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ import stable_baselines3 as sb3 from stable_baselines3 import A2C -from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv +from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor From 50192df2a8af4c8155e82e86d6014828a370eb0f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 26 Jan 2023 00:18:38 +0100 Subject: [PATCH 16/16] Update changelog --- docs/misc/changelog.rst | 3 ++- stable_baselines3/version.txt | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 30aab7f2e..8cd52fe17 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a2 (WIP) +Release 1.8.0a3 (WIP) -------------------------- @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``repeat_action_probability`` argument in ``AtariWrapper``. +- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper`` `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c3d22c01c..f5e92647d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a2 +1.8.0a3