Skip to content

Commit

Permalink
Merge pull request #50 from elliottower/umshini-minor-bugfix
Browse files Browse the repository at this point in the history
Remove pz classic requirement from umshini, slight bugfix in envs
  • Loading branch information
yuxiang-wu authored Jun 21, 2023
2 parents 2f8efa6 + 7e13d82 commit 1497ce0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
14 changes: 12 additions & 2 deletions chatarena/environments/umshini/pettingzoo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,17 @@ def __init__(
"ChatArena Environment or environment name must be specified"
)
elif env is not None:
# TODO: test that human rendering works with this
self._env = env
if hasattr(env, "topic"):
self.topic = topic
self.max_turns = round_length
elif hasattr(env, "moderation_policy"):
self.moderation_policy = env.moderation_policy
self.max_turns = round_length * 2
elif hasattr(env, "restricted_action"):
self.restricted_action = env.restricted_action
self.max_turns = round_length * 2
elif env_name is not None:
if env_name == "debate":
assert topic is not None, "topic must be specified for debate env"
Expand Down Expand Up @@ -115,7 +125,7 @@ def __init__(
self.max_turns = round_length * 2
else:
raise TypeError(
"Environment not found. Options: debate, content_moderation, deception"
f"Environment not found: {env_name}. Options: debate, content_moderation, deception"
)
else:
raise TypeError(
Expand Down Expand Up @@ -216,7 +226,7 @@ def render(self):
self.clock.tick(self.metadata["render_fps"])

self.screen.fill("black")
font = pygame.Font(None, 32) # default font
font = pygame.font.Font(None, 32) # default font
self.line_spacing = 10
self.border_padding = 10

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ anthropic = ["anthropic>=0.2.8"]
cohere = ["cohere>=4.3.1"]
huggingface = ["transformers>=4.27.4"]
bard = ["bardapi==0.1.11"]
langchain_requirements = ["langchain>=0.0.135"]
langchain = ["langchain>=0.0.135"]
gradio = ["gradio>=3.34.0"]
pettingzoo = ["pettingzoo[classic]>=1.23.1"]
umshini_requirements = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"]
umshini = ["pettingzoo>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"]
all_envs = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"]
all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio>=3.34.0", "pettingzoo>=1.23.1",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def remove_duplicate_requirements(requirements):
langchain_requirements = ["langchain>=0.0.135"]
gradio_requirements = ["gradio>=3.34.0"]
pettingzoo_requirements = ["pettingzoo[classic]>=1.23.1", "chess==1.9.4"]
umshini_requirements = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"]
umshini_requirements = ["pettingzoo>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"]

all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements + \
langchain_requirements
Expand Down

0 comments on commit 1497ce0

Please sign in to comment.