Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: have access to terminal_observation in the infos. #233

Merged
merged 8 commits into from
Nov 28, 2023
14 changes: 13 additions & 1 deletion supersuit/vector/markov_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def step_wait(self):
return self.step(self._saved_actions)

def reset(self, seed=None, options=None):
if seed is None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused, won't this pass the same seed to all of the environments, therefore, this does the opposite of what you want.
Even so, this should be part of a second PR, not this one if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this will create the same seed. Each time np.random.randint is called, np internally updates its internal state, meaning that a new seed is created in the next call - docs explaining this.

This would only create the same seed for all envs if np.random.seed is used in each process to set them all to have the same seed.

Nonetheless, I agree this should be removed from this PR.

# To ensure that subprocesses have different seeds,
# we still populate the seed variable when no argument is passed.
# Otherwise parallel vec env workers could have identical seeds (env could default to seed if no seed is passed)
# when reset is called as part of line 101.
seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))

# TODO: should this be changed to infos?
_observations, infos = self.par_env.reset(seed=seed, options=options)
observations = self.concat_obs(_observations)
Expand Down Expand Up @@ -91,9 +98,14 @@ def step(self, actions):
infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]

if env_done:
observations, infs = self.reset()
observations, reset_infs = self.reset()
else:
observations = self.concat_obs(observations)
# empty infos for reset infs
reset_infs = [{} for _ in range(len(self.par_env.possible_agents))]
# combine standard infos and reset infos
infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)]

assert (
self.black_death or self.par_env.agents == self.par_env.possible_agents
), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True"
Expand Down
26 changes: 26 additions & 0 deletions test/test_vector/test_pettingzoo_to_vec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

import numpy as np
import pytest
from pettingzoo.butterfly import knights_archers_zombies_v10
from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3
Expand Down Expand Up @@ -89,3 +90,28 @@ def test_env_black_death_wrapper():
for i in range(300):
actions = [env.action_space.sample() for i in range(env.num_envs)]
obss, rews, terms, truncs, infos = env.step(actions)


def test_terminal_obs_are_returned():
"""
If we reach (and pass) the end of the episode, the last observation is returned in the info dict.
"""
max_cycles = 300
env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300)
env = black_death_v3(env)
env = pettingzoo_env_to_vec_env_v1(env)
env.reset(seed=42)

# run past max_cycles or until terminated - causing the env to reset and continue
for _ in range(0, max_cycles + 10):
actions = [env.action_space.sample() for i in range(env.num_envs)]
_, _, terms, truncs, infos = env.step(actions)

env_done = (np.array(terms) | np.array(truncs)).all()

if env_done:
# check we have infos for all agents
assert len(infos) == len(env.par_env.possible_agents)
# check infos contain terminal_observation
for info in infos:
assert "terminal_observation" in info
Loading