Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved chunk validation #159

Draft
wants to merge 21 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions aidial_sdk/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from aidial_sdk.utils._reflection import get_method_implementation
from aidial_sdk.utils.log_config import LogConfig
from aidial_sdk.utils.logging import log_debug, set_log_deployment
from aidial_sdk.utils.streaming import merge_chunks
from aidial_sdk.utils.streaming import to_sse_stream

logging.config.dictConfig(LogConfig().dict())

Expand Down Expand Up @@ -201,18 +201,18 @@ async def _handler(original_request: Request):
impl.chat_completion, request
)

stream = response._generate_stream(first_chunk)

if request.stream:
return StreamingResponse(
response._generate_stream(first_chunk),
media_type="text/event-stream",
to_sse_stream(stream), media_type="text/event-stream"
)
else:
response_json = await merge_chunks(
response._generate_stream(first_chunk)
)

log_debug(f"response: {response_json}")
return JSONResponse(content=response_json)
async for _chunk in stream:
pass
response_dict = response.get_block_response()
log_debug(f"response: {response_dict}")
return JSONResponse(content=response_dict)

return _handler

Expand Down
7 changes: 7 additions & 0 deletions aidial_sdk/chat_completion/_chunk_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Protocol

from aidial_sdk.chat_completion.chunks import BaseChunk


class ChunkConsumer(Protocol):
def send_chunk(self, chunk: BaseChunk): ...
15 changes: 6 additions & 9 deletions aidial_sdk/chat_completion/choice.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from asyncio import Queue
from types import TracebackType
from typing import Any, Optional, Type, overload

from aidial_sdk.chat_completion._chunk_consumer import ChunkConsumer
from aidial_sdk.chat_completion.choice_base import ChoiceBase
from aidial_sdk.chat_completion.chunks import (
AttachmentChunk,
Expand All @@ -21,11 +20,10 @@
from aidial_sdk.utils._attachment import create_attachment
from aidial_sdk.utils._content_stream import ContentStream
from aidial_sdk.utils.errors import runtime_error
from aidial_sdk.utils.logging import log_debug


class Choice(ChoiceBase):
_queue: Queue
_sink: ChunkConsumer
_index: int
_last_attachment_index: int
_last_stage_index: int
Expand All @@ -36,8 +34,8 @@ class Choice(ChoiceBase):
_state_submitted: bool
_last_finish_reason: Optional[FinishReason]

def __init__(self, queue: Queue, choice_index: int):
self._queue = queue
def __init__(self, sink: ChunkConsumer, choice_index: int):
self._sink = sink
self._index = choice_index
self._last_attachment_index = 0
self._last_stage_index = 0
Expand All @@ -62,8 +60,7 @@ def __exit__(
return False

def send_chunk(self, chunk: BaseChunk) -> None:
log_debug("chunk: " + json.dumps(chunk.to_dict()))
self._queue.put_nowait(chunk)
self._sink.send_chunk(chunk)

@property
def index(self) -> int:
Expand Down Expand Up @@ -167,7 +164,7 @@ def create_stage(self, name: Optional[str] = None) -> Stage:
if self._closed:
raise runtime_error("Trying to create stage to a closed choice")

stage = Stage(self._queue, self._index, self._last_stage_index, name)
stage = Stage(self, self._index, self._last_stage_index, name)
self._last_stage_index += 1

return stage
Expand Down
75 changes: 54 additions & 21 deletions aidial_sdk/chat_completion/chunks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict

from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_sdk.chat_completion.enums import FinishReason, Status
from aidial_sdk.pydantic_v1 import BaseModel, root_validator
from aidial_sdk.utils.json import remove_nones


class DefaultChunk(TypedDict, total=False):
id: str
model: str
created: int
object: str


class BaseChunk(ABC):
_overrides: Optional[DefaultChunk] = None

def set_overrides(self, overrides: DefaultChunk):
self._overrides = overrides

@abstractmethod
def to_dict(self) -> Dict[str, Any]:
def to_raw_dict(self) -> Dict[str, Any]:
pass

def to_dict(self) -> Dict[str, Any]:
dict = self.to_raw_dict()
if self._overrides:
dict.update(self._overrides)
return dict


class ArbitraryChunk(BaseChunk):
data: ChatCompletionChunk

def __init__(self, data: ChatCompletionChunk):
self.data = data

def to_raw_dict(self):
return self.data.dict()


class StartChoiceChunk(BaseChunk):
choice_index: int

def __init__(self, choice_index: int):
self.choice_index = choice_index

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand All @@ -39,7 +69,7 @@ def __init__(self, finish_reason: FinishReason, choice_index: int):
self.finish_reason = finish_reason
self.choice_index = choice_index

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand All @@ -60,7 +90,7 @@ def __init__(self, content: str, choice_index: int):
self.content = content
self.choice_index = choice_index

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -94,7 +124,7 @@ def __init__(
self.name = name
self.arguments = arguments

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -138,7 +168,7 @@ def __init__(
self.name = name
self.arguments = arguments

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -170,7 +200,7 @@ def __init__(
self.stage_index = stage_index
self.name = name

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -203,7 +233,7 @@ def __init__(self, choice_index: int, stage_index: int, status: Status):
self.stage_index = stage_index
self.status = status

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -235,7 +265,7 @@ def __init__(self, choice_index: int, stage_index: int, content: str):
self.stage_index = stage_index
self.content = content

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -268,7 +298,7 @@ def __init__(self, choice_index: int, stage_index: int, name: str):
self.stage_index = stage_index
self.name = name

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -302,6 +332,9 @@ class Attachment(BaseModel):
reference_url: Optional[str]
reference_type: Optional[str]

class Config:
extra = "allow"

@root_validator
def check_data_or_url(cls, values):
data, url = values.get("data"), values.get("url")
Expand Down Expand Up @@ -333,7 +366,7 @@ def attachment_dict(self, index: int):


class AttachmentChunk(Attachment, BaseChunk):
def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand All @@ -355,7 +388,7 @@ def to_dict(self):
class AttachmentStageChunk(Attachment, BaseChunk):
stage_index: int

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand Down Expand Up @@ -390,7 +423,7 @@ def __init__(self, choice_index: int, state: Any):
self.state = state
self.choice_index = choice_index

def to_dict(self):
def to_raw_dict(self):
return {
"choices": [
{
Expand All @@ -411,13 +444,13 @@ def __init__(self, prompt_tokens: int, completion_tokens: int):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens

def to_dict(self):
def to_raw_dict(self):
return {
"usage": {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.prompt_tokens + self.completion_tokens,
}
},
}


Expand All @@ -439,7 +472,7 @@ def __init__(
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens

def to_dict(self):
def to_raw_dict(self):
return {
"statistics": {
"usage_per_model": [
Expand All @@ -452,7 +485,7 @@ def to_dict(self):
+ self.completion_tokens,
}
]
}
},
}


Expand All @@ -462,15 +495,15 @@ class DiscardedMessagesChunk(BaseChunk):
def __init__(self, discarded_messages: List[int]):
self.discarded_messages = discarded_messages

def to_dict(self):
def to_raw_dict(self):
return {
"statistics": {
"discarded_messages": self.discarded_messages,
}
},
}


class EndChunk:
class EndMarker:
exc: Optional[Exception]

def __init__(self, exc: Optional[Exception] = None):
Expand Down
Loading
Loading