Skip to content

Commit

Permalink
fix: migrate data stream protocol to fix text display
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyoucao577 committed Oct 23, 2024
1 parent c2baf8c commit e08a07c
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 116 deletions.
145 changes: 92 additions & 53 deletions agents/ten_packages/extension/message_collector/src/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import base64
import json
import threading
import time
import uuid
from ten import (
Expand All @@ -18,6 +20,7 @@
CmdResult,
Data,
)
import asyncio
from .log import logger

MAX_SIZE = 800 # 1 KB limit
Expand All @@ -32,9 +35,70 @@

# record the cached text data for each stream id
cached_text_map = {}
MAX_CHUNK_SIZE_BYTES = 1024

def _text_to_base64_chunks(text: str, msg_id: str) -> list:
# Ensure msg_id does not exceed 50 characters
if len(msg_id) > 36:
raise ValueError("msg_id cannot exceed 36 characters.")

# Convert text to bytearray
byte_array = bytearray(text, 'utf-8')

# Encode the bytearray into base64
base64_encoded = base64.b64encode(byte_array).decode('utf-8')

# Initialize list to hold the final chunks
chunks = []

# We'll split the base64 string dynamically based on the final byte size
part_index = 0
total_parts = None # We'll calculate total parts once we know how many chunks we create

# Process the base64-encoded content in chunks
current_position = 0
total_length = len(base64_encoded)

while current_position < total_length:
part_index += 1

# Start guessing the chunk size by limiting the base64 content part
estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically
content_chunk = ""
count = 0
while True:
# Create the content part of the chunk
content_chunk = base64_encoded[current_position:current_position + estimated_chunk_size]

# Format the chunk
formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}"

# Check if the byte length of the formatted chunk exceeds the max allowed size
if len(bytearray(formatted_chunk, 'utf-8')) <= MAX_CHUNK_SIZE_BYTES:
break
else:
# Reduce the estimated chunk size if the formatted chunk is too large
estimated_chunk_size -= 100 # Reduce content size gradually
count += 1

logger.debug(f"chunk estimate guess: {count}")

# Add the current chunk to the list
chunks.append(formatted_chunk)
current_position += estimated_chunk_size # Move to the next part of the content

# Now that we know the total number of parts, update the chunks with correct total_parts
total_parts = len(chunks)
updated_chunks = [
chunk.replace("???", str(total_parts)) for chunk in chunks
]

return updated_chunks

class MessageCollectorExtension(Extension):
# Create the queue for message processing
queue = asyncio.Queue()

def on_init(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_init")
ten_env.on_init_done()
Expand All @@ -43,6 +107,13 @@ def on_start(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_start")

# TODO: read properties, initialize resources
self.loop = asyncio.new_event_loop()
def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
threading.Thread(target=start_loop, args=[]).start()

self.loop.create_task(self._process_queue(ten_env))

ten_env.on_start_done()

Expand Down Expand Up @@ -123,7 +194,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
cached_text_map[stream_id] = text

# Generate a unique message ID for this batch of parts
message_id = str(uuid.uuid4())
message_id = str(uuid.uuid4())[:8]

# Prepare the main JSON structure without the text field
base_msg_data = {
Expand All @@ -132,61 +203,13 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
"message_id": message_id, # Add message_id to identify the split message
"data_type": "transcribe",
"text_ts": int(time.time() * 1000), # Convert to milliseconds
"text": text,
}

try:
# Convert the text to UTF-8 bytes
text_bytes = text.encode('utf-8')

# If the text + metadata fits within the size limit, send it directly
if len(text_bytes) + OVERHEAD_ESTIMATE <= MAX_SIZE:
base_msg_data["text"] = text
msg_data = json.dumps(base_msg_data)
ten_data = Data.create("data")
ten_data.set_property_buf("data", msg_data.encode())
ten_env.send_data(ten_data)
else:
# Split the text bytes into smaller chunks, ensuring safe UTF-8 splitting
max_text_size = MAX_SIZE - OVERHEAD_ESTIMATE
total_length = len(text_bytes)
total_parts = (total_length + max_text_size - 1) // max_text_size # Calculate number of parts

def get_valid_utf8_chunk(start, end):
"""Helper function to ensure valid UTF-8 chunks."""
while end > start:
try:
# Decode to check if this chunk is valid UTF-8
text_part = text_bytes[start:end].decode('utf-8')
return text_part, end
except UnicodeDecodeError:
# Reduce the end point to avoid splitting in the middle of a character
end -= 1
# If no valid chunk is found (shouldn't happen with valid UTF-8 input), return an empty string
return "", start

part_number = 0
start_index = 0
while start_index < total_length:
part_number += 1
# Get a valid UTF-8 chunk
text_part, end_index = get_valid_utf8_chunk(start_index, min(start_index + max_text_size, total_length))

# Prepare the part data with metadata
part_data = base_msg_data.copy()
part_data.update({
"text": text_part,
"part_number": part_number,
"total_parts": total_parts,
})

# Send each part
part_msg_data = json.dumps(part_data)
ten_data = Data.create("data")
ten_data.set_property_buf("data", part_msg_data.encode())
ten_env.send_data(ten_data)

# Move to the next chunk
start_index = end_index
chunks = _text_to_base64_chunks(json.dumps(base_msg_data), message_id)
for chunk in chunks:
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)

except Exception as e:
logger.warning(f"on_data new_data error: {e}")
Expand All @@ -199,3 +222,19 @@ def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass


async def _queue_message(self, data: str):
await self.queue.put(data)

async def _process_queue(self, ten_env: TenEnv):
while True:
data = await self.queue.get()
if data is None:
break
# process data
ten_data = Data.create("data")
ten_data.set_property_buf("data", data.encode())
ten_env.send_data(ten_data)
self.queue.task_done()
await asyncio.sleep(0.04)
155 changes: 92 additions & 63 deletions demo/src/manager/rtc/rtc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ import { AGEventEmitter } from "../events"
import { RtcEvents, IUserTracks } from "./types"
import { apiGenAgoraData } from "@/common"


const TIMEOUT_MS = 5000; // Timeout for incomplete messages

interface TextDataChunk {
message_id: string;
part_index: number;
total_parts: number;
content: string;
}

export class RtcManager extends AGEventEmitter<RtcEvents> {
private _joined
client: IAgoraRTCClient
Expand Down Expand Up @@ -110,75 +120,94 @@ export class RtcManager extends AGEventEmitter<RtcEvents> {
private _parseData(data: any): ITextItem | void {
let decoder = new TextDecoder('utf-8');
let decodedMessage = decoder.decode(data);
const textstream = JSON.parse(decodedMessage);

console.log("[test] textstream raw data", JSON.stringify(textstream));
console.log("[test] textstream raw data", decodedMessage);

const { stream_id, is_final, text, text_ts, data_type, message_id, part_number, total_parts } = textstream;
// const { stream_id, is_final, text, text_ts, data_type, message_id, part_number, total_parts } = textstream;

if (total_parts > 0) {
// If message is split, handle it accordingly
this._handleSplitMessage(message_id, part_number, total_parts, stream_id, is_final, text, text_ts);
} else {
// If there is no message_id, treat it as a complete message
this._handleCompleteMessage(stream_id, is_final, text, text_ts);
}
// if (total_parts > 0) {
// // If message is split, handle it accordingly
// this._handleSplitMessage(message_id, part_number, total_parts, stream_id, is_final, text, text_ts);
// } else {
// // If there is no message_id, treat it as a complete message
// this._handleCompleteMessage(stream_id, is_final, text, text_ts);
// }

this.handleChunk(decodedMessage);
}

private messageCache: { [key: string]: { parts: string[], totalParts: number } } = {};

/**
* Handle complete messages (not split).
*/
private _handleCompleteMessage(stream_id: number, is_final: boolean, text: string, text_ts: number): void {
const textItem: ITextItem = {
uid: `${stream_id}`,
time: text_ts,
dataType: "transcribe",
text: text,
isFinal: is_final
};

if (text.trim().length > 0) {
this.emit("textChanged", textItem);

private messageCache: { [key: string]: TextDataChunk[] } = {};

// Function to process received chunk via event emitter
handleChunk(formattedChunk: string) {
try {
// Split the chunk by the delimiter "|"
const [message_id, partIndexStr, totalPartsStr, content] = formattedChunk.split('|');

const part_index = parseInt(partIndexStr, 10);
const total_parts = totalPartsStr === '???' ? -1 : parseInt(totalPartsStr, 10); // -1 means total parts unknown

// Ensure total_parts is known before processing further
if (total_parts === -1) {
console.warn(`Total parts for message ${message_id} unknown, waiting for further parts.`);
return;
}

const chunkData: TextDataChunk = {
message_id,
part_index,
total_parts,
content,
};

// Check if we already have an entry for this message
if (!this.messageCache[message_id]) {
this.messageCache[message_id] = [];
// Set a timeout to discard incomplete messages
setTimeout(() => {
if (this.messageCache[message_id]?.length !== total_parts) {
console.warn(`Incomplete message with ID ${message_id} discarded`);
delete this.messageCache[message_id]; // Discard incomplete message
}
}, TIMEOUT_MS);
}

// Cache this chunk by message_id
this.messageCache[message_id].push(chunkData);

// If all parts are received, reconstruct the message
if (this.messageCache[message_id].length === total_parts) {
const completeMessage = this.reconstructMessage(this.messageCache[message_id]);
const { stream_id, is_final, text, text_ts } = JSON.parse(atob(completeMessage));
const textItem: ITextItem = {
uid: `${stream_id}`,
time: text_ts,
dataType: "transcribe",
text: text,
isFinal: is_final
};

if (text.trim().length > 0) {
this.emit("textChanged", textItem);
}


// Clean up the cache
delete this.messageCache[message_id];
}
} catch (error) {
console.error('Error processing chunk:', error);
}
}

/**
* Handle split messages, track parts, and reassemble once all parts are received.
*/
private _handleSplitMessage(
message_id: string,
part_number: number,
total_parts: number,
stream_id: number,
is_final: boolean,
text: string,
text_ts: number
): void {
// Ensure the messageCache entry exists for this message_id
if (!this.messageCache[message_id]) {
this.messageCache[message_id] = { parts: [], totalParts: total_parts };
}

const cache = this.messageCache[message_id];

// Store the received part at the correct index (part_number starts from 1, so we use part_number - 1)
cache.parts[part_number - 1] = text;

// Check if all parts have been received
const receivedPartsCount = cache.parts.filter(part => part !== undefined).length;

if (receivedPartsCount === total_parts) {
// All parts have been received, reassemble the message
const fullText = cache.parts.join('');

// Now that the message is reassembled, handle it like a complete message
this._handleCompleteMessage(stream_id, is_final, fullText, text_ts);

// Remove the cached message since it is now fully processed
delete this.messageCache[message_id];
}

// Function to reconstruct the full message from chunks
reconstructMessage(chunks: TextDataChunk[]): string {
// Sort chunks by their part index
chunks.sort((a, b) => a.part_index - b.part_index);

// Concatenate all chunks to form the full message
return chunks.map(chunk => chunk.content).join('');
}


Expand All @@ -196,4 +225,4 @@ export class RtcManager extends AGEventEmitter<RtcEvents> {
}


export const rtcManager = new RtcManager()
export const rtcManager = new RtcManager()

0 comments on commit e08a07c

Please sign in to comment.