Skip to content

Commit

Permalink
refact openai v2v (#438)
Browse files Browse the repository at this point in the history
* refact openai v2v

* reconnect opneai v2v
  • Loading branch information
TomasBack2Future authored Dec 2, 2024
1 parent df8a106 commit 817a1e3
Show file tree
Hide file tree
Showing 19 changed files with 652 additions and 784 deletions.
4 changes: 2 additions & 2 deletions agents/examples/demo/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -848,7 +848,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down
10 changes: 5 additions & 5 deletions agents/examples/experimental/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10,
"max_history": 10,
"vendor": "azure",
"base_uri": "${env:AZURE_OPENAI_REALTIME_BASE_URI}",
"path": "/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview",
"system_message": ""
"prompt": ""
}
},
{
Expand Down Expand Up @@ -2444,7 +2444,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -2566,7 +2566,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -2724,7 +2724,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10,
"max_history": 10,
"enable_storage": true
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from fastapi import Depends, FastAPI, HTTPException, Request
import asyncio

# Enable Pydantic debug mode
from pydantic import BaseConfig

BaseConfig.debug = True

# Set up logging
logging.config.dictConfig({
"version": 1,
Expand Down
222 changes: 156 additions & 66 deletions agents/ten_packages/extension/glue_python_async/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import traceback
import aiohttp
import json
import time
import re

from datetime import datetime
import numpy as np
from typing import List, Any, AsyncGenerator
from dataclasses import dataclass
from dataclasses import dataclass, field
from pydantic import BaseModel

from ten import (
Expand All @@ -23,7 +25,7 @@
Data,
)

from ten_ai_base import BaseConfig, ChatMemory
from ten_ai_base import BaseConfig, ChatMemory, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails, EVENT_MEMORY_APPENDED
from ten_ai_base.llm import AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata
from ten_ai_base.types import LLMChatCompletionUserMessageParam, LLMToolResult

Expand Down Expand Up @@ -84,27 +86,9 @@ class Choice(BaseModel):
index: int
finish_reason: str | None

class CompletionTokensDetails(BaseModel):
accepted_prediction_tokens: int = 0
audio_tokens: int = 0
reasoning_tokens: int = 0
rejected_prediction_tokens: int = 0

class PromptTokensDetails(BaseModel):
audio_tokens: int = 0
cached_tokens: int = 0

class Usage(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0

completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails | None = None

class ResponseChunk(BaseModel):
choices: List[Choice]
usage: Usage | None = None
usage: LLMUsage | None = None

@dataclass
class GlueConfig(BaseConfig):
Expand All @@ -113,17 +97,29 @@ class GlueConfig(BaseConfig):
prompt: str = ""
max_history: int = 10
greeting: str = ""
failure_info: str = ""
modalities: List[str] = field(default_factory=lambda: ["text"])
rtm_enabled: bool = True
ssml_enabled: bool = False
context_enabled: bool = False
extra_context: dict = field(default_factory=dict)
enable_storage: bool = False

class AsyncGlueExtension(AsyncLLMBaseExtension):
config : GlueConfig = None
sentence_fragment: str = ""
ten_env: AsyncTenEnv = None
loop: asyncio.AbstractEventLoop = None
stopped: bool = False
memory: ChatMemory = None
total_usage: Usage = Usage()
total_usage: LLMUsage = LLMUsage()
users_count = 0

completion_times = []
connect_times = []
first_token_times = []

remote_stream_id: int = 999 # TODO

async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
ten_env.log_debug("on_init")
Expand All @@ -139,6 +135,21 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:

self.memory = ChatMemory(self.config.max_history)

if self.config.enable_storage:
result = await ten_env.send_cmd(Cmd.create("retrieve"))
if result.get_status_code() == StatusCode.OK:
try:
history = json.loads(result.get_property_string("response"))
for i in history:
self.memory.put(i)
ten_env.log_info(f"on retrieve context {history}")
except Exception as e:
ten_env.log_error("Failed to handle retrieve result {e}")
else:
ten_env.log_warn("Failed to retrieve content")

self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)

self.ten_env = ten_env

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down Expand Up @@ -187,7 +198,21 @@ async def on_data_chat_completion(self, ten_env: AsyncTenEnv, **kargs: LLMDataCo
messages = []
if self.config.prompt:
messages.append({"role": "system", "content": self.config.prompt})
messages.extend(self.memory.get())

history = self.memory.get()
while history:
if history[0].get("role") == "tool":
history = history[1:]
continue
if history[0].get("role") == "assistant" and history[0].get("tool_calls"):
history = history[1:]
continue

# Skip the first tool role
break

messages.extend(history)

if not input:
ten_env.log_warn("No message in data")
else:
Expand Down Expand Up @@ -220,6 +245,10 @@ def tool_dict(tool: LLMToolMetadata):
json["function"]["parameters"]["required"].append(param.name)

return json

def trim_xml(input_string):
return re.sub(r'<[^>]+>', '', input_string).strip()

tools = []
for tool in self.available_tools:
tools.append(tool_dict(tool))
Expand All @@ -229,16 +258,25 @@ def tool_dict(tool: LLMToolMetadata):
calls = {}

sentences = []
start_time = time.time()
first_token_time = None
response = self._stream_chat(messages=messages, tools=tools)
async for message in response:
self.ten_env.log_info(f"content: {message}")
self.ten_env.log_debug(f"content: {message}")
# TODO: handle tool call
try:
c = ResponseChunk(**message)
if c.choices:
if c.choices[0].delta.content:
total_output += c.choices[0].delta.content
sentences, sentence_fragment = parse_sentences(sentence_fragment, c.choices[0].delta.content)
if first_token_time is None:
first_token_time = time.time()
self.first_token_times.append(first_token_time - start_time)

content = c.choices[0].delta.content
if self.config.ssml_enabled and content.startswith("<speak>"):
content = trim_xml(content)
total_output += content
sentences, sentence_fragment = parse_sentences(sentence_fragment, content)
for s in sentences:
await self._send_text(s)
if c.choices[0].delta.tool_calls:
Expand All @@ -252,10 +290,14 @@ def tool_dict(tool: LLMToolMetadata):
calls[call.index].function.arguments += call.function.arguments
if c.usage:
self.ten_env.log_info(f"usage: {c.usage}")
self._update_usage(c.usage)
await self._update_usage(c.usage)
except Exception as e:
self.ten_env.log_error(f"Failed to parse response: {message} {e}")
traceback.print_exc()
if sentence_fragment:
await self._send_text(sentence_fragment)
end_time = time.time()
self.completion_times.append(end_time - start_time)

if total_output:
self.memory.put({"role": "assistant", "content": total_output})
Expand Down Expand Up @@ -343,48 +385,67 @@ async def _send_text(self, text: str) -> None:
self.ten_env.send_data(data)

async def _stream_chat(self, messages: List[Any], tools: List[Any]) -> AsyncGenerator[dict, None]:
session = aiohttp.ClientSession()
try:
payload = {
"messages": messages,
"tools": tools,
"tools_choice": "auto" if tools else "none",
"model": "gpt-3.5-turbo",
"stream": True,
"stream_options": {"include_usage": True}
}
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}")
headers = {
"Authorization": f"Bearer {self.config.token}",
"Content-Type": "application/json"
}

async with session.post(self.config.api_url, json=payload, headers=headers) as response:
if response.status != 200:
r = await response.json()
self.ten_env.log_error(f"Received unexpected status {r} from the server.")
return
async with aiohttp.ClientSession() as session:
try:
payload = {
"messages": messages,
"tools": tools,
"tools_choice": "auto" if tools else "none",
"model": "gpt-3.5-turbo",
"stream": True,
"stream_options": {"include_usage": True},
"ssml_enabled": self.config.ssml_enabled
}
if self.config.context_enabled:
payload["context"] = {
**self.config.extra_context
}
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}")
headers = {
"Authorization": f"Bearer {self.config.token}",
"Content-Type": "application/json"
}

async for line in response.content:
if line:
l = line.decode('utf-8').strip()
if l.startswith("data:"):
content = l[5:].strip()
if content == "[DONE]":
break
self.ten_env.log_info(f"content: {content}")
yield json.loads(content)
except Exception as e:
self.ten_env.log_error(f"Failed to handle {e}")
finally:
await session.close()
session = None
start_time = time.time()
async with session.post(self.config.api_url, json=payload, headers=headers) as response:
if response.status != 200:
r = await response.json()
self.ten_env.log_error(f"Received unexpected status {r} from the server.")
if self.config.failure_info:
await self._send_text(self.config.failure_info)
return
end_time = time.time()
self.connect_times.append(end_time - start_time)

async for line in response.content:
if line:
l = line.decode('utf-8').strip()
if l.startswith("data:"):
content = l[5:].strip()
if content == "[DONE]":
break
self.ten_env.log_debug(f"content: {content}")
yield json.loads(content)
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"Failed to handle {e}")
finally:
await session.close()
session = None

async def _update_usage(self, usage: LLMUsage) -> None:
if not self.config.rtm_enabled:
return

async def _update_usage(self, usage: Usage) -> None:
self.total_usage.completion_tokens += usage.completion_tokens
self.total_usage.prompt_tokens += usage.prompt_tokens
self.total_usage.total_tokens += usage.total_tokens

if self.total_usage.completion_tokens_details is None:
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails()
if self.total_usage.prompt_tokens_details is None:
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails()

if usage.completion_tokens_details:
self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage.completion_tokens_details.accepted_prediction_tokens
self.total_usage.completion_tokens_details.audio_tokens += usage.completion_tokens_details.audio_tokens
Expand All @@ -395,4 +456,33 @@ async def _update_usage(self, usage: Usage) -> None:
self.total_usage.prompt_tokens_details.audio_tokens += usage.prompt_tokens_details.audio_tokens
self.total_usage.prompt_tokens_details.cached_tokens += usage.prompt_tokens_details.cached_tokens

self.ten_env.log_info(f"total usage: {self.total_usage}")
self.ten_env.log_info(f"total usage: {self.total_usage}")

data = Data.create("llm_stat")
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump()))
if self.connect_times and self.completion_times and self.first_token_times:
data.set_property_from_json("latency", json.dumps({
"connection_latency_95": np.percentile(self.connect_times, 95),
"completion_latency_95": np.percentile(self.completion_times, 95),
"first_token_latency_95": np.percentile(self.first_token_times, 95),
"connection_latency_99": np.percentile(self.connect_times, 99),
"completion_latency_99": np.percentile(self.completion_times, 99),
"first_token_latency_99": np.percentile(self.first_token_times, 99)
}))
self.ten_env.send_data(data)

async def _on_memory_appended(self, message: dict) -> None:
self.ten_env.log_info(f"Memory appended: {message}")
if not self.config.enable_storage:
return

role = message.get("role")
stream_id = self.remote_stream_id if role == "user" else 0
try:
d = Data.create("append")
d.set_property_string("text", message.get("content"))
d.set_property_string("role", role)
d.set_property_int("stream_id", stream_id)
self.ten_env.send_data(d)
except Exception as e:
self.ten_env.log_error(f"Error send append_context data {message} {e}")
Loading

0 comments on commit 817a1e3

Please sign in to comment.