From e2859e719b86f7b014ec104744815349df972545 Mon Sep 17 00:00:00 2001 From: snow-fox Date: Mon, 3 Oct 2022 13:40:19 +0100 Subject: [PATCH 01/24] fix wrapper unwrapped thing --- pettingzoo/utils/conversions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pettingzoo/utils/conversions.py b/pettingzoo/utils/conversions.py index ff3ea6068..4a4c7388a 100644 --- a/pettingzoo/utils/conversions.py +++ b/pettingzoo/utils/conversions.py @@ -190,7 +190,12 @@ def __init__(self, parallel_env): self.metadata = {**parallel_env.metadata} self.metadata["is_parallelizable"] = True - self.render_mode = self.env.render_mode + try: + self.render_mode = self.env.render_mode + except: + warnings.warn( + f"The base environment `{parallel_env}` does not have a `render_mode` defined." + ) try: self.possible_agents = parallel_env.possible_agents From 4ead31abb6b9fde4a2cf4a2483dba59017bc5ccb Mon Sep 17 00:00:00 2001 From: snow-fox Date: Mon, 3 Oct 2022 13:49:49 +0100 Subject: [PATCH 02/24] no bare excepts --- pettingzoo/utils/conversions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pettingzoo/utils/conversions.py b/pettingzoo/utils/conversions.py index 4a4c7388a..6f04d2434 100644 --- a/pettingzoo/utils/conversions.py +++ b/pettingzoo/utils/conversions.py @@ -192,7 +192,7 @@ def __init__(self, parallel_env): try: self.render_mode = self.env.render_mode - except: + except AttributeError: warnings.warn( f"The base environment `{parallel_env}` does not have a `render_mode` defined." ) From 4c06d311f40c28d6f17d987e4cb0b028f6b3036f Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 12:32:16 +0100 Subject: [PATCH 03/24] add first basic tutorial --- tutorials/cleanrl.py | 231 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tutorials/cleanrl.py diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py new file mode 100644 index 000000000..d175363cb --- /dev/null +++ b/tutorials/cleanrl.py @@ -0,0 +1,231 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.distributions.categorical import Categorical + +from pettingzoo.butterfly import cooperative_pong_v5 +from supersuit import color_reduction_v0, resize_v1, frame_stack_v1 + + +class Agent(nn.Module): + def __init__(self, num_actions): + super().__init__() + + self.network = nn.Sequential( + self.layer_init(nn.Conv2d(4, 32, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self.layer_init(nn.Conv2d(32, 64, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self.layer_init(nn.Conv2d(64, 128, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Flatten(), + self.layer_init(nn.Linear(128 * 4 * 4, 512)), + nn.ReLU(), + ) + self.actor = self.layer_init(nn.Linear(512, num_actions), std=0.01) + self.critic = self.layer_init(nn.Linear(512, 1)) + + def layer_init(self, layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +def batchify_obs(obs, device): + """Converts PZ style observations to batch of torch arrays""" + # convert to list of np arrays + obs = np.stack([obs[a] for a in env.possible_agents], axis=0) + # transpose to be (batch, channel, height, width) + obs = obs.transpose(0, -1, 1, 2) + # convert to torch + obs = torch.tensor(obs).to(device) + + return obs + + +def batchify(x, device): + """Converts PZ style returns to batch of torch arrays""" + # convert to list of np arrays + x = np.stack([x[a] for a in env.possible_agents], axis=0) + # convert to torch + x = torch.tensor(x).to(device) + + return x + + +def unbatchify(x, env): + """Converts np array to PZ style arguments""" + x = x.cpu().numpy() + x = {a: x[i] for i, a in enumerate(env.possible_agents)} + + return x + + +if __name__ == "__main__": + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + """ ENV SETUP """ + env = cooperative_pong_v5.parallel_env() + env = color_reduction_v0(env) + env = resize_v1(env, 32, 32) + env = frame_stack_v1(env, stack_size=4) + num_agents = len(env.possible_agents) + num_actions = env.action_space(env.possible_agents[0]).n + observation_size = env.observation_space((env.possible_agents[0])).shape + + """ LEARNER SETUP """ + agent = Agent(num_actions=num_actions).to(device) + optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5) + + """ ALGO PARAMS """ + ent_coef = 0.1 + vf_coef = 0.1 + clip_coef = 0.1 + gamma = 0.99 + batch_size = 32 + + """ ALGO LOGIC: EPISODE STORAGE""" + num_steps = 900 + end_step = 0 + total_episodic_return = 0 + rb_obs = torch.zeros((num_steps, num_agents, 4, 32, 32)).to(device) + rb_actions = torch.zeros((num_steps, num_agents)).to(device) + rb_logprobs = torch.zeros((num_steps, num_agents)).to(device) + rb_rewards = torch.zeros((num_steps, num_agents)).to(device) + rb_terms = torch.zeros((num_steps, num_agents)).to(device) + rb_values = torch.zeros((num_steps, num_agents)).to(device) + + """ TRAINING LOGIC """ + # train for n number of episodes + for episode in range(1, 1000): + + # collect observations and convert to batch of torch tensors + next_obs = batchify_obs(env.reset(seed=None), device) + # get next dones and convert to batch of torch tensors + next_dones = torch.zeros(num_agents).to(device) + # reset the episodic return + total_episodic_return = 0 + + # each episode has num_steps + for step in range(0, num_steps): + # store the observation and done + rb_obs[step] = next_obs + rb_terms[step] = next_dones + + # ALGO LOGIC: action logic + with torch.no_grad(): + actions, logprobs, _, values = agent.get_action_and_value(next_obs) + rb_values[step] = values.flatten() + rb_actions[step] = actions + rb_logprobs[step] = logprobs + + # execute the environment and log data + next_obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) + next_obs = batchify_obs(next_obs, device) + rewards = batchify(rewards, device) + terms = batchify(terms, device) + truncs = batchify(truncs, device) + rb_rewards[step] = rewards + total_episodic_return += rewards.cpu().numpy() + + # if we reach termination or truncation, end + if any(terms) or any(truncs): + end_step = step + break + + # bootstrap value if not done + with torch.no_grad(): + rb_advantages = torch.zeros_like(rb_rewards).to(device) + for t in reversed(range(end_step + 1)): + next_V = rb_values[t] + delta = rb_rewards[t] + gamma * rb_values[t + 1] - rb_values[t] + rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] + rb_returns = rb_advantages + rb_values + + # convert our episodes to individual transitions + b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1) + b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1) + b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1) + b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1) + b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1) + b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1) + + # Optimizing the policy and value network + b_index = np.arange(len(b_obs)) + clip_fracs = [] + for repeat in range(3): + np.random.shuffle(b_index) + for start in range(0, len(b_obs), batch_size): + end = start + batch_size + batch_index = b_index[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value( + b_obs[batch_index], b_actions.long()[batch_index] + ) + logratio = newlogprob - b_logprobs[batch_index] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clip_fracs += [ + ((ratio - 1.0).abs() > clip_coef).float().mean().item() + ] + + # Policy loss + pg_loss1 = -b_advantages[batch_index] * ratio + pg_loss2 = -b_advantages[batch_index] * torch.clamp( + ratio, 1 - clip_coef, 1 + clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss_unclipped = (newvalue - b_returns[batch_index]) ** 2 + v_clipped = b_values[batch_index] + torch.clamp( + newvalue - b_values[batch_index], + -clip_coef, + clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + + entropy_loss = entropy.mean() + loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + print(f"Training episode {episode}") + print(f"Episodic Return: {np.mean(total_episodic_return)}") + print(f"Episode Length: {end_step}") + print("") + print(f"Value Loss: {v_loss.item()}") + print(f"Policy Loss: {pg_loss.item()}") + print(f"Old Approx KL: {old_approx_kl.item()}") + print(f"Approx KL: {approx_kl.item()}") + print(f"Clip Fraction: {np.mean(clip_fracs)}") + print(f"Explained Variance: {explained_var.item()}") + print("\n-------------------------------------------\n") From 185e54a3a2e1b629bc95b55ef577d8811ab83504 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 14:09:00 +0100 Subject: [PATCH 04/24] fix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learnfix to pistonball cause easier to learn --- tutorials/cleanrl.py | 125 +++++++++++++++++++++++++++---------------- 1 file changed, 80 insertions(+), 45 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index d175363cb..5a59f72ca 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -4,7 +4,7 @@ import torch.optim as optim from torch.distributions.categorical import Categorical -from pettingzoo.butterfly import cooperative_pong_v5 +from pettingzoo.butterfly import pistonball_v6 from supersuit import color_reduction_v0, resize_v1, frame_stack_v1 @@ -23,7 +23,7 @@ def __init__(self, num_actions): nn.MaxPool2d(2), nn.ReLU(), nn.Flatten(), - self.layer_init(nn.Linear(128 * 4 * 4, 512)), + self.layer_init(nn.Linear(128 * 8 * 8, 512)), nn.ReLU(), ) self.actor = self.layer_init(nn.Linear(512, num_actions), std=0.01) @@ -49,7 +49,7 @@ def get_action_and_value(self, x, action=None): def batchify_obs(obs, device): """Converts PZ style observations to batch of torch arrays""" # convert to list of np arrays - obs = np.stack([obs[a] for a in env.possible_agents], axis=0) + obs = np.stack([obs[a] for a in obs], axis=0) # transpose to be (batch, channel, height, width) obs = obs.transpose(0, -1, 1, 2) # convert to torch @@ -61,7 +61,7 @@ def batchify_obs(obs, device): def batchify(x, device): """Converts PZ style returns to batch of torch arrays""" # convert to list of np arrays - x = np.stack([x[a] for a in env.possible_agents], axis=0) + x = np.stack([x[a] for a in x], axis=0) # convert to torch x = torch.tensor(x).to(device) @@ -81,9 +81,9 @@ def unbatchify(x, env): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ ENV SETUP """ - env = cooperative_pong_v5.parallel_env() + env = pistonball_v6.parallel_env(render_mode="rgb_array", continuous=False) env = color_reduction_v0(env) - env = resize_v1(env, 32, 32) + env = resize_v1(env, 64, 64) env = frame_stack_v1(env, stack_size=4) num_agents = len(env.possible_agents) num_actions = env.action_space(env.possible_agents[0]).n @@ -101,10 +101,10 @@ def unbatchify(x, env): batch_size = 32 """ ALGO LOGIC: EPISODE STORAGE""" - num_steps = 900 + num_steps = 125 end_step = 0 total_episodic_return = 0 - rb_obs = torch.zeros((num_steps, num_agents, 4, 32, 32)).to(device) + rb_obs = torch.zeros((num_steps, num_agents, 4, 64, 64)).to(device) rb_actions = torch.zeros((num_steps, num_agents)).to(device) rb_logprobs = torch.zeros((num_steps, num_agents)).to(device) rb_rewards = torch.zeros((num_steps, num_agents)).to(device) @@ -113,48 +113,55 @@ def unbatchify(x, env): """ TRAINING LOGIC """ # train for n number of episodes - for episode in range(1, 1000): - - # collect observations and convert to batch of torch tensors - next_obs = batchify_obs(env.reset(seed=None), device) - # get next dones and convert to batch of torch tensors - next_dones = torch.zeros(num_agents).to(device) - # reset the episodic return - total_episodic_return = 0 - - # each episode has num_steps - for step in range(0, num_steps): - # store the observation and done - rb_obs[step] = next_obs - rb_terms[step] = next_dones - - # ALGO LOGIC: action logic - with torch.no_grad(): - actions, logprobs, _, values = agent.get_action_and_value(next_obs) - rb_values[step] = values.flatten() + for episode in range(2): + + # collect an episode + with torch.no_grad(): + + # collect observations and convert to batch of torch tensors + next_obs = env.reset(seed=None) + # reset the episodic return + total_episodic_return = 0 + + # each episode has num_steps + for step in range(0, num_steps): + + # rollover the observation + obs = batchify_obs(next_obs, device) + + # get action from the agent + actions, logprobs, _, values = agent.get_action_and_value(obs) + + # execute the environment and log data + next_obs, rewards, terms, truncs, infos = env.step( + unbatchify(actions, env) + ) + + # add to episode storage + rb_obs[step] = obs + rb_rewards[step] = batchify(rewards, device) + rb_terms[step] = batchify(terms, device) rb_actions[step] = actions rb_logprobs[step] = logprobs + rb_values[step] = values.flatten() - # execute the environment and log data - next_obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) - next_obs = batchify_obs(next_obs, device) - rewards = batchify(rewards, device) - terms = batchify(terms, device) - truncs = batchify(truncs, device) - rb_rewards[step] = rewards - total_episodic_return += rewards.cpu().numpy() + # compute episodic return + total_episodic_return += rb_rewards[step].cpu().numpy() - # if we reach termination or truncation, end - if any(terms) or any(truncs): - end_step = step - break + # if we reach termination or truncation, end + if any([terms[a] for a in terms]) or any([truncs[a] for a in truncs]): + end_step = step + break # bootstrap value if not done with torch.no_grad(): rb_advantages = torch.zeros_like(rb_rewards).to(device) - for t in reversed(range(end_step + 1)): - next_V = rb_values[t] - delta = rb_rewards[t] + gamma * rb_values[t + 1] - rb_values[t] + for t in reversed(range(end_step)): + delta = ( + rb_rewards[t] + + gamma * rb_values[t + 1] * rb_terms[t + 1] + - rb_values[t] + ) rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] rb_returns = rb_advantages + rb_values @@ -170,12 +177,14 @@ def unbatchify(x, env): b_index = np.arange(len(b_obs)) clip_fracs = [] for repeat in range(3): + # shuffle the indices we use to access the data np.random.shuffle(b_index) for start in range(0, len(b_obs), batch_size): + # select the indices we want to train on end = start + batch_size batch_index = b_index[start:end] - _, newlogprob, entropy, newvalue = agent.get_action_and_value( + _, newlogprob, entropy, value = agent.get_action_and_value( b_obs[batch_index], b_actions.long()[batch_index] ) logratio = newlogprob - b_logprobs[batch_index] @@ -189,6 +198,10 @@ def unbatchify(x, env): ((ratio - 1.0).abs() > clip_coef).float().mean().item() ] + # normalize advantaegs + advantages = b_advantages[batch_index] + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + # Policy loss pg_loss1 = -b_advantages[batch_index] * ratio pg_loss2 = -b_advantages[batch_index] * torch.clamp( @@ -197,9 +210,10 @@ def unbatchify(x, env): pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss - v_loss_unclipped = (newvalue - b_returns[batch_index]) ** 2 + value = value.flatten() + v_loss_unclipped = (value - b_returns[batch_index]) ** 2 v_clipped = b_values[batch_index] + torch.clamp( - newvalue - b_values[batch_index], + value - b_values[batch_index], -clip_coef, clip_coef, ) @@ -229,3 +243,24 @@ def unbatchify(x, env): print(f"Clip Fraction: {np.mean(clip_fracs)}") print(f"Explained Variance: {explained_var.item()}") print("\n-------------------------------------------\n") + + """ RENDER THE POLICY """ + env = pistonball_v6.parallel_env(render_mode="human", continuous=False) + env = color_reduction_v0(env) + env = resize_v1(env, 64, 64) + env = frame_stack_v1(env, stack_size=4) + + agent.eval() + + # render 5 episodes out + for episode in range(1): + obs = batchify_obs(env.reset(seed=None), device) + terms = [False] + truncs = [False] + while not any(terms) and not any(truncs): + actions, logprobs, _, values = agent.get_action_and_value(obs) + obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) + obs = batchify_obs(obs, device) + terms = [terms[a] for a in terms] + truncs = [truncs[a] for a in truncs] + From 387fc1fa18cd47be8c2df62ad3ec633432a4efb4 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 14:19:02 +0100 Subject: [PATCH 05/24] update top readme --- tutorials/cleanrl.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index 5a59f72ca..6f43b63de 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -1,3 +1,15 @@ +"""Basic code which shows what it's like to run PPO on the Pistonball env using the parallel API, this code is inspired by CleanRL. + +This code is exceedingly basic, with no logging or weights saving. +The intention was for users to have a (relatively clean) ~200 line file to refer to when they want to design their own learning algorithm. + +Dependencies: +- SuperSuit==3.6.0 +- numpy==1.23.2 +- torch==1.12.1 +- pettinzoo==1.22.0 +""" + import numpy as np import torch import torch.nn as nn @@ -252,15 +264,16 @@ def unbatchify(x, env): agent.eval() - # render 5 episodes out - for episode in range(1): - obs = batchify_obs(env.reset(seed=None), device) - terms = [False] - truncs = [False] - while not any(terms) and not any(truncs): - actions, logprobs, _, values = agent.get_action_and_value(obs) - obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) - obs = batchify_obs(obs, device) - terms = [terms[a] for a in terms] - truncs = [truncs[a] for a in truncs] + with torch.no_grad(): + # render 5 episodes out + for episode in range(1): + obs = batchify_obs(env.reset(seed=None), device) + terms = [False] + truncs = [False] + while not any(terms) and not any(truncs): + actions, logprobs, _, values = agent.get_action_and_value(obs) + obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) + obs = batchify_obs(obs, device) + terms = [terms[a] for a in terms] + truncs = [truncs[a] for a in truncs] From 18f53d3ed5cd23badb34f334bd02806d63224d11 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 14:19:38 +0100 Subject: [PATCH 06/24] black isort --- tutorials/cleanrl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index 6f43b63de..1055e8331 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -14,10 +14,10 @@ import torch import torch.nn as nn import torch.optim as optim +from supersuit import color_reduction_v0, frame_stack_v1, resize_v1 from torch.distributions.categorical import Categorical from pettingzoo.butterfly import pistonball_v6 -from supersuit import color_reduction_v0, resize_v1, frame_stack_v1 class Agent(nn.Module): @@ -212,7 +212,9 @@ def unbatchify(x, env): # normalize advantaegs advantages = b_advantages[batch_index] - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) # Policy loss pg_loss1 = -b_advantages[batch_index] * ratio @@ -276,4 +278,3 @@ def unbatchify(x, env): obs = batchify_obs(obs, device) terms = [terms[a] for a in terms] truncs = [truncs[a] for a in truncs] - From 4ca2b004feaf426c859a1687dd4413425fce5263 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 14:29:14 +0100 Subject: [PATCH 07/24] move parameters up --- tutorials/cleanrl.py | 54 ++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index 1055e8331..eed001795 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -7,7 +7,7 @@ - SuperSuit==3.6.0 - numpy==1.23.2 - torch==1.12.1 -- pettinzoo==1.22.0 +- pettingzoo==1.22.0 """ import numpy as np @@ -59,7 +59,7 @@ def get_action_and_value(self, x, action=None): def batchify_obs(obs, device): - """Converts PZ style observations to batch of torch arrays""" + """Converts PZ style observations to batch of torch arrays.""" # convert to list of np arrays obs = np.stack([obs[a] for a in obs], axis=0) # transpose to be (batch, channel, height, width) @@ -71,7 +71,7 @@ def batchify_obs(obs, device): def batchify(x, device): - """Converts PZ style returns to batch of torch arrays""" + """Converts PZ style returns to batch of torch arrays.""" # convert to list of np arrays x = np.stack([x[a] for a in x], axis=0) # convert to torch @@ -81,7 +81,7 @@ def batchify(x, device): def unbatchify(x, env): - """Converts np array to PZ style arguments""" + """Converts np array to PZ style arguments.""" x = x.cpu().numpy() x = {a: x[i] for i, a in enumerate(env.possible_agents)} @@ -90,13 +90,25 @@ def unbatchify(x, env): if __name__ == "__main__": + """ALGO PARAMS""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ent_coef = 0.1 + vf_coef = 0.1 + clip_coef = 0.1 + gamma = 0.99 + batch_size = 32 + stack_size = 4 + frame_size = (64, 64) + max_cycles = 125 + total_episodes = 2 """ ENV SETUP """ - env = pistonball_v6.parallel_env(render_mode="rgb_array", continuous=False) + env = pistonball_v6.parallel_env( + render_mode="rgb_array", continuous=False, max_cycles=max_cycles + ) env = color_reduction_v0(env) - env = resize_v1(env, 64, 64) - env = frame_stack_v1(env, stack_size=4) + env = resize_v1(env, frame_size[0], frame_size[1]) + env = frame_stack_v1(env, stack_size=stack_size) num_agents = len(env.possible_agents) num_actions = env.action_space(env.possible_agents[0]).n observation_size = env.observation_space((env.possible_agents[0])).shape @@ -105,27 +117,19 @@ def unbatchify(x, env): agent = Agent(num_actions=num_actions).to(device) optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5) - """ ALGO PARAMS """ - ent_coef = 0.1 - vf_coef = 0.1 - clip_coef = 0.1 - gamma = 0.99 - batch_size = 32 - """ ALGO LOGIC: EPISODE STORAGE""" - num_steps = 125 end_step = 0 total_episodic_return = 0 - rb_obs = torch.zeros((num_steps, num_agents, 4, 64, 64)).to(device) - rb_actions = torch.zeros((num_steps, num_agents)).to(device) - rb_logprobs = torch.zeros((num_steps, num_agents)).to(device) - rb_rewards = torch.zeros((num_steps, num_agents)).to(device) - rb_terms = torch.zeros((num_steps, num_agents)).to(device) - rb_values = torch.zeros((num_steps, num_agents)).to(device) + rb_obs = torch.zeros((max_cycles, num_agents, stack_size, *frame_size)).to(device) + rb_actions = torch.zeros((max_cycles, num_agents)).to(device) + rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device) + rb_rewards = torch.zeros((max_cycles, num_agents)).to(device) + rb_terms = torch.zeros((max_cycles, num_agents)).to(device) + rb_values = torch.zeros((max_cycles, num_agents)).to(device) """ TRAINING LOGIC """ # train for n number of episodes - for episode in range(2): + for episode in range(total_episodes): # collect an episode with torch.no_grad(): @@ -136,7 +140,7 @@ def unbatchify(x, env): total_episodic_return = 0 # each episode has num_steps - for step in range(0, num_steps): + for step in range(0, max_cycles): # rollover the observation obs = batchify_obs(next_obs, device) @@ -177,7 +181,7 @@ def unbatchify(x, env): rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] rb_returns = rb_advantages + rb_values - # convert our episodes to individual transitions + # convert our episodes to batch of individual transitions b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1) b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1) b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1) @@ -268,7 +272,7 @@ def unbatchify(x, env): with torch.no_grad(): # render 5 episodes out - for episode in range(1): + for episode in range(5): obs = batchify_obs(env.reset(seed=None), device) terms = [False] truncs = [False] From 0f110f6655654900a83007ab0041ff9be82cf956 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 16:06:23 +0100 Subject: [PATCH 08/24] fix typo, remove numpy dependency --- tutorials/cleanrl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index eed001795..4da0a4e66 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -5,7 +5,6 @@ Dependencies: - SuperSuit==3.6.0 -- numpy==1.23.2 - torch==1.12.1 - pettingzoo==1.22.0 """ @@ -100,7 +99,7 @@ def unbatchify(x, env): stack_size = 4 frame_size = (64, 64) max_cycles = 125 - total_episodes = 2 + total_episodes = 150 """ ENV SETUP """ env = pistonball_v6.parallel_env( From cef513a69e6591faa21dbaff40716694425ebad6 Mon Sep 17 00:00:00 2001 From: Jet Date: Wed, 5 Oct 2022 16:07:31 +0100 Subject: [PATCH 09/24] reorder dependencies --- tutorials/cleanrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index 4da0a4e66..d7d0bc5d4 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -4,9 +4,9 @@ The intention was for users to have a (relatively clean) ~200 line file to refer to when they want to design their own learning algorithm. Dependencies: -- SuperSuit==3.6.0 -- torch==1.12.1 - pettingzoo==1.22.0 +- supersuit==3.6.0 +- torch==1.12.1 """ import numpy as np From 725dd99a90b2b16203f3cd664ebd60425fa94ec3 Mon Sep 17 00:00:00 2001 From: Jet Date: Thu, 6 Oct 2022 13:34:14 +0100 Subject: [PATCH 10/24] increase training episodes --- tutorials/cleanrl.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index d7d0bc5d4..778f06c53 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -99,7 +99,7 @@ def unbatchify(x, env): stack_size = 4 frame_size = (64, 64) max_cycles = 125 - total_episodes = 150 + total_episodes = 500 """ ENV SETUP """ env = pistonball_v6.parallel_env( @@ -261,23 +261,23 @@ def unbatchify(x, env): print(f"Explained Variance: {explained_var.item()}") print("\n-------------------------------------------\n") - """ RENDER THE POLICY """ - env = pistonball_v6.parallel_env(render_mode="human", continuous=False) - env = color_reduction_v0(env) - env = resize_v1(env, 64, 64) - env = frame_stack_v1(env, stack_size=4) - - agent.eval() - - with torch.no_grad(): - # render 5 episodes out - for episode in range(5): - obs = batchify_obs(env.reset(seed=None), device) - terms = [False] - truncs = [False] - while not any(terms) and not any(truncs): - actions, logprobs, _, values = agent.get_action_and_value(obs) - obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) - obs = batchify_obs(obs, device) - terms = [terms[a] for a in terms] - truncs = [truncs[a] for a in truncs] + # """ RENDER THE POLICY """ + # env = pistonball_v6.parallel_env(render_mode="human", continuous=False) + # env = color_reduction_v0(env) + # env = resize_v1(env, 64, 64) + # env = frame_stack_v1(env, stack_size=4) + + # agent.eval() + + # with torch.no_grad(): + # # render 5 episodes out + # for episode in range(5): + # obs = batchify_obs(env.reset(seed=None), device) + # terms = [False] + # truncs = [False] + # while not any(terms) and not any(truncs): + # actions, logprobs, _, values = agent.get_action_and_value(obs) + # obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) + # obs = batchify_obs(obs, device) + # terms = [terms[a] for a in terms] + # truncs = [truncs[a] for a in truncs] From fd71a97c41fe0ee3cc491e672b92963ccf34e327 Mon Sep 17 00:00:00 2001 From: Jet Date: Thu, 6 Oct 2022 16:59:07 +0100 Subject: [PATCH 11/24] long awaited pre-commit fix --- tutorials/cleanrl.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py index 778f06c53..d66c4181a 100644 --- a/tutorials/cleanrl.py +++ b/tutorials/cleanrl.py @@ -110,7 +110,7 @@ def unbatchify(x, env): env = frame_stack_v1(env, stack_size=stack_size) num_agents = len(env.possible_agents) num_actions = env.action_space(env.possible_agents[0]).n - observation_size = env.observation_space((env.possible_agents[0])).shape + observation_size = env.observation_space(env.possible_agents[0]).shape """ LEARNER SETUP """ agent = Agent(num_actions=num_actions).to(device) @@ -262,22 +262,22 @@ def unbatchify(x, env): print("\n-------------------------------------------\n") # """ RENDER THE POLICY """ - # env = pistonball_v6.parallel_env(render_mode="human", continuous=False) - # env = color_reduction_v0(env) - # env = resize_v1(env, 64, 64) - # env = frame_stack_v1(env, stack_size=4) - - # agent.eval() - - # with torch.no_grad(): - # # render 5 episodes out - # for episode in range(5): - # obs = batchify_obs(env.reset(seed=None), device) - # terms = [False] - # truncs = [False] - # while not any(terms) and not any(truncs): - # actions, logprobs, _, values = agent.get_action_and_value(obs) - # obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) - # obs = batchify_obs(obs, device) - # terms = [terms[a] for a in terms] - # truncs = [truncs[a] for a in truncs] + env = pistonball_v6.parallel_env(render_mode="human", continuous=False) + env = color_reduction_v0(env) + env = resize_v1(env, 64, 64) + env = frame_stack_v1(env, stack_size=4) + + agent.eval() + + with torch.no_grad(): + # render 5 episodes out + for episode in range(5): + obs = batchify_obs(env.reset(seed=None), device) + terms = [False] + truncs = [False] + while not any(terms) and not any(truncs): + actions, logprobs, _, values = agent.get_action_and_value(obs) + obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) + obs = batchify_obs(obs, device) + terms = [terms[a] for a in terms] + truncs = [truncs[a] for a in truncs] From 94ab14ab5baf47ed0d2dc487cf4ce2938be045f9 Mon Sep 17 00:00:00 2001 From: Jet Date: Fri, 14 Oct 2022 12:26:26 +0100 Subject: [PATCH 12/24] update image link and setup --- README.md | 2 +- setup.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c26ba9fac..c4cda1f82 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

- +

PettingZoo is a Python library for conducting research in multi-agent reinforcement learning, akin to a multi-agent version of [Gymnasium](https://github.com/Farama-Foundation/Gymnasium). diff --git a/setup.py b/setup.py index 0b5ceaefe..eeca020d8 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ def get_description(): + """Gets the description from the readme.""" with open("README.md") as fh: long_description = "" header_count = 0 @@ -65,11 +66,12 @@ def get_version(): + extras["tests"] ) +version = get_version() header_count, long_description = get_description() setup( name="PettingZoo", - version=get_version(), + version=version, author="Farama Foundation", author_email="contact@farama.org", description="Gymnasium for multi-agent reinforcement learning", From a79fb495011839484fcc82a61bd9a3f542ba4975 Mon Sep 17 00:00:00 2001 From: Jet Date: Fri, 14 Oct 2022 12:29:08 +0100 Subject: [PATCH 13/24] remove cleanrl from main tutorials section --- tutorials/cleanrl.py | 283 ------------------------------------------- 1 file changed, 283 deletions(-) delete mode 100644 tutorials/cleanrl.py diff --git a/tutorials/cleanrl.py b/tutorials/cleanrl.py deleted file mode 100644 index d66c4181a..000000000 --- a/tutorials/cleanrl.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Basic code which shows what it's like to run PPO on the Pistonball env using the parallel API, this code is inspired by CleanRL. - -This code is exceedingly basic, with no logging or weights saving. -The intention was for users to have a (relatively clean) ~200 line file to refer to when they want to design their own learning algorithm. - -Dependencies: -- pettingzoo==1.22.0 -- supersuit==3.6.0 -- torch==1.12.1 -""" - -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -from supersuit import color_reduction_v0, frame_stack_v1, resize_v1 -from torch.distributions.categorical import Categorical - -from pettingzoo.butterfly import pistonball_v6 - - -class Agent(nn.Module): - def __init__(self, num_actions): - super().__init__() - - self.network = nn.Sequential( - self.layer_init(nn.Conv2d(4, 32, 3, padding=1)), - nn.MaxPool2d(2), - nn.ReLU(), - self.layer_init(nn.Conv2d(32, 64, 3, padding=1)), - nn.MaxPool2d(2), - nn.ReLU(), - self.layer_init(nn.Conv2d(64, 128, 3, padding=1)), - nn.MaxPool2d(2), - nn.ReLU(), - nn.Flatten(), - self.layer_init(nn.Linear(128 * 8 * 8, 512)), - nn.ReLU(), - ) - self.actor = self.layer_init(nn.Linear(512, num_actions), std=0.01) - self.critic = self.layer_init(nn.Linear(512, 1)) - - def layer_init(self, layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.orthogonal_(layer.weight, std) - torch.nn.init.constant_(layer.bias, bias_const) - return layer - - def get_value(self, x): - return self.critic(self.network(x / 255.0)) - - def get_action_and_value(self, x, action=None): - hidden = self.network(x / 255.0) - logits = self.actor(hidden) - probs = Categorical(logits=logits) - if action is None: - action = probs.sample() - return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) - - -def batchify_obs(obs, device): - """Converts PZ style observations to batch of torch arrays.""" - # convert to list of np arrays - obs = np.stack([obs[a] for a in obs], axis=0) - # transpose to be (batch, channel, height, width) - obs = obs.transpose(0, -1, 1, 2) - # convert to torch - obs = torch.tensor(obs).to(device) - - return obs - - -def batchify(x, device): - """Converts PZ style returns to batch of torch arrays.""" - # convert to list of np arrays - x = np.stack([x[a] for a in x], axis=0) - # convert to torch - x = torch.tensor(x).to(device) - - return x - - -def unbatchify(x, env): - """Converts np array to PZ style arguments.""" - x = x.cpu().numpy() - x = {a: x[i] for i, a in enumerate(env.possible_agents)} - - return x - - -if __name__ == "__main__": - - """ALGO PARAMS""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ent_coef = 0.1 - vf_coef = 0.1 - clip_coef = 0.1 - gamma = 0.99 - batch_size = 32 - stack_size = 4 - frame_size = (64, 64) - max_cycles = 125 - total_episodes = 500 - - """ ENV SETUP """ - env = pistonball_v6.parallel_env( - render_mode="rgb_array", continuous=False, max_cycles=max_cycles - ) - env = color_reduction_v0(env) - env = resize_v1(env, frame_size[0], frame_size[1]) - env = frame_stack_v1(env, stack_size=stack_size) - num_agents = len(env.possible_agents) - num_actions = env.action_space(env.possible_agents[0]).n - observation_size = env.observation_space(env.possible_agents[0]).shape - - """ LEARNER SETUP """ - agent = Agent(num_actions=num_actions).to(device) - optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5) - - """ ALGO LOGIC: EPISODE STORAGE""" - end_step = 0 - total_episodic_return = 0 - rb_obs = torch.zeros((max_cycles, num_agents, stack_size, *frame_size)).to(device) - rb_actions = torch.zeros((max_cycles, num_agents)).to(device) - rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device) - rb_rewards = torch.zeros((max_cycles, num_agents)).to(device) - rb_terms = torch.zeros((max_cycles, num_agents)).to(device) - rb_values = torch.zeros((max_cycles, num_agents)).to(device) - - """ TRAINING LOGIC """ - # train for n number of episodes - for episode in range(total_episodes): - - # collect an episode - with torch.no_grad(): - - # collect observations and convert to batch of torch tensors - next_obs = env.reset(seed=None) - # reset the episodic return - total_episodic_return = 0 - - # each episode has num_steps - for step in range(0, max_cycles): - - # rollover the observation - obs = batchify_obs(next_obs, device) - - # get action from the agent - actions, logprobs, _, values = agent.get_action_and_value(obs) - - # execute the environment and log data - next_obs, rewards, terms, truncs, infos = env.step( - unbatchify(actions, env) - ) - - # add to episode storage - rb_obs[step] = obs - rb_rewards[step] = batchify(rewards, device) - rb_terms[step] = batchify(terms, device) - rb_actions[step] = actions - rb_logprobs[step] = logprobs - rb_values[step] = values.flatten() - - # compute episodic return - total_episodic_return += rb_rewards[step].cpu().numpy() - - # if we reach termination or truncation, end - if any([terms[a] for a in terms]) or any([truncs[a] for a in truncs]): - end_step = step - break - - # bootstrap value if not done - with torch.no_grad(): - rb_advantages = torch.zeros_like(rb_rewards).to(device) - for t in reversed(range(end_step)): - delta = ( - rb_rewards[t] - + gamma * rb_values[t + 1] * rb_terms[t + 1] - - rb_values[t] - ) - rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] - rb_returns = rb_advantages + rb_values - - # convert our episodes to batch of individual transitions - b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1) - b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1) - b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1) - b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1) - b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1) - b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1) - - # Optimizing the policy and value network - b_index = np.arange(len(b_obs)) - clip_fracs = [] - for repeat in range(3): - # shuffle the indices we use to access the data - np.random.shuffle(b_index) - for start in range(0, len(b_obs), batch_size): - # select the indices we want to train on - end = start + batch_size - batch_index = b_index[start:end] - - _, newlogprob, entropy, value = agent.get_action_and_value( - b_obs[batch_index], b_actions.long()[batch_index] - ) - logratio = newlogprob - b_logprobs[batch_index] - ratio = logratio.exp() - - with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clip_fracs += [ - ((ratio - 1.0).abs() > clip_coef).float().mean().item() - ] - - # normalize advantaegs - advantages = b_advantages[batch_index] - advantages = (advantages - advantages.mean()) / ( - advantages.std() + 1e-8 - ) - - # Policy loss - pg_loss1 = -b_advantages[batch_index] * ratio - pg_loss2 = -b_advantages[batch_index] * torch.clamp( - ratio, 1 - clip_coef, 1 + clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - # Value loss - value = value.flatten() - v_loss_unclipped = (value - b_returns[batch_index]) ** 2 - v_clipped = b_values[batch_index] + torch.clamp( - value - b_values[batch_index], - -clip_coef, - clip_coef, - ) - v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2 - v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = 0.5 * v_loss_max.mean() - - entropy_loss = entropy.mean() - loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() - var_y = np.var(y_true) - explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y - - print(f"Training episode {episode}") - print(f"Episodic Return: {np.mean(total_episodic_return)}") - print(f"Episode Length: {end_step}") - print("") - print(f"Value Loss: {v_loss.item()}") - print(f"Policy Loss: {pg_loss.item()}") - print(f"Old Approx KL: {old_approx_kl.item()}") - print(f"Approx KL: {approx_kl.item()}") - print(f"Clip Fraction: {np.mean(clip_fracs)}") - print(f"Explained Variance: {explained_var.item()}") - print("\n-------------------------------------------\n") - - # """ RENDER THE POLICY """ - env = pistonball_v6.parallel_env(render_mode="human", continuous=False) - env = color_reduction_v0(env) - env = resize_v1(env, 64, 64) - env = frame_stack_v1(env, stack_size=4) - - agent.eval() - - with torch.no_grad(): - # render 5 episodes out - for episode in range(5): - obs = batchify_obs(env.reset(seed=None), device) - terms = [False] - truncs = [False] - while not any(terms) and not any(truncs): - actions, logprobs, _, values = agent.get_action_and_value(obs) - obs, rewards, terms, truncs, infos = env.step(unbatchify(actions, env)) - obs = batchify_obs(obs, device) - terms = [terms[a] for a in terms] - truncs = [truncs[a] for a in truncs] From c4bff4a67bc9347f2335718f491d662f53bd9dee Mon Sep 17 00:00:00 2001 From: Jet Date: Fri, 14 Oct 2022 12:30:28 +0100 Subject: [PATCH 14/24] add CoC --- CODE_OF_CONDUCT.rst | 68 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 CODE_OF_CONDUCT.rst diff --git a/CODE_OF_CONDUCT.rst b/CODE_OF_CONDUCT.rst new file mode 100644 index 000000000..f91dd916a --- /dev/null +++ b/CODE_OF_CONDUCT.rst @@ -0,0 +1,68 @@ +================================= +Farama Foundation Code of Conduct +================================= + +The Farama Foundation is dedicated to providing a harassment-free experience for +everyone, regardless of gender, gender identity and expression, sexual +orientation, disability, physical appearance, body size, age, race, or +religion. We do not tolerate harassment of participants in any form. + +This code of conduct applies to all Farama Foundation repositories (including Gist +comments) both online and off. Anyone who violates this code of +conduct may be sanctioned or expelled from these spaces at the +discretion of the moderators. + +We may add additional rules over time, which will be made clearly +available to participants. Participants are responsible for knowing +and abiding by these rules. + +------------- +Our Standards +------------- +Members of the Farama Foundation community are **open**, **inclusive**, and **respectful**. +Examples of behavior that contributes to a positive environment for our community include: + +* **Being open**. Members of the community are open to collaboration, whether it's on issues, PRs, problems, or otherwise +* **Focusing on what is best for the community**. We're respectful of the processes set forth in the community, and we work within them to + improve the community. +* **Being respectful of differing viewpoints and experiences.** We're receptive to constructive comments and criticism, + as the experiences and skill sets of other members contribute to the whole of our efforts. +* **Showing empathy.** We're attentive in our communications, and we're tactful when approaching differing views. +* **Being respectful.** We're respectful of differing opinions, viewpoints, experiences, and efforts. +* **Gracefully accepting constructive criticism.** When we disagree, we are courteous in raising our issues. +* **Using welcoming and inclusive language.** We're accepting of all who wish to take part in our activities, fostering + an environment where anyone can participate and everyone can make a difference. + +Examples of unacceptable behavior include: + +* Harassment of any participants in any form. +* The use of sexual language or imagery, and sexual attention or advances of any kind. +* Insults, put downs, or jokes that are based upon stereotypes, that are exclusionary, or that hold others up for ridicule. +* Publishing others' private information, such as a physical or email address, without explicit permission. +* Incitement of violence or harassment towards any individual, including encouraging a person to commit suicide or to engage in self-harm. +* Sustained disruption of online community discussions, in-person presentations, or other in-person events. +* Creating additional online accounts in order to harass another person or circumvent a ban +* Other conduct which could reasonably be considered inappropriate in a professional setting including people of many different backgrounds. + +Members asked to stop any inappropriate behavior are expected to comply immediately. + +------------ +Consequences +------------ +If a participant engages in behavior that violates this code of conduct, the Farama Foundation team may take any action they deem +appropriate, including warning the offender or expulsion from the community. + +Thank you for helping make this a welcoming, friendly community for everyone. + +------- +License +------- +This Code of Conduct is licensed under the `Creative Commons Attribution-ShareAlike 3.0 Unported License +`_. + +----------- +Attribution +----------- +This Code of Conduct is adapted from `Python's Code of Conduct `_, which is under a `Creative Commons License +`_. + From c26a27878e10e710025da909572265e93229ddb8 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 22 Oct 2022 13:18:26 +0100 Subject: [PATCH 15/24] long awaited graph space --- .../knights_archers_zombies.py | 90 ++-- pettingzoo/test/state_test.py | 2 + test/all_parameter_combs_test.py | 411 +++++++++--------- 3 files changed, 263 insertions(+), 240 deletions(-) diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index 9bf581691..c3efae1a7 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -99,9 +99,9 @@ As a result, setting `use_typemask=True` results in the observation being a (N+1)x11 vector. -**Transformers** (Experimental) +**Graph Space** (Experimental) -There is an option to also pass `transformer=True` as a kwarg to the environment. This just removes all non-existent entities from the observation and state vectors. Note that this is **still experimental** as the state and observation size are no longer constant. In particular, `N` is now a +There is an option to also pass `graph_space=True` as a kwarg to the environment. This just removes all non-existent entities from the observation and state vectors. Note that this is **still experimental** as the state and observation size are no longer constant. In particular, `N` is now a variable number. #### Image-based @@ -135,7 +135,7 @@ max_cycles=900, vector_state=True, use_typemasks=False, - transformer=False, + graph_space=False, ``` `spawn_rate`: how many cycles before a new zombie is spawned. A lower number means zombies are spawned at a higher rate. @@ -160,7 +160,7 @@ `use_typemasks`: only relevant when `vector_state=True` is set, adds typemasks to the vectors. -`transformer`: **experimental**, only relevant when `vector_state=True` is set, removes non-existent entities in the vector state. +`graph_space`: **experimental**, only relevant when `vector_state=True` is set, removes non-existent entities in the vector state. ### Version History @@ -187,7 +187,7 @@ import numpy as np import pygame import pygame.gfxdraw -from gymnasium.spaces import Box, Discrete +from gymnasium.spaces import Box, Discrete, Graph, GraphInstance from gymnasium.utils import EzPickle, seeding from pettingzoo import AECEnv @@ -238,7 +238,7 @@ def __init__( max_cycles=900, vector_state=True, use_typemasks=False, - transformer=False, + graph_space=False, render_mode=None, ): EzPickle.__init__( @@ -255,11 +255,19 @@ def __init__( max_cycles, vector_state, use_typemasks, - transformer, + graph_space, render_mode, ) # variable state space - self.transformer = transformer + self.graph_space = graph_space + if self.graph_space: + assert ( + vector_state + ), "vector_state must be True if graph_space is True." + + assert ( + use_typemasks + ), "use_typemasks should be True if graph_space is True" # whether we want RGB state or vector state self.vector_state = vector_state @@ -267,7 +275,7 @@ def __init__( self.num_tracked = ( num_archers + num_knights + max_zombies + num_knights + max_arrows ) - self.use_typemasks = True if transformer else use_typemasks + self.use_typemasks = True if graph_space else use_typemasks self.typemask_width = 6 self.vector_width = 4 + self.typemask_width if use_typemasks else 4 @@ -318,15 +326,23 @@ def __init__( low = 0 if not self.vector_state else -1.0 high = 255 if not self.vector_state else 1.0 dtype = np.uint8 if not self.vector_state else np.float64 - self.observation_spaces = dict( - zip( - self.agents, - [ - Box(low=low, high=high, shape=shape, dtype=dtype) - for _ in enumerate(self.agents) - ], + if not self.graph_space: + obs_space = Box(low=low, high=high, shape=shape, dtype=dtype) + self.observation_spaces = dict( + zip( + self.agents, + [obs_space for _ in enumerate(self.agents)], + ) + ) + else: + box_space = Box(low=low, high=high, shape=[shape[-1]], dtype=dtype) + obs_space = Graph(node_space=box_space, edge_space=None) + self.observation_spaces = dict( + zip( + self.agents, + [obs_space for _ in enumerate(self.agents)], + ) ) - ) self.action_spaces = dict( zip(self.agents, [Discrete(6) for _ in enumerate(self.agents)]) @@ -570,6 +586,10 @@ def observe(self, agent): # prepend agent state to the observation state = np.concatenate([agent_state, state], axis=0) + if self.graph_space: + # remove pure zero rows if using graph space + state = state[~np.all(state == 0, axis=-1)] + state = GraphInstance(nodes=state, edges=None, edge_links=None) return state @@ -603,8 +623,7 @@ def get_vector_state(self): vector = np.concatenate((typemask, agent.vector_state), axis=0) state.append(vector) else: - if not self.transformer: - state.append(np.zeros(self.vector_width)) + state.append(np.zeros(self.vector_width)) # handle swords for agent in self.agent_list: @@ -618,13 +637,12 @@ def get_vector_state(self): state.append(vector) # handle empty swords - if not self.transformer: - state.extend( - repeat( - np.zeros(self.vector_width), - self.num_knights - self.num_active_swords, - ) + state.extend( + repeat( + np.zeros(self.vector_width), + self.num_knights - self.num_active_swords, ) + ) # handle arrows for agent in self.agent_list: @@ -638,13 +656,12 @@ def get_vector_state(self): state.append(vector) # handle empty arrows - if not self.transformer: - state.extend( - repeat( - np.zeros(self.vector_width), - self.max_arrows - self.num_active_arrows, - ) + state.extend( + repeat( + np.zeros(self.vector_width), + self.max_arrows - self.num_active_arrows, ) + ) # handle zombies for zombie in self.zombie_list: @@ -656,13 +673,12 @@ def get_vector_state(self): state.append(vector) # handle empty zombies - if not self.transformer: - state.extend( - repeat( - np.zeros(self.vector_width), - self.max_zombies - len(self.zombie_list), - ) + state.extend( + repeat( + np.zeros(self.vector_width), + self.max_zombies - len(self.zombie_list), ) + ) return np.stack(state, axis=0) diff --git a/pettingzoo/test/state_test.py b/pettingzoo/test/state_test.py index 62cdd95dd..d049e614c 100644 --- a/pettingzoo/test/state_test.py +++ b/pettingzoo/test/state_test.py @@ -55,6 +55,8 @@ def test_state(env, num_cycles): env.step(action) new_state = env.state() + print(new_state) + print(env.state_space) assert env.state_space.contains( new_state ), "Environment's state is outside of it's state space" diff --git a/test/all_parameter_combs_test.py b/test/all_parameter_combs_test.py index 8072d7a49..80447badf 100644 --- a/test/all_parameter_combs_test.py +++ b/test/all_parameter_combs_test.py @@ -8,212 +8,217 @@ from .all_modules import * # noqa: F403 parameterized_envs = [ - ["atari/boxing_v2", boxing_v2, dict(obs_type="grayscale_image")], - ["atari/boxing_v2", boxing_v2, dict(obs_type="ram")], - ["atari/boxing_v2", boxing_v2, dict(full_action_space=False)], - ["atari/combat_plane_v2", combat_plane_v2, dict(game_version="jet")], - ["atari/combat_plane_v2", combat_plane_v2, dict(guided_missile=True)], - ["atari/combat_tank_v2", combat_tank_v2, dict(has_maze=True)], - ["atari/combat_tank_v2", combat_tank_v2, dict(is_invisible=True)], - ["atari/combat_tank_v2", combat_tank_v2, dict(billiard_hit=True)], - ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="race")], - ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="capture")], - ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=1)], - ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=3)], - [ - "atari/space_invaders_v2", - space_invaders_v2, - dict( - alternating_control=True, - moving_shields=True, - zigzaging_bombs=True, - fast_bomb=True, - invisible_invaders=True, - ), - ], - ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=2)], - ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=3)], - ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=4)], - ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=2)], - ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=3)], - ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=4)], - ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=2)], - ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=3)], - ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=4)], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(spawn_rate=50), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(num_knights=4, num_archers=5), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(killable_knights=True, killable_archers=True), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(killable_knights=False, killable_archers=False), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(line_death=False), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(vector_state=False), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(vector_state=False, pad_observation=False), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(max_cycles=100), - ], + # ["atari/boxing_v2", boxing_v2, dict(obs_type="grayscale_image")], + # ["atari/boxing_v2", boxing_v2, dict(obs_type="ram")], + # ["atari/boxing_v2", boxing_v2, dict(full_action_space=False)], + # ["atari/combat_plane_v2", combat_plane_v2, dict(game_version="jet")], + # ["atari/combat_plane_v2", combat_plane_v2, dict(guided_missile=True)], + # ["atari/combat_tank_v2", combat_tank_v2, dict(has_maze=True)], + # ["atari/combat_tank_v2", combat_tank_v2, dict(is_invisible=True)], + # ["atari/combat_tank_v2", combat_tank_v2, dict(billiard_hit=True)], + # ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="race")], + # ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="capture")], + # ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=1)], + # ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=3)], + # [ + # "atari/space_invaders_v2", + # space_invaders_v2, + # dict( + # alternating_control=True, + # moving_shields=True, + # zigzaging_bombs=True, + # fast_bomb=True, + # invisible_invaders=True, + # ), + # ], + # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=2)], + # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=3)], + # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=4)], + # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=2)], + # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=3)], + # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=4)], + # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=2)], + # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=3)], + # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=4)], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(spawn_rate=50), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(num_knights=4, num_archers=5), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(killable_knights=True, killable_archers=True), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(killable_knights=False, killable_archers=False), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(line_death=False), + # ], [ "butterfly/knights_archers_zombies_v10", knights_archers_zombies_v10, - dict(use_typemasks=False), - ], - [ - "butterfly/knights_archers_zombies_v10", - knights_archers_zombies_v10, - dict(max_zombies=2, max_arrows=60), - ], - ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=True)], - ["butterfly/pistonball_v6", pistonball_v6, dict(n_pistons=30)], - ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=False)], - [ - "butterfly/pistonball_v6", - pistonball_v6, - dict(random_drop=True, random_rotate=True), - ], - [ - "butterfly/pistonball_v6", - pistonball_v6, - dict(random_drop=False, random_rotate=False), - ], - ["classic/go_v5", go_v5, dict(board_size=13, komi=2.5)], - ["classic/go_v5", go_v5, dict(board_size=9, komi=0.0)], - ["classic/hanabi_v4", hanabi_v4, dict(colors=3)], - ["classic/hanabi_v4", hanabi_v4, dict(ranks=3)], - ["classic/hanabi_v4", hanabi_v4, dict(players=4)], - ["classic/hanabi_v4", hanabi_v4, dict(hand_size=5)], - ["classic/hanabi_v4", hanabi_v4, dict(max_information_tokens=3)], - ["classic/hanabi_v4", hanabi_v4, dict(max_life_tokens=2)], - [ - "classic/hanabi_v4", - hanabi_v4, - dict( - colors=5, - ranks=3, - players=4, - hand_size=5, - max_information_tokens=3, - max_life_tokens=2, - ), - ], - ["classic/hanabi_v4", hanabi_v4, dict(observation_type=0)], - ["classic/hanabi_v4", hanabi_v4, dict(observation_type=1)], - ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=False)], - ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=True)], - ["mpe/simple_adversary_v2", simple_adversary_v2, dict(N=4)], - ["mpe/simple_reference_v2", simple_reference_v2, dict(local_ratio=0.2)], - ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5)], - [ - "mpe/simple_tag_v2", - simple_tag_v2, - dict(num_good=5, num_adversaries=10, num_obstacles=4), - ], - [ - "mpe/simple_tag_v2", - simple_tag_v2, - dict(num_good=1, num_adversaries=1, num_obstacles=1), - ], - [ - "mpe/simple_world_comm_v2", - simple_world_comm_v2, - dict(num_good=5, num_adversaries=10, num_obstacles=4, num_food=3), - ], - [ - "mpe/simple_world_comm_v2", - simple_world_comm_v2, - dict(num_good=1, num_adversaries=1, num_obstacles=1, num_food=1), - ], - [ - "mpe/simple_adversary_v2", - simple_adversary_v2, - dict(N=4, continuous_actions=True), - ], - [ - "mpe/simple_reference_v2", - simple_reference_v2, - dict(local_ratio=0.2, continuous_actions=True), - ], - ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5, continuous_actions=True)], - [ - "mpe/simple_tag_v2", - simple_tag_v2, - dict(num_good=5, num_adversaries=10, num_obstacles=4, continuous_actions=True), - ], - [ - "mpe/simple_tag_v2", - simple_tag_v2, - dict(num_good=1, num_adversaries=1, num_obstacles=1, continuous_actions=True), - ], - [ - "mpe/simple_world_comm_v2", - simple_world_comm_v2, - dict( - num_good=5, - num_adversaries=10, - num_obstacles=4, - num_food=3, - continuous_actions=True, - ), - ], - [ - "mpe/simple_world_comm_v2", - simple_world_comm_v2, - dict( - num_good=1, - num_adversaries=1, - num_obstacles=1, - num_food=1, - continuous_actions=True, - ), - ], - ["sisl/multiwalker_v9", multiwalker_v9, dict(n_walkers=10)], - ["sisl/multiwalker_v9", multiwalker_v9, dict(shared_reward=False)], - ["sisl/multiwalker_v9", multiwalker_v9, dict(terminate_on_fall=False)], - [ - "sisl/multiwalker_v8", - multiwalker_v9, - dict(terminate_on_fall=False, remove_on_fall=False), - ], - ["sisl/pursuit_v4", pursuit_v4, dict(x_size=8, y_size=19)], - ["sisl/pursuit_v4", pursuit_v4, dict(shared_reward=True)], - ["sisl/pursuit_v4", pursuit_v4, dict(n_evaders=5, n_pursuers=16)], - ["sisl/pursuit_v4", pursuit_v4, dict(obs_range=15)], - ["sisl/pursuit_v4", pursuit_v4, dict(n_catch=3)], - ["sisl/pursuit_v4", pursuit_v4, dict(freeze_evaders=True)], - ["sisl/waterworld_v4", waterworld_v4, dict(n_pursuers=3, n_evaders=6)], - ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], - ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], - ["sisl/waterworld_v4", waterworld_v4, dict(n_poisons=4)], - ["sisl/waterworld_v4", waterworld_v4, dict(n_sensors=4)], - ["sisl/waterworld_v4", waterworld_v4, dict(local_ratio=0.5)], - ["sisl/waterworld_v4", waterworld_v4, dict(speed_features=False)], + dict(graph_space=True, use_typemasks=True), + ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(vector_state=False), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(vector_state=False, pad_observation=False), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(max_cycles=100), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(use_typemasks=False), + # ], + # [ + # "butterfly/knights_archers_zombies_v10", + # knights_archers_zombies_v10, + # dict(max_zombies=2, max_arrows=60), + # ], + # ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=True)], + # ["butterfly/pistonball_v6", pistonball_v6, dict(n_pistons=30)], + # ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=False)], + # [ + # "butterfly/pistonball_v6", + # pistonball_v6, + # dict(random_drop=True, random_rotate=True), + # ], + # [ + # "butterfly/pistonball_v6", + # pistonball_v6, + # dict(random_drop=False, random_rotate=False), + # ], + # ["classic/go_v5", go_v5, dict(board_size=13, komi=2.5)], + # ["classic/go_v5", go_v5, dict(board_size=9, komi=0.0)], + # ["classic/hanabi_v4", hanabi_v4, dict(colors=3)], + # ["classic/hanabi_v4", hanabi_v4, dict(ranks=3)], + # ["classic/hanabi_v4", hanabi_v4, dict(players=4)], + # ["classic/hanabi_v4", hanabi_v4, dict(hand_size=5)], + # ["classic/hanabi_v4", hanabi_v4, dict(max_information_tokens=3)], + # ["classic/hanabi_v4", hanabi_v4, dict(max_life_tokens=2)], + # [ + # "classic/hanabi_v4", + # hanabi_v4, + # dict( + # colors=5, + # ranks=3, + # players=4, + # hand_size=5, + # max_information_tokens=3, + # max_life_tokens=2, + # ), + # ], + # ["classic/hanabi_v4", hanabi_v4, dict(observation_type=0)], + # ["classic/hanabi_v4", hanabi_v4, dict(observation_type=1)], + # ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=False)], + # ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=True)], + # ["mpe/simple_adversary_v2", simple_adversary_v2, dict(N=4)], + # ["mpe/simple_reference_v2", simple_reference_v2, dict(local_ratio=0.2)], + # ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5)], + # [ + # "mpe/simple_tag_v2", + # simple_tag_v2, + # dict(num_good=5, num_adversaries=10, num_obstacles=4), + # ], + # [ + # "mpe/simple_tag_v2", + # simple_tag_v2, + # dict(num_good=1, num_adversaries=1, num_obstacles=1), + # ], + # [ + # "mpe/simple_world_comm_v2", + # simple_world_comm_v2, + # dict(num_good=5, num_adversaries=10, num_obstacles=4, num_food=3), + # ], + # [ + # "mpe/simple_world_comm_v2", + # simple_world_comm_v2, + # dict(num_good=1, num_adversaries=1, num_obstacles=1, num_food=1), + # ], + # [ + # "mpe/simple_adversary_v2", + # simple_adversary_v2, + # dict(N=4, continuous_actions=True), + # ], + # [ + # "mpe/simple_reference_v2", + # simple_reference_v2, + # dict(local_ratio=0.2, continuous_actions=True), + # ], + # ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5, continuous_actions=True)], + # [ + # "mpe/simple_tag_v2", + # simple_tag_v2, + # dict(num_good=5, num_adversaries=10, num_obstacles=4, continuous_actions=True), + # ], + # [ + # "mpe/simple_tag_v2", + # simple_tag_v2, + # dict(num_good=1, num_adversaries=1, num_obstacles=1, continuous_actions=True), + # ], + # [ + # "mpe/simple_world_comm_v2", + # simple_world_comm_v2, + # dict( + # num_good=5, + # num_adversaries=10, + # num_obstacles=4, + # num_food=3, + # continuous_actions=True, + # ), + # ], + # [ + # "mpe/simple_world_comm_v2", + # simple_world_comm_v2, + # dict( + # num_good=1, + # num_adversaries=1, + # num_obstacles=1, + # num_food=1, + # continuous_actions=True, + # ), + # ], + # ["sisl/multiwalker_v9", multiwalker_v9, dict(n_walkers=10)], + # ["sisl/multiwalker_v9", multiwalker_v9, dict(shared_reward=False)], + # ["sisl/multiwalker_v9", multiwalker_v9, dict(terminate_on_fall=False)], + # [ + # "sisl/multiwalker_v8", + # multiwalker_v9, + # dict(terminate_on_fall=False, remove_on_fall=False), + # ], + # ["sisl/pursuit_v4", pursuit_v4, dict(x_size=8, y_size=19)], + # ["sisl/pursuit_v4", pursuit_v4, dict(shared_reward=True)], + # ["sisl/pursuit_v4", pursuit_v4, dict(n_evaders=5, n_pursuers=16)], + # ["sisl/pursuit_v4", pursuit_v4, dict(obs_range=15)], + # ["sisl/pursuit_v4", pursuit_v4, dict(n_catch=3)], + # ["sisl/pursuit_v4", pursuit_v4, dict(freeze_evaders=True)], + # ["sisl/waterworld_v4", waterworld_v4, dict(n_pursuers=3, n_evaders=6)], + # ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], + # ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], + # ["sisl/waterworld_v4", waterworld_v4, dict(n_poisons=4)], + # ["sisl/waterworld_v4", waterworld_v4, dict(n_sensors=4)], + # ["sisl/waterworld_v4", waterworld_v4, dict(local_ratio=0.5)], + # ["sisl/waterworld_v4", waterworld_v4, dict(speed_features=False)], ] From 717e4fc6f34a79d75da20e12e37f4e25c5a290cc Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 22 Oct 2022 13:18:55 +0100 Subject: [PATCH 16/24] reenable tests --- test/all_parameter_combs_test.py | 412 +++++++++++++++---------------- 1 file changed, 206 insertions(+), 206 deletions(-) diff --git a/test/all_parameter_combs_test.py b/test/all_parameter_combs_test.py index 80447badf..d8ee83a6c 100644 --- a/test/all_parameter_combs_test.py +++ b/test/all_parameter_combs_test.py @@ -8,217 +8,217 @@ from .all_modules import * # noqa: F403 parameterized_envs = [ - # ["atari/boxing_v2", boxing_v2, dict(obs_type="grayscale_image")], - # ["atari/boxing_v2", boxing_v2, dict(obs_type="ram")], - # ["atari/boxing_v2", boxing_v2, dict(full_action_space=False)], - # ["atari/combat_plane_v2", combat_plane_v2, dict(game_version="jet")], - # ["atari/combat_plane_v2", combat_plane_v2, dict(guided_missile=True)], - # ["atari/combat_tank_v2", combat_tank_v2, dict(has_maze=True)], - # ["atari/combat_tank_v2", combat_tank_v2, dict(is_invisible=True)], - # ["atari/combat_tank_v2", combat_tank_v2, dict(billiard_hit=True)], - # ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="race")], - # ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="capture")], - # ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=1)], - # ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=3)], - # [ - # "atari/space_invaders_v2", - # space_invaders_v2, - # dict( - # alternating_control=True, - # moving_shields=True, - # zigzaging_bombs=True, - # fast_bomb=True, - # invisible_invaders=True, - # ), - # ], - # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=2)], - # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=3)], - # ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=4)], - # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=2)], - # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=3)], - # ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=4)], - # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=2)], - # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=3)], - # ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=4)], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(spawn_rate=50), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(num_knights=4, num_archers=5), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(killable_knights=True, killable_archers=True), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(killable_knights=False, killable_archers=False), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(line_death=False), - # ], + ["atari/boxing_v2", boxing_v2, dict(obs_type="grayscale_image")], + ["atari/boxing_v2", boxing_v2, dict(obs_type="ram")], + ["atari/boxing_v2", boxing_v2, dict(full_action_space=False)], + ["atari/combat_plane_v2", combat_plane_v2, dict(game_version="jet")], + ["atari/combat_plane_v2", combat_plane_v2, dict(guided_missile=True)], + ["atari/combat_tank_v2", combat_tank_v2, dict(has_maze=True)], + ["atari/combat_tank_v2", combat_tank_v2, dict(is_invisible=True)], + ["atari/combat_tank_v2", combat_tank_v2, dict(billiard_hit=True)], + ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="race")], + ["atari/maze_craze_v3", maze_craze_v3, dict(game_version="capture")], + ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=1)], + ["atari/maze_craze_v3", maze_craze_v3, dict(visibilty_level=3)], + [ + "atari/space_invaders_v2", + space_invaders_v2, + dict( + alternating_control=True, + moving_shields=True, + zigzaging_bombs=True, + fast_bomb=True, + invisible_invaders=True, + ), + ], + ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=2)], + ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=3)], + ["classic/leduc_holdem_v4", leduc_holdem_v4, dict(num_players=4)], + ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=2)], + ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=3)], + ["classic/texas_holdem_v4", texas_holdem_v4, dict(num_players=4)], + ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=2)], + ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=3)], + ["classic/texas_holdem_no_limit_v6", texas_holdem_no_limit_v6, dict(num_players=4)], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(spawn_rate=50), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(num_knights=4, num_archers=5), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(killable_knights=True, killable_archers=True), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(killable_knights=False, killable_archers=False), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(line_death=False), + ], [ "butterfly/knights_archers_zombies_v10", knights_archers_zombies_v10, dict(graph_space=True, use_typemasks=True), ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(vector_state=False), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(vector_state=False, pad_observation=False), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(max_cycles=100), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(use_typemasks=False), - # ], - # [ - # "butterfly/knights_archers_zombies_v10", - # knights_archers_zombies_v10, - # dict(max_zombies=2, max_arrows=60), - # ], - # ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=True)], - # ["butterfly/pistonball_v6", pistonball_v6, dict(n_pistons=30)], - # ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=False)], - # [ - # "butterfly/pistonball_v6", - # pistonball_v6, - # dict(random_drop=True, random_rotate=True), - # ], - # [ - # "butterfly/pistonball_v6", - # pistonball_v6, - # dict(random_drop=False, random_rotate=False), - # ], - # ["classic/go_v5", go_v5, dict(board_size=13, komi=2.5)], - # ["classic/go_v5", go_v5, dict(board_size=9, komi=0.0)], - # ["classic/hanabi_v4", hanabi_v4, dict(colors=3)], - # ["classic/hanabi_v4", hanabi_v4, dict(ranks=3)], - # ["classic/hanabi_v4", hanabi_v4, dict(players=4)], - # ["classic/hanabi_v4", hanabi_v4, dict(hand_size=5)], - # ["classic/hanabi_v4", hanabi_v4, dict(max_information_tokens=3)], - # ["classic/hanabi_v4", hanabi_v4, dict(max_life_tokens=2)], - # [ - # "classic/hanabi_v4", - # hanabi_v4, - # dict( - # colors=5, - # ranks=3, - # players=4, - # hand_size=5, - # max_information_tokens=3, - # max_life_tokens=2, - # ), - # ], - # ["classic/hanabi_v4", hanabi_v4, dict(observation_type=0)], - # ["classic/hanabi_v4", hanabi_v4, dict(observation_type=1)], - # ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=False)], - # ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=True)], - # ["mpe/simple_adversary_v2", simple_adversary_v2, dict(N=4)], - # ["mpe/simple_reference_v2", simple_reference_v2, dict(local_ratio=0.2)], - # ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5)], - # [ - # "mpe/simple_tag_v2", - # simple_tag_v2, - # dict(num_good=5, num_adversaries=10, num_obstacles=4), - # ], - # [ - # "mpe/simple_tag_v2", - # simple_tag_v2, - # dict(num_good=1, num_adversaries=1, num_obstacles=1), - # ], - # [ - # "mpe/simple_world_comm_v2", - # simple_world_comm_v2, - # dict(num_good=5, num_adversaries=10, num_obstacles=4, num_food=3), - # ], - # [ - # "mpe/simple_world_comm_v2", - # simple_world_comm_v2, - # dict(num_good=1, num_adversaries=1, num_obstacles=1, num_food=1), - # ], - # [ - # "mpe/simple_adversary_v2", - # simple_adversary_v2, - # dict(N=4, continuous_actions=True), - # ], - # [ - # "mpe/simple_reference_v2", - # simple_reference_v2, - # dict(local_ratio=0.2, continuous_actions=True), - # ], - # ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5, continuous_actions=True)], - # [ - # "mpe/simple_tag_v2", - # simple_tag_v2, - # dict(num_good=5, num_adversaries=10, num_obstacles=4, continuous_actions=True), - # ], - # [ - # "mpe/simple_tag_v2", - # simple_tag_v2, - # dict(num_good=1, num_adversaries=1, num_obstacles=1, continuous_actions=True), - # ], - # [ - # "mpe/simple_world_comm_v2", - # simple_world_comm_v2, - # dict( - # num_good=5, - # num_adversaries=10, - # num_obstacles=4, - # num_food=3, - # continuous_actions=True, - # ), - # ], - # [ - # "mpe/simple_world_comm_v2", - # simple_world_comm_v2, - # dict( - # num_good=1, - # num_adversaries=1, - # num_obstacles=1, - # num_food=1, - # continuous_actions=True, - # ), - # ], - # ["sisl/multiwalker_v9", multiwalker_v9, dict(n_walkers=10)], - # ["sisl/multiwalker_v9", multiwalker_v9, dict(shared_reward=False)], - # ["sisl/multiwalker_v9", multiwalker_v9, dict(terminate_on_fall=False)], - # [ - # "sisl/multiwalker_v8", - # multiwalker_v9, - # dict(terminate_on_fall=False, remove_on_fall=False), - # ], - # ["sisl/pursuit_v4", pursuit_v4, dict(x_size=8, y_size=19)], - # ["sisl/pursuit_v4", pursuit_v4, dict(shared_reward=True)], - # ["sisl/pursuit_v4", pursuit_v4, dict(n_evaders=5, n_pursuers=16)], - # ["sisl/pursuit_v4", pursuit_v4, dict(obs_range=15)], - # ["sisl/pursuit_v4", pursuit_v4, dict(n_catch=3)], - # ["sisl/pursuit_v4", pursuit_v4, dict(freeze_evaders=True)], - # ["sisl/waterworld_v4", waterworld_v4, dict(n_pursuers=3, n_evaders=6)], - # ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], - # ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], - # ["sisl/waterworld_v4", waterworld_v4, dict(n_poisons=4)], - # ["sisl/waterworld_v4", waterworld_v4, dict(n_sensors=4)], - # ["sisl/waterworld_v4", waterworld_v4, dict(local_ratio=0.5)], - # ["sisl/waterworld_v4", waterworld_v4, dict(speed_features=False)], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(vector_state=False), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(vector_state=False, pad_observation=False), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(max_cycles=100), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(use_typemasks=False), + ], + [ + "butterfly/knights_archers_zombies_v10", + knights_archers_zombies_v10, + dict(max_zombies=2, max_arrows=60), + ], + ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=True)], + ["butterfly/pistonball_v6", pistonball_v6, dict(n_pistons=30)], + ["butterfly/pistonball_v6", pistonball_v6, dict(continuous=False)], + [ + "butterfly/pistonball_v6", + pistonball_v6, + dict(random_drop=True, random_rotate=True), + ], + [ + "butterfly/pistonball_v6", + pistonball_v6, + dict(random_drop=False, random_rotate=False), + ], + ["classic/go_v5", go_v5, dict(board_size=13, komi=2.5)], + ["classic/go_v5", go_v5, dict(board_size=9, komi=0.0)], + ["classic/hanabi_v4", hanabi_v4, dict(colors=3)], + ["classic/hanabi_v4", hanabi_v4, dict(ranks=3)], + ["classic/hanabi_v4", hanabi_v4, dict(players=4)], + ["classic/hanabi_v4", hanabi_v4, dict(hand_size=5)], + ["classic/hanabi_v4", hanabi_v4, dict(max_information_tokens=3)], + ["classic/hanabi_v4", hanabi_v4, dict(max_life_tokens=2)], + [ + "classic/hanabi_v4", + hanabi_v4, + dict( + colors=5, + ranks=3, + players=4, + hand_size=5, + max_information_tokens=3, + max_life_tokens=2, + ), + ], + ["classic/hanabi_v4", hanabi_v4, dict(observation_type=0)], + ["classic/hanabi_v4", hanabi_v4, dict(observation_type=1)], + ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=False)], + ["classic/hanabi_v4", hanabi_v4, dict(random_start_player=True)], + ["mpe/simple_adversary_v2", simple_adversary_v2, dict(N=4)], + ["mpe/simple_reference_v2", simple_reference_v2, dict(local_ratio=0.2)], + ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5)], + [ + "mpe/simple_tag_v2", + simple_tag_v2, + dict(num_good=5, num_adversaries=10, num_obstacles=4), + ], + [ + "mpe/simple_tag_v2", + simple_tag_v2, + dict(num_good=1, num_adversaries=1, num_obstacles=1), + ], + [ + "mpe/simple_world_comm_v2", + simple_world_comm_v2, + dict(num_good=5, num_adversaries=10, num_obstacles=4, num_food=3), + ], + [ + "mpe/simple_world_comm_v2", + simple_world_comm_v2, + dict(num_good=1, num_adversaries=1, num_obstacles=1, num_food=1), + ], + [ + "mpe/simple_adversary_v2", + simple_adversary_v2, + dict(N=4, continuous_actions=True), + ], + [ + "mpe/simple_reference_v2", + simple_reference_v2, + dict(local_ratio=0.2, continuous_actions=True), + ], + ["mpe/simple_spread_v2", simple_spread_v2, dict(N=5, continuous_actions=True)], + [ + "mpe/simple_tag_v2", + simple_tag_v2, + dict(num_good=5, num_adversaries=10, num_obstacles=4, continuous_actions=True), + ], + [ + "mpe/simple_tag_v2", + simple_tag_v2, + dict(num_good=1, num_adversaries=1, num_obstacles=1, continuous_actions=True), + ], + [ + "mpe/simple_world_comm_v2", + simple_world_comm_v2, + dict( + num_good=5, + num_adversaries=10, + num_obstacles=4, + num_food=3, + continuous_actions=True, + ), + ], + [ + "mpe/simple_world_comm_v2", + simple_world_comm_v2, + dict( + num_good=1, + num_adversaries=1, + num_obstacles=1, + num_food=1, + continuous_actions=True, + ), + ], + ["sisl/multiwalker_v9", multiwalker_v9, dict(n_walkers=10)], + ["sisl/multiwalker_v9", multiwalker_v9, dict(shared_reward=False)], + ["sisl/multiwalker_v9", multiwalker_v9, dict(terminate_on_fall=False)], + [ + "sisl/multiwalker_v8", + multiwalker_v9, + dict(terminate_on_fall=False, remove_on_fall=False), + ], + ["sisl/pursuit_v4", pursuit_v4, dict(x_size=8, y_size=19)], + ["sisl/pursuit_v4", pursuit_v4, dict(shared_reward=True)], + ["sisl/pursuit_v4", pursuit_v4, dict(n_evaders=5, n_pursuers=16)], + ["sisl/pursuit_v4", pursuit_v4, dict(obs_range=15)], + ["sisl/pursuit_v4", pursuit_v4, dict(n_catch=3)], + ["sisl/pursuit_v4", pursuit_v4, dict(freeze_evaders=True)], + ["sisl/waterworld_v4", waterworld_v4, dict(n_pursuers=3, n_evaders=6)], + ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], + ["sisl/waterworld_v4", waterworld_v4, dict(n_coop=1)], + ["sisl/waterworld_v4", waterworld_v4, dict(n_poisons=4)], + ["sisl/waterworld_v4", waterworld_v4, dict(n_sensors=4)], + ["sisl/waterworld_v4", waterworld_v4, dict(local_ratio=0.5)], + ["sisl/waterworld_v4", waterworld_v4, dict(speed_features=False)], ] From 0e7c52a545cfb537b515221dd7d37f55701406e6 Mon Sep 17 00:00:00 2001 From: Jet Date: Sat, 22 Oct 2022 13:20:43 +0100 Subject: [PATCH 17/24] remove rogue prints --- pettingzoo/test/state_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pettingzoo/test/state_test.py b/pettingzoo/test/state_test.py index d049e614c..62cdd95dd 100644 --- a/pettingzoo/test/state_test.py +++ b/pettingzoo/test/state_test.py @@ -55,8 +55,6 @@ def test_state(env, num_cycles): env.step(action) new_state = env.state() - print(new_state) - print(env.state_space) assert env.state_space.contains( new_state ), "Environment's state is outside of it's state space" From 111d50989df9f7c3c0040386c1a8292965a19ce1 Mon Sep 17 00:00:00 2001 From: Jet Date: Mon, 24 Oct 2022 00:45:03 +0100 Subject: [PATCH 18/24] black --- .../knights_archers_zombies/knights_archers_zombies.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index c3efae1a7..ba12d49d7 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -261,13 +261,9 @@ def __init__( # variable state space self.graph_space = graph_space if self.graph_space: - assert ( - vector_state - ), "vector_state must be True if graph_space is True." + assert vector_state, "vector_state must be True if graph_space is True." - assert ( - use_typemasks - ), "use_typemasks should be True if graph_space is True" + assert use_typemasks, "use_typemasks should be True if graph_space is True" # whether we want RGB state or vector state self.vector_state = vector_state From af6545849536ae95ca3e6f505147d40b9f89f28b Mon Sep 17 00:00:00 2001 From: snow-fox Date: Mon, 13 Feb 2023 16:30:48 +0000 Subject: [PATCH 19/24] change graph to sequence --- .../knights_archers_zombies.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index 36d7f9786..553182005 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -95,9 +95,9 @@ As a result, setting `use_typemask=True` results in the observation being a (N+1)x11 vector. -**Graph Space** (Experimental) +**Sequence Space** (Experimental) -There is an option to also pass `graph_space=True` as a kwarg to the environment. This just removes all non-existent entities from the observation and state vectors. Note that this is **still experimental** as the state and observation size are no longer constant. In particular, `N` is now a +There is an option to also pass `sequence_space=True` as a kwarg to the environment. This just removes all non-existent entities from the observation and state vectors. Note that this is **still experimental** as the state and observation size are no longer constant. In particular, `N` is now a variable number. #### Image-based @@ -131,7 +131,7 @@ max_cycles=900, vector_state=True, use_typemasks=False, - graph_space=False, + sequence_space=False, ``` `spawn_rate`: how many cycles before a new zombie is spawned. A lower number means zombies are spawned at a higher rate. @@ -156,7 +156,7 @@ `use_typemasks`: only relevant when `vector_state=True` is set, adds typemasks to the vectors. -`graph_space`: **experimental**, only relevant when `vector_state=True` is set, removes non-existent entities in the vector state. +`sequence_space`: **experimental**, only relevant when `vector_state=True` is set, removes non-existent entities in the vector state. ### Version History @@ -183,7 +183,7 @@ import numpy as np import pygame import pygame.gfxdraw -from gymnasium.spaces import Box, Discrete, Graph, GraphInstance +from gymnasium.spaces import Box, Discrete, Sequence from gymnasium.utils import EzPickle, seeding from pettingzoo import AECEnv @@ -234,7 +234,7 @@ def __init__( max_cycles=900, vector_state=True, use_typemasks=False, - graph_space=False, + sequence_space=False, render_mode=None, ): EzPickle.__init__( @@ -251,15 +251,15 @@ def __init__( max_cycles, vector_state, use_typemasks, - graph_space, + sequence_space, render_mode, ) # variable state space - self.graph_space = graph_space - if self.graph_space: - assert vector_state, "vector_state must be True if graph_space is True." + self.sequence_space = sequence_space + if self.sequence_space: + assert vector_state, "vector_state must be True if sequence_space is True." - assert use_typemasks, "use_typemasks should be True if graph_space is True" + assert use_typemasks, "use_typemasks should be True if sequence_space is True" # whether we want RGB state or vector state self.vector_state = vector_state @@ -267,7 +267,7 @@ def __init__( self.num_tracked = ( num_archers + num_knights + max_zombies + num_knights + max_arrows ) - self.use_typemasks = True if graph_space else use_typemasks + self.use_typemasks = True if sequence_space else use_typemasks self.typemask_width = 6 self.vector_width = 4 + self.typemask_width if use_typemasks else 4 @@ -318,7 +318,7 @@ def __init__( low = 0 if not self.vector_state else -1.0 high = 255 if not self.vector_state else 1.0 dtype = np.uint8 if not self.vector_state else np.float64 - if not self.graph_space: + if not self.sequence_space: obs_space = Box(low=low, high=high, shape=shape, dtype=dtype) self.observation_spaces = dict( zip( @@ -328,7 +328,7 @@ def __init__( ) else: box_space = Box(low=low, high=high, shape=[shape[-1]], dtype=dtype) - obs_space = Graph(node_space=box_space, edge_space=None) + obs_space = Sequence(space=box_space) self.observation_spaces = dict( zip( self.agents, @@ -578,10 +578,9 @@ def observe(self, agent): # prepend agent state to the observation state = np.concatenate([agent_state, state], axis=0) - if self.graph_space: - # remove pure zero rows if using graph space + if self.sequence_space: + # remove pure zero rows if using sequence space state = state[~np.all(state == 0, axis=-1)] - state = GraphInstance(nodes=state, edges=None, edge_links=None) return state From 15fe5e698b2d0d515c6c562fef6e3f519e033135 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 16:38:36 +0000 Subject: [PATCH 20/24] add farama notifications --- pettingzoo/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pettingzoo/__init__.py b/pettingzoo/__init__.py index e2a56276c..45f5bfbc8 100644 --- a/pettingzoo/__init__.py +++ b/pettingzoo/__init__.py @@ -13,3 +13,12 @@ os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" __version__ = "1.22.3" + +try: + import sys + from farama_notifications import notifications + + if "pettingzoo" in notifications and __version__ in notifications["pettingzoo"]: + print(notifications["pettingzoo"][__version__], file=sys.stderr) +except Exception: # nosec + pass From 53dcbba73c22a2eaaf99a0aebce6a62bbb23be5e Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 17:08:47 +0000 Subject: [PATCH 21/24] fix graph space bug --- test/all_parameter_combs_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/all_parameter_combs_test.py b/test/all_parameter_combs_test.py index d8ee83a6c..0bdc70ec7 100644 --- a/test/all_parameter_combs_test.py +++ b/test/all_parameter_combs_test.py @@ -68,7 +68,7 @@ [ "butterfly/knights_archers_zombies_v10", knights_archers_zombies_v10, - dict(graph_space=True, use_typemasks=True), + dict(sequence_space=True, use_typemasks=True), ], [ "butterfly/knights_archers_zombies_v10", From 89fca91183806ec96a0a5cc2c4696bcfc62a32fe Mon Sep 17 00:00:00 2001 From: jjshoots Date: Tue, 28 Feb 2023 14:12:22 +0000 Subject: [PATCH 22/24] precommit --- .../knights_archers_zombies/knights_archers_zombies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index f70465b93..ec5c0cca9 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -258,7 +258,9 @@ def __init__( if self.sequence_space: assert vector_state, "vector_state must be True if sequence_space is True." - assert use_typemasks, "use_typemasks should be True if sequence_space is True" + assert ( + use_typemasks + ), "use_typemasks should be True if sequence_space is True" # whether we want RGB state or vector state self.vector_state = vector_state From d2e48e81c9d86528e263ac59203df4a7edc7f040 Mon Sep 17 00:00:00 2001 From: snow-fox Date: Tue, 28 Feb 2023 17:16:22 +0000 Subject: [PATCH 23/24] precommit --- pettingzoo/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pettingzoo/__init__.py b/pettingzoo/__init__.py index 45f5bfbc8..14f67d04c 100644 --- a/pettingzoo/__init__.py +++ b/pettingzoo/__init__.py @@ -16,6 +16,7 @@ try: import sys + from farama_notifications import notifications if "pettingzoo" in notifications and __version__ in notifications["pettingzoo"]: From 10c2c7d9f59b7508cb05b2ba83c2b9a66b2576a5 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Fri, 10 Mar 2023 15:21:48 +0000 Subject: [PATCH 24/24] fix precommit --- pettingzoo/mpe/_mpe_utils/simple_env.py | 1 - pettingzoo/utils/conversions.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pettingzoo/mpe/_mpe_utils/simple_env.py b/pettingzoo/mpe/_mpe_utils/simple_env.py index 818a75953..d76e8842d 100644 --- a/pettingzoo/mpe/_mpe_utils/simple_env.py +++ b/pettingzoo/mpe/_mpe_utils/simple_env.py @@ -283,7 +283,6 @@ def render(self): elif self.render_mode == "human": pygame.display.flip() return - def draw(self): # clear screen diff --git a/pettingzoo/utils/conversions.py b/pettingzoo/utils/conversions.py index 765a7ad94..0f4a9e351 100644 --- a/pettingzoo/utils/conversions.py +++ b/pettingzoo/utils/conversions.py @@ -38,7 +38,7 @@ class my_par_class(pettingzoo.utils.env.ParallelEnv): def aec_fn(**kwargs): par_env = par_env_fn(**kwargs) - aec_env = pettingzoo.utils.parallel_to_aec(par_env) + aec_env = parallel_to_aec(par_env) return aec_env return aec_fn