Skip to content

Commit

Permalink
Fix seed test to work with action_mask in info, add tests to ensure i…
Browse files Browse the repository at this point in the history
…nfo action masking works (#1134)

Co-authored-by: Kenny <[email protected]>
  • Loading branch information
elliottower and Kennjaro authored Nov 15, 2023
1 parent 2b65b5f commit 23c4242
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Union

import gymnasium
import numpy as np

from pettingzoo import AECEnv
from pettingzoo.utils import wrappers
from pettingzoo.utils.agent_selector import agent_selector


def env():
env = raw_env()
env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
env = wrappers.AssertOutOfBoundsWrapper(env)
env = wrappers.OrderEnforcingWrapper(env)
return env


def get_type(agent):
return agent[: agent.rfind("_")]


class raw_env(AECEnv[str, np.ndarray, Union[int, None]]):
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"}

def __init__(self, max_cycles=100, render_mode=None):
super().__init__()
self._obs_spaces = {}
self._act_spaces = {}

# dummy state space, not actually used
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
self._state = self.state_space.sample()

self.types = []
self._agent_counters = {}
self.max_cycles = max_cycles
self._seed()
self.render_mode = render_mode
for i in range(3):
self.add_type()

def observation_space(self, agent):
return self._obs_spaces[get_type(agent)]

def action_space(self, agent):
return self._act_spaces[get_type(agent)]

def state(self) -> np.ndarray:
return self._state

def observe(self, agent):
return self.observation_space(agent).sample()

def add_type(self):
type_id = len(self.types)
num_actions = self.np_random.integers(3, 10)
obs_size = self.np_random.integers(10, 50)
obs_space = gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,))
act_space = gymnasium.spaces.Discrete(num_actions)
new_type = f"type{type_id}"
self.types.append(new_type)
self._obs_spaces[new_type] = obs_space
self._act_spaces[new_type] = act_space
self._agent_counters[new_type] = 0
return new_type

def add_agent(self, type):
agent_id = self._agent_counters[type]
self._agent_counters[type] += 1
agent = f"{type}_{agent_id}"
self.agents.append(agent)
self.terminations[agent] = False
self.truncations[agent] = False
self.rewards[agent] = 0
self._cumulative_rewards[agent] = 0
num_actions = self._act_spaces[type].n
self.infos[agent] = {
"action_mask": np.eye(num_actions)[
self.np_random.choice(num_actions)
].astype(np.int8)
}
return agent

def reset(self, seed=None, options=None):
if seed is not None:
self._seed(seed=seed)
self.agents = []
self.rewards = {}
self._cumulative_rewards = {}
self.terminations = {}
self.truncations = {}
self.infos = {}
self.num_steps = 0

self._obs_spaces = {}
self._act_spaces = {}
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
self._state = self.state_space.sample()

self.types = []
self._agent_counters = {}
for i in range(3):
self.add_type()
for i in range(5):
self.add_agent(self.np_random.choice(self.types))

self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()

# seed observation and action spaces
for i, agent in enumerate(self.agents):
self.observation_space(agent).seed(seed)
for i, agent in enumerate(self.agents):
self.action_space(agent).seed(seed)

def _seed(self, seed=None):
self.np_random, _ = gymnasium.utils.seeding.np_random(seed)

def step(self, action):
if (
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
return self._was_dead_step(action)

self._clear_rewards()
self._cumulative_rewards[self.agent_selection] = 0

if self._agent_selector.is_last():
for i in range(5):
if self.np_random.random() < 0.1:
if self.np_random.random() < 0.1:
type = self.add_type()
else:
type = self.np_random.choice(self.types)

agent = self.add_agent(type)
if len(self.agents) >= 20:
self.terminations[self.np_random.choice(self.agents)] = True

if self._agent_selector.is_last():
self.num_steps += 1

if self.num_steps > self.max_cycles:
for agent in self.agents:
self.truncations[agent] = True

self.rewards[self.np_random.choice(self.agents)] = 1

self._state = self.state_space.sample()

self._accumulate_rewards()
self._deads_step_first()

# Sample info action mask randomly
type = self.agent_selection.split("_")[0]
num_actions = self._act_spaces[type].n
self.infos[self.agent_selection] = {
"action_mask": np.eye(num_actions)[
self.np_random.choice(num_actions)
].astype(np.int8)
}

# Cycle agents
self.agent_selection = self._agent_selector.next()

if self.render_mode == "human":
self.render()

def render(self):
if self.render_mode is None:
gymnasium.logger.warn(
"You are calling render method without specifying any render mode."
)
else:
print(self.agents)

def close(self):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Union

import gymnasium
import numpy as np

from pettingzoo import AECEnv
from pettingzoo.utils import wrappers
from pettingzoo.utils.agent_selector import agent_selector


def env():
env = raw_env()
env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
env = wrappers.AssertOutOfBoundsWrapper(env)
env = wrappers.OrderEnforcingWrapper(env)
return env


def get_type(agent):
return agent[: agent.rfind("_")]


class raw_env(AECEnv[str, np.ndarray, Union[int, None]]):
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"}

def __init__(self, max_cycles=100, render_mode=None):
super().__init__()
self._obs_spaces = {}
self._act_spaces = {}

# dummy state space, not actually used
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
self._state = self.state_space.sample()

self.types = []
self._agent_counters = {}
self.max_cycles = max_cycles
self._seed()
self.render_mode = render_mode
for i in range(3):
self.add_type()

def observation_space(self, agent):
return self._obs_spaces[get_type(agent)]

def action_space(self, agent):
return self._act_spaces[get_type(agent)]

def state(self) -> np.ndarray:
return self._state

def observe(self, agent):
return self.observation_space(agent).sample()

def add_type(self):
type_id = len(self.types)
num_actions = self.np_random.integers(3, 10)
obs_size = self.np_random.integers(10, 50)
obs_space = gymnasium.spaces.Dict(
{
"observation": gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,)),
"action_mask": gymnasium.spaces.Box(
low=0, high=1, shape=(num_actions,), dtype=np.int8
),
}
)
act_space = gymnasium.spaces.Discrete(num_actions)
new_type = f"type{type_id}"
self.types.append(new_type)
self._obs_spaces[new_type] = obs_space
self._act_spaces[new_type] = act_space
self._agent_counters[new_type] = 0
return new_type

def add_agent(self, type):
agent_id = self._agent_counters[type]
self._agent_counters[type] += 1
agent = f"{type}_{agent_id}"
self.agents.append(agent)
self.terminations[agent] = False
self.truncations[agent] = False
self.rewards[agent] = 0
self._cumulative_rewards[agent] = 0
self.infos[agent] = {}
return agent

def reset(self, seed=None, options=None):
if seed is not None:
self._seed(seed=seed)
self.agents = []
self.rewards = {}
self._cumulative_rewards = {}
self.terminations = {}
self.truncations = {}
self.infos = {}
self.num_steps = 0

self._obs_spaces = {}
self._act_spaces = {}
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
self._state = self.state_space.sample()

self.types = []
self._agent_counters = {}
for i in range(3):
self.add_type()
for i in range(5):
self.add_agent(self.np_random.choice(self.types))

self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()

# seed observation and action spaces
for i, agent in enumerate(self.agents):
self.observation_space(agent).seed(seed)
for i, agent in enumerate(self.agents):
self.action_space(agent).seed(seed)

def _seed(self, seed=None):
self.np_random, _ = gymnasium.utils.seeding.np_random(seed)

def step(self, action):
if (
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
return self._was_dead_step(action)

self._clear_rewards()
self._cumulative_rewards[self.agent_selection] = 0

if self._agent_selector.is_last():
for i in range(5):
if self.np_random.random() < 0.1:
if self.np_random.random() < 0.1:
type = self.add_type()
else:
type = self.np_random.choice(self.types)

agent = self.add_agent(type)
if len(self.agents) >= 20:
self.terminations[self.np_random.choice(self.agents)] = True

if self._agent_selector.is_last():
self.num_steps += 1

if self.num_steps > self.max_cycles:
for agent in self.agents:
self.truncations[agent] = True

self.rewards[self.np_random.choice(self.agents)] = 1

self._state = self.state_space.sample()

self._accumulate_rewards()
self._deads_step_first()

# Cycle agents
self.agent_selection = self._agent_selector.next()

if self.render_mode == "human":
self.render()

def render(self):
if self.render_mode is None:
gymnasium.logger.warn(
"You are calling render method without specifying any render mode."
)
else:
print(self.agents)

def close(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def step(self, action):

self._accumulate_rewards()
self._deads_step_first()

# Cycle agents
self.agent_selection = self._agent_selector.next()

if self.render_mode == "human":
self.render()

Expand Down
4 changes: 4 additions & 0 deletions pettingzoo/test/example_envs/generated_agents_env_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def step(self, action):

self._accumulate_rewards()
self._deads_step_first()

# Cycle agents
self.agent_selection = self._agent_selector.next()

if self.render_mode == "human":
self.render()

Expand Down
13 changes: 11 additions & 2 deletions pettingzoo/test/seed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,17 @@ def check_environment_deterministic(env1, env2, num_cycles):
if termination1 or truncation1:
break

mask1 = obs1.get("action_mask") if isinstance(obs1, dict) else None
mask2 = obs2.get("action_mask") if isinstance(obs2, dict) else None
mask1 = (
obs1.get("action_mask")
if (isinstance(obs1, dict) and "action_mask" in obs1)
else (info1.get("action_mask") if "action_mask" in info1 else None)
)
mask2 = (
obs2.get("action_mask")
if (isinstance(obs2, dict) and "action_mask" in obs2)
else (info2.get("action_mask") if "action_mask" in info2 else None)
)

assert data_equivalence(mask1, mask2), f"Incorrect action mask: {mask1} {mask2}"

action1 = env1.action_space(agent1).sample(mask1)
Expand Down
Loading

0 comments on commit 23c4242

Please sign in to comment.