From d211cb2dbd493c4ace48521f830077ab541c54e6 Mon Sep 17 00:00:00 2001 From: Mustafa <109312699+MSCetin37@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:14:50 -0800 Subject: [PATCH] Docsum Gateway Fix (#902) * update gateway Signed-off-by: Mustafa * update the gateway Signed-off-by: Mustafa * update the gateway Signed-off-by: Mustafa * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mustafa Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/mega/gateway.py | 66 +++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index e6a0d6dcf..1aa1a0e1a 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -419,12 +419,62 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): output_datatype=ChatCompletionResponse, ) - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.model_validate(data) + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + + if "application/json" in request.headers.get("content-type"): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + prompt = self._handle_message(chat_request.messages) + + initial_inputs_data = {data["type"]: prompt} + + elif "multipart/form-data" in request.headers.get("content-type"): + data = await request.form() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + + data_type = data.get("type") + + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + if data_type is not None and data_type in ["audio", "video"]: + raise ValueError( + "Audio and Video file uploads are not supported in docsum with curl request, please use the UI." + ) + + else: + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + + docs = read_text_from_file(file, file_path) + os.remove(file_path) + + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + + data_type = data.get("type") + if data_type is not None: + initial_inputs_data = {} + initial_inputs_data[data_type] = prompt + else: + initial_inputs_data = {"query": prompt} + + else: + raise ValueError(f"Unknown request type: {request.headers.get('content-type')}") - prompt = self._handle_message(chat_request.messages) parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -434,12 +484,14 @@ async def handle_request(self, request: Request): presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, - language=chat_request.language if chat_request.language else "auto", model=chat_request.model if chat_request.model else None, + language=chat_request.language if chat_request.language else "auto", ) + result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={data["type"]: prompt}, llm_parameters=parameters + initial_inputs=initial_inputs_data, llm_parameters=parameters ) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if (