From c90f94780253c70ec89ce1f7695dabf7017b7b14 Mon Sep 17 00:00:00 2001 From: Umut Ucak Date: Thu, 5 Oct 2023 17:09:10 +0200 Subject: [PATCH] Multiwalker dead variable removal + macOS pygame fix (#1107) --- pettingzoo/classic/connect_four/connect_four.py | 1 + pettingzoo/sisl/multiwalker/multiwalker.py | 6 ++++-- pettingzoo/sisl/multiwalker/multiwalker_base.py | 9 +++------ pettingzoo/sisl/pursuit/pursuit_base.py | 3 ++- pettingzoo/sisl/waterworld/waterworld_base.py | 3 ++- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pettingzoo/classic/connect_four/connect_four.py b/pettingzoo/classic/connect_four/connect_four.py index 5aa637930..e2a2390e9 100644 --- a/pettingzoo/classic/connect_four/connect_four.py +++ b/pettingzoo/classic/connect_four/connect_four.py @@ -283,6 +283,7 @@ def render(self): ) if self.render_mode == "human": + pygame.event.pump() pygame.display.update() self.clock.tick(self.metadata["render_fps"]) diff --git a/pettingzoo/sisl/multiwalker/multiwalker.py b/pettingzoo/sisl/multiwalker/multiwalker.py index b398268e9..839b6152e 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker.py +++ b/pettingzoo/sisl/multiwalker/multiwalker.py @@ -151,9 +151,11 @@ def __init__(self, *args, **kwargs): EzPickle.__init__(self, *args, **kwargs) self.env = _env(*args, **kwargs) self.render_mode = self.env.render_mode - self.agents = ["walker_" + str(r) for r in range(self.env.num_agents)] + self.agents = ["walker_" + str(r) for r in range(self.env.n_walkers)] self.possible_agents = self.agents[:] - self.agent_name_mapping = dict(zip(self.agents, list(range(self.num_agents)))) + self.agent_name_mapping = dict( + zip(self.agents, list(range(self.env.n_walkers))) + ) self._agent_selector = agent_selector(self.agents) # spaces self.action_spaces = dict(zip(self.agents, self.env.action_space)) diff --git a/pettingzoo/sisl/multiwalker/multiwalker_base.py b/pettingzoo/sisl/multiwalker/multiwalker_base.py index 3ff961237..3da9a1260 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker_base.py +++ b/pettingzoo/sisl/multiwalker/multiwalker_base.py @@ -333,7 +333,6 @@ def __init__( self._seed() self.setup() self.screen = None - self.agent_list = list(range(self.n_walkers)) self.last_rewards = [0 for _ in range(self.n_walkers)] self.last_dones = [False for _ in range(self.n_walkers)] self.last_obs = [None for _ in range(self.n_walkers)] @@ -359,15 +358,12 @@ def setup(self): BipedalWalker(self.world, init_x=sx, init_y=init_y, seed=self.seed_val) for sx in self.start_x ] - self.num_agents = len(self.walkers) self.observation_space = [agent.observation_space for agent in self.walkers] self.action_space = [agent.action_space for agent in self.walkers] self.package_scale = self.n_walkers / 1.75 self.package_length = PACKAGE_LENGTH / SCALE * self.package_scale - self.total_agents = self.n_walkers - self.prev_shaping = np.zeros(self.n_walkers) self.prev_package_shaping = 0.0 @@ -534,7 +530,7 @@ def get_last_rewards(self): ) def get_last_dones(self): - return dict(zip(self.agent_list, self.last_dones)) + return dict(zip(list(range(self.n_walkers)), self.last_dones)) def get_last_obs(self): return dict( @@ -692,7 +688,8 @@ def render(self, close=False): self.surf = pygame.transform.flip(self.surf, False, True) self.screen.blit(self.surf, (-self.scroll * render_scale - offset, 0)) if self.render_mode == "human": - pygame.display.flip() + pygame.event.pump() + pygame.display.update() elif self.render_mode == "rgb_array": return np.transpose( np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) diff --git a/pettingzoo/sisl/pursuit/pursuit_base.py b/pettingzoo/sisl/pursuit/pursuit_base.py index 7b31d1fa1..8ba3be943 100644 --- a/pettingzoo/sisl/pursuit/pursuit_base.py +++ b/pettingzoo/sisl/pursuit/pursuit_base.py @@ -417,7 +417,8 @@ def render(self): new_observation = np.copy(observation) del observation if self.render_mode == "human": - pygame.display.flip() + pygame.event.pump() + pygame.display.update() return ( np.transpose(new_observation, axes=(1, 0, 2)) if self.render_mode == "rgb_array" diff --git a/pettingzoo/sisl/waterworld/waterworld_base.py b/pettingzoo/sisl/waterworld/waterworld_base.py index d2a53a1e0..00cfbacb1 100644 --- a/pettingzoo/sisl/waterworld/waterworld_base.py +++ b/pettingzoo/sisl/waterworld/waterworld_base.py @@ -741,7 +741,8 @@ def render(self): del observation if self.render_mode == "human": - pygame.display.flip() + pygame.event.pump() + pygame.display.update() return ( np.transpose(new_observation, axes=(1, 0, 2)) if self.render_mode == "rgb_array"