diff --git a/comps/asr/asr.py b/comps/asr/asr.py index 97fbb0bb0..1f5cf2df4 100644 --- a/comps/asr/asr.py +++ b/comps/asr/asr.py @@ -10,8 +10,8 @@ from comps import ( Base64ByteStrDoc, + LLMParamsDoc, ServiceType, - TextDoc, opea_microservices, register_microservice, register_statistics, @@ -26,7 +26,7 @@ host="0.0.0.0", port=9099, input_datatype=Base64ByteStrDoc, - output_datatype=TextDoc, + output_datatype=LLMParamsDoc, ) @register_statistics(names=["opea_service@asr"]) async def audio_to_text(audio: Base64ByteStrDoc): @@ -37,7 +37,7 @@ async def audio_to_text(audio: Base64ByteStrDoc): response = requests.post(url=f"{asr_endpoint}/v1/asr", data=json.dumps(inputs), proxies={"http": None}) statistics_dict["opea_service@asr"].append_latency(time.time() - start, None) - return TextDoc(text=response.json()["asr_result"]) + return LLMParamsDoc(query=response.json()["asr_result"]) if __name__ == "__main__": diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index 723f0db5d..0217d0d17 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -131,13 +131,6 @@ def generate(): return StreamingResponse(generate(), media_type="text/event-stream"), cur_node else: - if ( - self.services[cur_node].service_type == ServiceType.LLM - and runtime_graph.predecessors(cur_node) - and "asr" in runtime_graph.predecessors(cur_node)[0] - ): - inputs["query"] = inputs["text"] - del inputs["text"] async with session.post(endpoint, json=inputs) as response: print(response.status) return await response.json(), cur_node