Skip to content

Commit

Permalink
Merge pull request #70 from elliottower/umshini-printing-update
Browse files Browse the repository at this point in the history
Umshini rendering update (attack/defend color text)
  • Loading branch information
elliottower authored Oct 27, 2023
2 parents 000ae96 + d51e2fe commit 88c9ead
Showing 5 changed files with 34 additions and 15 deletions.
21 changes: 18 additions & 3 deletions chatarena/environments/umshini/pettingzoo_wrapper.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

from typing import List

from colorama import Fore
from chatarena.environments import Environment
from chatarena.environments.base import TimeStep
from chatarena.message import Message
@@ -227,9 +228,21 @@ def render(self):
raise Exception("New messages not found")
else:
for message in new_messages:
print(
f"[{message.agent_name}->{message.visible_to}]: {message.content}\n"
)
# Don't repeat things from previous turns
if not self.current_turn > message.turn:
if message.agent_name == "Moderator":
color = Fore.BLACK
role = ""
else:
if self.infos[message.agent_name]["role"] == "attacker":
color = Fore.RED
role = "(attacker)"
else:
color = Fore.BLUE
role = "(defender)"
print(
color + f"[{message.agent_name} {role} -> {message.visible_to}]: {message.content}\n "
)

def observe(self, agent: AgentID) -> ObsType:
"""observe.
@@ -380,6 +393,8 @@ def _unravel_timestep(self, timestep: TimeStep):
self.infos[self.possible_agents[0]]["role"] = "attacker"
self.infos[self.possible_agents[1]]["role"] = "defender"

info["role"] = self.infos[self.agent_selection]["role"]

# info: environment specific information
if hasattr(self, "restricted_action"):
info["restricted_action"] = self.restricted_action
18 changes: 8 additions & 10 deletions chatarena/environments/umshini/symmetric_content_moderation.py
Original file line number Diff line number Diff line change
@@ -121,23 +121,22 @@ def step(self, player_name: str, action: str) -> TimeStep:
# it is player 2's turn to respond to player 1's query
# they should be judged on whether they allow the restricted action
# return rewards this step
rews = self.get_rewards()
timestep = TimeStep(
observation=self.get_observation(player_name=player_name),
reward=rews,
terminal=False,
)
if self._current_turn == self._round_length:
# swap roles
self._moderator_speak(f"Roles are being swapped. {self.player_names[0]} is defending, and {self.player_names[1]} is attacking.")
self.player_names.reverse()
self.agent_selector.reinit(self.player_names)
self._current_phase = "player_2_attack"
timestep = TimeStep(
observation=self.get_observation(player_name=player_name),
reward=self.get_rewards(),
terminal=False,
)
return timestep

elif self._current_turn <= 2 * self._round_length:
self._current_phase = "player_2_attack"
if self._current_turn % 2 == 1:
# it is player 1's turn to go first in attacking
# it is player 2's turn to go first in attacking
# add their message to the pool, return no reward
return TimeStep(
observation=self.get_observation(player_name=player_name),
@@ -151,10 +150,9 @@ def step(self, player_name: str, action: str) -> TimeStep:
is_now_terminal = self._current_turn == 2 * self._round_length

# get the rewards before getting the observation, so that the moderator's final message is displayed (winner)
rews = self.get_rewards()
return TimeStep(
observation=self.get_observation(player_name=player_name),
reward=rews,
reward=self.get_rewards(),
terminal=is_now_terminal,
)
else:
6 changes: 6 additions & 0 deletions chatarena/environments/umshini/symmetric_deception.py
Original file line number Diff line number Diff line change
@@ -124,6 +124,12 @@ def step(self, player_name: str, action: str) -> TimeStep:
self._moderator_speak(f"Roles are being swapped. {self.player_names[0]} is defending, and {self.player_names[1]} is attacking.")
self.player_names.reverse()
self.agent_selector.reinit(self.player_names)
self._current_phase = "player_2_attack"
timestep = TimeStep(
observation=self.get_observation(player_name=player_name),
reward=self.get_rewards(),
terminal=False,
)
return timestep
elif self._current_turn <= 2 * self._round_length + 1:
self._current_phase = "player_2_attack"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ bard = ["bardapi==0.1.11"]
langchain = ["langchain>=0.0.135"]
gradio = ["gradio>=3.34.0"]
pettingzoo = ["pettingzoo[classic]>=1.23.1"]
umshini = ["pettingzoo>=1.23.1", "langchain>=0.0.135"]
umshini = ["pettingzoo>=1.23.1", "langchain>=0.0.135", "colorama>=0.4.6"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"]
all_envs = ["pettingzoo[classic]>=1.23.1", "langchain>=0.0.135"]
all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio>=3.34.0", "pettingzoo>=1.23.1",
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def remove_duplicate_requirements(requirements):
langchain_requirements = ["langchain>=0.0.135"]
gradio_requirements = ["gradio>=3.34.0"]
pettingzoo_requirements = ["pettingzoo[classic]>=1.23.1", "chess==1.9.4"]
umshini_requirements = ["pettingzoo>=1.23.1", "langchain>=0.0.135"]
umshini_requirements = ["pettingzoo>=1.23.1", "langchain>=0.0.135", "colorama>=0.4.6"]

all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements + \
langchain_requirements

0 comments on commit 88c9ead

Please sign in to comment.