Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: have access to terminal_observation in the infos. #233

Merged
merged 8 commits into from
Nov 28, 2023
7 changes: 6 additions & 1 deletion supersuit/vector/markov_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,14 @@ def step(self, actions):
infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]

if env_done:
observations, infs = self.reset()
observations, reset_infs = self.reset()
else:
observations = self.concat_obs(observations)
# empty infos for reset infs
reset_infs = [{} for _ in range(len(self.par_env.possible_agents))]
# combine standard infos and reset infos
infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)]

assert (
self.black_death or self.par_env.agents == self.par_env.possible_agents
), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True"
Expand Down
26 changes: 26 additions & 0 deletions test/test_vector/test_pettingzoo_to_vec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

import numpy as np
import pytest
from pettingzoo.butterfly import knights_archers_zombies_v10
from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3
Expand Down Expand Up @@ -89,3 +90,28 @@ def test_env_black_death_wrapper():
for i in range(300):
actions = [env.action_space.sample() for i in range(env.num_envs)]
obss, rews, terms, truncs, infos = env.step(actions)


def test_terminal_obs_are_returned():
"""
If we reach (and pass) the end of the episode, the last observation is returned in the info dict.
"""
max_cycles = 300
env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300)
env = black_death_v3(env)
env = pettingzoo_env_to_vec_env_v1(env)
env.reset(seed=42)

# run past max_cycles or until terminated - causing the env to reset and continue
for _ in range(0, max_cycles + 10):
actions = [env.action_space.sample() for i in range(env.num_envs)]
_, _, terms, truncs, infos = env.step(actions)

env_done = (np.array(terms) | np.array(truncs)).all()

if env_done:
# check we have infos for all agents
assert len(infos) == len(env.par_env.possible_agents)
# check infos contain terminal_observation
for info in infos:
assert "terminal_observation" in info
Loading