Skip to content

Commit

Permalink
Support only new step API (while retaining compatibility functions) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
arjun-kg authored Aug 30, 2022
1 parent 884ba08 commit 54b406b
Show file tree
Hide file tree
Showing 58 changed files with 379 additions and 560 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ observation, info = env.reset(seed=42)

for _ in range(1000):
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
observation, reward, terminated, truncarted, info = env.step(action)

if done:
if terminated or truncated:
observation, info = env.reset()
env.close()
```
Expand Down
50 changes: 10 additions & 40 deletions gym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from gym import spaces
from gym.logger import deprecation, warn
from gym.logger import warn
from gym.utils import seeding

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,16 +83,11 @@ def np_random(self) -> np.random.Generator:
def np_random(self, value: np.random.Generator):
self._np_random = value

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Run one timestep of the environment's dynamics.
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple
(observation, reward, done, info). The latter is deprecated and will be removed in future versions.
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`.
Args:
action (ActType): an action provided by the agent
Expand Down Expand Up @@ -226,25 +221,18 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""

def __init__(self, env: Env, new_step_api: bool = False):
def __init__(self, env: Env):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
Args:
env: The environment to wrap
new_step_api: Whether the wrapper's step method will output in new or old step API
"""
self.env = env

self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self.new_step_api = new_step_api

if not self.new_step_api:
deprecation(
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
)

def __getattr__(self, name):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
Expand Down Expand Up @@ -326,17 +314,9 @@ def _np_random(self):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment with action."""
from gym.utils.step_api_compatibility import ( # avoid circular import
step_api_compatibility,
)

return step_api_compatibility(self.env.step(action), self.new_step_api)
return self.env.step(action)

def reset(self, **kwargs) -> Tuple[ObsType, dict]:
"""Resets the environment with kwargs."""
Expand Down Expand Up @@ -401,13 +381,8 @@ def reset(self, **kwargs):

def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return self.observation(observation), reward, terminated, truncated, info
else:
observation, reward, done, info = step_returns
return self.observation(observation), reward, done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info

def observation(self, observation):
"""Returns a modified observation."""
Expand Down Expand Up @@ -440,13 +415,8 @@ def reward(self, reward):

def step(self, action):
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return observation, self.reward(reward), terminated, truncated, info
else:
observation, reward, done, info = step_returns
return observation, self.reward(reward), done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, self.reward(reward), terminated, truncated, info

def reward(self, reward):
"""Returns a modified ``reward``."""
Expand Down
18 changes: 10 additions & 8 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class EnvSpec:
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
new_step_api: bool = field(default=False)
apply_step_compatibility: bool = field(default=False)

# Environment arguments
kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -547,7 +547,7 @@ def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
autoreset: bool = False,
new_step_api: bool = False,
apply_step_compatibility: bool = False,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> Env:
Expand All @@ -557,7 +557,7 @@ def make(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
apply_step_compatibility: Whether to use apply compatibility wrapper that converts step method to return two bools (StepAPICompatibility wrapper)
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
otherwise will run according to this parameter (`True` = not run, `False` = run)
Expand Down Expand Up @@ -684,26 +684,28 @@ def make(
):
env = PassiveEnvChecker(env)

env = StepAPICompatibility(env, new_step_api)

# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)

# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps, new_step_api)
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
env = TimeLimit(env, spec_.max_episode_steps)

# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env, new_step_api)
env = AutoResetWrapper(env)

# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)

# Add step API wrapper
if apply_step_compatibility:
env = StepAPICompatibility(env, True)

return env


Expand Down
16 changes: 10 additions & 6 deletions gym/utils/passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,16 @@ def env_reset_passive_checker(env, **kwargs):
logger.warn(
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
)

obs, info = result
check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
elif len(result) != 2:
logger.warn(
"The result returned by `env.reset()` should be `(obs, info)` by default, , where `obs` is a observation and `info` is a dictionary containing additional information."
)
else:
obs, info = result
check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
return result


Expand Down
30 changes: 16 additions & 14 deletions gym/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def play(
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
for last 150 steps.
>>> def callback(obs_t, obs_tp1, action, rew, done, info):
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
... return [rew,]
>>> plotter = PlayPlot(callback, 150, ["reward"])
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
Expand All @@ -187,7 +187,8 @@ def play(
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
terminated: whether the environment is terminated or not
truncated: whether the environment is truncated or not
info: debug info
keys_to_action: Mapping from keys pressed to action performed.
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
Expand Down Expand Up @@ -219,11 +220,6 @@ def play(
deprecation(
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
)
if env.render_mode not in {"rgb_array", "single_rgb_array"}:
logger.error(
"play method works only with rgb_array and single_rgb_array render modes, "
f"but your environment render_mode = {env.render_mode}."
)

env.reset(seed=seed)

Expand Down Expand Up @@ -261,9 +257,10 @@ def play(
else:
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
prev_obs = obs
obs, rew, done, info = env.step(action)
obs, rew, terminated, truncated, info = env.step(action)
done = terminated or truncated
if callback is not None:
callback(prev_obs, obs, action, rew, done, info)
callback(prev_obs, obs, action, rew, terminated, truncated, info)
if obs is not None:
rendered = env.render()
if isinstance(rendered, List):
Expand All @@ -290,13 +287,14 @@ class PlayPlot:
- obs_tp1: observation after performing action
- action: action that was executed
- rew: reward that was received
- done: whether the environment is done or not
- terminated: whether the environment is terminated or not
- truncated: whether the environment is truncated or not
- info: debug info
It should return a list of metrics that are computed from this data.
For instance, the function may look like this::
>>> def compute_metrics(obs_t, obs_tp, action, reward, done, info):
>>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
Expand Down Expand Up @@ -353,7 +351,8 @@ def callback(
obs_tp1: ObsType,
action: ActType,
rew: float,
done: bool,
terminated: bool,
truncated: bool,
info: dict,
):
"""The callback that calls the provided data callback and adds the data to the plots.
Expand All @@ -363,10 +362,13 @@ def callback(
obs_tp1: The observation at time step t+1
action: The action
rew: The reward
done: If the environment is done
terminated: If the environment is terminated
truncated: If the environment is truncated
info: The information from the environment
"""
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
points = self.data_callback(
obs_t, obs_tp1, action, rew, terminated, truncated, info
)
for point, data_series in zip(points, self.data):
data_series.append(point)
self.t += 1
Expand Down
45 changes: 23 additions & 22 deletions gym/utils/step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
from typing import Tuple, Union

import numpy as np

from gym.core import ObsType

OldStepType = Tuple[
DoneStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Union[dict, list],
]

NewStepType = Tuple[
TerminatedTruncatedStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Expand All @@ -21,9 +21,9 @@
]


def step_to_new_api(
step_returns: Union[OldStepType, NewStepType], is_vector_env=False
) -> NewStepType:
def convert_to_terminated_truncated_step_api(
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False
) -> TerminatedTruncatedStepType:
"""Function to transform step returns to new step API irrespective of input API.
Args:
Expand Down Expand Up @@ -73,9 +73,10 @@ def step_to_new_api(
)


def step_to_old_api(
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False
) -> OldStepType:
def convert_to_done_step_api(
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
is_vector_env: bool = False,
) -> DoneStepType:
"""Function to transform step returns to old step API irrespective of input API.
Args:
Expand Down Expand Up @@ -128,33 +129,33 @@ def step_to_old_api(


def step_api_compatibility(
step_returns: Union[NewStepType, OldStepType],
new_step_api: bool = False,
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
output_truncation_bool: bool = True,
is_vector_env: bool = False,
) -> Union[NewStepType, OldStepType]:
"""Function to transform step returns to the API specified by `new_step_api` bool.
) -> Union[TerminatedTruncatedStepType, DoneStepType]:
"""Function to transform step returns to the API specified by `output_truncation_bool` bool.
Old step API refers to step() method returning (observation, reward, done, info)
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
Done (old) step API refers to step() method returning (observation, reward, done, info)
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info)
(Refer to docs for details on the API change)
Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
new_step_api (bool): Whether the output should be in new step API or old (False by default)
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default)
is_vector_env (bool): Whether the step_returns are from a vector environment
Returns:
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
Examples:
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
wrapper is written in new API, and the final step output is desired to be in old API.
>>> obs, rew, done, info = step_api_compatibility(env.step(action))
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False)
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True)
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
"""
if new_step_api:
return step_to_new_api(step_returns, is_vector_env)
if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
else:
return step_to_old_api(step_returns, is_vector_env)
return convert_to_done_step_api(step_returns, is_vector_env)
Loading

0 comments on commit 54b406b

Please sign in to comment.