Skip to content

Commit

Permalink
fix: supported approximate tokenization of tools and functions (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored and roman-romanov-o committed Nov 27, 2024
1 parent 35d7ee7 commit 37236f8
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 129 deletions.
32 changes: 18 additions & 14 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,32 @@


def plain_text_truncate_prompt(
messages: List[dict], max_prompt_tokens: int, tokenizer: PlainTextTokenizer
request: dict,
messages: List[dict],
max_prompt_tokens: int,
tokenizer: PlainTextTokenizer,
) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message["role"] == "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


async def gpt_chat_completion(
data: dict,
request: dict,
deployment_id: str,
upstream_endpoint: str,
creds: OpenAICreds,
api_version: str,
tokenizer: PlainTextTokenizer,
):
discarded_messages = None
prompt_tokens = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
estimated_prompt_tokens = None
if "max_prompt_tokens" in request:
max_prompt_tokens = request["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
raise InvalidRequestError(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
Expand All @@ -54,11 +57,12 @@ async def gpt_chat_completion(
raise InvalidRequestError(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
)
del data["max_prompt_tokens"]
del request["max_prompt_tokens"]

data["messages"], discarded_messages, prompt_tokens = (
request["messages"], discarded_messages, estimated_prompt_tokens = (
plain_text_truncate_prompt(
messages=cast(List[dict], data["messages"]),
request=request,
messages=cast(List[dict], request["messages"]),
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
)
Expand All @@ -68,14 +72,14 @@ async def gpt_chat_completion(
{**creds, "api_version": api_version}
)
response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(client.chat.completions.create, data)
await call_with_extra_body(client.chat.completions.create, request)
)

if isinstance(response, AsyncIterator):
return generate_stream(
get_prompt_tokens=lambda: prompt_tokens
or tokenizer.calculate_prompt_tokens(data["messages"]),
tokenize=tokenizer.calculate_text_tokens,
get_prompt_tokens=lambda: estimated_prompt_tokens
or tokenizer.tokenize_request(request, request["messages"]),
tokenize_response=tokenizer.tokenize_response,
deployment=deployment_id,
discarded_messages=discarded_messages,
stream=map_stream(chunk_to_dict, response),
Expand Down
43 changes: 23 additions & 20 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
ResourceProcessor,
)
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.chat_completion_response import (
ChatCompletionBlock,
)
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage
from aidial_adapter_openai.utils.sse_stream import parse_openai_sse_stream
Expand Down Expand Up @@ -116,18 +119,18 @@ async def predict_non_stream(


def multi_modal_truncate_prompt(
request: dict,
messages: List[MultiModalMessage],
max_prompt_tokens: int,
initial_prompt_tokens: int,
tokenizer: MultiModalTokenizer,
) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.calculate_message_tokens,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message.raw_message["role"]
== "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=initial_prompt_tokens,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


Expand Down Expand Up @@ -215,18 +218,18 @@ async def chat_completion(
if max_prompt_tokens is not None:
multi_modal_messages, discarded_messages, estimated_prompt_tokens = (
multi_modal_truncate_prompt(
request=request,
messages=multi_modal_messages,
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
tokenizer=tokenizer,
)
)
logger.debug(
f"prompt tokens after truncation: {estimated_prompt_tokens}"
)
else:
estimated_prompt_tokens = tokenizer.calculate_prompt_tokens(
multi_modal_messages
estimated_prompt_tokens = tokenizer.tokenize_request(
request, multi_modal_messages
)
logger.debug(
f"prompt tokens without truncation: {estimated_prompt_tokens}"
Expand Down Expand Up @@ -255,7 +258,7 @@ def debug_print(chunk: T) -> T:
debug_print,
generate_stream(
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize=tokenizer.calculate_text_tokens,
tokenize_response=tokenizer.tokenize_response,
deployment=deployment,
discarded_messages=discarded_messages,
stream=map_stream(
Expand All @@ -277,25 +280,25 @@ def debug_print(chunk: T) -> T:
type="invalid_response_error",
)

content = response["choices"][0]["message"].get("content") or ""
usage = response["usage"]

if discarded_messages:
response |= {
"statistics": {"discarded_messages": discarded_messages}
}

actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)
if usage := response.get("usage"):
actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)

actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.calculate_text_tokens(content)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.tokenize_response(
ChatCompletionBlock(resp=response)
)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
)

return response
51 changes: 51 additions & 0 deletions aidial_adapter_openai/utils/chat_completion_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Iterable, Literal, Self

from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks
from pydantic import BaseModel


class ChatCompletionResponse(BaseModel):
message_key: Literal["delta", "message"]
resp: dict = {}

@property
def usage(self) -> Any | None:
return self.resp.get("usage")

@property
def is_empty(self) -> bool:
return self.resp == {}

@property
def finish_reasons(self) -> Iterable[Any]:
for choice in self.resp.get("choices") or []:
if (reason := choice.get("finish_reason")) is not None:
yield reason

@property
def has_finish_reason(self) -> bool:
return len(list(self.finish_reasons)) > 0

@property
def messages(self) -> Iterable[Any]:
for choice in self.resp.get("choices") or []:
if (message := choice.get(self.message_key)) is not None:
yield message

@property
def has_messages(self) -> bool:
return len(list(self.messages)) > 0


class ChatCompletionBlock(ChatCompletionResponse):
def __init__(self, **kwargs):
super().__init__(message_key="message", **kwargs)


class ChatCompletionStreamingChunk(ChatCompletionResponse):
def __init__(self, **kwargs):
super().__init__(message_key="delta", **kwargs)

def merge(self, chunk: dict) -> Self:
self.resp = merge_chat_completion_chunks(self.resp, chunk)
return self
29 changes: 0 additions & 29 deletions aidial_adapter_openai/utils/merge_chunks.py

This file was deleted.

Loading

0 comments on commit 37236f8

Please sign in to comment.