Skip to content

Commit

Permalink
support faqgen upload file in UI (#866)
Browse files Browse the repository at this point in the history
* support faqgen upload file in UI

Signed-off-by: Xinyao Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xinyao Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
XinyaoWa and pre-commit-ci[bot] authored Nov 8, 2024
1 parent 78d8276 commit 453ff72
Showing 1 changed file with 59 additions and 37 deletions.
96 changes: 59 additions & 37 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,41 @@
from .micro_service import MicroService


def read_pdf(file):
from langchain.document_loaders import PyPDFLoader

loader = PyPDFLoader(file)
docs = loader.load_and_split()
return docs


def read_text_from_file(file, save_file_name):
import docx2txt
from langchain.text_splitter import CharacterTextSplitter

# read text file
if file.headers["content-type"] == "text/plain":
file.file.seek(0)
content = file.file.read().decode("utf-8")
# Split text
text_splitter = CharacterTextSplitter()
texts = text_splitter.split_text(content)
# Create multiple documents
file_content = texts
# read pdf file
elif file.headers["content-type"] == "application/pdf":
documents = read_pdf(save_file_name)
file_content = [doc.page_content for doc in documents]
# read docx file
elif (
file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file.headers["content-type"] == "application/octet-stream"
):
file_content = docx2txt.process(save_file_name)

return file_content


class Gateway:
def __init__(
self,
Expand Down Expand Up @@ -365,39 +400,6 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):
megaservice, host, port, str(MegaServiceEndpoint.DOC_SUMMARY), ChatCompletionRequest, ChatCompletionResponse
)

def read_pdf(self, file):
from langchain.document_loaders import PyPDFLoader

loader = PyPDFLoader(file)
docs = loader.load_and_split()
return docs

def read_text_from_file(self, file, save_file_name):
import docx2txt
from langchain.text_splitter import CharacterTextSplitter

# read text file
if file.headers["content-type"] == "text/plain":
file.file.seek(0)
content = file.file.read().decode("utf-8")
# Split text
text_splitter = CharacterTextSplitter()
texts = text_splitter.split_text(content)
# Create multiple documents
file_content = texts
# read pdf file
elif file.headers["content-type"] == "application/pdf":
documents = self.read_pdf(save_file_name)
file_content = [doc.page_content for doc in documents]
# read docx file
elif (
file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file.headers["content-type"] == "application/octet-stream"
):
file_content = docx2txt.process(save_file_name)

return file_content

async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
data = await request.form()
stream_opt = data.get("stream", True)
Expand All @@ -411,7 +413,7 @@ async def handle_request(self, request: Request, files: List[UploadFile] = File(

async with aiofiles.open(file_path, "wb") as f:
await f.write(await file.read())
docs = self.read_text_from_file(file, file_path)
docs = read_text_from_file(file, file_path)
os.remove(file_path)
if isinstance(docs, list):
file_summaries.extend(docs)
Expand Down Expand Up @@ -547,11 +549,31 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):
megaservice, host, port, str(MegaServiceEndpoint.FAQ_GEN), ChatCompletionRequest, ChatCompletionResponse
)

async def handle_request(self, request: Request):
data = await request.json()
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
data = await request.form()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
file_summaries = []
if files:
for file in files:
file_path = f"/tmp/{file.filename}"

import aiofiles

async with aiofiles.open(file_path, "wb") as f:
await f.write(await file.read())
docs = read_text_from_file(file, file_path)
os.remove(file_path)
if isinstance(docs, list):
file_summaries.extend(docs)
else:
file_summaries.append(docs)

if file_summaries:
prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries)
else:
prompt = self._handle_message(chat_request.messages)

parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
Expand Down

0 comments on commit 453ff72

Please sign in to comment.