Skip to content

Commit

Permalink
Expose send_data (#8)
Browse files Browse the repository at this point in the history
Also inform server of client_version (and api_version), handle new
transcript structure, expose data_message events, and update
dependencies.

This now matches fixie-ai/ultravox-client-sdk-flutter#11
  • Loading branch information
mdepinet authored Nov 26, 2024
1 parent e7a4220 commit 6633577
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 403 deletions.
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ default: format check test

# Install dependencies for local development.
install:
pip install poetry==1.7.1
pip install poetry==1.8.4
cd ultravox-client && poetry install --sync
cd example && poetry install --sync --no-root

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ This project uses [Poetry](https://python-poetry.org/) to manage dependencies al

## Publishing to PyPi
1. Bump version number in `ultravox_client/pyproject.toml`
1. (in the `ultravox_client` directory) Run `poetry publish --build -u __token__ -p <your_pypi_token>`
1. (in the `ultravox_client` directory) Run `poetry publish --build -u __token__ -p <your_pypi_token>`
1. Please tag the new version in GitHub and create a release, preferably with a changelog.
343 changes: 172 additions & 171 deletions example/poetry.lock

Large diffs are not rendered by default.

335 changes: 168 additions & 167 deletions ultravox-client/poetry.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions ultravox-client/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ultravox-client"
version = "0.0.6"
version = "0.0.7"
packages = [
{ include = "ultravox_client", from = "." },
]
Expand All @@ -15,11 +15,11 @@ keywords = ["ultravox", "audio", "realtime", "artificial intelligence"]

[tool.poetry.dependencies]
python = "^3.11"
livekit = "0.8"
websockets = "^12.0"
livekit = "^0.18.1"
websockets = "^14.1"
pyee = "^11.0.1"
sounddevice = "^0.5.0"
numpy = "^2.1.1"
sounddevice = "^0.5.1"
numpy = "^2.1.3"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.3"
Expand Down
129 changes: 72 additions & 57 deletions ultravox-client/ultravox_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import json
import logging
import urllib.parse
from importlib import metadata
from typing import Any, Awaitable, Callable, Literal, Tuple

import websockets
from livekit import rtc
from websockets.asyncio import client as ws_client
from websockets import exceptions as ws_exceptions

from ultravox_client import async_close
from ultravox_client import audio
Expand Down Expand Up @@ -89,11 +91,8 @@ async def _pump(self, stream: rtc.AudioStream):
async with contextlib.AsyncExitStack() as stack:
stack.push_async_callback(stream.aclose)
async for chunk in stream:
self._sink.write(
chunk.data.tobytes()
if self._enabled
else b"\x00" * len(chunk.data.tobytes())
)
data = chunk.frame.data.tobytes()
self._sink.write(data if self._enabled else b"\x00" * len(data))


class UltravoxSessionStatus(enum.Enum):
Expand Down Expand Up @@ -157,16 +156,19 @@ class UltravoxSession(patched_event_emitter.PatchedAsyncIOEventEmitter):
The message is included as the first argument to the event handler.
- "mic_muted": emitted when the user's microphone is muted or unmuted.
- "speaker_muted": emitted when the user's speaker (i.e. output audio from the agent) is muted or unmuted.
- "data_message": emitted when any data message is received (including those
typically handled by this SDK). See https://docs.ultravox.ai/api/data_messages.
The message is included as the first argument to the event handler.
"""

def __init__(self, experimental_messages: set[str] | None = None) -> None:
super().__init__()
self._transcripts: list[Transcript] = []
self._transcripts: list[Transcript | None] = []
self._status = UltravoxSessionStatus.DISCONNECTED

self._room: rtc.Room | None = None
self._room_listener: room_listener.RoomListener | None = None
self._socket: websockets.WebSocketClientProtocol | None = None
self._socket: ws_client.ClientConnection | None = None
self._receive_task: asyncio.Task | None = None
self._source_adapter: _AudioSourceToSendTrackAdapter | None = None
self._sink_adapter: _AudioSinkFromRecvTrackAdapter | None = None
Expand All @@ -179,7 +181,7 @@ def status(self):

@property
def transcripts(self):
return self._transcripts.copy()
return [t for t in self._transcripts if t is not None]

@property
def mic_muted(self) -> bool:
Expand Down Expand Up @@ -218,7 +220,7 @@ def toggle_speaker_muted(self) -> None:
async def set_output_medium(self, medium: Medium) -> None:
"""Sets the agent's output medium. If the agent is currently speaking, this will
take effect at the end of the agent's utterance. Also see speaker_muted above."""
await self._send_data({"type": "set_output_medium", "medium": medium})
await self.send_data({"type": "set_output_medium", "medium": medium})

def register_tool_implementation(
self, name: str, tool_impl: ClientToolImplementation
Expand Down Expand Up @@ -247,18 +249,26 @@ async def join_call(
join_url: str,
source: audio.AudioSource | None = None,
sink: audio.AudioSink | None = None,
client_version: str | None = None,
) -> None:
"""Connects to a call using the given joinUrl."""
if self._status != UltravoxSessionStatus.DISCONNECTED:
raise RuntimeError("Cannot join a new call while already in a call.")
self._update_status(UltravoxSessionStatus.CONNECTING)

url_parts = list(urllib.parse.urlparse(join_url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
if self._experimental_messages:
url_parts = list(urllib.parse.urlparse(join_url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query["experimentalMessages"] = ",".join(self._experimental_messages)
url_parts[4] = urllib.parse.urlencode(query)
join_url = urllib.parse.urlunparse(url_parts)
self._socket = await websockets.connect(join_url)
uv_client_version = f"python_{metadata.version('ultravox-client')}"
if client_version:
uv_client_version += f":{client_version}"
query["clientVersion"] = uv_client_version
query["apiVersion"] = "1"
url_parts[4] = urllib.parse.urlencode(query)
join_url = urllib.parse.urlunparse(url_parts)

self._socket = await ws_client.connect(join_url)
self._source_adapter = _AudioSourceToSendTrackAdapter(
source or audio.LocalAudioSource()
)
Expand All @@ -277,7 +287,15 @@ async def send_text(self, text: str):
raise RuntimeError(
f"Cannot send text while not connected. Current status is {self.status}"
)
await self._send_data({"type": "input_text_message", "text": text})
await self.send_data({"type": "input_text_message", "text": text})

async def send_data(self, msg: dict):
"""Sends a data message to the Ultravox server. See https://docs.ultravox.ai/api/data_messages."""
if not self._room:
raise RuntimeError("Cannot send data while not connected")
if "type" not in msg:
raise ValueError("Message must have a 'type' field")
await self._room.local_participant.publish_data(json.dumps(msg).encode("utf-8"))

async def _socket_receive(self):
assert self._socket
Expand All @@ -286,7 +304,7 @@ async def _socket_receive(self):
if isinstance(message, str):
await self._on_message(message)
except Exception as e:
if not isinstance(e, websockets.ConnectionClosed):
if not isinstance(e, ws_exceptions.ConnectionClosed):
logging.exception("UltravoxSession websocket error", exc_info=e)
await self._disconnect()

Expand Down Expand Up @@ -343,6 +361,7 @@ def _on_track_subscribed(
async def _on_data_received(self, data_packet: rtc.DataPacket):
msg = json.loads(data_packet.data.decode("utf-8"))
assert isinstance(msg, dict)
self.emit("data_message", msg)
match msg.get("type", None):
case "state":
match msg.get("state", None):
Expand All @@ -353,32 +372,18 @@ async def _on_data_received(self, data_packet: rtc.DataPacket):
case "speaking":
self._update_status(UltravoxSessionStatus.SPEAKING)
case "transcript":
transcript = Transcript(
msg["transcript"]["text"],
msg["transcript"]["final"],
"user",
msg["transcript"]["medium"],
ordinal = msg.get("ordinal", -1)
medium = msg.get("medium", "voice")
role = msg.get("role", "agent")
final = msg.get("final", False)
self._add_or_update_transcript(
ordinal,
medium,
role,
final,
text=msg.get("text", None),
delta=msg.get("delta", None),
)
self._add_or_update_transcript(transcript)
case "voice_synced_transcript" | "agent_text_transcript":
medium = "voice" if msg["type"] == "voice_synced_transcript" else "text"
if msg.get("text", None):
transcript = Transcript(
msg["text"], msg.get("final", False), "agent", medium
)
self._add_or_update_transcript(transcript)
elif msg.get("delta", None):
last_transcript = (
self._transcripts[-1] if self._transcripts else None
)
if last_transcript and last_transcript.speaker == "agent":
transcript = Transcript(
last_transcript.text + msg["delta"],
msg.get("final", False),
"agent",
medium,
)
self._add_or_update_transcript(transcript)
case "client_tool_invocation":
await self._invoke_client_tool(
msg["toolName"], msg["invocationId"], msg["parameters"]
Expand All @@ -400,7 +405,7 @@ async def _invoke_client_tool(
"errorType": "undefined",
"errorMessage": f"Client tool {tool_name} is not registered (Python client)",
}
await self._send_data(result_msg)
await self.send_data(result_msg)
return
try:
result = self._registered_tools[tool_name](parameters)
Expand All @@ -421,7 +426,7 @@ async def _invoke_client_tool(
if response_type:
assert isinstance(response_type, str)
result_msg["responseType"] = response_type
await self._send_data(result_msg)
await self.send_data(result_msg)
except Exception as e:
logging.exception(f"Error invoking client tool {tool_name}", exc_info=e)
result_msg = {
Expand All @@ -430,25 +435,35 @@ async def _invoke_client_tool(
"errorType": "implementation-error",
"errorMessage": str(e),
}
await self._send_data(result_msg)

async def _send_data(self, msg: dict):
assert self._room
await self._room.local_participant.publish_data(json.dumps(msg).encode("utf-8"))
await self.send_data(result_msg)

def _update_status(self, status: UltravoxSessionStatus):
if self._status == status:
return
self._status = status
self.emit("status")

def _add_or_update_transcript(self, transcript: Transcript):
if (
self._transcripts
and not self._transcripts[-1].final
and self._transcripts[-1].speaker == transcript.speaker
):
self._transcripts[-1] = transcript
def _add_or_update_transcript(
self,
ordinal: int,
medium: Medium,
role: Role,
final: bool,
*,
text: str | None = None,
delta: str | None = None,
):
present_text = text or delta or ""
while len(self._transcripts) < ordinal:
self._transcripts.append(None)
if len(self._transcripts) == ordinal:
self._transcripts.append(Transcript(present_text, final, role, medium))
else:
self._transcripts.append(transcript)
if text is not None:
new_text = text
else:
prior_transcript = self._transcripts[ordinal]
prior_text = prior_transcript.text if prior_transcript else ""
new_text = prior_text + (delta or "")
self._transcripts[ordinal] = Transcript(new_text, final, role, medium)
self.emit("transcripts")
4 changes: 3 additions & 1 deletion ultravox-client/ultravox_client/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def side_effect(url, extra_headers=None):
)
return server

mocker.patch("websockets.connect", side_effect=side_effect)
mocker.patch("websockets.asyncio.client.connect", side_effect=side_effect)
yield server
server.flush()

Expand Down Expand Up @@ -142,6 +142,7 @@ async def tool_impl(params: dict[str, Any]):
}
).encode(),
kind=rtc.DataPacketKind.KIND_RELIABLE,
participant=None,
)
fake_room.emit("data_received", data_packet)
await asyncio.sleep(0.001)
Expand Down Expand Up @@ -174,6 +175,7 @@ def tool_impl(params: dict[str, Any]):
}
).encode(),
kind=rtc.DataPacketKind.KIND_RELIABLE,
participant=None,
)
fake_room.emit("data_received", data_packet)
await asyncio.sleep(0.001)
Expand Down

0 comments on commit 6633577

Please sign in to comment.