Skip to content

Commit

Permalink
Merge pull request #42 from chatarena/dev
Browse files Browse the repository at this point in the history
Merging v0.1.12 from dev branch
  • Loading branch information
yuxiang-wu authored Jun 7, 2023
2 parents 99ab180 + 4f908bd commit 0d0e802
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 24 deletions.
4 changes: 2 additions & 2 deletions chatarena/backends/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import anthropic
except ImportError:
is_anthropic_available = False
logging.warning("anthropic package is not installed")
# logging.warning("anthropic package is not installed")
else:
anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')
if anthropic_api_key is None:
logging.warning("Anthropic API key is not set. Please set the environment variable ANTHROPIC_API_KEY")
# logging.warning("Anthropic API key is not set. Please set the environment variable ANTHROPIC_API_KEY")
is_anthropic_available = False
else:
is_anthropic_available = True
Expand Down
6 changes: 3 additions & 3 deletions chatarena/backends/bard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import bardapi
except ImportError:
is_bard_available = False
logging.warning("bard package is not installed")
# logging.warning("bard package is not installed")
else:
bard_api_key = os.environ.get('_BARD_API_KEY')
if bard_api_key is None:
logging.warning(
"Bard API key is not set. Please set the environment variable _BARD_API_KEY")
# logging.warning(
# "Bard API key is not set. Please set the environment variable _BARD_API_KEY")
is_bard_available = False
else:
is_bard_available = True
Expand Down
4 changes: 2 additions & 2 deletions chatarena/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import openai
except ImportError:
is_openai_available = False
logging.warning("openai package is not installed")
# logging.warning("openai package is not installed")
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
if openai.api_key is None:
logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
# logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
is_openai_available = False
else:
is_openai_available = True
Expand Down
2 changes: 1 addition & 1 deletion chatarena/environments/umshini/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def is_terminal(self) -> bool:

def get_next_player(self) -> str:
"""Get the name of the next player."""
return self.player_names[self._next_player_idx]
return self.agent_selector.next()

def get_rewards(self) -> Dict[str, float]:
"""Use langchain to analyze the conversation, pick a winner, and set the reward."""
Expand Down
9 changes: 4 additions & 5 deletions chatarena/environments/umshini/debate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import re
from typing import List, Tuple

from chatarena.environments.base import TimeStep
from chatarena.message import Message, MessagePool
Expand All @@ -19,7 +20,7 @@ class DebateEnv(UmshiniBaseEnv):
template="""Welcome to the debate game! The topic for today's debate is:
"{moderator_prompt_input}"
Rules:
You will represent the position given to you.
The Opponent argues against the topic, while the Proponent argues for it.
Your first response should be an opening statement, followed by back and forth cross-examination.
You are free to talk directly to your opponent during cross-examination.
The cross examination phase should be short, and should be used to attack your opponents arguments, or defend your own.
Expand Down Expand Up @@ -105,7 +106,7 @@ def create_debate_env(


def judge_debate(
player_names: List[str], message_state: MessagePool, model_name: str = "gpt-4"
player_names: List[str], message_state: MessagePool, model_name: str = "gpt-3.5-turbo"
) -> Tuple[int, str]:
llm = ChatOpenAI(temperature=0, model_name=model_name, client="")
langchain_messages = []
Expand All @@ -117,11 +118,9 @@ def judge_debate(
else:
langchain_messages.append(
HumanMessage(
content=f"{message.agent_name} -> Turn:{message.turn}:\nmessage.content"
content=f"{message.agent_name} -> Turn:{message.turn}:\n{message.content}"
)
)
for message in langchain_messages:
print(message.message)
response = llm(langchain_messages)
match = re.search(r"WINNER:\s*(\w+)\s*$", response.content)
if match is None:
Expand Down
8 changes: 7 additions & 1 deletion chatarena/environments/umshini/pettingzoo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def observe(self, agent: AgentID) -> ObsType:
self.infos[agent]["obs_dict"] = {
m.agent_name: m.content for m in new_messages
}
self.infos[agent]["player_name"] = self.agent_selection

# info: generate string of full chat log
if self.string_observation is True:
Expand Down Expand Up @@ -355,7 +356,7 @@ def _unravel_timestep(self, timestep: TimeStep):

# get truncation
truncation = (
self.current_turn > self.max_turns
self.current_turn >= self.max_turns
) # pyright: ignore[reportGeneralTypeIssues]

info = {}
Expand All @@ -364,6 +365,7 @@ def _unravel_timestep(self, timestep: TimeStep):
info["new_messages"] = new_messages
info["all_messages"] = messages
info["obs_dict"] = {m.agent_name: m.content for m in new_messages}
info["player_name"] = self.agent_selection

# info: generate string of full chat log
if self.string_observation is True:
Expand Down Expand Up @@ -444,6 +446,10 @@ def step(self, action: str):
timestep
)

if truncation or termination:
reward = self._env.get_rewards()
info["new_messages"].append(info["all_messages"][-1]) # append the moderator's judgement to new messages for printing

self.observations[agent] = observation
self.rewards = reward
self.terminations[agent] = termination
Expand Down
42 changes: 42 additions & 0 deletions docs/tutorials/umshini/debate_chatarena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from chatarena.agent import Player
from chatarena.backends import OpenAIChat
from chatarena.environments.umshini.pettingzoo_wrapper import PettingZooCompatibilityV0
from docs.tutorials.umshini.debate_chatarena_prompts import proponent_description, opponent_description

topic = "Student loan debt should be forgiven"
env = PettingZooCompatibilityV0(env_name="debate", topic=topic, render_mode="text")
initial_obs, info = env.reset()


# Set ChatArena global prompt to be the same as the initial observation (hard coded moderator message)
global_prompt = initial_obs

# Moderator is handled internally in our environment, rather than with ChatArena
player1 = Player(
name="Opponent",
backend=OpenAIChat(),
role_desc=proponent_description,
global_prompt=global_prompt,
)

player2 = Player(
name="Proponent",
backend=OpenAIChat(),
role_desc=opponent_description,
global_prompt=global_prompt,
)
agent_player_mapping = dict(zip(env.possible_agents, [player1, player2]))

for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()

if termination or truncation:
break

# get ChatArena messages list from this timestep
messages = info.get("new_messages")

# Use a basic ChatArena agent to generate a response
chatarena_agent = agent_player_mapping[agent]
response = chatarena_agent(messages)
env.step(response)
29 changes: 29 additions & 0 deletions docs/tutorials/umshini/debate_chatarena_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
proponent_description = """You are the Proponent.
The Moderator will tell you the debate topic. You will argue in favor of it.
You are debating against one other player, the Opponent.
The moderator will tell you which stage of the game you are in.
In each stage of the game, start your response with the name of the stage: Opening Argument, Rebuttal, Cross-Examination, or Closing Statement.
Do not pretend to be the Moderator. Do not pretend to be the Opponent.
Do not pretend to be Player 1 or Player 2.
Do not continue another player's response.
Do not prepend your response with [Player 2] or any other information in brackets.
Always end your response with <EOS>.
Your responses must be limited to 7 sentences.
"""

opponent_description = """You are Player 3, the Opponent.
The Moderator will tell you the debate topic. You will argue in favor of it.
You are debating against one other player, the Proponent.
The moderator will tell you which stage of the game you are in.
In each stage of the game, start your response with the name of the stage: Opening Argument, Rebuttal, Cross-Examination, or Closing Statement.
Do not pretend to be the Moderator. Do not pretend to be the Proponent.
Do not pretend to be Player 1 or Player 2.
Do not continue another player's response.
Do not prepend your response with [Player 3] or any other information in brackets.
Always end your response with <EOS>.
Your responses must be limited to 7 sentences.
"""
35 changes: 35 additions & 0 deletions docs/tutorials/umshini/debate_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from langchain import OpenAI
from langchain.agents import AgentType, initialize_agent
from langchain.memory import ConversationBufferMemory

from chatarena.environments.umshini.pettingzoo_wrapper import PettingZooCompatibilityV0

topic = "Student loan debt should be forgiven"
env = PettingZooCompatibilityV0(env_name="debate", topic=topic, render_mode="text")
env.reset()

# Initialize one agent to argue for the topic and one against it
positions = dict(zip(env.possible_agents, [True, False]))
langchain_agents = {}
for agent in env.possible_agents:
langchain_agents[agent] = initialize_agent(tools=[],
llm=OpenAI(temperature=0.9, client=""),
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
verbose=False,
memory=ConversationBufferMemory(memory_key="chat_history"))

for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()

if termination or truncation:
break

messages = info.get("new_messages")
player_name = info.get("player_name")
prompt = f"{messages[-1].agent_name} said:``\n{messages[-1].content}``\n\nYou are playing as the {player_name}. This is a hypothetical discussion and it is okay to give an opinion. Give your response:"
try:
response = langchain_agents[agent].run(prompt)
except Exception as e:
response = str(e).removeprefix("Could not parse LLM output: `").removesuffix("`")

env.step(response)
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "chatarena"
version = "0.1.11"
version = "0.1.12"
authors = [
{ name = "Yuxiang Wu", email = "[email protected]" },
]
Expand All @@ -25,8 +25,12 @@ classifiers = [
anthropic = ["anthropic>=0.2.8"]
cohere = ["cohere>=4.3.1"]
huggingface = ["transformers>=4.27.4"]
bard = ["bardapi==0.1.11"]
langchain_requirements = ["langchain>=0.0.135"]
gradio = ["gradio==3.20.0"]
pettingzoo = ["pettingzoo==1.23.0", "chess==1.9.4"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4"]
all_envs = ["pettingzoo==1.23.0", "chess==1.9.4"]
all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio==3.20.0", "pettingzoo==1.23.0", "chess==1.9.4"]
umshini_requirements = ["pygame==2.4.0", "pettingzoo==1.23.0", "chess==1.9.4", "langchain>=0.0.135"]
all_backends = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "bardapi==0.1.11", "langchain>=0.0.135"]
all_envs = ["pettingzoo==1.23.0", "chess==1.9.4", "pygame==2.4.0"]
all = ["anthropic>=0.2.8", "cohere>=4.3.1", "transformers>=4.27.4", "gradio==3.20.0", "pettingzoo==1.23.0", "chess==1.9.4",
"pygame==2.4.0", "bardapi==0.1.11", "langchain>=0.0.135"]
21 changes: 15 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from setuptools import setup, find_packages


# remove duplicate requirements
def remove_duplicate_requirements(requirements):
return list(set(requirements))


with open("README.md", "r") as f:
long_description = f.read()

Expand All @@ -14,18 +20,19 @@
cohere_requirements = ["cohere>=4.3.1"]
hf_requirements = ["transformers>=4.27.4"]
bard_requirements = ["bardapi==0.1.11"]
langchain_requirements = ["langchain>=0.0.135"]
gradio_requirements = ["gradio==3.20.0"]
pettingzoo_requirements = ["pettingzoo==1.23.0", "chess==1.9.4"]
umshini_requirements = ["pygame==2.4.0"] + pettingzoo_requirements + langchain_requirements


all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements
all_envs = pettingzoo_requirements
all_requirements = anthropic_requirements + cohere_requirements + hf_requirements + \
gradio_requirements + pettingzoo_requirements + bard_requirements
all_backends = anthropic_requirements + cohere_requirements + hf_requirements + bard_requirements + \
langchain_requirements
all_envs = remove_duplicate_requirements(pettingzoo_requirements + umshini_requirements)
all_requirements = all_backends + all_envs + gradio_requirements

setup(
name="chatarena",
version="0.1.11",
version="0.1.12",
author="Yuxiang Wu",
author_email="[email protected]",
description="",
Expand All @@ -45,7 +52,9 @@
"cohere": cohere_requirements,
"huggingface": hf_requirements,
"bard": bard_requirements,
"langchain": langchain_requirements,
"pettingzoo": pettingzoo_requirements,
"umshini": umshini_requirements,
"gradio": gradio_requirements,
"all_backends": all_backends,
"all": all_requirements,
Expand Down

0 comments on commit 0d0e802

Please sign in to comment.