Skip to content

Commit

Permalink
fix retriever and reranker to process chat completion request (#915)
Browse files Browse the repository at this point in the history
* fix retriever and reranker to process chat completion request

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] authored Nov 19, 2024
1 parent 8121602 commit 1cf2781
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8889):
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest],
Union[TextDoc, ChatCompletionRequest],
Union[RerankedDoc, LLMParamsDoc],
)

Expand All @@ -789,7 +789,7 @@ def parser_input(data, TypeClass, key):

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
for key, TypeClass in zip(["text", "messages"], [TextDoc, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
Expand Down
4 changes: 2 additions & 2 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
endpoint="/v1/reranking",
host="0.0.0.0",
port=8000,
input_datatype=SearchedDoc,
output_datatype=LLMParamsDoc,
input_datatype=Union[SearchedDoc, RerankingRequest, ChatCompletionRequest],
output_datatype=Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest],
)
@register_statistics(names=["opea_service@reranking_tei"])
async def reranking(
Expand Down
16 changes: 15 additions & 1 deletion comps/retrievers/redis/langchain/retriever_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
EmbeddingResponse,
RetrievalRequest,
RetrievalResponse,
RetrievalResponseData,
Expand Down Expand Up @@ -54,12 +55,25 @@ async def retrieve(
else:
if isinstance(input, EmbedDoc):
query = input.text
embedding_data_input = input.embedding
else:
# for RetrievalRequest, ChatCompletionRequest
query = input.input
if isinstance(input.embedding, EmbeddingResponse):
embeddings = input.embedding.data
embedding_data_input = []
for emb in embeddings:
# each emb is EmbeddingResponseData
# print("Embedding data: ", emb.embedding)
# print("Embedding data length: ",len(emb.embedding))
embedding_data_input.append(emb.embedding)
# print("All Embedding data length: ",len(embedding_data_input))
else:
embedding_data_input = input.embedding

# if the Redis index has data, perform the search
if input.search_type == "similarity":
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k)
search_res = await vector_db.asimilarity_search_by_vector(embedding=embedding_data_input, k=input.k)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
Expand Down

0 comments on commit 1cf2781

Please sign in to comment.