Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Umshini: fix bug with debate env and self.role attribute #76

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions chatarena/environments/umshini/pettingzoo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,19 @@ def __init__(
elif env is not None:
self._env = env
if hasattr(env, "topic"):
self.env_name = "debate"
self.topic = topic
self.max_turns = round_length
elif hasattr(env, "moderation_policy"):
self.env_name = "content_moderation"
self.moderation_policy = env.moderation_policy
self.max_turns = round_length * 2
elif hasattr(env, "restricted_action"):
self.env_name = "deception"
self.restricted_action = env.restricted_action
self.max_turns = round_length * 2
elif env_name is not None:
self.env_name = env_name
if env_name == "debate":
assert topic is not None, "topic must be specified for debate env"
self._env = create_debate_env(
Expand Down Expand Up @@ -233,6 +237,12 @@ def render(self):
if message.agent_name == "Moderator":
color = Fore.BLACK
role = ""
elif message.agent_name == "Proponent":
color = Fore.BLUE
role = ""
elif message.agent_name == "Opponent":
color = Fore.RED
role = ""
else:
if self.infos[message.agent_name]["role"] == "attacker":
color = Fore.RED
Expand Down Expand Up @@ -292,14 +302,15 @@ def observe(self, agent: AgentID) -> ObsType:
}
self.infos[agent]["player_name"] = self.agent_selection

# Role in symmetric environments
if hasattr(self._env, "_current_phase"):
if self._env._current_phase == "player_2_attack" or self._env._current_phase == "end":
self.infos[self.possible_agents[0]]["role"] = "defender"
self.infos[self.possible_agents[1]]["role"] = "attacker"
else:
self.infos[self.possible_agents[0]]["role"] = "attacker"
self.infos[self.possible_agents[1]]["role"] = "defender"
# Role in symmetric environments (not applicable if env has terminated)
if self.env_name != "debate":
if hasattr(self._env, "_current_phase") and not any(self.terminations.values()):
if self._env._current_phase == "player_2_attack":
self.infos[self.possible_agents[0]]["role"] = "defender"
self.infos[self.possible_agents[1]]["role"] = "attacker"
else:
self.infos[self.possible_agents[0]]["role"] = "attacker"
self.infos[self.possible_agents[1]]["role"] = "defender"

# info: generate string of full chat log
if self.string_observation is True:
Expand Down