Skip to content

Commit

Permalink
Fix VecExtractDictObs does not handle terminal observation (#1443)
Browse files Browse the repository at this point in the history
* VecExtractDictObs handle terminal_observation

* Added VecExtractDictObs handle terminal_output to changelog

* Update changelog.rst

* Update test_vec_extract_dict_obs.py

Add random dones in env to test if terminal_observation is properly handled

* Made test deterministic

* Fixed bug in test

* Improved test

* Fix format in test

* Update test

* Fix type hint

* Ignore pytype warning

* Ignore pytype

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
WeberSamuel and araffin authored Apr 12, 2023
1 parent 4232f9d commit 15c9daa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1299,4 +1300,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def getattr_recursive(self, name: str) -> Any:

return attr

def getattr_depth_check(self, name: str, already_found: bool) -> str:
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""See base class.
:return: name of module whose attribute is being shadowed, if any.
Expand Down
6 changes: 4 additions & 2 deletions stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,16 @@ def reset(self, observation: TObs) -> TObs:
:return: The stacked reset observation
"""
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
return {
key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()
} # pytype: disable=bad-return-type

self.stacked_obs[...] = 0
if self.channels_first:
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs
return self.stacked_obs # pytype: disable=bad-return-type

def update(
self,
Expand Down
7 changes: 5 additions & 2 deletions stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ def reset(self) -> np.ndarray:
return obs[self.key]

def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info
obs, reward, done, infos = self.venv.step_wait()
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos
26 changes: 24 additions & 2 deletions tests/test_vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,31 @@ def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5

def step_async(self, actions):
self.actions = actions

def step_wait(self):
self.n_steps += 1
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86))},
np.zeros((self.num_envs,)),
np.zeros((self.num_envs,), dtype=bool),
[{} for _ in range(self.num_envs)],
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)

def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86))}

def render(self, mode="human", close=False):
Expand All @@ -40,6 +52,16 @@ def test_extract_dict_obs():
env = VecExtractDictObs(env, "rgb")
assert env.reset().shape == (4, 86, 86)

for _ in range(10):
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
assert obs.shape == (4, 86, 86)
for idx, info in enumerate(infos):
if "terminal_observation" in info:
assert dones[idx]
assert info["terminal_observation"].shape == (86, 86)
else:
assert not dones[idx]


def test_vec_with_ppo():
"""
Expand Down

0 comments on commit 15c9daa

Please sign in to comment.