From 286729561f9b4683afe2a3701fdc4cd1fe3f3d6f Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Thu, 26 Sep 2024 10:35:20 +0800 Subject: [PATCH] Support async for embedding micorservice (#742) * Support async in embedding micorservice Signed-off-by: lvliang-intel --- comps/embeddings/tei/langchain/embedding_tei.py | 6 +++--- comps/embeddings/tei/langchain/local_embedding.py | 4 ++-- comps/embeddings/tei/langchain/local_embedding_768.py | 4 ++-- comps/embeddings/tei/llama_index/embedding_tei.py | 4 ++-- comps/embeddings/tei/llama_index/local_embedding.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/comps/embeddings/tei/langchain/embedding_tei.py b/comps/embeddings/tei/langchain/embedding_tei.py index 0ddefb49a..2f49904bb 100644 --- a/comps/embeddings/tei/langchain/embedding_tei.py +++ b/comps/embeddings/tei/langchain/embedding_tei.py @@ -36,17 +36,17 @@ port=6000, ) @register_statistics(names=["opea_service@embedding_tei_langchain"]) -def embedding( +async def embedding( input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest] ) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]: start = time.time() if logflag: logger.info(input) if isinstance(input, TextDoc): - embed_vector = embeddings.embed_query(input.text) + embed_vector = await embeddings.aembed_query(input.text) res = EmbedDoc(text=input.text, embedding=embed_vector) else: - embed_vector = embeddings.embed_query(input.input) + embed_vector = await embeddings.aembed_query(input.input) if input.dimensions is not None: embed_vector = embed_vector[: input.dimensions] diff --git a/comps/embeddings/tei/langchain/local_embedding.py b/comps/embeddings/tei/langchain/local_embedding.py index 6a0a1a630..3f3fd5fc4 100644 --- a/comps/embeddings/tei/langchain/local_embedding.py +++ b/comps/embeddings/tei/langchain/local_embedding.py @@ -29,10 +29,10 @@ output_datatype=EmbedDoc, ) @opea_telemetry -def embedding(input: TextDoc) -> EmbedDoc: +async def embedding(input: TextDoc) -> EmbedDoc: if logflag: logger.info(input) - embed_vector = embeddings.embed_query(input.text) + embed_vector = await embeddings.aembed_query(input.text) res = EmbedDoc(text=input.text, embedding=embed_vector) if logflag: logger.info(res) diff --git a/comps/embeddings/tei/langchain/local_embedding_768.py b/comps/embeddings/tei/langchain/local_embedding_768.py index a079bd6ed..dae52299b 100644 --- a/comps/embeddings/tei/langchain/local_embedding_768.py +++ b/comps/embeddings/tei/langchain/local_embedding_768.py @@ -16,8 +16,8 @@ output_datatype=EmbedDoc768, ) @opea_telemetry -def embedding(input: TextDoc) -> EmbedDoc768: - embed_vector = embeddings.embed_query(input.text) +async def embedding(input: TextDoc) -> EmbedDoc768: + embed_vector = await embeddings.aembed_query(input.text) res = EmbedDoc768(text=input.text, embedding=embed_vector) return res diff --git a/comps/embeddings/tei/llama_index/embedding_tei.py b/comps/embeddings/tei/llama_index/embedding_tei.py index 943bd7535..e96b75e75 100644 --- a/comps/embeddings/tei/llama_index/embedding_tei.py +++ b/comps/embeddings/tei/llama_index/embedding_tei.py @@ -20,10 +20,10 @@ input_datatype=TextDoc, output_datatype=EmbedDoc, ) -def embedding(input: TextDoc) -> EmbedDoc: +async def embedding(input: TextDoc) -> EmbedDoc: if logflag: logger.info(input) - embed_vector = embeddings._get_query_embedding(input.text) + embed_vector = await embeddings.aget_query_embedding(input.text) res = EmbedDoc(text=input.text, embedding=embed_vector) if logflag: logger.info(res) diff --git a/comps/embeddings/tei/llama_index/local_embedding.py b/comps/embeddings/tei/llama_index/local_embedding.py index 17ee6e89a..ba9d3dd5a 100644 --- a/comps/embeddings/tei/llama_index/local_embedding.py +++ b/comps/embeddings/tei/llama_index/local_embedding.py @@ -20,10 +20,10 @@ input_datatype=TextDoc, output_datatype=EmbedDoc, ) -def embedding(input: TextDoc) -> EmbedDoc: +async def embedding(input: TextDoc) -> EmbedDoc: if logflag: logger.info(input) - embed_vector = embeddings.get_text_embedding(input.text) + embed_vector = await embeddings.aget_query_embedding(input.text) res = EmbedDoc(text=input.text, embedding=embed_vector) if logflag: logger.info(res)