-
Notifications
You must be signed in to change notification settings - Fork 8.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Step API with terminated, truncated bools instead of done #2752
Changes from 21 commits
e6b0a40
6618da5
a0c4475
2aabc30
1babe4e
c9c6add
c5fe53c
f88927d
7c1e9c7
6af7182
22c1cc7
68ef969
f06343b
794737b
f89e5da
8b518bb
9a2a9af
29eafe5
63fc044
97f36d3
63d3d19
492c6e1
9ce03cb
f93295f
1940494
f12b5fb
aa5a071
e135b9e
2bb742a
1f11077
e861fbc
fe04e7c
8e56f45
4491d9a
be947e3
5e8f085
cdb3516
8cc2074
2f83d55
57e839c
b1660cf
ea10e7a
bffa257
d7dff2c
b2c10a4
6553bed
50d367e
d71836f
78a507e
a747625
28c7b36
d65d21b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,11 +66,16 @@ def np_random(self) -> RandomNumberGenerator: | |
def np_random(self, value: RandomNumberGenerator): | ||
self._np_random = value | ||
|
||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: | ||
def step( | ||
self, action: ActType | ||
) -> Union[ | ||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] | ||
]: | ||
"""Run one timestep of the environment's dynamics. | ||
|
||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. | ||
Accepts an action and returns a tuple `(observation, reward, done, info)`. | ||
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple | ||
(observation, reward, done, info). The latter is deprecated and will be removed in future versions. | ||
|
||
Args: | ||
action (ActType): an action provided by the agent | ||
|
@@ -79,14 +84,18 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: | |
observation (object): this will be an element of the environment's :attr:`observation_space`. | ||
This may, for instance, be a numpy array containing the positions and velocities of certain objects. | ||
reward (float): The amount of reward returned as a result of taking the action. | ||
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. | ||
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, | ||
a certain timelimit was exceeded, or the physics simulation has entered an invalid state. | ||
terminated (bool): whether the episode has ended due to reaching a terminal state intrinsic to the core environment, in which case further step() calls will return undefined results | ||
truncated (bool): whether the episode has ended due to a truncation, i.e., a timelimit outside the scope of the problem defined in the environment. | ||
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal. | ||
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). | ||
This might, for instance, contain: metrics that describe the agent's performance state, variables that are | ||
hidden from observations, information that distinguishes truncation and termination or individual reward terms | ||
that are combined to produce the total reward | ||
|
||
(deprecated) | ||
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. | ||
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, | ||
a certain timelimit was exceeded, or the physics simulation has entered an invalid state. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
@@ -242,18 +251,20 @@ class Wrapper(Env[ObsType, ActType]): | |
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. | ||
""" | ||
|
||
def __init__(self, env: Env): | ||
def __init__(self, env: Env, new_step_api: bool = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally, Im really unsure of this approach to adding new_step_api in wrappers. Even if we go with this approach, I don't like this code style as every wrappers needs to add this parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They don't need to override the |
||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. | ||
|
||
Args: | ||
env: The environment to wrap | ||
new_step_api: Whether the wrapper's step method will output in new or old step API | ||
""" | ||
self.env = env | ||
|
||
self._action_space: Optional[spaces.Space] = None | ||
self._observation_space: Optional[spaces.Space] = None | ||
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None | ||
self._metadata: Optional[dict] = None | ||
self.new_step_api = new_step_api | ||
|
||
def __getattr__(self, name): | ||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" | ||
|
@@ -315,9 +326,17 @@ def metadata(self) -> dict: | |
def metadata(self, value): | ||
self._metadata = value | ||
|
||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: | ||
def step( | ||
self, action: ActType | ||
) -> Union[ | ||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] | ||
]: | ||
"""Steps through the environment with action.""" | ||
return self.env.step(action) | ||
from gym.utils.step_api_compatibility import ( # avoid circular import | ||
step_api_compatibility, | ||
) | ||
|
||
return step_api_compatibility(self.env.step(action), self.new_step_api) | ||
|
||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: | ||
"""Resets the environment with kwargs.""" | ||
|
@@ -387,8 +406,13 @@ def reset(self, **kwargs): | |
|
||
def step(self, action): | ||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" | ||
observation, reward, done, info = self.env.step(action) | ||
return self.observation(observation), reward, done, info | ||
step_returns = self.env.step(action) | ||
if len(step_returns) == 5: | ||
observation, reward, terminated, truncated, info = step_returns | ||
return self.observation(observation), reward, terminated, truncated, info | ||
else: | ||
observation, reward, done, info = step_returns | ||
return self.observation(observation), reward, done, info | ||
arjun-kg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def observation(self, observation): | ||
"""Returns a modified observation.""" | ||
|
@@ -421,8 +445,13 @@ def reward(self, reward): | |
|
||
def step(self, action): | ||
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" | ||
observation, reward, done, info = self.env.step(action) | ||
return observation, self.reward(reward), done, info | ||
step_returns = self.env.step(action) | ||
if len(step_returns) == 5: | ||
observation, reward, terminated, truncated, info = step_returns | ||
return observation, self.reward(reward), terminated, truncated, info | ||
else: | ||
observation, reward, done, info = step_returns | ||
return observation, self.reward(reward), done, info | ||
arjun-kg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def reward(self, reward): | ||
"""Returns a modified ``reward``.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if I like this approach to backwards compatibility. If this is the official state of (for example) 0.24.0, then you can't reliably write an algorithm that will work for all valid 0.24.0 environments. I think we should just say that an environment should have the signature of (ObsType, float, bool, bool, dict), and then provide a wrapper-like compatibility layer that can convert an old-style environment to a new-style environment.