From aac623f14438c42638ba50b1b5d1d55bda86b2c9 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 22:19:33 -0700 Subject: [PATCH 1/2] stash --- langchain/chains/python.py | 5 +- langchain/chains/serpapi.py | 80 +++++++++++++++++++ langchain/llms/cohere.py | 2 +- langchain/llms/openai.py | 2 +- requirements.txt | 1 + tests/integration_tests/chains/__init__.py | 1 + .../integration_tests/chains/test_serpapi.py | 7 ++ 7 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 langchain/chains/serpapi.py create mode 100644 tests/integration_tests/chains/__init__.py create mode 100644 tests/integration_tests/chains/test_serpapi.py diff --git a/langchain/chains/python.py b/langchain/chains/python.py index b3b6f1b6fc776..30a9c84525949 100644 --- a/langchain/chains/python.py +++ b/langchain/chains/python.py @@ -1,4 +1,7 @@ -"""Chain that runs python code.""" +"""Chain that runs python code. + +Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py +""" import sys from io import StringIO from typing import Dict, List diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py new file mode 100644 index 0000000000000..ff8d021c50e6d --- /dev/null +++ b/langchain/chains/serpapi.py @@ -0,0 +1,80 @@ +"""Chain that calls SerpAPI. + +Heavily borrowed from https://github.com/ofirpress/self-ask +""" +import os +from typing import Any, Callable, Dict, List + +from pydantic import BaseModel, root_validator + +from langchain.chains.base import Chain + + +class SerpAPIChain(Chain, BaseModel): + """Chain that calls SerpAPI.""" + + search: Callable + input_key: str = "search_query" + output_key: str = "search_result" + + @property + def input_keys(self) -> List[str]: + """Return the singular input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key.""" + return [self.output_key] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "SERPAPI_API_KEY" not in os.environ: + raise ValueError( + "Did not find SerpAPI API key, please add an environment variable" + " `SERPAPI_API_KEY` which contains it." + ) + try: + from serpapi import GoogleSearch + + values["search"] = GoogleSearch + except ImportError: + raise ValueError( + "Could not import serpapi python package. " + "Please it install it with `pip install google-search-results`." + ) + return values + + def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: + params = { + "api_key": os.environ["SERPAPI_API_KEY"], + "engine": "google", + "q": inputs[self.input_key], + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + + search = self.search(params) + res = search.get_dict() + + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret = res["answer_box"]["answer"] + elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret = res["answer_box"]["snippet"] + elif ( + "answer_box" in res.keys() + and "snippet_highlighted_words" in res["answer_box"].keys() + ): + toret = res["answer_box"]["snippet_highlighted_words"][0] + elif "snippet" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["snippet"] + else: + toret = None + + return {self.output_key: toret} + + def search(self, search_question: str) -> str: + """More user-friendly interface for interfacing with search.""" + return self({self.input_key: search_question})[self.output_key] diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index e690b454dda9a..61504ba63df60 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -34,7 +34,7 @@ class Config: @root_validator() def template_is_valid(cls, values: Dict) -> Dict: - """Validate that api key python package exists in environment.""" + """Validate that api key and python package exists in environment.""" if "COHERE_API_KEY" not in os.environ: raise ValueError( "Did not find Cohere API key, please add an environment variable" diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 76bdb09073fd3..70296a9ff0593 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -27,7 +27,7 @@ class Config: @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key python package exists in environment.""" + """Validate that api key and python package exists in environment.""" if "OPENAI_API_KEY" not in os.environ: raise ValueError( "Did not find OpenAI API key, please add an environment variable" diff --git a/requirements.txt b/requirements.txt index 8ecab12d964e3..43adc5d2051c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ flake8 flake8-docstrings cohere openai +google-search-results diff --git a/tests/integration_tests/chains/__init__.py b/tests/integration_tests/chains/__init__.py new file mode 100644 index 0000000000000..3ca2420123d3a --- /dev/null +++ b/tests/integration_tests/chains/__init__.py @@ -0,0 +1 @@ +"""All integration tests for chains.""" diff --git a/tests/integration_tests/chains/test_serpapi.py b/tests/integration_tests/chains/test_serpapi.py new file mode 100644 index 0000000000000..453558fa9e9a9 --- /dev/null +++ b/tests/integration_tests/chains/test_serpapi.py @@ -0,0 +1,7 @@ +"""Integration test for SerpAPI.""" +from langchain.chains.serpapi import SerpAPIChain + +def test_call(): + chain = SerpAPIChain() + output = chain.search("What was Obama's first name?") + breakpoint() \ No newline at end of file From 141f0fa8662059a091577651d9271c15bd9b2d28 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 22:37:41 -0700 Subject: [PATCH 2/2] add serpapi --- langchain/chains/serpapi.py | 16 ++++++++++------ tests/integration_tests/chains/test_serpapi.py | 6 ++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py index ff8d021c50e6d..6a9f4a17e5ad4 100644 --- a/langchain/chains/serpapi.py +++ b/langchain/chains/serpapi.py @@ -3,9 +3,9 @@ Heavily borrowed from https://github.com/ofirpress/self-ask """ import os -from typing import Any, Callable, Dict, List +from typing import Any, Dict, List -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain @@ -13,10 +13,15 @@ class SerpAPIChain(Chain, BaseModel): """Chain that calls SerpAPI.""" - search: Callable + search_engine: Any input_key: str = "search_query" output_key: str = "search_result" + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + @property def input_keys(self) -> List[str]: """Return the singular input key.""" @@ -38,7 +43,7 @@ def validate_environment(cls, values: Dict) -> Dict: try: from serpapi import GoogleSearch - values["search"] = GoogleSearch + values["search_engine"] = GoogleSearch except ImportError: raise ValueError( "Could not import serpapi python package. " @@ -56,7 +61,7 @@ def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: "hl": "en", } - search = self.search(params) + search = self.search_engine(params) res = search.get_dict() if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): @@ -72,7 +77,6 @@ def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: toret = res["organic_results"][0]["snippet"] else: toret = None - return {self.output_key: toret} def search(self, search_question: str) -> str: diff --git a/tests/integration_tests/chains/test_serpapi.py b/tests/integration_tests/chains/test_serpapi.py index 453558fa9e9a9..5f4aa04887676 100644 --- a/tests/integration_tests/chains/test_serpapi.py +++ b/tests/integration_tests/chains/test_serpapi.py @@ -1,7 +1,9 @@ """Integration test for SerpAPI.""" from langchain.chains.serpapi import SerpAPIChain -def test_call(): + +def test_call() -> None: + """Test that call gives the correct answer.""" chain = SerpAPIChain() output = chain.search("What was Obama's first name?") - breakpoint() \ No newline at end of file + assert output == "Barack Hussein Obama II"