Skip to content

Commit

Permalink
feat: change the return value of python async send_cmd to async gener…
Browse files Browse the repository at this point in the history
…ator (#294)
  • Loading branch information
halajohn authored Nov 19, 2024
1 parent 8a69f98 commit ec86629
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 19 deletions.
32 changes: 23 additions & 9 deletions core/src/ten_runtime/binding/python/interface/ten/async_ten_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .cmd import Cmd
from .cmd_result import CmdResult
from .ten_env import TenEnv
from typing import AsyncGenerator


class AsyncTenEnv(TenEnv):
Expand All @@ -25,25 +26,38 @@ def __init__(
def __del__(self) -> None:
pass

async def send_cmd(self, cmd: Cmd) -> CmdResult:
q = asyncio.Queue(1)
async def send_cmd(self, cmd: Cmd) -> AsyncGenerator[CmdResult, None]:
q = asyncio.Queue(maxsize=10)
self._internal.send_cmd(
cmd,
lambda ten_env, result: asyncio.run_coroutine_threadsafe(
lambda _, result: asyncio.run_coroutine_threadsafe(
q.put(result), self._ten_loop
), # type: ignore
),
)
return await q.get()

async def send_json(self, json_str: str) -> CmdResult:
q = asyncio.Queue(1)
while True:
result: CmdResult = await q.get()
if result.is_completed():
yield result
# This is the final result, so break the while loop.
break
yield result

async def send_json(self, json_str: str) -> AsyncGenerator[CmdResult, None]:
q = asyncio.Queue(maxsize=10)
self._internal.send_json(
json_str,
lambda ten_env, result: asyncio.run_coroutine_threadsafe(
q.put(result), self._ten_loop
), # type: ignore
),
)
return await q.get()
while True:
result: CmdResult = await q.get()
if result.is_completed():
yield result
# This is the final result, so break the while loop.
break
yield result

def _deinit_routine(self) -> None:
# Wait for the internal thread to finish.
Expand Down
6 changes: 4 additions & 2 deletions packages/example_extensions/aio_http_server_python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ async def default_handler(self, request: web_request.Request):
'{"_ten":{"type":"close_app",'
'"dest":[{"app":"localhost"}]}}'
)
asyncio.create_task(self.ten_env.send_json(close_app_cmd_json))
asyncio.create_task(
anext(self.ten_env.send_json(close_app_cmd_json))
)
return web.Response(status=200, text="OK")
elif "name" in data["_ten"]:
# Send the command to the TEN runtime.
Expand All @@ -53,7 +55,7 @@ async def default_handler(self, request: web_request.Request):
if cmd is None:
return web.Response(status=400, text="Bad request")

cmd_result = await self.ten_env.send_cmd(cmd)
cmd_result = await anext(self.ten_env.send_cmd(cmd))
else:
return web.Response(status=404, text="Not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
# Send a new command to other extensions and wait for the result. The
# result will be returned to the original sender.
new_cmd = Cmd.create("hello")
cmd_result = await ten_env.send_cmd(new_cmd)
cmd_result = await anext(ten_env.send_cmd(new_cmd))
ten_env.return_result(cmd_result, cmd)

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def stop_thread(self):

async def send_cmd_async(self, ten_env: TenEnv, cmd: Cmd) -> CmdResult:
print("DefaultExtension send_cmd_async")
q = asyncio.Queue(1)
q = asyncio.Queue(maxsize=10)
ten_env.send_cmd(
cmd,
lambda ten_env, result: asyncio.run_coroutine_threadsafe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ async def greeting(self, ten_env: AsyncTenEnv) -> CmdResult:
await asyncio.sleep(1)

new_cmd = Cmd.create("greeting")
return await ten_env.send_cmd(
new_cmd,
)
return await anext(ten_env.send_cmd(new_cmd))

async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_json = cmd.to_json()
Expand All @@ -66,7 +64,7 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:

await asyncio.sleep(0.5)

result = await ten_env.send_cmd(new_cmd)
result = await anext(ten_env.send_cmd(new_cmd))

statusCode = result.get_status_code()
detail = result.get_property_string("detail")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
# Send a new command to other extensions and wait for the result. The
# result will be returned to the original sender.
new_cmd = Cmd.create("hello")
cmd_result = await ten_env.send_cmd(new_cmd)
cmd_result = await anext(ten_env.send_cmd(new_cmd))
ten_env.return_result(cmd_result, cmd)

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
# Send a new command to other extensions and wait for the result. The
# result will be returned to the original sender.
new_cmd = Cmd.create("hello")
cmd_result = await ten_env.send_cmd(new_cmd)
cmd_result = await anext(ten_env.send_cmd(new_cmd))
ten_env.return_result(cmd_result, cmd)

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down

0 comments on commit ec86629

Please sign in to comment.