From f3fc5c2ea91ad422b272c3b5a5f4dde66f6e4759 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 13 Nov 2023 16:12:12 -0500 Subject: [PATCH] Add self.env_name and fix logic with debate vs symmetric envs --- .../umshini/pettingzoo_wrapper.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/chatarena/environments/umshini/pettingzoo_wrapper.py b/chatarena/environments/umshini/pettingzoo_wrapper.py index a6600b5d..00037276 100644 --- a/chatarena/environments/umshini/pettingzoo_wrapper.py +++ b/chatarena/environments/umshini/pettingzoo_wrapper.py @@ -91,15 +91,19 @@ def __init__( elif env is not None: self._env = env if hasattr(env, "topic"): + self.env_name = "debate" self.topic = topic self.max_turns = round_length elif hasattr(env, "moderation_policy"): + self.env_name = "content_moderation" self.moderation_policy = env.moderation_policy self.max_turns = round_length * 2 elif hasattr(env, "restricted_action"): + self.env_name = "deception" self.restricted_action = env.restricted_action self.max_turns = round_length * 2 elif env_name is not None: + self.env_name = env_name if env_name == "debate": assert topic is not None, "topic must be specified for debate env" self._env = create_debate_env( @@ -233,6 +237,12 @@ def render(self): if message.agent_name == "Moderator": color = Fore.BLACK role = "" + elif message.agent_name == "Proponent": + color = Fore.BLUE + role = "" + elif message.agent_name == "Opponent": + color = Fore.RED + role = "" else: if self.infos[message.agent_name]["role"] == "attacker": color = Fore.RED @@ -292,14 +302,15 @@ def observe(self, agent: AgentID) -> ObsType: } self.infos[agent]["player_name"] = self.agent_selection - # Role in symmetric environments - if hasattr(self._env, "_current_phase"): - if self._env._current_phase == "player_2_attack" or self._env._current_phase == "end": - self.infos[self.possible_agents[0]]["role"] = "defender" - self.infos[self.possible_agents[1]]["role"] = "attacker" - else: - self.infos[self.possible_agents[0]]["role"] = "attacker" - self.infos[self.possible_agents[1]]["role"] = "defender" + # Role in symmetric environments (not applicable if env has terminated) + if self.env_name != "debate": + if hasattr(self._env, "_current_phase") and not any(self.terminations.values()): + if self._env._current_phase == "player_2_attack": + self.infos[self.possible_agents[0]]["role"] = "defender" + self.infos[self.possible_agents[1]]["role"] = "attacker" + else: + self.infos[self.possible_agents[0]]["role"] = "attacker" + self.infos[self.possible_agents[1]]["role"] = "defender" # info: generate string of full chat log if self.string_observation is True: