From 9df4b3c040f0f852ba0d40e73e41137d8650b5d8 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Mon, 30 Sep 2024 09:37:00 +0800 Subject: [PATCH] TEI rerank microservice async support (#746) * tTEIrerank microservice support async Signed-off-by: lvliang-intel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: lvliang-intel Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/reranks/tei/reranking_tei.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/comps/reranks/tei/reranking_tei.py b/comps/reranks/tei/reranking_tei.py index cb423cf83..5e5acd716 100644 --- a/comps/reranks/tei/reranking_tei.py +++ b/comps/reranks/tei/reranking_tei.py @@ -1,14 +1,12 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import heapq import json import os -import re import time from typing import Union -import requests +import aiohttp from comps import ( CustomLogger, @@ -27,12 +25,12 @@ RerankingResponseData, ) -logger = CustomLogger("reranking_tgi_gaudi") +logger = CustomLogger("reranking_tei") logflag = os.getenv("LOGFLAG", False) @register_microservice( - name="opea_service@reranking_tgi_gaudi", + name="opea_service@reranking_tei", service_type=ServiceType.RERANK, endpoint="/v1/reranking", host="0.0.0.0", @@ -40,8 +38,8 @@ input_datatype=SearchedDoc, output_datatype=LLMParamsDoc, ) -@register_statistics(names=["opea_service@reranking_tgi_gaudi"]) -def reranking( +@register_statistics(names=["opea_service@reranking_tei"]) +async def reranking( input: Union[SearchedDoc, RerankingRequest, ChatCompletionRequest] ) -> Union[LLMParamsDoc, RerankingResponse, ChatCompletionRequest]: if logflag: @@ -58,15 +56,16 @@ def reranking( query = input.input data = {"query": query, "texts": docs} headers = {"Content-Type": "application/json"} - response = requests.post(url, data=json.dumps(data), headers=headers) - response_data = response.json() + async with aiohttp.ClientSession() as session: + async with session.post(url, data=json.dumps(data), headers=headers) as response: + response_data = await response.json() for best_response in response_data[: input.top_n]: reranking_results.append( {"text": input.retrieved_docs[best_response["index"]].text, "score": best_response["score"]} ) - statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None) + statistics_dict["opea_service@reranking_tei"].append_latency(time.time() - start, None) if isinstance(input, SearchedDoc): result = [doc["text"] for doc in reranking_results] if logflag: @@ -92,4 +91,4 @@ def reranking( if __name__ == "__main__": tei_reranking_endpoint = os.getenv("TEI_RERANKING_ENDPOINT", "http://localhost:8080") - opea_microservices["opea_service@reranking_tgi_gaudi"].start() + opea_microservices["opea_service@reranking_tei"].start()