From faecfa4e4e606d13998b0edddc802a1b6df48f03 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Wed, 20 Dec 2023 19:36:38 -0500 Subject: [PATCH] Umshini: fix minor typo with player_names kwarg (#117) --- chatarena/environments/umshini/content_moderation.py | 9 +++++---- chatarena/environments/umshini/debate.py | 2 +- chatarena/environments/umshini/deception.py | 9 ++++----- .../environments/umshini/symmetric_content_moderation.py | 5 ++--- chatarena/environments/umshini/symmetric_deception.py | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/chatarena/environments/umshini/content_moderation.py b/chatarena/environments/umshini/content_moderation.py index fbe42e1d..4fa2a5ca 100644 --- a/chatarena/environments/umshini/content_moderation.py +++ b/chatarena/environments/umshini/content_moderation.py @@ -71,11 +71,12 @@ def get_rewards(self) -> tuple[dict[str, float], str, str, str]: self._moderator_speak("The defender is not responding. No score.") return scores, "", "", "" - explanation = "" - vulnerability_category = "" - if self.disable_judging: violation = True if random.randint(0, 1) else False + explanation = "EXPLANATION: Judging has been disabled." + vulnerability_category = ( + "VULNERABILITY CATEGORY: Judging has been disabled." + ) else: response = self.judge_deception(message_state=self.message_pool).content try: @@ -265,7 +266,7 @@ def create_content_moderation_env( ): if player_names is not None: assert isinstance(player_names, list), "player_names must be a list" - assert len(player_names == 2), "number of players must be 2" + assert len(player_names) == 2, "number of players must be 2" assert isinstance(player_names[0], str), "player names must be strings" assert isinstance(player_names[1], str), "player names must be strings" diff --git a/chatarena/environments/umshini/debate.py b/chatarena/environments/umshini/debate.py index c55d6be7..d31edafb 100644 --- a/chatarena/environments/umshini/debate.py +++ b/chatarena/environments/umshini/debate.py @@ -109,7 +109,7 @@ def create_debate_env( ): if player_names is not None: assert isinstance(player_names, list), "player_names must be a list" - assert len(player_names == 2), "number of players must be 2" + assert len(player_names) == 2, "number of players must be 2" assert isinstance(player_names[0], str), "player names must be strings" assert isinstance(player_names[1], str), "player names must be strings" diff --git a/chatarena/environments/umshini/deception.py b/chatarena/environments/umshini/deception.py index f1aedda9..0b184afb 100644 --- a/chatarena/environments/umshini/deception.py +++ b/chatarena/environments/umshini/deception.py @@ -70,12 +70,12 @@ def get_rewards(self) -> tuple[dict[str, float], str, str, str]: self._moderator_speak("The defender is not responding. No score.") return scores, "", "", "" - explanation = "" - vulnerability_category = "" - if self.disable_judging: violation = True if random.randint(0, 1) else False - explanation = None + explanation = "EXPLANATION: Judging has been disabled." + vulnerability_category = ( + "VULNERABILITY CATEGORY: Judging has been disabled." + ) else: response = self.judge_deception(message_state=self.message_pool).content try: @@ -172,7 +172,6 @@ def step(self, player_name: str, action: str) -> TimeStep: explanation, vulnerability_category, ) = self.get_rewards() - # TODO: make it say something along these lines # JUDGEMENT: This response has violated the content moderation policy: {policy} # EXPLANATION: {explanation} # SCORES: {scores} diff --git a/chatarena/environments/umshini/symmetric_content_moderation.py b/chatarena/environments/umshini/symmetric_content_moderation.py index 28be3c36..e9f0af69 100644 --- a/chatarena/environments/umshini/symmetric_content_moderation.py +++ b/chatarena/environments/umshini/symmetric_content_moderation.py @@ -71,10 +71,9 @@ def get_rewards(self) -> tuple[dict[str, float], str, str]: self._moderator_speak("The defender is not responding. No score.") return scores, "", "" - violation = False - explanation = None if self.disable_judging: violation = True if random.randint(0, 1) else False + explanation = "EXPLANATION: Judging has been disabled." else: response = self.judge_deception(message_state=self.message_pool).content try: @@ -289,7 +288,7 @@ def create_symmetric_content_moderation_env( ): if player_names is not None: assert isinstance(player_names, list), "player_names must be a list" - assert len(player_names == 2), "number of players must be 2" + assert len(player_names) == 2, "number of players must be 2" assert isinstance(player_names[0], str), "player names must be strings" assert isinstance(player_names[1], str), "player names must be strings" diff --git a/chatarena/environments/umshini/symmetric_deception.py b/chatarena/environments/umshini/symmetric_deception.py index 104bdf0d..75a125f1 100644 --- a/chatarena/environments/umshini/symmetric_deception.py +++ b/chatarena/environments/umshini/symmetric_deception.py @@ -73,7 +73,7 @@ def get_rewards(self) -> tuple[dict[str, float], str, str]: violation = False if self.disable_judging: violation = True if random.randint(0, 1) else False - explanation = None + explanation = "EXPLANATION: Judging has been disabled." else: response = self.judge_deception(message_state=self.message_pool).content try: