diff --git a/langchain/chains/python.py b/langchain/chains/python.py index b3b6f1b6fc776..30a9c84525949 100644 --- a/langchain/chains/python.py +++ b/langchain/chains/python.py @@ -1,4 +1,7 @@ -"""Chain that runs python code.""" +"""Chain that runs python code. + +Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py +""" import sys from io import StringIO from typing import Dict, List diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py new file mode 100644 index 0000000000000..6a9f4a17e5ad4 --- /dev/null +++ b/langchain/chains/serpapi.py @@ -0,0 +1,84 @@ +"""Chain that calls SerpAPI. + +Heavily borrowed from https://github.com/ofirpress/self-ask +""" +import os +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.base import Chain + + +class SerpAPIChain(Chain, BaseModel): + """Chain that calls SerpAPI.""" + + search_engine: Any + input_key: str = "search_query" + output_key: str = "search_result" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def input_keys(self) -> List[str]: + """Return the singular input key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key.""" + return [self.output_key] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "SERPAPI_API_KEY" not in os.environ: + raise ValueError( + "Did not find SerpAPI API key, please add an environment variable" + " `SERPAPI_API_KEY` which contains it." + ) + try: + from serpapi import GoogleSearch + + values["search_engine"] = GoogleSearch + except ImportError: + raise ValueError( + "Could not import serpapi python package. " + "Please it install it with `pip install google-search-results`." + ) + return values + + def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: + params = { + "api_key": os.environ["SERPAPI_API_KEY"], + "engine": "google", + "q": inputs[self.input_key], + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + + search = self.search_engine(params) + res = search.get_dict() + + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret = res["answer_box"]["answer"] + elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret = res["answer_box"]["snippet"] + elif ( + "answer_box" in res.keys() + and "snippet_highlighted_words" in res["answer_box"].keys() + ): + toret = res["answer_box"]["snippet_highlighted_words"][0] + elif "snippet" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["snippet"] + else: + toret = None + return {self.output_key: toret} + + def search(self, search_question: str) -> str: + """More user-friendly interface for interfacing with search.""" + return self({self.input_key: search_question})[self.output_key] diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index e690b454dda9a..61504ba63df60 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -34,7 +34,7 @@ class Config: @root_validator() def template_is_valid(cls, values: Dict) -> Dict: - """Validate that api key python package exists in environment.""" + """Validate that api key and python package exists in environment.""" if "COHERE_API_KEY" not in os.environ: raise ValueError( "Did not find Cohere API key, please add an environment variable" diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 76bdb09073fd3..70296a9ff0593 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -27,7 +27,7 @@ class Config: @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key python package exists in environment.""" + """Validate that api key and python package exists in environment.""" if "OPENAI_API_KEY" not in os.environ: raise ValueError( "Did not find OpenAI API key, please add an environment variable" diff --git a/requirements.txt b/requirements.txt index 8ecab12d964e3..43adc5d2051c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ flake8 flake8-docstrings cohere openai +google-search-results diff --git a/tests/integration_tests/chains/__init__.py b/tests/integration_tests/chains/__init__.py new file mode 100644 index 0000000000000..3ca2420123d3a --- /dev/null +++ b/tests/integration_tests/chains/__init__.py @@ -0,0 +1 @@ +"""All integration tests for chains.""" diff --git a/tests/integration_tests/chains/test_serpapi.py b/tests/integration_tests/chains/test_serpapi.py new file mode 100644 index 0000000000000..5f4aa04887676 --- /dev/null +++ b/tests/integration_tests/chains/test_serpapi.py @@ -0,0 +1,9 @@ +"""Integration test for SerpAPI.""" +from langchain.chains.serpapi import SerpAPIChain + + +def test_call() -> None: + """Test that call gives the correct answer.""" + chain = SerpAPIChain() + output = chain.search("What was Obama's first name?") + assert output == "Barack Hussein Obama II"