-
-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix seed test to work with action_mask in info, add tests to ensure i…
…nfo action masking works (#1134) Co-authored-by: Kenny <[email protected]>
- Loading branch information
1 parent
2b65b5f
commit 23c4242
Showing
6 changed files
with
415 additions
and
2 deletions.
There are no files selected for viewing
180 changes: 180 additions & 0 deletions
180
pettingzoo/test/example_envs/generated_agents_env_action_mask_info_v0.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from typing import Union | ||
|
||
import gymnasium | ||
import numpy as np | ||
|
||
from pettingzoo import AECEnv | ||
from pettingzoo.utils import wrappers | ||
from pettingzoo.utils.agent_selector import agent_selector | ||
|
||
|
||
def env(): | ||
env = raw_env() | ||
env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) | ||
env = wrappers.AssertOutOfBoundsWrapper(env) | ||
env = wrappers.OrderEnforcingWrapper(env) | ||
return env | ||
|
||
|
||
def get_type(agent): | ||
return agent[: agent.rfind("_")] | ||
|
||
|
||
class raw_env(AECEnv[str, np.ndarray, Union[int, None]]): | ||
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"} | ||
|
||
def __init__(self, max_cycles=100, render_mode=None): | ||
super().__init__() | ||
self._obs_spaces = {} | ||
self._act_spaces = {} | ||
|
||
# dummy state space, not actually used | ||
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10]) | ||
self._state = self.state_space.sample() | ||
|
||
self.types = [] | ||
self._agent_counters = {} | ||
self.max_cycles = max_cycles | ||
self._seed() | ||
self.render_mode = render_mode | ||
for i in range(3): | ||
self.add_type() | ||
|
||
def observation_space(self, agent): | ||
return self._obs_spaces[get_type(agent)] | ||
|
||
def action_space(self, agent): | ||
return self._act_spaces[get_type(agent)] | ||
|
||
def state(self) -> np.ndarray: | ||
return self._state | ||
|
||
def observe(self, agent): | ||
return self.observation_space(agent).sample() | ||
|
||
def add_type(self): | ||
type_id = len(self.types) | ||
num_actions = self.np_random.integers(3, 10) | ||
obs_size = self.np_random.integers(10, 50) | ||
obs_space = gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,)) | ||
act_space = gymnasium.spaces.Discrete(num_actions) | ||
new_type = f"type{type_id}" | ||
self.types.append(new_type) | ||
self._obs_spaces[new_type] = obs_space | ||
self._act_spaces[new_type] = act_space | ||
self._agent_counters[new_type] = 0 | ||
return new_type | ||
|
||
def add_agent(self, type): | ||
agent_id = self._agent_counters[type] | ||
self._agent_counters[type] += 1 | ||
agent = f"{type}_{agent_id}" | ||
self.agents.append(agent) | ||
self.terminations[agent] = False | ||
self.truncations[agent] = False | ||
self.rewards[agent] = 0 | ||
self._cumulative_rewards[agent] = 0 | ||
num_actions = self._act_spaces[type].n | ||
self.infos[agent] = { | ||
"action_mask": np.eye(num_actions)[ | ||
self.np_random.choice(num_actions) | ||
].astype(np.int8) | ||
} | ||
return agent | ||
|
||
def reset(self, seed=None, options=None): | ||
if seed is not None: | ||
self._seed(seed=seed) | ||
self.agents = [] | ||
self.rewards = {} | ||
self._cumulative_rewards = {} | ||
self.terminations = {} | ||
self.truncations = {} | ||
self.infos = {} | ||
self.num_steps = 0 | ||
|
||
self._obs_spaces = {} | ||
self._act_spaces = {} | ||
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10]) | ||
self._state = self.state_space.sample() | ||
|
||
self.types = [] | ||
self._agent_counters = {} | ||
for i in range(3): | ||
self.add_type() | ||
for i in range(5): | ||
self.add_agent(self.np_random.choice(self.types)) | ||
|
||
self._agent_selector = agent_selector(self.agents) | ||
self.agent_selection = self._agent_selector.reset() | ||
|
||
# seed observation and action spaces | ||
for i, agent in enumerate(self.agents): | ||
self.observation_space(agent).seed(seed) | ||
for i, agent in enumerate(self.agents): | ||
self.action_space(agent).seed(seed) | ||
|
||
def _seed(self, seed=None): | ||
self.np_random, _ = gymnasium.utils.seeding.np_random(seed) | ||
|
||
def step(self, action): | ||
if ( | ||
self.terminations[self.agent_selection] | ||
or self.truncations[self.agent_selection] | ||
): | ||
return self._was_dead_step(action) | ||
|
||
self._clear_rewards() | ||
self._cumulative_rewards[self.agent_selection] = 0 | ||
|
||
if self._agent_selector.is_last(): | ||
for i in range(5): | ||
if self.np_random.random() < 0.1: | ||
if self.np_random.random() < 0.1: | ||
type = self.add_type() | ||
else: | ||
type = self.np_random.choice(self.types) | ||
|
||
agent = self.add_agent(type) | ||
if len(self.agents) >= 20: | ||
self.terminations[self.np_random.choice(self.agents)] = True | ||
|
||
if self._agent_selector.is_last(): | ||
self.num_steps += 1 | ||
|
||
if self.num_steps > self.max_cycles: | ||
for agent in self.agents: | ||
self.truncations[agent] = True | ||
|
||
self.rewards[self.np_random.choice(self.agents)] = 1 | ||
|
||
self._state = self.state_space.sample() | ||
|
||
self._accumulate_rewards() | ||
self._deads_step_first() | ||
|
||
# Sample info action mask randomly | ||
type = self.agent_selection.split("_")[0] | ||
num_actions = self._act_spaces[type].n | ||
self.infos[self.agent_selection] = { | ||
"action_mask": np.eye(num_actions)[ | ||
self.np_random.choice(num_actions) | ||
].astype(np.int8) | ||
} | ||
|
||
# Cycle agents | ||
self.agent_selection = self._agent_selector.next() | ||
|
||
if self.render_mode == "human": | ||
self.render() | ||
|
||
def render(self): | ||
if self.render_mode is None: | ||
gymnasium.logger.warn( | ||
"You are calling render method without specifying any render mode." | ||
) | ||
else: | ||
print(self.agents) | ||
|
||
def close(self): | ||
pass |
173 changes: 173 additions & 0 deletions
173
pettingzoo/test/example_envs/generated_agents_env_action_mask_obs_v0.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from typing import Union | ||
|
||
import gymnasium | ||
import numpy as np | ||
|
||
from pettingzoo import AECEnv | ||
from pettingzoo.utils import wrappers | ||
from pettingzoo.utils.agent_selector import agent_selector | ||
|
||
|
||
def env(): | ||
env = raw_env() | ||
env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) | ||
env = wrappers.AssertOutOfBoundsWrapper(env) | ||
env = wrappers.OrderEnforcingWrapper(env) | ||
return env | ||
|
||
|
||
def get_type(agent): | ||
return agent[: agent.rfind("_")] | ||
|
||
|
||
class raw_env(AECEnv[str, np.ndarray, Union[int, None]]): | ||
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"} | ||
|
||
def __init__(self, max_cycles=100, render_mode=None): | ||
super().__init__() | ||
self._obs_spaces = {} | ||
self._act_spaces = {} | ||
|
||
# dummy state space, not actually used | ||
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10]) | ||
self._state = self.state_space.sample() | ||
|
||
self.types = [] | ||
self._agent_counters = {} | ||
self.max_cycles = max_cycles | ||
self._seed() | ||
self.render_mode = render_mode | ||
for i in range(3): | ||
self.add_type() | ||
|
||
def observation_space(self, agent): | ||
return self._obs_spaces[get_type(agent)] | ||
|
||
def action_space(self, agent): | ||
return self._act_spaces[get_type(agent)] | ||
|
||
def state(self) -> np.ndarray: | ||
return self._state | ||
|
||
def observe(self, agent): | ||
return self.observation_space(agent).sample() | ||
|
||
def add_type(self): | ||
type_id = len(self.types) | ||
num_actions = self.np_random.integers(3, 10) | ||
obs_size = self.np_random.integers(10, 50) | ||
obs_space = gymnasium.spaces.Dict( | ||
{ | ||
"observation": gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,)), | ||
"action_mask": gymnasium.spaces.Box( | ||
low=0, high=1, shape=(num_actions,), dtype=np.int8 | ||
), | ||
} | ||
) | ||
act_space = gymnasium.spaces.Discrete(num_actions) | ||
new_type = f"type{type_id}" | ||
self.types.append(new_type) | ||
self._obs_spaces[new_type] = obs_space | ||
self._act_spaces[new_type] = act_space | ||
self._agent_counters[new_type] = 0 | ||
return new_type | ||
|
||
def add_agent(self, type): | ||
agent_id = self._agent_counters[type] | ||
self._agent_counters[type] += 1 | ||
agent = f"{type}_{agent_id}" | ||
self.agents.append(agent) | ||
self.terminations[agent] = False | ||
self.truncations[agent] = False | ||
self.rewards[agent] = 0 | ||
self._cumulative_rewards[agent] = 0 | ||
self.infos[agent] = {} | ||
return agent | ||
|
||
def reset(self, seed=None, options=None): | ||
if seed is not None: | ||
self._seed(seed=seed) | ||
self.agents = [] | ||
self.rewards = {} | ||
self._cumulative_rewards = {} | ||
self.terminations = {} | ||
self.truncations = {} | ||
self.infos = {} | ||
self.num_steps = 0 | ||
|
||
self._obs_spaces = {} | ||
self._act_spaces = {} | ||
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10]) | ||
self._state = self.state_space.sample() | ||
|
||
self.types = [] | ||
self._agent_counters = {} | ||
for i in range(3): | ||
self.add_type() | ||
for i in range(5): | ||
self.add_agent(self.np_random.choice(self.types)) | ||
|
||
self._agent_selector = agent_selector(self.agents) | ||
self.agent_selection = self._agent_selector.reset() | ||
|
||
# seed observation and action spaces | ||
for i, agent in enumerate(self.agents): | ||
self.observation_space(agent).seed(seed) | ||
for i, agent in enumerate(self.agents): | ||
self.action_space(agent).seed(seed) | ||
|
||
def _seed(self, seed=None): | ||
self.np_random, _ = gymnasium.utils.seeding.np_random(seed) | ||
|
||
def step(self, action): | ||
if ( | ||
self.terminations[self.agent_selection] | ||
or self.truncations[self.agent_selection] | ||
): | ||
return self._was_dead_step(action) | ||
|
||
self._clear_rewards() | ||
self._cumulative_rewards[self.agent_selection] = 0 | ||
|
||
if self._agent_selector.is_last(): | ||
for i in range(5): | ||
if self.np_random.random() < 0.1: | ||
if self.np_random.random() < 0.1: | ||
type = self.add_type() | ||
else: | ||
type = self.np_random.choice(self.types) | ||
|
||
agent = self.add_agent(type) | ||
if len(self.agents) >= 20: | ||
self.terminations[self.np_random.choice(self.agents)] = True | ||
|
||
if self._agent_selector.is_last(): | ||
self.num_steps += 1 | ||
|
||
if self.num_steps > self.max_cycles: | ||
for agent in self.agents: | ||
self.truncations[agent] = True | ||
|
||
self.rewards[self.np_random.choice(self.agents)] = 1 | ||
|
||
self._state = self.state_space.sample() | ||
|
||
self._accumulate_rewards() | ||
self._deads_step_first() | ||
|
||
# Cycle agents | ||
self.agent_selection = self._agent_selector.next() | ||
|
||
if self.render_mode == "human": | ||
self.render() | ||
|
||
def render(self): | ||
if self.render_mode is None: | ||
gymnasium.logger.warn( | ||
"You are calling render method without specifying any render mode." | ||
) | ||
else: | ||
print(self.agents) | ||
|
||
def close(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.