Skip to content

Commit

Permalink
Umshini: fix minor typo with player_names kwarg (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Dec 21, 2023
1 parent 4abf91e commit faecfa4
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
9 changes: 5 additions & 4 deletions chatarena/environments/umshini/content_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def get_rewards(self) -> tuple[dict[str, float], str, str, str]:
self._moderator_speak("The defender is not responding. No score.")
return scores, "", "", ""

explanation = ""
vulnerability_category = ""

if self.disable_judging:
violation = True if random.randint(0, 1) else False
explanation = "EXPLANATION: Judging has been disabled."
vulnerability_category = (
"VULNERABILITY CATEGORY: Judging has been disabled."
)
else:
response = self.judge_deception(message_state=self.message_pool).content
try:
Expand Down Expand Up @@ -265,7 +266,7 @@ def create_content_moderation_env(
):
if player_names is not None:
assert isinstance(player_names, list), "player_names must be a list"
assert len(player_names == 2), "number of players must be 2"
assert len(player_names) == 2, "number of players must be 2"
assert isinstance(player_names[0], str), "player names must be strings"
assert isinstance(player_names[1], str), "player names must be strings"

Expand Down
2 changes: 1 addition & 1 deletion chatarena/environments/umshini/debate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_debate_env(
):
if player_names is not None:
assert isinstance(player_names, list), "player_names must be a list"
assert len(player_names == 2), "number of players must be 2"
assert len(player_names) == 2, "number of players must be 2"
assert isinstance(player_names[0], str), "player names must be strings"
assert isinstance(player_names[1], str), "player names must be strings"

Expand Down
9 changes: 4 additions & 5 deletions chatarena/environments/umshini/deception.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def get_rewards(self) -> tuple[dict[str, float], str, str, str]:
self._moderator_speak("The defender is not responding. No score.")
return scores, "", "", ""

explanation = ""
vulnerability_category = ""

if self.disable_judging:
violation = True if random.randint(0, 1) else False
explanation = None
explanation = "EXPLANATION: Judging has been disabled."
vulnerability_category = (
"VULNERABILITY CATEGORY: Judging has been disabled."
)
else:
response = self.judge_deception(message_state=self.message_pool).content
try:
Expand Down Expand Up @@ -172,7 +172,6 @@ def step(self, player_name: str, action: str) -> TimeStep:
explanation,
vulnerability_category,
) = self.get_rewards()
# TODO: make it say something along these lines
# JUDGEMENT: This response has violated the content moderation policy: {policy}
# EXPLANATION: {explanation}
# SCORES: {scores}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ def get_rewards(self) -> tuple[dict[str, float], str, str]:
self._moderator_speak("The defender is not responding. No score.")
return scores, "", ""

violation = False
explanation = None
if self.disable_judging:
violation = True if random.randint(0, 1) else False
explanation = "EXPLANATION: Judging has been disabled."
else:
response = self.judge_deception(message_state=self.message_pool).content
try:
Expand Down Expand Up @@ -289,7 +288,7 @@ def create_symmetric_content_moderation_env(
):
if player_names is not None:
assert isinstance(player_names, list), "player_names must be a list"
assert len(player_names == 2), "number of players must be 2"
assert len(player_names) == 2, "number of players must be 2"
assert isinstance(player_names[0], str), "player names must be strings"
assert isinstance(player_names[1], str), "player names must be strings"

Expand Down
2 changes: 1 addition & 1 deletion chatarena/environments/umshini/symmetric_deception.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_rewards(self) -> tuple[dict[str, float], str, str]:
violation = False
if self.disable_judging:
violation = True if random.randint(0, 1) else False
explanation = None
explanation = "EXPLANATION: Judging has been disabled."
else:
response = self.judge_deception(message_state=self.message_pool).content
try:
Expand Down

0 comments on commit faecfa4

Please sign in to comment.