Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize mega flow by removing microservice wrapper #582

Merged
merged 14 commits into from
Sep 4, 2024
18 changes: 16 additions & 2 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -167,8 +167,22 @@
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(

Check warning on line 170 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L170

Added line #L170 was not covered by tests
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(

Check warning on line 178 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L178

Added line #L178 was not covered by tests
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
Expand Down
69 changes: 56 additions & 13 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import copy
import json
import os
import re
from typing import Dict, List

Expand All @@ -14,6 +15,10 @@
from ..proto.docarray import LLMParams
from .constants import ServiceType
from .dag import DAG
from .logger import CustomLogger

logger = CustomLogger("comps-core-orchestrator")
LOGFLAG = os.getenv("LOGFLAG", False)


class ServiceOrchestrator(DAG):
Expand All @@ -36,18 +41,22 @@
self.add_edge(from_service.name, to_service.name)
return True
except Exception as e:
print(e)
logger.error(e)

Check warning on line 44 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L44

Added line #L44 was not covered by tests
return False

async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()):
async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs):
result_dict = {}
runtime_graph = DAG()
runtime_graph.graph = copy.deepcopy(self.graph)
if LOGFLAG:
logger.info(initial_inputs)

Check warning on line 52 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L52

Added line #L52 was not covered by tests

timeout = aiohttp.ClientTimeout(total=1000)
async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session:
pending = {
asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph, llm_parameters))
asyncio.create_task(
self.execute(session, node, initial_inputs, runtime_graph, llm_parameters, **kwargs)
)
for node in self.ind_nodes()
}
ind_nodes = self.ind_nodes()
Expand All @@ -67,11 +76,12 @@
for downstream in reversed(downstreams):
try:
if re.findall(black_node, downstream):
print(f"skip forwardding to {downstream}...")
if LOGFLAG:
logger.info(f"skip forwardding to {downstream}...")

Check warning on line 80 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L80

Added line #L80 was not covered by tests
runtime_graph.delete_edge(node, downstream)
downstreams.remove(downstream)
except re.error as e:
print("Pattern invalid! Operation cancelled.")
logger.error("Pattern invalid! Operation cancelled.")

Check warning on line 84 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L84

Added line #L84 was not covered by tests
if len(downstreams) == 0 and llm_parameters.streaming:
# turn the response to a StreamingResponse
# to make the response uniform to UI
Expand All @@ -90,7 +100,7 @@
inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict)
pending.add(
asyncio.create_task(
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
self.execute(session, d_node, inputs, runtime_graph, llm_parameters, **kwargs)
)
)
nodes_to_keep = []
Expand Down Expand Up @@ -121,21 +131,33 @@
inputs: Dict,
runtime_graph: DAG,
llm_parameters: LLMParams = LLMParams(),
**kwargs,
):
# send the cur_node request/reply
endpoint = self.services[cur_node].endpoint_path
llm_parameters_dict = llm_parameters.dict()
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value
if self.services[cur_node].service_type == ServiceType.LLM:
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value

# pre-process
inputs = self.align_inputs(inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs)

if (
self.services[cur_node].service_type == ServiceType.LLM
or self.services[cur_node].service_type == ServiceType.LVM
) and llm_parameters.streaming:
# Still leave to sync requests.post for StreamingResponse
if LOGFLAG:
logger.info(inputs)

Check warning on line 153 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L153

Added line #L153 was not covered by tests
response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
url=endpoint,
data=json.dumps(inputs),
headers={"Content-type": "application/json"},
proxies={"http": None},
stream=True,
timeout=1000,
)
downstream = runtime_graph.downstream(cur_node)
if downstream:
Expand Down Expand Up @@ -169,11 +191,32 @@
else:
yield chunk

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
return (
StreamingResponse(self.align_generator(generate(), **kwargs), media_type="text/event-stream"),
cur_node,
)
else:
if LOGFLAG:
logger.info(inputs)

Check warning on line 200 in comps/cores/mega/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/orchestrator.py#L200

Added line #L200 was not covered by tests
async with session.post(endpoint, json=inputs) as response:
print(f"{cur_node}: {response.status}")
return await response.json(), cur_node
# Parse as JSON
data = await response.json()
# post process
data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs)

return data, cur_node

def align_inputs(self, inputs, *args, **kwargs):
"""Override this method in megaservice definition."""
return inputs

def align_outputs(self, data, *args, **kwargs):
"""Override this method in megaservice definition."""
return data

def align_generator(self, gen, *args, **kwargs):
"""Override this method in megaservice definition."""
return gen

def dump_outputs(self, node, response, result_dict):
result_dict[node] = response
Expand Down
13 changes: 13 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,19 @@ class LLMParams(BaseDoc):
)


class RetrieverParms(BaseDoc):
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2


class RerankerParms(BaseDoc):
top_n: int = 1


class RAGASParams(BaseDoc):
questions: DocList[TextDoc]
answers: DocList[TextDoc]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from comps import (
EmbedDoc,
Gateway,
RerankedDoc,
ServiceOrchestrator,
TextDoc,
opea_microservices,
register_microservice,
)
from comps.cores.mega.constants import ServiceType
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add", service_type=ServiceType.RETRIEVER)
async def s1_add(request: EmbedDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += f"opea top_k {req_dict['k']}"
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.RERANK)
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "project!"
return {"text": text}


def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.RETRIEVER:
inputs["k"] = kwargs["retriever_parameters"].k

return inputs


def align_outputs(self, outputs, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.RERANK:
top_n = kwargs["reranker_parameters"].top_n
outputs["text"] = outputs["text"][:top_n]
return outputs


class TestServiceOrchestratorParams(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

ServiceOrchestrator.align_inputs = align_inputs
ServiceOrchestrator.align_outputs = align_outputs
self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = Gateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

async def test_retriever_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, ", "embedding": [1.0, 2.0, 3.0]},
retriever_parameters=RetrieverParms(k=8),
reranker_parameters=RerankerParms(top_n=20),
)
self.assertEqual(len(result_dict[self.s2.name]["text"]), 20) # Check reranker top_n is accessed
self.assertTrue("8" in result_dict[self.s2.name]["text"]) # Check retriever k is accessed


if __name__ == "__main__":
unittest.main()