From 4b8cd7a09a6a2bfe36e072189619cb54524de1de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Philippart?= Date: Mon, 4 Nov 2024 22:40:30 +0100 Subject: [PATCH] =?UTF-8?q?community:=20=E2=9C=A8=20Use=20new=20OVHcloud?= =?UTF-8?q?=20batch=20embedding=20(#26209)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** change to do the batch embedding server side and not client side - **Twitter handle:** @wildagsx --------- Co-authored-by: ccurme --- .../embeddings/ovhcloud.py | 64 ++++++++++++------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/libs/community/langchain_community/embeddings/ovhcloud.py b/libs/community/langchain_community/embeddings/ovhcloud.py index 5786a761f49fe..49b9cbfa21097 100644 --- a/libs/community/langchain_community/embeddings/ovhcloud.py +++ b/libs/community/langchain_community/embeddings/ovhcloud.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any, List @@ -41,17 +42,55 @@ def _generate_embedding(self, text: str) -> List[float]: Returns: List[float]: Embeddings for the text. """ + + return self._send_request_to_ai_endpoints("text/plain", text, "text2vec") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents. + Args: + texts (List[str]): The list of texts to embed. + + Returns: + List[List[float]]: List of embeddings, one for each input text. + + """ + + return self._send_request_to_ai_endpoints( + "application/json", json.dumps(texts), "batch_text2vec" + ) + + def embed_query(self, text: str) -> List[float]: + """Embed a single query text. + Args: + text (str): The text to embed. + Returns: + List[float]: Embeddings for the text. + """ + return self._generate_embedding(text) + + def _send_request_to_ai_endpoints( + self, contentType: str, payload: str, route: str + ) -> Any: + """Send a HTTPS request to OVHcloud AI Endpoints + Args: + contentType (str): The content type of the request, application/json or text/plain. + payload (str): The payload of the request. + route (str): The route of the request, batch_text2vec or text2vec. + """ # noqa: E501 headers = { - "content-type": "text/plain", + "content-type": contentType, "Authorization": f"Bearer {self.access_token}", } session = requests.session() while True: response = session.post( - f"https://{self.model_name}.endpoints.{self.region}.ai.cloud.ovh.net/api/text2vec", + ( + f"https://{self.model_name}.endpoints.{self.region}" + f".ai.cloud.ovh.net/api/{route}" + ), headers=headers, - data=text, + data=payload, ) if response.status_code != 200: if response.status_code == 429: @@ -74,22 +113,3 @@ def _generate_embedding(self, text: str) -> List[float]: ) ) return response.json() - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Create a retry decorator for PremAIEmbeddings. - Args: - texts (List[str]): The list of texts to embed. - - Returns: - List[List[float]]: List of embeddings, one for each input text. - """ - return [self._generate_embedding(text) for text in texts] - - def embed_query(self, text: str) -> List[float]: - """Embed a single query text. - Args: - text (str): The text to embed. - Returns: - List[float]: Embeddings for the text. - """ - return self._generate_embedding(text)