Skip to content

Commit

Permalink
Client-implemented tools (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdepinet authored Oct 4, 2024
1 parent 44b3f1d commit 01e7e60
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 5 deletions.
10 changes: 9 additions & 1 deletion ultravox-client/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ultravox-client"
version = "0.0.4"
version = "0.0.5"
packages = [
{ include = "ultravox_client", from = "." },
]
Expand Down Expand Up @@ -39,3 +39,11 @@ build-backend = "poetry.core.masonry.api"

[tool.deptry]
extend_exclude = [".*test\\.py", ".*tool\\.py"]


[tool.pytest.ini_options]
asyncio_mode = "auto"
addopts = "--doctest-modules"
filterwarnings = [
"error",
]
21 changes: 21 additions & 0 deletions ultravox-client/ultravox_client/room_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio
import typing

from livekit import rtc

from ultravox_client import patched_event_emitter

EVENT_TYPES = list(typing.get_args(rtc.room.EventTypes))


class RoomListener(patched_event_emitter.PatchedAsyncIOEventEmitter):
def __init__(self, room: rtc.Room):
super().__init__(loop=asyncio.get_running_loop())
for event in EVENT_TYPES:
room.on(event, self.create_propagater(event))

def create_propagater(self, event: str):
def propagate(*args, **kwargs):
self.emit(event, *args, **kwargs)

return propagate
94 changes: 90 additions & 4 deletions ultravox-client/ultravox_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
import contextlib
import dataclasses
import enum
import inspect
import json
import logging
import urllib.parse
from typing import Literal
from typing import Any, Awaitable, Callable, Literal, Tuple

import websockets
from livekit import rtc

from ultravox_client import async_close
from ultravox_client import audio
from ultravox_client import patched_event_emitter
from ultravox_client import room_listener


class _AudioSourceToSendTrackAdapter:
Expand Down Expand Up @@ -124,6 +126,12 @@ def is_live(self):
Role = Literal["user", "agent"]


ClientToolImplementation = Callable[
[dict[str, Any]],
str | Awaitable[str] | Tuple[str, str] | Awaitable[Tuple[str, str]],
]


@dataclasses.dataclass(frozen=True)
class Transcript:
"""A transcription of a single utterance."""
Expand Down Expand Up @@ -156,11 +164,13 @@ def __init__(self, experimental_messages: set[str] | None = None) -> None:
self._status = UltravoxSessionStatus.DISCONNECTED

self._room: rtc.Room | None = None
self._room_listener: room_listener.RoomListener | None = None
self._socket: websockets.WebSocketClientProtocol | None = None
self._receive_task: asyncio.Task | None = None
self._source_adapter: _AudioSourceToSendTrackAdapter | None = None
self._sink_adapter: _AudioSinkFromRecvTrackAdapter | None = None
self._experimental_messages = experimental_messages or set()
self._registered_tools: dict[str, ClientToolImplementation] = {}

@property
def status(self):
Expand Down Expand Up @@ -204,6 +214,28 @@ def toggle_speaker_muted(self) -> None:
"""Toggles the mute state of the user's speaker (i.e. output audio from the agent)."""
self.speaker_muted = not self.speaker_muted

def register_tool_implementation(
self, name: str, tool_impl: ClientToolImplementation
) -> None:
"""Registers a client tool implementation with the given name. If the
call is started with a client-implemented tool, this implementation will
be invoked when the model calls the tool.
The implementation should accept a single argument, a dict[str, Any] of
parameters defined by the tool, and return a string or a tuple of two
strings where the first is the result value and the second is the
response type (for affecting the call itself, e.g. to have the agent
hang up). The implementation may optionally be async.
See https://docs.ultravox.ai/tools for more information."""
self._registered_tools[name] = tool_impl

def register_tool_implementations(
self, impls: dict[str, ClientToolImplementation]
) -> None:
"""Convenience batch wrapper on register_tool_implementation."""
self._registered_tools.update(impls)

async def join_call(
self,
join_url: str,
Expand Down Expand Up @@ -257,8 +289,13 @@ async def _on_message(self, payload: str):
match msg.get("type", None):
case "room_info":
self._room = rtc.Room()
self._room.on("track_subscribed", self._on_track_subscribed)
self._room.on("data_received", self._on_data_received)
self._room_listener = room_listener.RoomListener(self._room)
self._room_listener.add_listener(
"track_subscribed", self._on_track_subscribed
)
self._room_listener.add_listener(
"data_received", self._on_data_received
)

await self._room.connect(msg["roomUrl"], msg["token"])
self._update_status(UltravoxSessionStatus.IDLE)
Expand Down Expand Up @@ -297,7 +334,7 @@ def _on_track_subscribed(
assert self._sink_adapter
self._sink_adapter.start(track)

def _on_data_received(self, data_packet: rtc.DataPacket):
async def _on_data_received(self, data_packet: rtc.DataPacket):
msg = json.loads(data_packet.data.decode("utf-8"))
assert isinstance(msg, dict)
match msg.get("type", None):
Expand Down Expand Up @@ -336,10 +373,59 @@ def _on_data_received(self, data_packet: rtc.DataPacket):
medium,
)
self._add_or_update_transcript(transcript)
case "client_tool_invocation":
await self._invoke_client_tool(
msg["toolName"], msg["invocationId"], msg["parameters"]
)
case _:
if self._experimental_messages:
self.emit("experimental_message", msg)

async def _invoke_client_tool(
self, tool_name: str, invocation_id: str, parameters: dict[str, Any]
):
if tool_name not in self._registered_tools:
logging.warning(
f"Client tool {tool_name} was invoked but is not registered"
)
result_msg = {
"type": "client_tool_result",
"invocationId": invocation_id,
"errorType": "undefined",
"errorMessage": f"Client tool {tool_name} is not registered (Python client)",
}
await self._send_data(result_msg)
return
try:
result = self._registered_tools[tool_name](parameters)
if inspect.isawaitable(result):
result = await result
if isinstance(result, tuple):
val = result[0]
response_type = result[1]
else:
val = result
response_type = None
assert isinstance(val, str)
result_msg = {
"type": "client_tool_result",
"invocationId": invocation_id,
"result": val,
}
if response_type:
assert isinstance(response_type, str)
result_msg["responseType"] = response_type
await self._send_data(result_msg)
except Exception as e:
logging.exception(f"Error invoking client tool {tool_name}", exc_info=e)
result_msg = {
"type": "client_tool_result",
"invocationId": invocation_id,
"errorType": "implementation-error",
"errorMessage": str(e),
}
await self._send_data(result_msg)

async def _send_data(self, msg: dict):
assert self._room
await self._room.local_participant.publish_data(json.dumps(msg).encode("utf-8"))
Expand Down
190 changes: 190 additions & 0 deletions ultravox-client/ultravox_client/session_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import asyncio
from collections import defaultdict
from unittest import mock
from typing import Any, AsyncGenerator, Callable

import json
import pytest
import websockets
from livekit import rtc

from ultravox_client import audio
from ultravox_client import session


class FakeRoom:
def __init__(self):
self.listeners: dict[str, list[Callable]] = defaultdict(list)
self.local_participant = mock.AsyncMock(spec=rtc.LocalParticipant)

def on(self, event: str, listener: Callable):
self.listeners[event].append(listener)

def emit(self, event: str, *args, **kwargs):
for listener in self.listeners[event]:
listener(*args, **kwargs)

async def connect(self, room_url: str, token: str):
pass

async def disconnect(self):
pass


@pytest.fixture(autouse=True)
def fake_room(mocker):
room = FakeRoom()
mocker.patch("livekit.rtc.Room", return_value=room)
mocker.patch("ultravox_client.session._AudioSourceToSendTrackAdapter.start")
mocker.patch("ultravox_client.session._AudioSinkFromRecvTrackAdapter.start")
return room


class FakeWsServer:
def __init__(self):
super().__init__()
self._messages = asyncio.Queue()
self.open = True

def __aiter__(self):
return self

async def __anext__(self):
message = await self._messages.get()
if message is None:
raise StopAsyncIteration
if isinstance(message, Exception):
if isinstance(message, websockets.ConnectionClosed):
self.open = False
raise message
return message

@property
def response_headers(self):
return {}

def add_message(self, message: str):
self._add(message)

def add_error(self, error: Exception):
self._add(error)

def flush(self):
self._add(None)

def _add(self, message):
self._messages.put_nowait(message)

def reset(self, url: str):
self._messages = asyncio.Queue()
self.open = True

@property
def closed(self):
return not self.open

async def close(self):
self.open = False
self.flush()

async def send(self, message):
raise AssertionError(f"Unexpected web socket message sent by client: {message}")


@pytest.fixture(autouse=True)
async def fake_ws_server(mocker):
server = FakeWsServer()

async def side_effect(url, extra_headers=None):
server.reset(url)
server.add_message(
'{"type":"room_info", "roomUrl": "wss://some-url", "token": "banana"}'
)
return server

mocker.patch("websockets.connect", side_effect=side_effect)
yield server
server.flush()


class FakeAudioSource(audio.AudioSource):
async def stream(self) -> AsyncGenerator[bytes, None]:
yield b"\0" * 3200


class FakeAudioSink(audio.AudioSink):
def write(self, data: bytes) -> None:
pass

async def close(self) -> None:
pass


async def test_client_tool_implementation(fake_room):
s = session.UltravoxSession()

async def tool_impl(params: dict[str, Any]):
assert params == {"foo": "bar"}
await asyncio.sleep(0)
return "baz"

s.register_tool_implementation("test_tool", tool_impl)
await s.join_call("wss://test.ultravox.ai", FakeAudioSource(), FakeAudioSink())
await asyncio.sleep(0.001)

data_packet = rtc.DataPacket(
data=json.dumps(
{
"type": "client_tool_invocation",
"toolName": "test_tool",
"invocationId": "call_1",
"parameters": {"foo": "bar"},
}
).encode(),
kind=rtc.DataPacketKind.KIND_RELIABLE,
)
fake_room.emit("data_received", data_packet)
await asyncio.sleep(0.001)
fake_room.local_participant.publish_data.assert_called_once_with(
json.dumps(
{"type": "client_tool_result", "invocationId": "call_1", "result": "baz"}
).encode("utf-8")
)
await s.leave_call()


async def test_client_tool_implementation_with_response_type(fake_room):
s = session.UltravoxSession()

def tool_impl(params: dict[str, Any]):
assert params == {"foo": "bar"}
return '{"strict": true}', "hang-up"

s.register_tool_implementation("test_tool", tool_impl)
await s.join_call("wss://test.ultravox.ai", FakeAudioSource(), FakeAudioSink())
await asyncio.sleep(0.001)

data_packet = rtc.DataPacket(
data=json.dumps(
{
"type": "client_tool_invocation",
"toolName": "test_tool",
"invocationId": "call_1",
"parameters": {"foo": "bar"},
}
).encode(),
kind=rtc.DataPacketKind.KIND_RELIABLE,
)
fake_room.emit("data_received", data_packet)
await asyncio.sleep(0.001)
fake_room.local_participant.publish_data.assert_called_once_with(
json.dumps(
{
"type": "client_tool_result",
"invocationId": "call_1",
"result": '{"strict": true}',
"responseType": "hang-up",
}
).encode("utf-8")
)
await s.leave_call()

0 comments on commit 01e7e60

Please sign in to comment.