From 7e13d824f8241b66e2035d53d18da0d4f3f68a18 Mon Sep 17 00:00:00 2001 From: elliottower Date: Tue, 20 Jun 2023 21:05:31 -0400 Subject: [PATCH] Remove pz classic requirement from umshini, slight bugfix in envs --- .../environments/umshini/pettingzoo_wrapper.py | 14 ++++++++++++-- pyproject.toml | 4 ++-- setup.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/chatarena/environments/umshini/pettingzoo_wrapper.py b/chatarena/environments/umshini/pettingzoo_wrapper.py index 393ff02d..d70f24ac 100644 --- a/chatarena/environments/umshini/pettingzoo_wrapper.py +++ b/chatarena/environments/umshini/pettingzoo_wrapper.py @@ -82,7 +82,17 @@ def __init__( "ChatArena Environment or environment name must be specified" ) elif env is not None: + # TODO: test that human rendering works with this self._env = env + if hasattr(env, "topic"): + self.topic = topic + self.max_turns = round_length + elif hasattr(env, "moderation_policy"): + self.moderation_policy = env.moderation_policy + self.max_turns = round_length * 2 + elif hasattr(env, "restricted_action"): + self.restricted_action = env.restricted_action + self.max_turns = round_length * 2 elif env_name is not None: if env_name == "debate": assert topic is not None, "topic must be specified for debate env" @@ -115,7 +125,7 @@ def __init__( self.max_turns = round_length * 2 else: raise TypeError( - "Environment not found. Options: debate, content_moderation, deception" + f"Environment not found: {env_name}. Options: debate, content_moderation, deception" ) else: raise TypeError( @@ -216,7 +226,7 @@ def render(self): self.clock.tick(self.metadata["render_fps"]) self.screen.fill("black") - font = pygame.Font(None, 32) # default font + font = pygame.font.Font(None, 32) # default font self.line_spacing = 10 self.border_padding = 10 diff --git a/pyproject.toml b/pyproject.toml index 1a8f034b..ec062198 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,10 @@ anthropic = ["anthropic>=0.2.8"] cohere = ["cohere>=4.3.1"] huggingface = ["transformers>=4.27.4"] bard = ["bardapi==0.1.11"] -langchain_requirements = ["langchain>=0.0.135"] +langchain = ["langchain>=0.0.135"] gradio = ["gradio>=3.34.0"] pettingzoo = ["pettingzoo[classic]>=1.23.1"] -umshini_requirements = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"] +umshini = ["pettingzoo>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"] all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"] all_envs = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"] all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio>=3.34.0", "pettingzoo>=1.23.1", diff --git a/setup.py b/setup.py index c95dbe20..351f6a63 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def remove_duplicate_requirements(requirements): langchain_requirements = ["langchain>=0.0.135"] gradio_requirements = ["gradio>=3.34.0"] pettingzoo_requirements = ["pettingzoo[classic]>=1.23.1", "chess==1.9.4"] -umshini_requirements = ["pettingzoo[classic]>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"] +umshini_requirements = ["pettingzoo>=1.23.1", "pygame-ce>=2.2.1", "langchain>=0.0.135"] all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements + \ langchain_requirements