Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sticky actions for Atari games #1286

Merged
merged 19 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
68 changes: 50 additions & 18 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,36 @@
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn


class StickyActionEnv(gym.Wrapper):
"""
Sticky action.
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

:param env: Environment to wrap
:param action_repeat_probability: Probability of repeating the last action
"""

def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
super().__init__(env)
self.action_repeat_probability = action_repeat_probability
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

def reset(self, **kwargs) -> GymObs:
self._sticky_action = 0 # NOOP
return self.env.reset(**kwargs)

def step(self, action: int) -> GymStepReturn:
if self.np_random.random() >= self.action_repeat_probability:
self._sticky_action = action
return self.env.step(self._sticky_action)


class NoopResetEnv(gym.Wrapper):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.

:param env: the environment to wrap
:param noop_max: the maximum value of no-ops to run
:param env: Environment to wrap
:param noop_max: Maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
Expand Down Expand Up @@ -47,7 +70,7 @@ class FireResetEnv(gym.Wrapper):
"""
Take action on reset for environments that are fixed until firing.

:param env: the environment to wrap
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand All @@ -71,7 +94,7 @@ class EpisodicLifeEnv(gym.Wrapper):
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.

:param env: the environment to wrap
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand Down Expand Up @@ -115,8 +138,8 @@ class MaxAndSkipEnv(gym.Wrapper):
"""
Return only every ``skip``-th frame (frameskipping)

:param env: the environment
:param skip: number of ``skip``-th frame
:param env: Environment to wrap
:param skip: Number of ``skip``-th frame
"""

def __init__(self, env: gym.Env, skip: int = 4) -> None:
Expand Down Expand Up @@ -156,9 +179,9 @@ def reset(self, **kwargs) -> GymObs:

class ClipRewardEnv(gym.RewardWrapper):
"""
Clips the reward to {+1, 0, -1} by its sign.
Clip the reward to {+1, 0, -1} by its sign.

:param env: the environment
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand All @@ -179,9 +202,9 @@ class WarpFrame(gym.ObservationWrapper):
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.

:param env: the environment
:param width:
:param height:
:param env: Environment to wrap
:param width: New frame width
:param height: New frame height
"""

def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
Expand Down Expand Up @@ -210,32 +233,41 @@ class AtariWrapper(gym.Wrapper):

Specifically:

* NoopReset: obtain initial state by taking random number of no-ops on reset.
* Sticky actions: disabled by default
* Noop reset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost.
* Resize to a square image: 84x84 by default
* Grayscale observation
* Clip reward to {-1, 0, 1}

:param env: gym environment
:param noop_max: max number of no-ops
:param frame_skip: the frequency at which the agent experiences the game.
:param screen_size: resize Atari frame
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
.. warning::
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.

:param env: Environment to wrap
:param action_repeat_probability: Probability of repeating the last action
:param noop_max: Max number of no-ops
:param frame_skip: Frequency at which the agent experiences the game.
:param screen_size: Resize Atari frame
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
"""

def __init__(
self,
env: gym.Env,
action_repeat_probability: float = 0.0,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
) -> None:
env = NoopResetEnv(env, noop_max=noop_max)
if action_repeat_probability > 0.0:
araffin marked this conversation as resolved.
Show resolved Hide resolved
env = StickyActionEnv(env, action_repeat_probability)
if noop_max > 0:
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
env = NoopResetEnv(env, noop_max=noop_max)
env = MaxAndSkipEnv(env, skip=frame_skip)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe skip that one if frame_skip<=1 (need to check if it is <=0 or <=1, but if I recall, it should have been called action repeat.

if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
Expand Down
62 changes: 43 additions & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,54 @@ def test_make_vec_env_func_checker():
env.close()


@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"])
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)])
def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
# Use Asterix as it does not requires fire reset
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"])
@pytest.mark.parametrize("noop_max", [0, 10])
@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25])
@pytest.mark.parametrize("frame_skip", [1, 4])
@pytest.mark.parametrize("screen_size", [60])
@pytest.mark.parametrize("terminal_on_life_loss", [True, False])
@pytest.mark.parametrize("clip_reward", [True])
def test_make_atari_env(
env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward
):
n_envs = 2
wrapper_kwargs = {
"noop_max": noop_max,
"action_repeat_probability": action_repeat_probability,
"frame_skip": frame_skip,
"screen_size": screen_size,
"terminal_on_life_loss": terminal_on_life_loss,
"clip_reward": clip_reward,
}
venv = make_atari_env(
env_id,
n_envs=2,
wrapper_kwargs=wrapper_kwargs,
monitor_dir=None,
seed=0,
)

assert env.num_envs == n_envs
assert venv.num_envs == n_envs

obs = env.reset()
needs_fire_reset = env_id == "BreakoutNoFrameskip-v4"
expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset
expected_frame_number_high = expected_frame_number_low + noop_max
expected_shape = (n_envs, screen_size, screen_size, 1)

new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
obs = venv.reset()
frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number in frame_numbers:
assert expected_frame_number_low <= frame_number <= expected_frame_number_high
assert obs.shape == expected_shape

assert obs.shape == new_obs.shape
new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)])

# Wrapped into DummyVecEnv
wrapped_atari_env = env.envs[0]
if wrapper_kwargs is not None:
assert obs.shape == (n_envs, 60, 60, 1)
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
assert not isinstance(wrapped_atari_env.env, ClipRewardEnv)
else:
assert obs.shape == (n_envs, 84, 84, 1)
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
assert isinstance(wrapped_atari_env.env, ClipRewardEnv)
new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers):
assert new_frame_number - frame_number == frame_skip
assert new_obs.shape == expected_shape
if clip_reward:
assert np.max(np.abs(reward)) < 1.0


Expand Down