diff --git a/supersuit/vector/markov_vector_wrapper.py b/supersuit/vector/markov_vector_wrapper.py index e900e8a..598a908 100644 --- a/supersuit/vector/markov_vector_wrapper.py +++ b/supersuit/vector/markov_vector_wrapper.py @@ -92,9 +92,14 @@ def step(self, actions): infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents] if env_done: - observations, infs = self.reset() + observations, reset_infs = self.reset() else: observations = self.concat_obs(observations) + # empty infos for reset infs + reset_infs = [{} for _ in range(len(self.par_env.possible_agents))] + # combine standard infos and reset infos + infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)] + assert ( self.black_death or self.par_env.agents == self.par_env.possible_agents ), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True" diff --git a/test/test_vector/test_pettingzoo_to_vec.py b/test/test_vector/test_pettingzoo_to_vec.py index 9202a1b..bd379e7 100644 --- a/test/test_vector/test_pettingzoo_to_vec.py +++ b/test/test_vector/test_pettingzoo_to_vec.py @@ -1,5 +1,6 @@ import copy +import numpy as np import pytest from pettingzoo.butterfly import knights_archers_zombies_v10 from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3 @@ -89,3 +90,28 @@ def test_env_black_death_wrapper(): for i in range(300): actions = [env.action_space.sample() for i in range(env.num_envs)] obss, rews, terms, truncs, infos = env.step(actions) + + +def test_terminal_obs_are_returned(): + """ + If we reach (and pass) the end of the episode, the last observation is returned in the info dict. + """ + max_cycles = 300 + env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300) + env = black_death_v3(env) + env = pettingzoo_env_to_vec_env_v1(env) + env.reset(seed=42) + + # run past max_cycles or until terminated - causing the env to reset and continue + for _ in range(0, max_cycles + 10): + actions = [env.action_space.sample() for i in range(env.num_envs)] + _, _, terms, truncs, infos = env.step(actions) + + env_done = (np.array(terms) | np.array(truncs)).all() + + if env_done: + # check we have infos for all agents + assert len(infos) == len(env.par_env.possible_agents) + # check infos contain terminal_observation + for info in infos: + assert "terminal_observation" in info