From 1cf27817aadad016a1b682be26c6a326df88a686 Mon Sep 17 00:00:00 2001 From: minmin-intel Date: Mon, 18 Nov 2024 21:54:59 -0800 Subject: [PATCH] fix retriever and reranker to process chat completion request (#915) * fix retriever and reranker to process chat completion request Signed-off-by: minmin-intel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: minmin-intel Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/mega/gateway.py | 4 ++-- comps/reranks/tei/reranking_tei.py | 4 ++-- .../redis/langchain/retriever_redis.py | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 63a9fe3b2..29642eea5 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -773,7 +773,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8889): host, port, str(MegaServiceEndpoint.RETRIEVALTOOL), - Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], + Union[TextDoc, ChatCompletionRequest], Union[RerankedDoc, LLMParamsDoc], ) @@ -789,7 +789,7 @@ def parser_input(data, TypeClass, key): data = await request.json() query = None - for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + for key, TypeClass in zip(["text", "messages"], [TextDoc, ChatCompletionRequest]): query, chat_request = parser_input(data, TypeClass, key) if query is not None: break diff --git a/comps/reranks/tei/reranking_tei.py b/comps/reranks/tei/reranking_tei.py index daae461da..682346f6d 100644 --- a/comps/reranks/tei/reranking_tei.py +++ b/comps/reranks/tei/reranking_tei.py @@ -41,8 +41,8 @@ endpoint="/v1/reranking", host="0.0.0.0", port=8000, - input_datatype=SearchedDoc, - output_datatype=LLMParamsDoc, + input_datatype=Union[SearchedDoc, RerankingRequest, ChatCompletionRequest], + output_datatype=Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest], ) @register_statistics(names=["opea_service@reranking_tei"]) async def reranking( diff --git a/comps/retrievers/redis/langchain/retriever_redis.py b/comps/retrievers/redis/langchain/retriever_redis.py index d46e792f0..ada07d236 100644 --- a/comps/retrievers/redis/langchain/retriever_redis.py +++ b/comps/retrievers/redis/langchain/retriever_redis.py @@ -23,6 +23,7 @@ ) from comps.cores.proto.api_protocol import ( ChatCompletionRequest, + EmbeddingResponse, RetrievalRequest, RetrievalResponse, RetrievalResponseData, @@ -54,12 +55,25 @@ async def retrieve( else: if isinstance(input, EmbedDoc): query = input.text + embedding_data_input = input.embedding else: # for RetrievalRequest, ChatCompletionRequest query = input.input + if isinstance(input.embedding, EmbeddingResponse): + embeddings = input.embedding.data + embedding_data_input = [] + for emb in embeddings: + # each emb is EmbeddingResponseData + # print("Embedding data: ", emb.embedding) + # print("Embedding data length: ",len(emb.embedding)) + embedding_data_input.append(emb.embedding) + # print("All Embedding data length: ",len(embedding_data_input)) + else: + embedding_data_input = input.embedding + # if the Redis index has data, perform the search if input.search_type == "similarity": - search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k) + search_res = await vector_db.asimilarity_search_by_vector(embedding=embedding_data_input, k=input.k) elif input.search_type == "similarity_distance_threshold": if input.distance_threshold is None: raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")