diff --git a/chatarena/environments/umshini/pettingzoo_wrapper.py b/chatarena/environments/umshini/pettingzoo_wrapper.py index ac008df6..f8f033f2 100644 --- a/chatarena/environments/umshini/pettingzoo_wrapper.py +++ b/chatarena/environments/umshini/pettingzoo_wrapper.py @@ -7,6 +7,7 @@ from typing import List +from colorama import Fore from chatarena.environments import Environment from chatarena.environments.base import TimeStep from chatarena.message import Message @@ -227,9 +228,21 @@ def render(self): raise Exception("New messages not found") else: for message in new_messages: - print( - f"[{message.agent_name}->{message.visible_to}]: {message.content}\n" - ) + # Don't repeat things from previous turns + if not self.current_turn > message.turn: + if message.agent_name == "Moderator": + color = Fore.BLACK + role = "" + else: + if self.infos[message.agent_name]["role"] == "attacker": + color = Fore.RED + role = "(attacker)" + else: + color = Fore.BLUE + role = "(defender)" + print( + color + f"[{message.agent_name} {role} -> {message.visible_to}]: {message.content}\n " + ) def observe(self, agent: AgentID) -> ObsType: """observe. @@ -380,6 +393,8 @@ def _unravel_timestep(self, timestep: TimeStep): self.infos[self.possible_agents[0]]["role"] = "attacker" self.infos[self.possible_agents[1]]["role"] = "defender" + info["role"] = self.infos[self.agent_selection]["role"] + # info: environment specific information if hasattr(self, "restricted_action"): info["restricted_action"] = self.restricted_action diff --git a/chatarena/environments/umshini/symmetric_content_moderation.py b/chatarena/environments/umshini/symmetric_content_moderation.py index 5aa7bdbf..0cb2d343 100644 --- a/chatarena/environments/umshini/symmetric_content_moderation.py +++ b/chatarena/environments/umshini/symmetric_content_moderation.py @@ -121,23 +121,22 @@ def step(self, player_name: str, action: str) -> TimeStep: # it is player 2's turn to respond to player 1's query # they should be judged on whether they allow the restricted action # return rewards this step - rews = self.get_rewards() - timestep = TimeStep( - observation=self.get_observation(player_name=player_name), - reward=rews, - terminal=False, - ) if self._current_turn == self._round_length: # swap roles self._moderator_speak(f"Roles are being swapped. {self.player_names[0]} is defending, and {self.player_names[1]} is attacking.") self.player_names.reverse() self.agent_selector.reinit(self.player_names) + self._current_phase = "player_2_attack" + timestep = TimeStep( + observation=self.get_observation(player_name=player_name), + reward=self.get_rewards(), + terminal=False, + ) return timestep elif self._current_turn <= 2 * self._round_length: - self._current_phase = "player_2_attack" if self._current_turn % 2 == 1: - # it is player 1's turn to go first in attacking + # it is player 2's turn to go first in attacking # add their message to the pool, return no reward return TimeStep( observation=self.get_observation(player_name=player_name), @@ -151,10 +150,9 @@ def step(self, player_name: str, action: str) -> TimeStep: is_now_terminal = self._current_turn == 2 * self._round_length # get the rewards before getting the observation, so that the moderator's final message is displayed (winner) - rews = self.get_rewards() return TimeStep( observation=self.get_observation(player_name=player_name), - reward=rews, + reward=self.get_rewards(), terminal=is_now_terminal, ) else: diff --git a/chatarena/environments/umshini/symmetric_deception.py b/chatarena/environments/umshini/symmetric_deception.py index 172f2bcc..7d4e48ae 100644 --- a/chatarena/environments/umshini/symmetric_deception.py +++ b/chatarena/environments/umshini/symmetric_deception.py @@ -124,6 +124,12 @@ def step(self, player_name: str, action: str) -> TimeStep: self._moderator_speak(f"Roles are being swapped. {self.player_names[0]} is defending, and {self.player_names[1]} is attacking.") self.player_names.reverse() self.agent_selector.reinit(self.player_names) + self._current_phase = "player_2_attack" + timestep = TimeStep( + observation=self.get_observation(player_name=player_name), + reward=self.get_rewards(), + terminal=False, + ) return timestep elif self._current_turn <= 2 * self._round_length + 1: self._current_phase = "player_2_attack" diff --git a/pyproject.toml b/pyproject.toml index 6c1b1a29..2f03d11d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ bard = ["bardapi==0.1.11"] langchain = ["langchain>=0.0.135"] gradio = ["gradio>=3.34.0"] pettingzoo = ["pettingzoo[classic]>=1.23.1"] -umshini = ["pettingzoo>=1.23.1", "langchain>=0.0.135"] +umshini = ["pettingzoo>=1.23.1", "langchain>=0.0.135", "colorama>=0.4.6"] all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"] all_envs = ["pettingzoo[classic]>=1.23.1", "langchain>=0.0.135"] all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio>=3.34.0", "pettingzoo>=1.23.1", diff --git a/setup.py b/setup.py index a68d087b..2ed1ef06 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def remove_duplicate_requirements(requirements): langchain_requirements = ["langchain>=0.0.135"] gradio_requirements = ["gradio>=3.34.0"] pettingzoo_requirements = ["pettingzoo[classic]>=1.23.1", "chess==1.9.4"] -umshini_requirements = ["pettingzoo>=1.23.1", "langchain>=0.0.135"] +umshini_requirements = ["pettingzoo>=1.23.1", "langchain>=0.0.135", "colorama>=0.4.6"] all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements + \ langchain_requirements