diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 2b8bcbcaa..a6a96bfec 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -4,10 +4,10 @@ import base64 import os from io import BytesIO -from typing import Union +from typing import List, Union import requests -from fastapi import Request +from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse from PIL import Image @@ -361,11 +361,63 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): megaservice, host, port, str(MegaServiceEndpoint.DOC_SUMMARY), ChatCompletionRequest, ChatCompletionResponse ) - async def handle_request(self, request: Request): - data = await request.json() + def read_pdf(self, file): + from langchain.document_loaders import PyPDFLoader + + loader = PyPDFLoader(file) + docs = loader.load_and_split() + return docs + + def read_text_from_file(self, file, save_file_name): + import docx2txt + from langchain.text_splitter import CharacterTextSplitter + + # read text file + if file.headers["content-type"] == "text/plain": + file.file.seek(0) + content = file.file.read().decode("utf-8") + # Split text + text_splitter = CharacterTextSplitter() + texts = text_splitter.split_text(content) + # Create multiple documents + file_content = texts + # read pdf file + elif file.headers["content-type"] == "application/pdf": + documents = self.read_pdf(save_file_name) + file_content = [doc.page_content for doc in documents] + # read docx file + elif ( + file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + or file.headers["content-type"] == "application/octet-stream" + ): + file_content = docx2txt.process(save_file_name) + + return file_content + + async def handle_request(self, request: Request, files: List[UploadFile] = File(...)): + data = await request.form() stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + file_summaries = [] + for file in files: + file_path = f"/tmp/{file.filename}" + + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + docs = self.read_text_from_file(file, file_path) + os.remove(file_path) + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, diff --git a/requirements.txt b/requirements.txt index c88a356a8..9c4b2d770 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,19 @@ +aiofiles aiohttp docarray +docx2txt fastapi httpx kubernetes +langchain +langchain-community opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk Pillow prometheus-fastapi-instrumentator +pypdf +python-multipart pyyaml requests shortuuid