From 9c08cdea9288ac3b3340bab21ff96c607390d52b Mon Sep 17 00:00:00 2001 From: Kahlil Wehmeyer <7523160+kwehmeyer@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:52:36 -0400 Subject: [PATCH 1/9] core[patch]: ToolException docs/exception message (#17590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description:** This PR adds a slightly more helpful message to a Tool Exception ``` # current state langchain_core.tools.ToolException: Too many arguments to single-input tool # proposed state langchain_core.tools.ToolException: Too many arguments to single-input tool. Consider using a StructuredTool instead. ``` **Issue:** Somewhat discussed here 👉 #6197 **Dependencies:** None **Twitter handle:** N/A --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- libs/core/langchain_core/tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 585bd1d53e63e..97afdfaca0e46 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -563,7 +563,8 @@ def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict all_args = list(args) + list(kwargs.values()) if len(all_args) != 1: raise ToolException( - f"Too many arguments to single-input tool {self.name}." + f"""Too many arguments to single-input tool {self.name}. + Consider using StructuredTool instead.""" f" Args: {all_args}" ) return tuple(all_args), {} From b9016490325f2719ccb1dcec77963902477cebb1 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:55:16 -0700 Subject: [PATCH 2/9] docs: move extraction up (#19667) --- docs/docs/use_cases/extraction/index.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/use_cases/extraction/index.ipynb b/docs/docs/use_cases/extraction/index.ipynb index cc384d3f58bda..3122130a8a05a 100644 --- a/docs/docs/use_cases/extraction/index.ipynb +++ b/docs/docs/use_cases/extraction/index.ipynb @@ -7,7 +7,7 @@ "source": [ "---\n", "title: Extraction\n", - "sidebar_position: 3\n", + "sidebar_position: 0.05\n", "---" ] }, From be2adb108393cc85948fb4dc8df22b04a9142fed Mon Sep 17 00:00:00 2001 From: chyroc Date: Thu, 28 Mar 2024 06:03:48 +0800 Subject: [PATCH 3/9] community[patch]: support unstructured_kwargs for s3 loader (#15473) fix https://github.com/langchain-ai/langchain/issues/15472 Co-authored-by: Bagatur --- .../langchain_community/document_loaders/s3_file.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/s3_file.py b/libs/community/langchain_community/document_loaders/s3_file.py index 59b3164993afb..fb0f0c675aba9 100644 --- a/libs/community/langchain_community/document_loaders/s3_file.py +++ b/libs/community/langchain_community/document_loaders/s3_file.py @@ -2,7 +2,7 @@ import os import tempfile -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from langchain_community.document_loaders.unstructured import UnstructuredBaseLoader @@ -29,6 +29,7 @@ def __init__( boto_config: Optional[botocore.client.Config] = None, mode: str = "single", post_processors: Optional[List[Callable]] = None, + **unstructured_kwargs: Any, ): """Initialize with bucket and key name. @@ -85,11 +86,13 @@ def __init__( the client will be the result of calling ``merge()`` on the default config with the config provided to this call. :param mode: Mode in which to read the file. Valid options are: single, - paged and elements + paged and elements. :param post_processors: Post processing functions to be applied to - extracted elements + extracted elements. + :param **unstructured_kwargs: Arbitrary additional kwargs to pass in when + calling `partition` """ - super().__init__(mode, post_processors) + super().__init__(mode, post_processors, **unstructured_kwargs) self.bucket = bucket self.key = key self.region_name = region_name @@ -129,7 +132,7 @@ def _get_elements(self) -> List: file_path = f"{temp_dir}/{self.key}" os.makedirs(os.path.dirname(file_path), exist_ok=True) s3.download_file(self.bucket, self.key, file_path) - return partition(filename=file_path) + return partition(filename=file_path, **self.unstructured_kwargs) def _get_metadata(self) -> dict: return {"source": f"s3://{self.bucket}/{self.key}"} From cf96060ab77f38ae5bb1ca99e588594a1649073c Mon Sep 17 00:00:00 2001 From: CaroFG <48251481+CaroFG@users.noreply.github.com> Date: Wed, 27 Mar 2024 22:08:27 +0000 Subject: [PATCH 4/9] community[patch]: update for compatibility with latest Meilisearch version (#18970) - **Description:** Updates Meilisearch vectorstore for compatibility with v1.6 and above. Adds embedders settings and embedder_name which are now required. --------- Co-authored-by: Bagatur --- .../vectorstores/meilisearch.ipynb | 41 +++++++++++++---- .../vectorstores/meilisearch.py | 46 +++++++++++++++++-- .../vectorstores/test_meilisearch.py | 41 +++++++++++++---- 3 files changed, 107 insertions(+), 21 deletions(-) diff --git a/docs/docs/integrations/vectorstores/meilisearch.ipynb b/docs/docs/integrations/vectorstores/meilisearch.ipynb index 11777cceda837..58ed11a56818c 100644 --- a/docs/docs/integrations/vectorstores/meilisearch.ipynb +++ b/docs/docs/integrations/vectorstores/meilisearch.ipynb @@ -130,7 +130,14 @@ "from langchain_openai import OpenAIEmbeddings\n", "from langchain_text_splitters import CharacterTextSplitter\n", "\n", - "embeddings = OpenAIEmbeddings()" + "embeddings = OpenAIEmbeddings()\n", + "embedders = {\n", + " \"default\": {\n", + " \"source\": \"userProvided\",\n", + " \"dimensions\": 1536,\n", + " }\n", + "}\n", + "embedder_name = \"default\"" ] }, { @@ -152,7 +159,9 @@ "outputs": [], "source": [ "# Use Meilisearch vector store to store texts & associated embeddings as vector\n", - "vector_store = Meilisearch.from_texts(texts=texts, embedding=embeddings)" + "vector_store = Meilisearch.from_texts(\n", + " texts=texts, embedding=embeddings, embedders=embedders, embedder_name=embedder_name\n", + ")" ] }, { @@ -188,11 +197,16 @@ "docs = text_splitter.split_documents(documents)\n", "\n", "# Import documents & embeddings in the vector store\n", - "vector_store = Meilisearch.from_documents(documents=documents, embedding=embeddings)\n", + "vector_store = Meilisearch.from_documents(\n", + " documents=documents,\n", + " embedding=embeddings,\n", + " embedders=embedders,\n", + " embedder_name=embedder_name,\n", + ")\n", "\n", "# Search in our vector store\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n", - "docs = vector_store.similarity_search(query)\n", + "docs = vector_store.similarity_search(query, embedder_name=embedder_name)\n", "print(docs[0].page_content)" ] }, @@ -221,7 +235,11 @@ "\n", "client = meilisearch.Client(url=\"http://127.0.0.1:7700\", api_key=\"***\")\n", "vector_store = Meilisearch(\n", - " embedding=embeddings, client=client, index_name=\"langchain_demo\", text_key=\"text\"\n", + " embedding=embeddings,\n", + " embedders=embedders,\n", + " client=client,\n", + " index_name=\"langchain_demo\",\n", + " text_key=\"text\",\n", ")\n", "vector_store.add_documents(documents)" ] @@ -232,7 +250,7 @@ "source": [ "## Similarity Search with score\n", "\n", - "This specific method allows you to return the documents and the distance score of the query to them." + "This specific method allows you to return the documents and the distance score of the query to them. `embedder_name` is the name of the embedder that should be used for semantic search, defaults to \"default\"." ] }, { @@ -241,7 +259,9 @@ "metadata": {}, "outputs": [], "source": [ - "docs_and_scores = vector_store.similarity_search_with_score(query)\n", + "docs_and_scores = vector_store.similarity_search_with_score(\n", + " query, embedder_name=embedder_name\n", + ")\n", "docs_and_scores[0]" ] }, @@ -249,7 +269,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Similarity Search by vector" + "## Similarity Search by vector\n", + "`embedder_name` is the name of the embedder that should be used for semantic search, defaults to \"default\"." ] }, { @@ -259,7 +280,9 @@ "outputs": [], "source": [ "embedding_vector = embeddings.embed_query(query)\n", - "docs_and_scores = vector_store.similarity_search_by_vector(embedding_vector)\n", + "docs_and_scores = vector_store.similarity_search_by_vector(\n", + " embedding_vector, embedder_name=embedder_name\n", + ")\n", "docs_and_scores[0]" ] }, diff --git a/libs/community/langchain_community/vectorstores/meilisearch.py b/libs/community/langchain_community/vectorstores/meilisearch.py index b34a990cce2df..522f107405dbe 100644 --- a/libs/community/langchain_community/vectorstores/meilisearch.py +++ b/libs/community/langchain_community/vectorstores/meilisearch.py @@ -65,8 +65,15 @@ class Meilisearch(VectorStore): # api_key is optional; provide it if your meilisearch instance requires it client = meilisearch.Client(url='http://127.0.0.1:7700', api_key='***') embeddings = OpenAIEmbeddings() + embedders = { + "theEmbedderName": { + "source": "userProvided", + "dimensions": "1536" + } + } vectorstore = Meilisearch( embedding=embeddings, + embedders=embedders, client=client, index_name='langchain_demo', text_key='text') @@ -81,6 +88,8 @@ def __init__( index_name: str = "langchain-demo", text_key: str = "text", metadata_key: str = "metadata", + *, + embedders: Optional[Dict[str, Any]] = None, ): """Initialize with Meilisearch client.""" client = _create_client(client=client, url=url, api_key=api_key) @@ -90,18 +99,24 @@ def __init__( self._embedding = embedding self._text_key = text_key self._metadata_key = metadata_key + self._embedders = embedders + self._embedders_settings = self._client.index( + str(self._index_name) + ).update_embedders(embedders) def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + embedder_name: Optional[str] = "default", **kwargs: Any, ) -> List[str]: """Run more texts through the embedding and add them to the vector store. Args: texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. + embedder_name: Name of the embedder. Defaults to "default". metadatas (Optional[List[dict]]): Optional list of metadata. Defaults to None. ids Optional[List[str]]: Optional list of IDs. @@ -128,7 +143,7 @@ def add_texts( docs.append( { "id": id, - "_vectors": embedding, + "_vectors": {f"{embedder_name}": embedding}, f"{self._metadata_key}": metadata, } ) @@ -142,12 +157,14 @@ def similarity_search( query: str, k: int = 4, filter: Optional[Dict[str, str]] = None, + embedder_name: Optional[str] = "default", **kwargs: Any, ) -> List[Document]: """Return meilisearch documents most similar to the query. Args: query (str): Query text for which to find similar documents. + embedder_name: Name of the embedder to be used. Defaults to "default". k (int): Number of documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. @@ -158,6 +175,7 @@ def similarity_search( """ docs_and_scores = self.similarity_search_with_score( query=query, + embedder_name=embedder_name, k=k, filter=filter, kwargs=kwargs, @@ -169,12 +187,14 @@ def similarity_search_with_score( query: str, k: int = 4, filter: Optional[Dict[str, str]] = None, + embedder_name: Optional[str] = "default", **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return meilisearch documents most similar to the query, along with scores. Args: query (str): Query text for which to find similar documents. + embedder_name: Name of the embedder to be used. Defaults to "default". k (int): Number of documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. @@ -187,6 +207,7 @@ def similarity_search_with_score( docs = self.similarity_search_by_vector_with_scores( embedding=_query, + embedder_name=embedder_name, k=k, filter=filter, kwargs=kwargs, @@ -196,6 +217,7 @@ def similarity_search_with_score( def similarity_search_by_vector_with_scores( self, embedding: List[float], + embedder_name: Optional[str] = "default", k: int = 4, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, @@ -204,6 +226,7 @@ def similarity_search_by_vector_with_scores( Args: embedding (List[float]): Embedding to look up similar documents. + embedder_name: Name of the embedder to be used. Defaults to "default". k (int): Number of documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. @@ -214,7 +237,13 @@ def similarity_search_by_vector_with_scores( """ docs = [] results = self._client.index(str(self._index_name)).search( - "", {"vector": embedding, "limit": k, "filter": filter} + "", + { + "vector": embedding, + "hybrid": {"semanticRatio": 1.0, "embedder": embedder_name}, + "limit": k, + "filter": filter, + }, ) for result in results["hits"]: @@ -233,12 +262,14 @@ def similarity_search_by_vector( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + embedder_name: Optional[str] = "default", **kwargs: Any, ) -> List[Document]: """Return meilisearch documents most similar to embedding vector. Args: embedding (List[float]): Embedding to look up similar documents. + embedder_name: Name of the embedder to be used. Defaults to "default". k (int): Number of documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. @@ -249,6 +280,7 @@ def similarity_search_by_vector( """ docs = self.similarity_search_by_vector_with_scores( embedding=embedding, + embedder_name=embedder_name, k=k, filter=filter, kwargs=kwargs, @@ -268,6 +300,8 @@ def from_texts( ids: Optional[List[str]] = None, text_key: Optional[str] = "text", metadata_key: Optional[str] = "metadata", + embedders: Dict[str, Any] = {}, + embedder_name: Optional[str] = "default", **kwargs: Any, ) -> Meilisearch: """Construct Meilisearch wrapper from raw documents. @@ -288,21 +322,25 @@ def from_texts( # The environment should be the one specified next to the API key # in your Meilisearch console client = meilisearch.Client(url='http://127.0.0.1:7700', api_key='***') - embeddings = OpenAIEmbeddings() + embedding = OpenAIEmbeddings() + embedders: Embedders index setting. + embedder_name: Name of the embedder. Defaults to "default". docsearch = Meilisearch.from_texts( client=client, - embeddings=embeddings, + embedding=embedding, ) """ client = _create_client(client=client, url=url, api_key=api_key) vectorstore = cls( embedding=embedding, + embedders=embedders, client=client, index_name=index_name, ) vectorstore.add_texts( texts=texts, + embedder_name=embedder_name, metadatas=metadatas, ids=ids, text_key=text_key, diff --git a/libs/community/tests/integration_tests/vectorstores/test_meilisearch.py b/libs/community/tests/integration_tests/vectorstores/test_meilisearch.py index 1dd795f74ca25..3b6695dcb40bb 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_meilisearch.py +++ b/libs/community/tests/integration_tests/vectorstores/test_meilisearch.py @@ -1,5 +1,6 @@ """Test Meilisearch functionality.""" -from typing import TYPE_CHECKING, Generator + +from typing import TYPE_CHECKING, Any, Dict, Generator import pytest import requests @@ -33,6 +34,16 @@ def enable_vector_search(self) -> Generator[str, None, None]: timeout=10, ) + @pytest.fixture + def new_embedders(self) -> Dict[str, Dict[str, Any]]: + return { + "default": { + "source": "userProvided", + # Dimension defined in FakeEmbeddings as [float(1.0)] * 9 + [float(0.0)] + "dimensions": 10, + } + } + @pytest.fixture(autouse=True) def setup(self) -> None: self.delete_all_indexes() @@ -63,12 +74,14 @@ def _wait_last_task(self) -> None: # Wait for the last task to be completed client.wait_for_task(tasks.results[0].uid) - def test_meilisearch(self) -> None: + def test_meilisearch(self, new_embedders: Dict[str, Any]) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] vectorstore = Meilisearch.from_texts( texts=texts, embedding=FakeEmbeddings(), + embedders=new_embedders, + embedder_name=list(new_embedders)[0], url=TEST_MEILI_HTTP_ADDR, api_key=TEST_MEILI_MASTER_KEY, index_name=INDEX_NAME, @@ -77,12 +90,14 @@ def test_meilisearch(self) -> None: output = vectorstore.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] - def test_meilisearch_with_client(self) -> None: + def test_meilisearch_with_client(self, new_embedders: Dict[str, Any]) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] vectorstore = Meilisearch.from_texts( texts=texts, embedding=FakeEmbeddings(), + embedders=new_embedders, + embedder_name=list(new_embedders)[0], client=self.client(), index_name=INDEX_NAME, ) @@ -90,13 +105,15 @@ def test_meilisearch_with_client(self) -> None: output = vectorstore.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] - def test_meilisearch_with_metadatas(self) -> None: + def test_meilisearch_with_metadatas(self, new_embedders: Dict[str, Any]) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = Meilisearch.from_texts( texts=texts, embedding=FakeEmbeddings(), + embedders=new_embedders, + embedder_name=list(new_embedders)[0], url=TEST_MEILI_HTTP_ADDR, api_key=TEST_MEILI_MASTER_KEY, index_name=INDEX_NAME, @@ -109,13 +126,17 @@ def test_meilisearch_with_metadatas(self) -> None: assert output[0].metadata["page"] == 0 assert output == [Document(page_content="foo", metadata={"page": 0})] - def test_meilisearch_with_metadatas_with_scores(self) -> None: + def test_meilisearch_with_metadatas_with_scores( + self, new_embedders: Dict[str, Any] + ) -> None: """Test end to end construction and scored search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] docsearch = Meilisearch.from_texts( texts=texts, embedding=FakeEmbeddings(), + embedders=new_embedders, + embedder_name=list(new_embedders)[0], url=TEST_MEILI_HTTP_ADDR, api_key=TEST_MEILI_MASTER_KEY, index_name=INDEX_NAME, @@ -123,9 +144,11 @@ def test_meilisearch_with_metadatas_with_scores(self) -> None: ) self._wait_last_task() output = docsearch.similarity_search_with_score("foo", k=1) - assert output == [(Document(page_content="foo", metadata={"page": "0"}), 9.0)] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] - def test_meilisearch_with_metadatas_with_scores_using_vector(self) -> None: + def test_meilisearch_with_metadatas_with_scores_using_vector( + self, new_embedders: Dict[str, Any] + ) -> None: """Test end to end construction and scored search, using embedding vector.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": str(i)} for i in range(len(texts))] @@ -134,6 +157,8 @@ def test_meilisearch_with_metadatas_with_scores_using_vector(self) -> None: docsearch = Meilisearch.from_texts( texts=texts, embedding=FakeEmbeddings(), + embedders=new_embedders, + embedder_name=list(new_embedders)[0], url=TEST_MEILI_HTTP_ADDR, api_key=TEST_MEILI_MASTER_KEY, index_name=INDEX_NAME, @@ -144,4 +169,4 @@ def test_meilisearch_with_metadatas_with_scores_using_vector(self) -> None: output = docsearch.similarity_search_by_vector_with_scores( embedding=embedded_query, k=1 ) - assert output == [(Document(page_content="foo", metadata={"page": "0"}), 9.0)] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] From 9b70131aed02221866b2590e787a0e0703e1aad7 Mon Sep 17 00:00:00 2001 From: Hyeongchan Kim Date: Thu, 28 Mar 2024 07:31:54 +0900 Subject: [PATCH 5/9] community[patch]: refactor the type hint of `file_path` in `UnstructuredAPIFileLoader` class (#18839) * **Description**: add `None` type for `file_path` along with `str` and `List[str]` types. * `file_path`/`filename` arguments in `get_elements_from_api()` and `partition()` can be `None`, however, there's no `None` type hint for `file_path` in `UnstructuredAPIFileLoader` and `UnstructuredFileLoader` currently. * calling the function with `file_path=None` is no problem, but my IDE annoys me lol. * **Issue**: N/A * **Dependencies**: N/A Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_community/document_loaders/unstructured.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/unstructured.py b/libs/community/langchain_community/document_loaders/unstructured.py index 22df465589d36..bc056ca702583 100644 --- a/libs/community/langchain_community/document_loaders/unstructured.py +++ b/libs/community/langchain_community/document_loaders/unstructured.py @@ -156,7 +156,7 @@ class UnstructuredFileLoader(UnstructuredBaseLoader): def __init__( self, - file_path: Union[str, List[str], Path, List[Path]], + file_path: Union[str, List[str], Path, List[Path], None], mode: str = "single", **unstructured_kwargs: Any, ): @@ -255,7 +255,7 @@ class UnstructuredAPIFileLoader(UnstructuredFileLoader): def __init__( self, - file_path: Union[str, List[str]] = "", + file_path: Union[str, List[str], None] = "", mode: str = "single", url: str = "https://api.unstructured.io/general/v0/general", api_key: str = "", From 7e29b6061f7125d6cf43029697247c8ba03efb6f Mon Sep 17 00:00:00 2001 From: "yongheng.liu" <56812134+liuyonghengheng@users.noreply.github.com> Date: Thu, 28 Mar 2024 07:02:40 +0800 Subject: [PATCH 6/9] community[minor]: integrate China Mobile Ecloud vector search (#15298) - **Description:** integrate China Mobile Ecloud vector search, - **Dependencies:** elasticsearch==7.10.1 Co-authored-by: liuyongheng Co-authored-by: Bagatur --- .../vectorstores/ecloud_vector_search.ipynb | 317 ++++++++++ .../vectorstores/__init__.py | 1 + .../vectorstores/ecloud_vector_search.py | 580 ++++++++++++++++++ .../vectorstores/test_ecloud_vector_search.py | 330 ++++++++++ .../vectorstores/test_indexing_docs.py | 1 + .../vectorstores/test_public_api.py | 1 + 6 files changed, 1230 insertions(+) create mode 100644 docs/docs/integrations/vectorstores/ecloud_vector_search.ipynb create mode 100644 libs/community/langchain_community/vectorstores/ecloud_vector_search.py create mode 100644 libs/community/tests/integration_tests/vectorstores/test_ecloud_vector_search.py diff --git a/docs/docs/integrations/vectorstores/ecloud_vector_search.ipynb b/docs/docs/integrations/vectorstores/ecloud_vector_search.ipynb new file mode 100644 index 0000000000000..d11d5ca411e22 --- /dev/null +++ b/docs/docs/integrations/vectorstores/ecloud_vector_search.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# China Mobile ECloud ElasticSearch VectorSearch\n", + "\n", + ">[China Mobile ECloud VectorSearch](https://ecloud.10086.cn/portal/product/elasticsearch) is a fully managed, enterprise-level distributed search and analysis service. China Mobile ECloud VectorSearch provides low-cost, high-performance, and reliable retrieval and analysis platform level product services for structured/unstructured data. As a vector database , it supports multiple index types and similarity distance methods. \n", + "\n", + "This notebook shows how to use functionality related to the `ECloud ElasticSearch VectorStore`.\n", + "To run, you should have an [China Mobile ECloud VectorSearch](https://ecloud.10086.cn/portal/product/elasticsearch) instance up and running:\n", + "\n", + "Read the [help document](https://ecloud.10086.cn/op-help-center/doc/category/1094) to quickly familiarize and configure China Mobile ECloud ElasticSearch instance." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the instance is up and running, follow these steps to split documents, get embeddings, connect to the baidu cloud elasticsearch instance, index documents, and perform vector retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install elasticsearch == 7.10.1" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Secondly, split documents and get embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import EcloudESVectorStore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "ES_URL = \"http://localhost:9200\"\n", + "USER = \"your user name\"\n", + "PASSWORD = \"your password\"\n", + "indexname = \"your index name\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "then, index documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docsearch = EcloudESVectorStore.from_documents(\n", + " docs,\n", + " embeddings,\n", + " es_url=ES_URL,\n", + " user=USER,\n", + " password=PASSWORD,\n", + " index_name=indexname,\n", + " refresh_indices=True,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, Query and retrive data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = docsearch.similarity_search(query, k=10)\n", + "print(docs[0].page_content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A commonly used case" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_dense_float_vectore_lsh_cosine() -> None:\n", + " \"\"\"\n", + " Test indexing with vectore type knn_dense_float_vector and model-similarity of lsh-cosine\n", + " this mapping is compatible with model of exact and similarity of l2/cosine\n", + " this mapping is compatible with model of lsh and similarity of cosine\n", + " \"\"\"\n", + " docsearch = EcloudESVectorStore.from_documents(\n", + " docs,\n", + " embeddings,\n", + " es_url=ES_URL,\n", + " user=USER,\n", + " password=PASSWORD,\n", + " index_name=indexname,\n", + " refresh_indices=True,\n", + " text_field=\"my_text\",\n", + " vector_field=\"my_vec\",\n", + " vector_type=\"knn_dense_float_vector\",\n", + " vector_params={\"model\": \"lsh\", \"similarity\": \"cosine\", \"L\": 99, \"k\": 1},\n", + " )\n", + "\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)\n", + "\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"similarity\": \"l2\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)\n", + "\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"similarity\": \"cosine\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)\n", + "\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " search_params={\n", + " \"model\": \"lsh\",\n", + " \"similarity\": \"cosine\",\n", + " \"candidates\": 10,\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With filter case" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_dense_float_vectore_exact_with_filter() -> None:\n", + " \"\"\"\n", + " Test indexing with vectore type knn_dense_float_vector and default model/similarity\n", + " this mapping is compatible with model of exact and similarity of l2/cosine\n", + " \"\"\"\n", + " docsearch = EcloudESVectorStore.from_documents(\n", + " docs,\n", + " embeddings,\n", + " es_url=ES_URL,\n", + " user=USER,\n", + " password=PASSWORD,\n", + " index_name=indexname,\n", + " refresh_indices=True,\n", + " text_field=\"my_text\",\n", + " vector_field=\"my_vec\",\n", + " vector_type=\"knn_dense_float_vector\",\n", + " )\n", + " # filter={\"match_all\": {}} ,default\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " filter={\"match_all\": {}},\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)\n", + "\n", + " # filter={\"term\": {\"my_text\": \"Jackson\"}}\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " filter={\"term\": {\"my_text\": \"Jackson\"}},\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)\n", + "\n", + " # filter={\"term\": {\"my_text\": \"president\"}}\n", + " docs = docsearch.similarity_search(\n", + " query,\n", + " k=10,\n", + " filter={\"term\": {\"my_text\": \"president\"}},\n", + " search_params={\n", + " \"model\": \"exact\",\n", + " \"similarity\": \"l2\",\n", + " \"vector_field\": \"my_vec\",\n", + " \"text_field\": \"my_text\",\n", + " },\n", + " )\n", + " print(docs[0].page_content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/community/langchain_community/vectorstores/__init__.py b/libs/community/langchain_community/vectorstores/__init__.py index 4e1bfa9eeab92..f2a272b1a9647 100644 --- a/libs/community/langchain_community/vectorstores/__init__.py +++ b/libs/community/langchain_community/vectorstores/__init__.py @@ -52,6 +52,7 @@ "DocArrayInMemorySearch": "langchain_community.vectorstores.docarray", "DocumentDBVectorSearch": "langchain_community.vectorstores.documentdb", "DuckDB": "langchain_community.vectorstores.duckdb", + "EcloudESVectorStore": "langchain_community.vectorstores.ecloud_vector_search", "ElasticKnnSearch": "langchain_community.vectorstores.elastic_vector_search", "ElasticVectorSearch": "langchain_community.vectorstores.elastic_vector_search", "ElasticsearchStore": "langchain_community.vectorstores.elasticsearch", diff --git a/libs/community/langchain_community/vectorstores/ecloud_vector_search.py b/libs/community/langchain_community/vectorstores/ecloud_vector_search.py new file mode 100644 index 0000000000000..58401336f8e0c --- /dev/null +++ b/libs/community/langchain_community/vectorstores/ecloud_vector_search.py @@ -0,0 +1,580 @@ +import logging +import uuid +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + +logger = logging.getLogger(__name__) + + +class EcloudESVectorStore(VectorStore): + """`ecloud Elasticsearch` vector store. + + Example: + .. code-block:: python + + from langchain.vectorstores import EcloudESVectorStore + from langchain.embeddings.openai import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + vectorstore = EcloudESVectorStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_url="http://localhost:9200" + ) + + Args: + index_name: Name of the Elasticsearch index to create. + es_url: URL of the ecloud Elasticsearch instance to connect to. + user: Username to use when connecting to Elasticsearch. + password: Password to use when connecting to Elasticsearch. + + """ + + def __init__( + self, + index_name: str, + es_url: str, + user: Optional[str] = None, + password: Optional[str] = None, + embedding: Optional[Embeddings] = None, + **kwargs: Optional[dict], + ) -> None: + self.embedding = embedding + self.index_name = index_name + self.text_field = kwargs.get("text_field", "text") + self.vector_field = kwargs.get("vector_field", "vector") + self.vector_type = kwargs.get("vector_type", "knn_dense_float_vector") + self.vector_params = kwargs.get("vector_params") or {} + self.model = self.vector_params.get("model", "") + self.index_settings = kwargs.get("index_settings") or {} + + key_list = [ + "text_field", + "vector_field", + "vector_type", + "vector_params", + "index_settings", + ] + [kwargs.pop(key, None) for key in key_list] + if es_url is not None: + self.client = EcloudESVectorStore.es_client( + es_url=es_url, username=user, password=password, **kwargs + ) + else: + raise ValueError("""Please specified a es connection url.""") + + @property + def embeddings(self) -> Optional[Embeddings]: + return self.embedding + + @staticmethod + def es_client( + *, + es_url: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + **kwargs: Optional[dict], + ) -> "Elasticsearch": + try: + import elasticsearch + except ImportError: + raise ImportError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticsearch`." + ) + + connection_params: Dict[str, Any] = {"hosts": [es_url]} + + if username and password: + connection_params["http_auth"] = (username, password) + connection_params.update(kwargs) + + es_client = elasticsearch.Elasticsearch(**connection_params) + try: + es_client.info() + except Exception as e: + logger.error(f"Error connecting to Elasticsearch: {e}") + raise e + return es_client + + def _create_index_if_not_exists(self, dims_length: Optional[int] = None) -> None: + """Create the index if it doesn't already exist. + + Args: + dims_length: Length of the embedding vectors. + """ + + if self.client.indices.exists(index=self.index_name): + logger.info(f"Index {self.index_name} already exists. Skipping creation.") + + else: + if dims_length is None: + raise ValueError( + "Cannot create index without specifying dims_length " + + "when the index doesn't already exist. " + ) + + indexMapping = self._index_mapping(dims_length=dims_length) + + logger.debug( + f"Creating index {self.index_name} with mappings {indexMapping}" + ) + + self.client.indices.create( + index=self.index_name, + body={ + "settings": {"index.knn": True, **self.index_settings}, + "mappings": {"properties": indexMapping}, + }, + ) + + def _index_mapping(self, dims_length: Union[int, None]) -> Dict: + """ + Executes when the index is created. + + Args: + dims_length: Numeric length of the embedding vectors, + or None if not using vector-based query. + index_params: The extra pamameters for creating index. + + Returns: + Dict: The Elasticsearch settings and mappings for the strategy. + """ + model = self.vector_params.get("model", "") + if "lsh" == model: + mapping: Dict[Any, Any] = { + self.vector_field: { + "type": self.vector_type, + "knn": { + "dims": dims_length, + "model": "lsh", + "similarity": self.vector_params.get("similarity", "cosine"), + "L": self.vector_params.get("L", 99), + "k": self.vector_params.get("k", 1), + }, + } + } + if mapping[self.vector_field]["knn"]["similarity"] == "l2": + mapping[self.vector_field]["knn"]["w"] = self.vector_params.get("w", 3) + return mapping + elif "permutation_lsh" == model: + return { + self.vector_field: { + "type": self.vector_type, + "knn": { + "dims": dims_length, + "model": "permutation_lsh", + "k": self.vector_params.get("k", 10), + "similarity": self.vector_params.get("similarity", "cosine"), + "repeating": self.vector_params.get("repeating", True), + }, + } + } + else: + return { + self.vector_field: { + "type": self.vector_type, + "knn": {"dims": dims_length}, + } + } + + def delete( + self, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> Optional[bool]: + """Delete documents from the index. + + Args: + ids: List of ids of documents to delete + """ + try: + from elasticsearch.helpers import BulkIndexError, bulk + except ImportError: + raise ImportError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticsearch`." + ) + + body = [] + + if ids is None: + raise ValueError("ids must be provided.") + + for _id in ids: + body.append({"_op_type": "delete", "_index": self.index_name, "_id": _id}) + + if len(body) > 0: + try: + bulk( + self.client, + body, + refresh=kwargs.get("refresh_indices", True), + ignore_status=404, + ) + logger.debug(f"Deleted {len(body)} texts from index") + return True + except BulkIndexError as e: + logger.error(f"Error deleting texts: {e}") + raise e + else: + logger.info("No documents to delete") + return False + + def _query_body( + self, + query_vector: Union[List[float], None], + filter: Optional[dict] = None, + search_params: Dict = {}, + ) -> Dict: + query_vector_body = { + "field": search_params.get("vector_field", self.vector_field) + } + + if self.vector_type == "knn_dense_float_vector": + query_vector_body["vec"] = {"values": query_vector} + specific_params = self.get_dense_specific_model_similarity_params( + search_params + ) + query_vector_body.update(specific_params) + else: + query_vector_body["vec"] = { + "true_indices": query_vector, + "total_indices": len(query_vector) if query_vector is not None else 0, + } + specific_params = self.get_sparse_specific_model_similarity_params( + search_params + ) + query_vector_body.update(specific_params) + + query_vector_body = {"knn_nearest_neighbors": query_vector_body} + if filter is not None and len(filter) != 0: + query_vector_body = { + "function_score": {"query": filter, "functions": [query_vector_body]} + } + + return { + "size": search_params.get("size", 4), + "query": query_vector_body, + } + + @staticmethod + def get_dense_specific_model_similarity_params( + search_params: Dict[str, Any], + ) -> Dict: + model = search_params.get("model", "exact") + similarity = search_params.get("similarity", "cosine") + specific_params = {"model": model, "similarity": similarity} + if not model == "exact": + if model not in ("lsh", "permutation_lsh"): + raise ValueError( + f"vector type knn_dense_float_vector doesn't support model {model}" + ) + if similarity not in ("cosine", "l2"): + raise ValueError(f"model exact doesn't support similarity {similarity}") + specific_params["candidates"] = search_params.get( + "candidates", search_params.get("size", 4) + ) + if model == "lsh" and similarity == "l2": + specific_params["probes"] = search_params.get("probes", 0) + else: + if similarity not in ("cosine", "l2"): + raise ValueError(f"model exact don't support similarity {similarity}") + + return specific_params + + @staticmethod + def get_sparse_specific_model_similarity_params( + search_params: Dict[str, Any], + ) -> Dict: + model = search_params.get("model", "exact") + similarity = search_params.get("similarity", "hamming") + specific_params = {"model": model, "similarity": similarity} + if not model == "exact": + if model not in ("lsh",): + raise ValueError( + f"vector type knn_dense_float_vector doesn't support model {model}" + ) + if similarity not in ("hamming", "jaccard"): + raise ValueError(f"model exact doesn't support similarity {similarity}") + specific_params["candidates"] = search_params.get( + "candidates", search_params.get("size", 4) + ) + else: + if similarity not in ("hamming", "jaccard"): + raise ValueError(f"model exact don't support similarity {similarity}") + + return specific_params + + def _search( + self, + query: Optional[str] = None, + query_vector: Union[List[float], None] = None, + filter: Optional[dict] = None, + custom_query: Optional[Callable[[Dict, Union[str, None]], Dict]] = None, + search_params: Dict = {}, + ) -> List[Tuple[Document, float]]: + """Return searched documents result from ecloud ES + + Args: + query: Text to look up documents similar to. + query_vector: Embedding to look up documents similar to. + filter: Array of ecloud ElasticSearch filter clauses to apply to the query. + custom_query: Function to modify the query body before it is sent to ES. + + Returns: + List of Documents most similar to the query and score for each + """ + + if self.embedding and query is not None: + query_vector = self.embedding.embed_query(query) + + query_body = self._query_body( + query_vector=query_vector, filter=filter, search_params=search_params + ) + + if custom_query is not None: + query_body = custom_query(query_body, query) + logger.debug(f"Calling custom_query, Query body now: {query_body}") + + logger.debug(f"Query body: {query_body}") + + # Perform the kNN search on the ES index and return the results. + response = self.client.search(index=self.index_name, body=query_body) + logger.debug(f"response={response}") + + hits = [hit for hit in response["hits"]["hits"]] + docs_and_scores = [ + ( + Document( + page_content=hit["_source"][ + search_params.get("text_field", self.text_field) + ], + metadata=hit["_source"]["metadata"], + ), + hit["_score"], + ) + for hit in hits + ] + + return docs_and_scores + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return documents most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query, + in descending order of similarity. + """ + + results = self.similarity_search_with_score( + query=query, k=k, filter=filter, **kwargs + ) + return [doc for doc, _ in results] + + def similarity_search_with_score( + self, query: str, k: int, filter: Optional[dict] = None, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return documents most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + size: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query and score for each + """ + search_params: Dict[str, Any] = kwargs.get("search_params") or {} + + if len(search_params) == 0: + kwargs = {"search_params": {"size": k}} + elif search_params.get("size") is None: + search_params["size"] = k + kwargs["search_params"] = search_params + + return self._search(query=query, filter=filter, **kwargs) + + @classmethod + def from_documents( + cls, + documents: List[Document], + embedding: Optional[Embeddings] = None, + **kwargs: Any, + ) -> "EcloudESVectorStore": + """Construct EcloudESVectorStore wrapper from documents. + + Args: + documents: List of documents to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + Do not provide if using a strategy + that doesn't require inference. + kwargs: create index key words arguments + """ + + vectorStore = EcloudESVectorStore._es_vector_store( + embedding=embedding, **kwargs + ) + # Encode the provided texts and add them to the newly created index. + vectorStore.add_documents(documents) + + return vectorStore + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Optional[Embeddings] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> "EcloudESVectorStore": + """Construct EcloudESVectorStore wrapper from raw documents. + + Args: + texts: List of texts to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + metadatas: Optional list of metadatas associated with the texts. + index_name: Name of the Elasticsearch index to create. + kwargs: create index key words arguments + """ + + vectorStore = cls._es_vector_store(embedding=embedding, **kwargs) + + # Encode the provided texts and add them to the newly created index. + vectorStore.add_texts(texts, metadatas=metadatas, **kwargs) + + return vectorStore + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[Any, Any]]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + Returns: + List of ids from adding the texts into the vectorstore. + """ + try: + from elasticsearch.helpers import BulkIndexError, bulk + except ImportError: + raise ImportError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticsearch`." + ) + + embeddings = [] + create_index_if_not_exists = kwargs.get("create_index_if_not_exists", True) + ids = kwargs.get("ids", [str(uuid.uuid4()) for _ in texts]) + refresh_indices = kwargs.get("refresh_indices", False) + requests = [] + + if self.embedding is not None: + embeddings = self.embedding.embed_documents(list(texts)) + dims_length = len(embeddings[0]) + + if create_index_if_not_exists: + self._create_index_if_not_exists(dims_length=dims_length) + + for i, (text, vector) in enumerate(zip(texts, embeddings)): + metadata = metadatas[i] if metadatas else {} + doc = { + "_op_type": "index", + "_index": self.index_name, + self.text_field: text, + "metadata": metadata, + "_id": ids[i], + } + if self.vector_type == "knn_dense_float_vector": + doc[self.vector_field] = vector + elif self.vector_type == "knn_sparse_bool_vector": + doc[self.vector_field] = { + "true_indices": vector, + "total_indices": len(vector), + } + requests.append(doc) + else: + if create_index_if_not_exists: + self._create_index_if_not_exists() + + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + + requests.append( + { + "_op_type": "index", + "_index": self.index_name, + self.text_field: text, + "metadata": metadata, + "_id": ids[i], + } + ) + + if len(requests) > 0: + try: + success, failed = bulk( + self.client, requests, stats_only=True, refresh=refresh_indices + ) + logger.debug( + f"Added {success} and failed to add {failed} texts to index" + ) + + logger.debug(f"added texts {ids} to index") + if refresh_indices: + self.client.indices.refresh(index=self.index_name) + return ids + except BulkIndexError as e: + logger.error(f"Error adding texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + else: + logger.debug("No texts to add to index") + return [] + + @staticmethod + def _es_vector_store( + embedding: Optional[Embeddings] = None, **kwargs: Any + ) -> "EcloudESVectorStore": + index_name = kwargs.get("index_name") + + if index_name is None: + raise ValueError("Please provide an index_name.") + + es_url = kwargs.get("es_url") + if es_url is None: + raise ValueError("Please provided a valid es connection url") + + return EcloudESVectorStore(embedding=embedding, **kwargs) diff --git a/libs/community/tests/integration_tests/vectorstores/test_ecloud_vector_search.py b/libs/community/tests/integration_tests/vectorstores/test_ecloud_vector_search.py new file mode 100644 index 0000000000000..9d764787105e1 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/test_ecloud_vector_search.py @@ -0,0 +1,330 @@ +"""Test EcloudESVectorStore functionality.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from langchain_core.documents import Document + +from langchain_community.vectorstores.ecloud_vector_search import EcloudESVectorStore +from tests.integration_tests.vectorstores.fake_embeddings import ( + FakeEmbeddings, + fake_texts, +) + +if TYPE_CHECKING: + from elasticsearch.client import Elasticsearch + +user = "elastic" +password = "*****" +ES_URL = "http://localhost:9200" + + +def _ecloud_vector_db_from_texts( + metadatas: Optional[List[dict]] = None, index_name: str = "testknn" +) -> EcloudESVectorStore: + return EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + metadatas=metadatas, + es_url=ES_URL, + user=user, + password=password, + index_name=index_name, + refresh_indices=True, + ) + + +def delete_index(es: Elasticsearch, index: str) -> None: + """Delete the specific index""" + try: + es.indices.delete(index) + except Exception: + pass + + +def test_ecloud_vector_db() -> None: + """Test end to end construction and search.""" + index_name = "testknn1" + docsearch = _ecloud_vector_db_from_texts(index_name=index_name) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + delete_index(docsearch.client, index_name) + + +def test_ecloud_vector_index_settings() -> None: + index_name = "testknn2" + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + es_url=ES_URL, + user=user, + password=password, + index_name=index_name, + refresh_indices=True, + vector_field="my_vector", + text_field="custom_text", + time_out=120, + ) + res = docsearch.client.indices.get_settings(index=index_name) + assert res[index_name]["settings"]["index"]["number_of_shards"] == "1" + assert res[index_name]["settings"]["index"]["number_of_replicas"] == "1" + + delete_index(docsearch.client, index_name) + + index_name = "testknn3" + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + es_url=ES_URL, + user=user, + password=password, + index_name=index_name, + refresh_indices=True, + vector_field="my_vector", + text_field="custom_text", + index_settings={"index": {"number_of_shards": "3", "number_of_replicas": "0"}}, + ) + res = docsearch.client.indices.get_settings(index=index_name) + assert res[index_name]["settings"]["index"]["number_of_shards"] == "3" + assert res[index_name]["settings"]["index"]["number_of_replicas"] == "0" + delete_index(docsearch.client, index_name) + + +def test_similarity_search_with_score() -> None: + """Test similarity search with score using Approximate Search.""" + metadatas = [{"page": i} for i in range(len(fake_texts))] + index_name = "testknn4" + docsearch = _ecloud_vector_db_from_texts(metadatas=metadatas, index_name=index_name) + output = docsearch.similarity_search_with_score("foo", k=2) + assert output == [ + (Document(page_content="foo", metadata={"page": 0}), 2.0), + (Document(page_content="bar", metadata={"page": 1}), 1.9486833), + ] + delete_index(docsearch.client, index_name) + + +def test_ecloud_with_custom_field_name() -> None: + """Test indexing and search using custom vector field and text field name.""" + index_name = "testknn5" + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + es_url=ES_URL, + user=user, + password=password, + index_name=index_name, + refresh_indices=True, + vector_field="my_vector", + text_field="custom_text", + ) + output = docsearch.similarity_search( + "foo", k=1, vector_field="my_vector", text_field="custom_text" + ) + assert output == [Document(page_content="foo")] + + text_input = ["test", "add", "text", "method"] + EcloudESVectorStore.add_texts( + docsearch, text_input, vector_field="my_vector", text_field="custom_text" + ) + output = docsearch.similarity_search( + "add", k=1, vector_field="my_vector", text_field="custom_text" + ) + assert output == [Document(page_content="foo")] + delete_index(docsearch.client, index_name) + + +def test_ecloud_with_metadatas() -> None: + """Test end to end indexing and search with metadata.""" + index_name = "testknn6" + metadatas = [{"page": i} for i in range(len(fake_texts))] + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + index_name=index_name, + refresh_indices=True, + metadatas=metadatas, + es_url=ES_URL, + user=user, + password=password, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] + delete_index(docsearch.client, index_name) + + +def test_add_text() -> None: + """Test adding additional text elements to existing index.""" + index_name = "testknn7" + text_input = ["test", "add", "text", "method"] + metadatas = [{"page": i} for i in range(len(text_input))] + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + index_name=index_name, + refresh_indices=True, + es_url=ES_URL, + user=user, + password=password, + ) + docids = EcloudESVectorStore.add_texts(docsearch, text_input, metadatas) + assert len(docids) == len(text_input) + delete_index(docsearch.client, index_name) + + +def test_dense_float_vector_lsh_cosine() -> None: + """ + Test indexing with vector type knn_dense_float_vector and + model-similarity of lsh-cosine + this mapping is compatible with model of exact and similarity of l2/cosine + this mapping is compatible with model of lsh and similarity of cosine + """ + index_name = "testknn9" + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + index_name=index_name, + refresh_indices=True, + es_url=ES_URL, + user=user, + password=password, + text_field="my_text", + vector_field="my_vec", + vector_type="knn_dense_float_vector", + vector_params={"model": "lsh", "similarity": "cosine", "L": 99, "k": 1}, + ) + output = docsearch.similarity_search( + "foo", + k=1, + search_params={ + "model": "exact", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + output = docsearch.similarity_search( + "foo", + k=1, + search_params={ + "model": "exact", + "similarity": "l2", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + output = docsearch.similarity_search( + "foo", + k=1, + search_params={ + "model": "exact", + "similarity": "cosine", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + output = docsearch.similarity_search( + "foo", + k=1, + search_params={ + "model": "lsh", + "similarity": "cosine", + "candidates": 1, + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + delete_index(docsearch.client, index_name) + + +def test_dense_float_vector_exact_with_filter() -> None: + """ + Test indexing with vector type knn_dense_float_vector and + default model/similarity + this mapping is compatible with model of exact and + similarity of l2/cosine + """ + index_name = "testknn15" + docsearch = EcloudESVectorStore.from_texts( + fake_texts, + FakeEmbeddings(), + index_name=index_name, + refresh_indices=True, + es_url=ES_URL, + user=user, + password=password, + text_field="my_text", + vector_field="my_vec", + vector_type="knn_dense_float_vector", + ) + + output = docsearch.similarity_search( + "foo", + k=1, + filter={"match_all": {}}, + search_params={ + "model": "exact", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + output = docsearch.similarity_search( + "bar", + k=2, + filter={"term": {"my_text.keyword": "bar"}}, + search_params={ + "model": "exact", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="bar")] + + output = docsearch.similarity_search( + "bar", + k=2, + filter={"term": {"my_text.keyword": "foo"}}, + search_params={ + "model": "exact", + "similarity": "l2", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="foo")] + + output = docsearch.similarity_search( + "foo", + k=2, + filter={"bool": {"filter": {"term": {"my_text.keyword": "bar"}}}}, + search_params={ + "model": "exact", + "similarity": "cosine", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="bar")] + + output = docsearch.similarity_search( + "foo", + k=2, + filter={"bool": {"filter": [{"term": {"my_text.keyword": "bar"}}]}}, + search_params={ + "model": "exact", + "similarity": "cosine", + "vector_field": "my_vec", + "text_field": "my_text", + }, + ) + assert output == [Document(page_content="bar")] + + delete_index(docsearch.client, index_name) diff --git a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py index da068990c0071..f04cc0b64db21 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py +++ b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py @@ -82,6 +82,7 @@ def check_compatibility(vector_store: VectorStore) -> bool: "SurrealDBStore", "TileDB", "TimescaleVector", + "EcloudESVectorStore", "Vald", "Vearch", "VespaStore", diff --git a/libs/community/tests/unit_tests/vectorstores/test_public_api.py b/libs/community/tests/unit_tests/vectorstores/test_public_api.py index d91f0eb0f5350..b092e0fba2d74 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_public_api.py +++ b/libs/community/tests/unit_tests/vectorstores/test_public_api.py @@ -29,6 +29,7 @@ "DocArrayInMemorySearch", "DocumentDBVectorSearch", "DuckDB", + "EcloudESVectorStore", "ElasticKnnSearch", "ElasticVectorSearch", "ElasticsearchStore", From 5c41f4083e6a6b9cabdc6f502ad384e8149eea52 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:23:35 -0700 Subject: [PATCH 7/9] [Evals] Fix function calling support (#19658) Current implementation is overzealous in validating chat datasets Fixes [#langsmith-sdk:557](https://github.com/langchain-ai/langsmith-sdk/issues/557) --- .../smith/evaluation/runner_utils.py | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 35c7b8a7e823d..2559c8eb71e30 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -280,7 +280,11 @@ def _get_prompt(inputs: Dict[str, Any]) -> str: ) -def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: +class ChatModelInput(TypedDict): + messages: List[BaseMessage] + + +def _get_messages(inputs: Dict[str, Any]) -> dict: """Get Chat Messages from inputs. Args: @@ -293,35 +297,29 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: """ if not inputs: raise InputFormatError("Inputs should not be empty.") - + input_copy = inputs.copy() if "messages" in inputs: - single_input = inputs["messages"] + input_copy["input"] = input_copy.pop("messages") elif len(inputs) == 1: - single_input = next(iter(inputs.values())) - else: - raise InputFormatError( - f"Chat Run expects 'messages' in inputs when example has multiple" - f" input keys. Got {inputs}" - ) - if isinstance(single_input, list) and all( - isinstance(i, dict) for i in single_input - ): - raw_messages = [single_input] - elif isinstance(single_input, list) and all( - isinstance(i, list) for i in single_input - ): - raw_messages = single_input - else: - raise InputFormatError( - f"Chat Run expects List[dict] or List[List[dict]] values for" - f" 'messages' key input. Got {inputs}" - ) - if len(raw_messages) == 1: - return messages_from_dict(raw_messages[0]) + input_copy["input"] = next(iter(inputs.values())) + if "input" in input_copy: + raw_messages = input_copy["input"] + if isinstance(raw_messages, list) and all( + isinstance(i, dict) for i in raw_messages + ): + raw_messages = [raw_messages] + if len(raw_messages) == 1: + input_copy["input"] = messages_from_dict(raw_messages[0]) + else: + raise InputFormatError( + "Batch messages not supported. Please provide a" + " single list of messages." + ) + return input_copy else: raise InputFormatError( f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" - f" input. Got {len(raw_messages)} messages from inputs {inputs}" + f" input. Got {inputs}" ) @@ -711,9 +709,9 @@ async def _arun_llm( ), ) except InputFormatError: - messages = _get_messages(inputs) + llm_inputs = _get_messages(inputs) llm_output = await llm.ainvoke( - messages, + **llm_inputs, config=RunnableConfig( callbacks=callbacks, tags=tags or [], metadata=metadata or {} ), @@ -864,9 +862,9 @@ def _run_llm( ), ) except InputFormatError: - llm_messages = _get_messages(inputs) + llm_inputs = _get_messages(inputs) llm_output = llm.invoke( - llm_messages, + **llm_inputs, config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}), ) return llm_output From 3685f8ceac42ee3614127c1210e3d746599219e5 Mon Sep 17 00:00:00 2001 From: harry-cohere <127103098+harry-cohere@users.noreply.github.com> Date: Thu, 28 Mar 2024 01:35:43 +0000 Subject: [PATCH 8/9] cohere[patch]: Add cohere tools agent (#19602) **Description**: Adds a cohere tools agent and related notebook. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Co-authored-by: Erick Friis --- .../cohere/langchain_cohere/__init__.py | 2 + .../cohere/langchain_cohere/chat_models.py | 97 ++++++++-- .../cohere/langchain_cohere/cohere_agent.py | 168 ++++++++++++++++++ .../cohere/langchain_cohere/rag_retrievers.py | 1 + libs/partners/cohere/poetry.lock | 9 +- libs/partners/cohere/pyproject.toml | 4 +- .../integration_tests/test_chat_models.py | 81 +++++++++ .../tests/unit_tests/test_chat_models.py | 85 +++++++++ .../tests/unit_tests/test_cohere_agent.py | 82 +++++++++ .../cohere/tests/unit_tests/test_imports.py | 1 + 10 files changed, 510 insertions(+), 20 deletions(-) create mode 100644 libs/partners/cohere/langchain_cohere/cohere_agent.py create mode 100644 libs/partners/cohere/tests/unit_tests/test_cohere_agent.py diff --git a/libs/partners/cohere/langchain_cohere/__init__.py b/libs/partners/cohere/langchain_cohere/__init__.py index 1f554a006e258..52d53361931f4 100644 --- a/libs/partners/cohere/langchain_cohere/__init__.py +++ b/libs/partners/cohere/langchain_cohere/__init__.py @@ -1,4 +1,5 @@ from langchain_cohere.chat_models import ChatCohere +from langchain_cohere.cohere_agent import create_cohere_tools_agent from langchain_cohere.embeddings import CohereEmbeddings from langchain_cohere.rag_retrievers import CohereRagRetriever from langchain_cohere.rerank import CohereRerank @@ -9,4 +10,5 @@ "CohereEmbeddings", "CohereRagRetriever", "CohereRerank", + "create_cohere_tools_agent", ] diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index f60f5636dd15d..ea830e81f9b29 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -1,9 +1,22 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +import json +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, +) +from cohere.types import NonStreamedChatResponse, ToolCall from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -18,7 +31,11 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_cohere.cohere_agent import _format_to_cohere_tools from langchain_cohere.llms import BaseCohere @@ -143,6 +160,14 @@ def _default_params(self) -> Dict[str, Any]: } return {k: v for k, v in base_params.items() if v is not None} + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = _format_to_cohere_tools(tools) + return super().bind(tools=formatted_tools, **kwargs) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" @@ -169,6 +194,14 @@ def _stream( if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk + elif data.event_type == "stream-end": + generation_info = self._get_generation_info(data.response) + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs=generation_info + ), + generation_info=generation_info, + ) async def _astream( self, @@ -191,16 +224,34 @@ async def _astream( if run_manager: await run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk - - def _get_generation_info(self, response: Any) -> Dict[str, Any]: + elif data.event_type == "stream-end": + generation_info = self._get_generation_info(data.response) + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs=generation_info + ), + generation_info=generation_info, + ) + + def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]: """Get the generation info from cohere API response.""" - return { + generation_info = { "documents": response.documents, "citations": response.citations, "search_results": response.search_results, "search_queries": response.search_queries, - "token_count": response.token_count, + "is_search_required": response.is_search_required, + "generation_id": response.generation_id, } + if response.tool_calls: + # Only populate tool_calls when 1) present on the response and + # 2) has one or more calls. + generation_info["tool_calls"] = _format_cohere_tool_calls( + response.generation_id or "", response.tool_calls + ) + if hasattr(response, "token_count"): + generation_info["token_count"] = response.token_count + return generation_info def _generate( self, @@ -218,10 +269,8 @@ def _generate( request = get_cohere_chat_request(messages, **self._default_params, **kwargs) response = self.client.chat(**request) - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) + generation_info = self._get_generation_info(response) + message = AIMessage(content=response.text, additional_kwargs=generation_info) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -244,10 +293,8 @@ async def _agenerate( request = get_cohere_chat_request(messages, **self._default_params, **kwargs) response = self.client.chat(**request) - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) + generation_info = self._get_generation_info(response) + message = AIMessage(content=response.text, additional_kwargs=generation_info) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -257,3 +304,27 @@ async def _agenerate( def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" return len(self.client.tokenize(text).tokens) + + +def _format_cohere_tool_calls( + generation_id: str, tool_calls: Optional[List[ToolCall]] = None +) -> List[Dict]: + """ + Formats a Cohere API response into the tool call format used elsewhere in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": generation_id, + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.parameters), + }, + "type": "function", + } + ) + return formatted_tool_calls diff --git a/libs/partners/cohere/langchain_cohere/cohere_agent.py b/libs/partners/cohere/langchain_cohere/cohere_agent.py new file mode 100644 index 0000000000000..5bf8328e8c6f2 --- /dev/null +++ b/libs/partners/cohere/langchain_cohere/cohere_agent.py @@ -0,0 +1,168 @@ +from typing import Any, Dict, List, Sequence, Tuple, Type, Union + +from cohere.types import Tool, ToolParameterDefinitionsValue +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.outputs import Generation +from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables.base import RunnableLambda +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function + + +def create_cohere_tools_agent( + llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate +) -> Runnable: + def llm_with_tools(input_: Dict) -> Runnable: + tool_results = ( + input_["tool_results"] if len(input_["tool_results"]) > 0 else None + ) + tools_ = input_["tools"] if len(input_["tools"]) > 0 else None + return RunnableLambda(lambda x: x["input"]) | llm.bind( + tools=tools_, tool_results=tool_results + ) + + agent = ( + RunnablePassthrough.assign( + # Intermediate steps are in tool results. + # Edit below to change the prompt parameters. + input=lambda x: prompt.format_messages( + input=x["input"], agent_scratchpad=[] + ), + tools=lambda x: _format_to_cohere_tools(tools), + tool_results=lambda x: _format_to_cohere_tools_messages( + x["intermediate_steps"] + ), + ) + | llm_with_tools + | _CohereToolsAgentOutputParser() + ) + return agent + + +def _format_to_cohere_tools( + tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]], +) -> List[Dict[str, Any]]: + return [_convert_to_cohere_tool(tool) for tool in tools] + + +def _format_to_cohere_tools_messages( + intermediate_steps: Sequence[Tuple[AgentAction, str]], +) -> list: + """Convert (AgentAction, tool output) tuples into tool messages.""" + if len(intermediate_steps) == 0: + return [] + tool_results = [] + for agent_action, observation in intermediate_steps: + tool_results.append( + { + "call": { + "name": agent_action.tool, + "parameters": agent_action.tool_input, + }, + "outputs": [{"answer": observation}], + } + ) + + return tool_results + + +def _convert_to_cohere_tool( + tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]], +) -> Dict[str, Any]: + """ + Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool. + """ + if isinstance(tool, BaseTool): + return Tool( + name=tool.name, + description=tool.description, + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required="default" not in param_definition, + ) + for param_name, param_definition in tool.args.items() + }, + ).dict() + elif isinstance(tool, dict): + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) + return Tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required="default" not in param_definition, + ) + for param_name, param_definition in tool.get("properties", {}).items() + }, + ).dict() + elif issubclass(tool, BaseModel): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + properties = parameters.get("properties", {}) + return Tool( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + # The Cohere API requires the description field. + "description", + as_json_schema_function.get("name"), + ), + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required=param_name in parameters.get("required", []), + ) + for param_name, param_definition in properties.items() + }, + ).dict() + else: + raise ValueError( + f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) + + +class _CohereToolsAgentOutputParser( + BaseOutputParser[Union[List[AgentAction], AgentFinish]] +): + """Parses a message into agent actions/finish.""" + + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Union[List[AgentAction], AgentFinish]: + if not isinstance(result[0], ChatGeneration): + raise ValueError(f"Expected ChatGeneration, got {type(result)}") + if result[0].message.additional_kwargs["tool_calls"]: + actions = [] + for tool in result[0].message.additional_kwargs["tool_calls"]: + function = tool.get("function", {}) + actions.append( + AgentAction( + tool=function.get("name"), + tool_input=function.get("arguments"), + log=function.get("name"), + ) + ) + return actions + else: + return AgentFinish( + return_values={ + "text": result[0].message.content, + "additional_info": result[0].message.additional_kwargs, + }, + log="", + ) + + def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: + raise ValueError("Can only parse messages") diff --git a/libs/partners/cohere/langchain_cohere/rag_retrievers.py b/libs/partners/cohere/langchain_cohere/rag_retrievers.py index 91f4c3a0886f5..0d194596203c9 100644 --- a/libs/partners/cohere/langchain_cohere/rag_retrievers.py +++ b/libs/partners/cohere/langchain_cohere/rag_retrievers.py @@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]: docs = ( [] if "documents" not in response.generation_info + or len(response.generation_info["documents"]) == 0 else [ Document(page_content=doc["snippet"], metadata=doc) for doc in response.generation_info["documents"] diff --git a/libs/partners/cohere/poetry.lock b/libs/partners/cohere/poetry.lock index e3fae42bd8e20..d18c8a94e5cf3 100644 --- a/libs/partners/cohere/poetry.lock +++ b/libs/partners/cohere/poetry.lock @@ -165,13 +165,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency [[package]] name = "cohere" -version = "5.1.2" +version = "5.1.4" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "cohere-5.1.2-py3-none-any.whl", hash = "sha256:7782e32cba671fc04203c3b56a9ce1b70e9459d7c983e8576b04d394fbe809f5"}, - {file = "cohere-5.1.2.tar.gz", hash = "sha256:21af5ed6edcf939062c41240040316084cd7e753cf3207f661f68abb4bbbe846"}, + {file = "cohere-5.1.4-py3-none-any.whl", hash = "sha256:b88c44dfa44301f55f509db120582a6127c2e391c6c43a4dc58767f4df056a9d"}, + {file = "cohere-5.1.4.tar.gz", hash = "sha256:81b45fe37df2d62aaf57094402cb62b5fed285c25667dab96023f2ad2591ff35"}, ] [package.dependencies] @@ -742,7 +742,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -957,4 +956,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7ed2d31c084d528c64eb959df1a2ea29345a70117e9d29f322607fe247804cc5" +content-hash = "6a5887a0391a649e1a45f3e3c766a880e133367d2656a9b5a37d75ebc33adef6" diff --git a/libs/partners/cohere/pyproject.toml b/libs/partners/cohere/pyproject.toml index 8fcad47bfd323..71ddcca2e1a2e 100644 --- a/libs/partners/cohere/pyproject.toml +++ b/libs/partners/cohere/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-cohere" -version = "0.1.0rc1" +version = "0.1.0rc2" description = "An integration package connecting Cohere and LangChain" authors = [] readme = "README.md" @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.32" -cohere = "^5.1.1" +cohere = "^5.1.4" [tool.poetry.group.test] optional = true diff --git a/libs/partners/cohere/tests/integration_tests/test_chat_models.py b/libs/partners/cohere/tests/integration_tests/test_chat_models.py index 81246c37aa534..a94c9e627d759 100644 --- a/libs/partners/cohere/tests/integration_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/integration_tests/test_chat_models.py @@ -1,4 +1,12 @@ """Test ChatCohere chat model.""" + +import json +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.pydantic_v1 import BaseModel + from langchain_cohere import ChatCohere @@ -61,3 +69,76 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +def test_invoke_tool_calls() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it calls the tool + result = tool_llm.invoke("Erick, 27 years old") + + assert isinstance(result, AIMessage) + additional_kwargs = result.additional_kwargs + assert "tool_calls" in additional_kwargs + assert len(additional_kwargs["tool_calls"]) == 1 + assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person" + assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == { + "name": "Erick", + "age": 27, + } + + +def test_streaming_tool_call() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it calls the tool + strm = tool_llm.stream("Erick, 27 years old") + + additional_kwargs = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + assert chunk.content == "" + additional_kwargs = chunk.additional_kwargs + + assert additional_kwargs is not None + assert "tool_calls" in additional_kwargs + assert len(additional_kwargs["tool_calls"]) == 1 + assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person" + assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == { + "name": "Erick", + "age": 27, + } + + +@pytest.mark.xfail( + reason="Cohere models return empty output when a tool is passed in but not called." +) +def test_streaming_tool_call_no_tool_calls() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it doesn't call the tool + strm = tool_llm.stream("What is 2+2?") + acc: Any = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + acc = chunk if acc is None else acc + chunk + assert acc.content != "" + assert "tool_calls" not in acc.additional_kwargs diff --git a/libs/partners/cohere/tests/unit_tests/test_chat_models.py b/libs/partners/cohere/tests/unit_tests/test_chat_models.py index eecfe33f3311a..545fa7f2887d4 100644 --- a/libs/partners/cohere/tests/unit_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/unit_tests/test_chat_models.py @@ -2,6 +2,7 @@ import typing import pytest +from cohere.types import NonStreamedChatResponse, ToolCall from langchain_cohere.chat_models import ChatCohere @@ -28,3 +29,87 @@ def test_initialization() -> None: def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None: actual = chat_cohere._default_params assert expected == actual + + +@pytest.mark.parametrize( + "response, expected", + [ + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="", + tool_calls=[ + ToolCall(name="tool1", parameters={"arg1": 1, "arg2": "2"}), + ToolCall(name="tool2", parameters={"arg3": 3, "arg4": "4"}), + ], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + "tool_calls": [ + { + "id": "foo", + "function": { + "name": "tool1", + "arguments": '{"arg1": 1, "arg2": "2"}', + }, + "type": "function", + }, + { + "id": "foo", + "function": { + "name": "tool2", + "arguments": '{"arg3": 3, "arg4": "4"}', + }, + "type": "function", + }, + ], + }, + id="tools should be called", + ), + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="", + tool_calls=[], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + }, + id="no tools should be called", + ), + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="bar", + tool_calls=[], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + }, + id="chat response without tools/documents/citations/tools etc", + ), + ], +) +def test_get_generation_info( + response: typing.Any, expected: typing.Dict[str, typing.Any] +) -> None: + chat_cohere = ChatCohere(cohere_api_key="test") + + actual = chat_cohere._get_generation_info(response) + + assert expected == actual diff --git a/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py new file mode 100644 index 0000000000000..9dc082a55e671 --- /dev/null +++ b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Optional, Type, Union + +import pytest +from langchain_core.tools import BaseModel, BaseTool, Field + +from langchain_cohere.cohere_agent import _format_to_cohere_tools + +expected_test_tool_definition = { + "description": "test_tool description", + "name": "test_tool", + "parameter_definitions": { + "arg_1": { + "description": "Arg1 description", + "required": True, + "type": "string", + }, + "optional_arg_2": { + "description": "Arg2 description", + "required": False, + "type": "string", + }, + "arg_3": { + "description": "Arg3 description", + "required": True, + "type": "integer", + }, + }, +} + + +class _TestToolSchema(BaseModel): + arg_1: str = Field(description="Arg1 description") + optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2") + arg_3: int = Field(description="Arg3 description") + + +class _TestTool(BaseTool): + name = "test_tool" + description = "test_tool description" + args_schema: Type[_TestToolSchema] = _TestToolSchema + + def _run(self, *args: Any, **kwargs: Any) -> Any: + pass + + +class test_tool(BaseModel): + """test_tool description""" + + arg_1: str = Field(description="Arg1 description") + optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2") + arg_3: int = Field(description="Arg3 description") + + +test_tool_as_dict = { + "title": "test_tool", + "description": "test_tool description", + "properties": { + "arg_1": {"description": "Arg1 description", "type": "string"}, + "optional_arg_2": { + "description": "Arg2 description", + "type": "string", + "default": "2", + }, + "arg_3": {"description": "Arg3 description", "type": "integer"}, + }, +} + + +@pytest.mark.parametrize( + "tool", + [ + pytest.param(_TestTool(), id="tool from BaseTool"), + pytest.param(test_tool, id="BaseModel"), + pytest.param(test_tool_as_dict, id="JSON schema dict"), + ], +) +def test_format_to_cohere_tools( + tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]], +) -> None: + actual = _format_to_cohere_tools([tool]) + + assert [expected_test_tool_definition] == actual diff --git a/libs/partners/cohere/tests/unit_tests/test_imports.py b/libs/partners/cohere/tests/unit_tests/test_imports.py index ceff62f104e92..0159c19e94ea9 100644 --- a/libs/partners/cohere/tests/unit_tests/test_imports.py +++ b/libs/partners/cohere/tests/unit_tests/test_imports.py @@ -6,6 +6,7 @@ "CohereEmbeddings", "CohereRagRetriever", "CohereRerank", + "create_cohere_tools_agent", ] From fdfb51ad8daffa1e6e5c6889fd71627697de178e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 27 Mar 2024 18:45:01 -0700 Subject: [PATCH 9/9] core: Two updates to chat model interface (#19684) - .stream() and .astream() call on_llm_new_token, removing the need for subclasses to do so. Backwards compatible because now we don't pass run_manager into ._stream and ._astream - .generate() and .agenerate() now handle `stream: bool` kwarg for _generate and _agenerate. Subclasses handle this arg by delegating to ._stream(), now one less thing they need to do. Backwards compat because this is an optional arg that we now never pass to the subclasses - .generate() and .agenerate() now inspect callback handlers to decide on a default value for stream:bool if not passed in. This auto enables streaming when using astream_events and astream_log - as a result of these three changes any usage of .astream_events and .astream_log should now yield chat model stream events - In future PRs we can update all subclasses to reflect these two things now handled by base class, but in meantime all will continue to work --- .../language_models/chat_models.py | 81 ++++++-- .../runnables/test_runnable_events.py | 188 +++++++++++++++++- 2 files changed, 254 insertions(+), 15 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 60885a84a8694..235b5d4b7b283 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -50,6 +50,7 @@ from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables.config import ensure_config, run_in_executor +from langchain_core.tracers.log_stream import LogStreamCallbackHandler if TYPE_CHECKING: from langchain_core.runnables import RunnableConfig @@ -219,9 +220,10 @@ def stream( ) generation: Optional[ChatGenerationChunk] = None try: - for chunk in self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): + for chunk in self._stream(messages, stop=stop, **kwargs): + run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: @@ -287,9 +289,11 @@ async def astream( async for chunk in self._astream( messages, stop=stop, - run_manager=run_manager, **kwargs, ): + await run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: @@ -585,12 +589,37 @@ def _generate_with_cache( raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if inspect.signature(self._generate).parameters.get("run_manager"): - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs + # If stream is not explicitly set, check if implicitly requested by + # astream_events() or astream_log(). Bail out if _stream not implemented + if type(self)._stream != BaseChatModel._stream and kwargs.pop( + "stream", + next( + ( + True + for h in run_manager.handlers + if isinstance(h, LogStreamCallbackHandler) + ), + False, ) + if run_manager + else False, + ): + chunks: List[ChatGenerationChunk] = [] + for chunk in self._stream(messages, stop=stop, **kwargs): + if run_manager: + run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) + chunks.append(chunk) + result = generate_from_stream(iter(chunks)) else: - result = self._generate(messages, stop=stop, **kwargs) + if inspect.signature(self._generate).parameters.get("run_manager"): + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = self._generate(messages, stop=stop, **kwargs) # Add response metadata to each generation for generation in result.generations: @@ -634,12 +663,40 @@ async def _agenerate_with_cache( raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if inspect.signature(self._agenerate).parameters.get("run_manager"): - result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs + # If stream is not explicitly set, check if implicitly requested by + # astream_events() or astream_log(). Bail out if _astream not implemented + if ( + type(self)._astream != BaseChatModel._astream + or type(self)._stream != BaseChatModel._stream + ) and kwargs.pop( + "stream", + next( + ( + True + for h in run_manager.handlers + if isinstance(h, LogStreamCallbackHandler) + ), + False, ) + if run_manager + else False, + ): + chunks: List[ChatGenerationChunk] = [] + async for chunk in self._astream(messages, stop=stop, **kwargs): + if run_manager: + await run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) + chunks.append(chunk) + result = generate_from_stream(iter(chunks)) else: - result = await self._agenerate(messages, stop=stop, **kwargs) + if inspect.signature(self._agenerate).parameters.get("run_manager"): + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = await self._agenerate(messages, stop=stop, **kwargs) # Add response metadata to each generation for generation in result.generations: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index 3b822bce6fe72..bc5d6102ecccd 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -1,4 +1,5 @@ """Module that contains tests for runnable.astream_events API.""" +import sys from itertools import cycle from typing import Any, AsyncIterator, Dict, List, Sequence, cast @@ -22,6 +23,7 @@ from langchain_core.runnables import ( ConfigurableField, Runnable, + RunnableConfig, RunnableLambda, ) from langchain_core.runnables.history import RunnableWithMessageHistory @@ -314,9 +316,7 @@ async def test_event_stream_with_lambdas_from_lambda() -> None: async def test_astream_events_from_model() -> None: """Test the output of a model.""" - infinite_cycle = cycle( - [AIMessage(content="hello world!"), AIMessage(content="goodbye world!")] - ) + infinite_cycle = cycle([AIMessage(content="hello world!")]) # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces model = ( GenericFakeChatModel(messages=infinite_cycle) @@ -373,6 +373,188 @@ async def test_astream_events_from_model() -> None: }, ] + @RunnableLambda + def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return model.invoke(input) + else: + return model.invoke(input, config) + + events = await _collect_events(i_dont_stream.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": { + "generations": [ + [ + { + "generation_info": None, + "message": AIMessage(content="hello world!"), + "text": "hello world!", + "type": "ChatGeneration", + } + ] + ], + "llm_output": None, + "run": None, + }, + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!")}, + "event": "on_chain_stream", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!")}, + "event": "on_chain_end", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + @RunnableLambda + async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return await model.ainvoke(input) + else: + return await model.ainvoke(input, config) + + events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": { + "generations": [ + [ + { + "generation_info": None, + "message": AIMessage(content="hello world!"), + "text": "hello world!", + "type": "ChatGeneration", + } + ] + ], + "llm_output": None, + "run": None, + }, + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!")}, + "event": "on_chain_stream", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!")}, + "event": "on_chain_end", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + ] + async def test_event_stream_with_simple_chain() -> None: """Test as event stream."""