Skip to content

Commit

Permalink
Update PMAT and necessary changes for new langchain and small changes…
Browse files Browse the repository at this point in the history
… to be compatible with PMA (#103)
  • Loading branch information
kongzii authored Jul 16, 2024
1 parent 87a5bf0 commit ba98871
Show file tree
Hide file tree
Showing 14 changed files with 573 additions and 671 deletions.
Empty file removed evo_prophet/py.typed
Empty file.
Empty file removed evo_researcher/py.typed
Empty file.
1,188 changes: 538 additions & 650 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion prediction_prophet/autonolas/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from prediction_prophet.functions.cache import persistent_inmemory_cache
from prediction_prophet.functions.parallelism import par_map
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr

load_dotenv()

Expand Down Expand Up @@ -1186,7 +1187,7 @@ def make_prediction(

prediction_prompt = ChatPromptTemplate.from_template(template=PREDICTION_PROMPT)

llm = ChatOpenAI(model=engine, temperature=temperature, api_key=api_key.get_secret_value() if api_key else None)
llm = ChatOpenAI(model=engine, temperature=temperature, api_key=secretstr_to_v1_secretstr(api_key))
formatted_messages = prediction_prompt.format_messages(user_prompt=prompt, additional_information=additional_information, timestamp=formatted_time_utc)
generation = llm.generate([formatted_messages], logprobs=True, top_logprobs=5)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


def create_embeddings_from_results(results: list[WebScrapeResult], text_splitter: RecursiveCharacterTextSplitter, api_key: SecretStr | None = None) -> Chroma:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")

collection = Chroma(embedding_function=OpenAIEmbeddings(api_key=api_key.get_secret_value() if api_key else None))
collection = Chroma(embedding_function=OpenAIEmbeddings(api_key=secretstr_to_v1_secretstr(api_key)))
texts = []
metadatas = []

Expand Down
5 changes: 3 additions & 2 deletions prediction_prophet/functions/debate_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


PREDICTION_PROMPT = """
Expand Down Expand Up @@ -92,7 +93,7 @@ def make_debated_prediction(prompt: str, additional_information: str, api_key: S

prediction_chain = (
prediction_prompt |
ChatOpenAI(model="gpt-4-0125-preview", api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model="gpt-4-0125-preview", api_key=secretstr_to_v1_secretstr(api_key)) |
StrOutputParser()
)

Expand Down Expand Up @@ -127,7 +128,7 @@ def make_debated_prediction(prompt: str, additional_information: str, api_key: S

extraction_chain = (
extraction_prompt |
ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=secretstr_to_v1_secretstr(api_key)) |
StrOutputParser()
)

Expand Down
3 changes: 2 additions & 1 deletion prediction_prophet/functions/evaluate_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from prediction_prophet.functions.cache import persistent_inmemory_cache
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


# I tried to make it return a JSON, but it didn't work well in combo with asking it to do chain of thought.
Expand Down Expand Up @@ -50,7 +51,7 @@ def is_predictable(

if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
llm = ChatOpenAI(model=engine, temperature=0.0, api_key=api_key.get_secret_value() if api_key else None)
llm = ChatOpenAI(model=engine, temperature=0.0, api_key=secretstr_to_v1_secretstr(api_key))

prompt = ChatPromptTemplate.from_template(template=prompt_template)
messages = prompt.format_messages(question=question)
Expand Down
3 changes: 2 additions & 1 deletion prediction_prophet/functions/generate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain.prompts import ChatPromptTemplate
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


subquery_generation_template = """
Expand All @@ -22,7 +23,7 @@ def generate_subqueries(query: str, limit: int, model: str, api_key: SecretStr |

subquery_generation_chain = (
subquery_generation_prompt |
ChatOpenAI(model=model, api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model=model, api_key=secretstr_to_v1_secretstr(api_key)) |
CommaSeparatedListOutputParser()
)

Expand Down
4 changes: 2 additions & 2 deletions prediction_prophet/functions/is_predictable_and_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from prediction_prophet.functions.cache import persistent_inmemory_cache
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env

from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


# I tried to make it return a JSON, but it didn't work well in combo with asking it to do chain of thought.
Expand Down Expand Up @@ -52,7 +52,7 @@ def is_predictable_and_binary(

if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
llm = ChatOpenAI(model=engine, temperature=0.0, api_key=api_key.get_secret_value() if api_key else None)
llm = ChatOpenAI(model=engine, temperature=0.0, api_key=secretstr_to_v1_secretstr(api_key))

prompt = ChatPromptTemplate.from_template(template=prompt_template)
messages = prompt.format_messages(question=question)
Expand Down
5 changes: 3 additions & 2 deletions prediction_prophet/functions/prepare_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from prediction_prophet.functions.utils import trim_to_n_tokens
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


@persistent_inmemory_cache
Expand All @@ -28,7 +29,7 @@ def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | No

research_evaluation_chain = (
evaluation_prompt |
ChatOpenAI(model=model, api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model=model, api_key=secretstr_to_v1_secretstr(api_key)) |
StrOutputParser()
)

Expand Down Expand Up @@ -69,7 +70,7 @@ def prepare_report(goal: str, scraped: list[str], model: str, api_key: SecretStr

research_evaluation_chain = (
evaluation_prompt |
ChatOpenAI(model=model, api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model=model, api_key=secretstr_to_v1_secretstr(api_key)) |
StrOutputParser()
)

Expand Down
3 changes: 2 additions & 1 deletion prediction_prophet/functions/rerank_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain.schema.output_parser import StrOutputParser
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr

rerank_queries_template = """
I will present you with a list of queries to search the web for, for answers to the question: {goal}.
Expand All @@ -23,7 +24,7 @@ def rerank_subqueries(queries: list[str], goal: str, model: str, api_key: Secret

rerank_results_chain = (
rerank_results_prompt |
ChatOpenAI(model=model, api_key=api_key.get_secret_value() if api_key else None) |
ChatOpenAI(model=model, api_key=secretstr_to_v1_secretstr(api_key)) |
StrOutputParser()
)

Expand Down
6 changes: 5 additions & 1 deletion prediction_prophet/functions/research.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing as t

from langchain.text_splitter import RecursiveCharacterTextSplitter
from prediction_prophet.functions.create_embeddings_from_results import create_embeddings_from_results
Expand All @@ -10,6 +11,9 @@
from prediction_prophet.functions.search import search
from pydantic.types import SecretStr

if t.TYPE_CHECKING:
from loguru import Logger

def research(
goal: str,
use_summaries: bool,
Expand All @@ -22,7 +26,7 @@ def research(
use_tavily_raw_content: bool = False,
openai_api_key: SecretStr | None = None,
tavily_api_key: SecretStr | None = None,
logger: logging.Logger = logging.getLogger()
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger()
) -> str:
logger.info("Started subqueries generation")
queries = generate_subqueries(query=goal, limit=initial_subqueries_limit, model=model, api_key=openai_api_key)
Expand Down
4 changes: 3 additions & 1 deletion prediction_prophet/functions/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import ReduceDocumentsChain, StuffDocumentsChain, MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.text_splitter import RecursiveCharacterTextSplitter

def summarize(objective: str, content: str) -> str:
Expand Down
17 changes: 9 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "~3.10"
langchain = "^0.1.9"
langchain = "^0.2.6"
beautifulsoup4 = "^4.12.3"
click = "^8.0.2"
markdownify = "0.11.6"
pandas = "2.1.1"
pytest = "^8.0.0"
openai = "^1.10.0"
chromadb = "0.4.22"
chromadb = "0.4.24"
spacy = "3.7.5"
en_core_web_md = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1-py3-none-any.whl" }
google-api-python-client = "2.95.0"
tiktoken = "0.5.2"
tiktoken = "^0.7.0"
tavily-python = "^0.3.0"
tabulate = "^0.9.0"
pysqlite3-binary = {version="^0.5.2.post3", markers = "sys_platform == 'linux'"}
langchain-openai = "^0.0.5"
langchain-openai = "^0.1.0"
tenacity = "^8.2.3"
joblib = "^1.3.2"
streamlit = "^1.30.0"
Expand All @@ -31,15 +30,17 @@ scipy = "^1.12.0"
scikit-learn = "^1.4.0"
typer = ">=0.9.0,<1.0.0"
types-requests = "^2.31.0.20240125"
types-python-dateutil = "^2.8.19.20240106"
prediction-market-agent-tooling = { version = "^0.40.0", extras = ["langchain", "google"] }
langchain-community = "^0.0.32"
types-python-dateutil = "^2.9.0"
prediction-market-agent-tooling = { version = ">=0.43.0,<1", extras = ["langchain", "google"] }
langchain-community = "^0.2.6"
memory-profiler = "^0.61.0"
matplotlib = "^3.8.3"
pyautogen = "^0.2.19"
python-dateutil = "^2.9.0.post0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.8.0"
pytest = "^8.0.0"

[tool.poetry.scripts]
research= "prediction_prophet.main:research"
Expand Down

0 comments on commit ba98871

Please sign in to comment.