Skip to content

Commit

Permalink
fix pydantic BaseModel in/out in dataflow (#818)
Browse files Browse the repository at this point in the history
* fix protocol in/out supported types

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] authored Oct 23, 2024
1 parent 3473bfb commit 02c3dfe
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
11 changes: 9 additions & 2 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aiohttp
import requests
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from ..proto.docarray import LLMParams
from .constants import ServiceType
Expand Down Expand Up @@ -44,7 +45,7 @@ def flow_to(self, from_service, to_service):
logger.error(e)
return False

async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs):
async def schedule(self, initial_inputs: Dict | BaseModel, llm_parameters: LLMParams = LLMParams(), **kwargs):
result_dict = {}
runtime_graph = DAG()
runtime_graph.graph = copy.deepcopy(self.graph)
Expand Down Expand Up @@ -201,7 +202,13 @@ def generate():
else:
if LOGFLAG:
logger.info(inputs)
async with session.post(endpoint, json=inputs) as response:
if not isinstance(inputs, dict):
input_data = inputs.dict()
# remove null
input_data = {k: v for k, v in input_data.items() if v is not None}
else:
input_data = inputs
async with session.post(endpoint, json=input_data) as response:
if response.content_type == "audio/wav":
audio_data = await response.read()
data = self.align_outputs(
Expand Down
39 changes: 39 additions & 0 deletions tests/cores/mega/test_service_orchestrator_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import unittest

from comps import ServiceOrchestrator, opea_microservices, register_microservice
from comps.cores.proto.api_protocol import ChatCompletionRequest


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: ChatCompletionRequest) -> ChatCompletionRequest:
# support pydantic protocol message object in/out in data flow
return request


class TestServiceOrchestratorProtocol(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s1.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"])

def tearDown(self):
self.s1.stop()

async def test_schedule(self):
input_data = ChatCompletionRequest(messages=[{"role": "user", "content": "What's up man?"}], seed=None)
result_dict, _ = await self.service_builder.schedule(initial_inputs=input_data)
self.assertEqual(
result_dict[self.s1.name]["messages"],
[{"role": "user", "content": "What's up man?"}],
)
self.assertEqual(result_dict[self.s1.name]["seed"], None)


if __name__ == "__main__":
unittest.main()

0 comments on commit 02c3dfe

Please sign in to comment.