Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable adjusting the parms for LLM, retriever and reranker #324

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
74a2e8a
add ut for parms
XuhuiRen Jul 18, 2024
3973201
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
f9928a3
revise
XuhuiRen Jul 18, 2024
6338fc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
28dcbfa
fix
XuhuiRen Jul 18, 2024
c2bf55a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
aee9267
add fake detection
XuhuiRen Jul 18, 2024
4024fee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
bc937e8
Update test_service_orchestrator_with_gatway_fake.py
XuhuiRen Jul 18, 2024
a9e0355
Merge branch 'main' into xh/parms
XuhuiRen Jul 19, 2024
75e23aa
Update test_service_orchestrator_with_gatway_fake.py
XuhuiRen Jul 24, 2024
a1aedc2
Merge branch 'main' into xh/parms
XuhuiRen Jul 24, 2024
aeb3842
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
766eef2
Update test_service_orchestrator_with_gatway_fake.py
XuhuiRen Jul 24, 2024
b92c3ed
Update orchestrator.py
XuhuiRen Jul 24, 2024
a05c6e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
2b4010e
Update orchestrator.py
XuhuiRen Jul 24, 2024
55f267f
Update orchestrator.py
XuhuiRen Jul 24, 2024
cdfb7cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
c87de4e
Merge branch 'main' into xh/parms
XuhuiRen Jul 24, 2024
2f38446
Merge branch 'main' into xh/parms
XuhuiRen Jul 29, 2024
101c3e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2024
98b00e3
Merge branch 'main' into xh/parms
XuhuiRen Aug 21, 2024
c35e9dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
0d6724d
Update orchestrator.py
XuhuiRen Aug 21, 2024
f917ae6
Update orchestrator.py
XuhuiRen Aug 21, 2024
ae2f6ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -158,7 +158,7 @@
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
llm_parameters = LLMParams(

Check warning on line 161 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L161

Added line #L161 was not covered by tests
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
Expand All @@ -167,8 +167,23 @@
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(

Check warning on line 170 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L170

Added line #L170 was not covered by tests
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(

Check warning on line 178 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L178

Added line #L178 was not covered by tests
top_n=chat_request.top_n if chat_request.top_n else 1,
)

result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
initial_inputs={"text": prompt},
llm_parameters=llm_parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
Expand Down
138 changes: 90 additions & 48 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests
from fastapi.responses import StreamingResponse

from ..proto.docarray import LLMParams
from ..proto.docarray import LLMParams, RerankerParms, RetrieverParms
from .constants import ServiceType
from .dag import DAG

Expand Down Expand Up @@ -39,7 +39,7 @@
print(e)
return False

async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()):
async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs):
result_dict = {}
runtime_graph = DAG()
runtime_graph.graph = copy.deepcopy(self.graph)
Expand Down Expand Up @@ -76,11 +76,28 @@
for d_node in downstreams:
if all(i in result_dict for i in runtime_graph.predecessors(d_node)):
inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict)
pending.add(
asyncio.create_task(
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
if "retriever_parameters" in kwargs and "reranker_parameters" in kwargs:
retriever_parameters = kwargs["retriever_parameters"]
reranker_parameters = kwargs["reranker_parameters"]
pending.add(
asyncio.create_task(
self.execute(
session,
d_node,
inputs,
runtime_graph,
llm_parameters,
retriever_parameters,
reranker_parameters,
)
)
)
else:
pending.add(
asyncio.create_task(
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
)
)
)
nodes_to_keep = []
for i in ind_nodes:
nodes_to_keep.append(i)
Expand Down Expand Up @@ -109,55 +126,80 @@
inputs: Dict,
runtime_graph: DAG,
llm_parameters: LLMParams = LLMParams(),
retriever_parameters: RetrieverParms = RetrieverParms(),
reranker_parameters: RerankerParms = RerankerParms(),
):
# send the cur_node request/reply
endpoint = self.services[cur_node].endpoint_path
llm_parameters_dict = llm_parameters.dict()
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value

if (
self.services[cur_node].service_type == ServiceType.LLM
or self.services[cur_node].service_type == ServiceType.LVM
) and llm_parameters.streaming:
# Still leave to sync requests.post for StreamingResponse
response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
)
downstream = runtime_graph.downstream(cur_node)
if downstream:
assert len(downstream) == 1, "Not supported multiple streaming downstreams yet!"
cur_node = downstream[0]
hitted_ends = [".", "?", "!", "。", ",", "!"]
downstream_endpoint = self.services[downstream[0]].endpoint_path

def generate():
if response:
buffered_chunk_str = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
if downstream:
chunk = chunk.decode("utf-8")
buffered_chunk_str += self.extract_chunk_str(chunk)
is_last = chunk.endswith("[DONE]\n\n")
if (buffered_chunk_str and buffered_chunk_str[-1] in hitted_ends) or is_last:
res = requests.post(
url=downstream_endpoint,
data=json.dumps({"text": buffered_chunk_str}),
proxies={"http": None},
)
res_json = res.json()
if "text" in res_json:
res_txt = res_json["text"]
else:
raise Exception("Other response types not supported yet!")
buffered_chunk_str = "" # clear
yield from self.token_generator(res_txt, is_last=is_last)
else:
yield chunk

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
):
if llm_parameters.streaming:
llm_parameters_dict = llm_parameters.dict()

for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value
# Still leave to sync requests.post for StreamingResponse

response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
)
downstream = runtime_graph.downstream(cur_node)
if downstream:
assert len(downstream) == 1, "Not supported multiple streaming downstreams yet!"
cur_node = downstream[0]
hitted_ends = [".", "?", "!", "。", ",", "!"]
downstream_endpoint = self.services[downstream[0]].endpoint_path

def generate():
if response:
buffered_chunk_str = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
if downstream:
chunk = chunk.decode("utf-8")
buffered_chunk_str += self.extract_chunk_str(chunk)
is_last = chunk.endswith("[DONE]\n\n")
if (buffered_chunk_str and buffered_chunk_str[-1] in hitted_ends) or is_last:
res = requests.post(
url=downstream_endpoint,
data=json.dumps({"text": buffered_chunk_str}),
proxies={"http": None},
)
res_json = res.json()
if "text" in res_json:
res_txt = res_json["text"]
else:
raise Exception("Other response types not supported yet!")

Check warning on line 176 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L176

Added line #L176 was not covered by tests
buffered_chunk_str = "" # clear
yield from self.token_generator(res_txt, is_last=is_last)
else:
yield chunk

Check warning on line 180 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L180

Added line #L180 was not covered by tests

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
else:
async with session.post(endpoint, json=inputs) as response:
print(f"{cur_node}: {response.status}")
return await response.json(), cur_node

Check warning on line 186 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L184-L186

Added lines #L184 - L186 were not covered by tests
elif self.services[cur_node].service_type == ServiceType.RETRIEVER:
retriever_parameters_dict = retriever_parameters.dict()
for field, value in retriever_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value
async with session.post(endpoint, json=inputs) as response:
print(response.status)
return await response.json(), cur_node

Check warning on line 194 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L188-L194

Added lines #L188 - L194 were not covered by tests
elif self.services[cur_node].service_type == ServiceType.RERANK:
reranker_parameters_dict = reranker_parameters.dict()
for field, value in reranker_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value
async with session.post(endpoint, json=inputs) as response:
print(response.status)
return await response.json(), cur_node

Check warning on line 202 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L196-L202

Added lines #L196 - L202 were not covered by tests
else:
async with session.post(endpoint, json=inputs) as response:
print(f"{cur_node}: {response.status}")
Expand Down
13 changes: 13 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ class LLMParams(BaseDoc):
)


class RetrieverParms(BaseDoc):
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2


class RerankerParms(BaseDoc):
top_n: int = 1


class RAGASParams(BaseDoc):
questions: DocList[TextDoc]
answers: DocList[TextDoc]
Expand Down
64 changes: 64 additions & 0 deletions tests/cores/mega/test_service_orchestrator_with_gatway_fake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from comps import Gateway, ServiceOrchestrator, TextDoc, opea_microservices, register_microservice
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "opea "
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add")
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "project!"
return {"text": text}


class TestServiceOrchestratorParmLLM(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = Gateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

# async def test_llm_schedule(self):
# result_dict, _ = await self.service_builder.schedule(
# initial_inputs={"text": "hello, "},
# llm_parameters=LLMParams(),
# )
# self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!")

async def test_retriever_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, "},
retriever_parameters=RetrieverParms(),
reranker_parameters=RerankerParms(),
)
self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
import os

from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms

MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8888)
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", 6000)
RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0")
RETRIEVER_SERVICE_PORT = os.getenv("RETRIEVER_SERVICE_PORT", 7000)
RERANK_SERVICE_HOST_IP = os.getenv("RERANK_SERVICE_HOST_IP", "0.0.0.0")
RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = os.getenv("LLM_SERVICE_PORT", 9000)


class ChatQnAService:
def __init__(self, host="0.0.0.0", port=8000):
self.host = host
self.port = port
self.megaservice = ServiceOrchestrator()

def add_remote_service(self):
embedding = MicroService(
name="embedding",
host=EMBEDDING_SERVICE_HOST_IP,
port=EMBEDDING_SERVICE_PORT,
endpoint="/v1/embeddings",
use_remote_service=True,
service_type=ServiceType.EMBEDDING,
)
retriever = MicroService(
name="retriever",
host=RETRIEVER_SERVICE_HOST_IP,
port=RETRIEVER_SERVICE_PORT,
endpoint="/v1/retrieval",
use_remote_service=True,
service_type=ServiceType.RETRIEVER,
)
rerank = MicroService(
name="rerank",
host=RERANK_SERVICE_HOST_IP,
port=RERANK_SERVICE_PORT,
endpoint="/v1/reranking",
use_remote_service=True,
service_type=ServiceType.RERANK,
)
llm = MicroService(
name="llm",
host=LLM_SERVICE_HOST_IP,
port=LLM_SERVICE_PORT,
endpoint="/v1/chat/completions",
use_remote_service=True,
service_type=ServiceType.LLM,
)
self.megaservice.add(embedding).add(retriever).add(rerank).add(llm)
self.megaservice.flow_to(embedding, retriever)
self.megaservice.flow_to(retriever, rerank)
self.megaservice.flow_to(rerank, llm)
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)

async def schedule(self):
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": "What is the revenue of Nike in 2023?"},
llm_parameters=LLMParams(),
)
print(result_dict)

result_dict, runtime_graph = await self.service_builder.schedule(
initial_inputs={"text": "hello, "},
retriever_parameters=RetrieverParms(),
reranker_parameters=RerankerParms(),
)


if __name__ == "__main__":
chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
chatqna.add_remote_service()
asyncio.run(chatqna.schedule())