Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Fixes) OpenAI Streaming Token Counting + Fixes usage track when litellm.turn_off_message_logging=True #8156

Merged
merged 10 commits into from
Jan 31, 2025
71 changes: 40 additions & 31 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,21 +1029,13 @@ def success_handler( # noqa: PLR0915
] = None
if "complete_streaming_response" in self.model_call_details:
return # break out of this.
if self.stream and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, ModelResponseStream)
):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
complete_streaming_response = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
end_time=end_time,
is_async=False,
streaming_chunks=self.sync_streaming_chunks,
)
if complete_streaming_response is not None:
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
Expand Down Expand Up @@ -1542,22 +1534,13 @@ async def async_success_handler( # noqa: PLR0915
return # break out of this.
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = None
if self.stream is True and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.ModelResponseStream)
or isinstance(result, TextCompletionResponse)
):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.streaming_chunks,
is_async=True,
)
] = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
end_time=end_time,
is_async=True,
streaming_chunks=self.streaming_chunks,
)

if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
Expand Down Expand Up @@ -2259,6 +2242,32 @@ def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List:
_new_callbacks.append(_c)
return _new_callbacks

def _get_assembled_streaming_response(
self,
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
start_time: datetime.datetime,
end_time: datetime.datetime,
is_async: bool,
streaming_chunks: List[Any],
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
if isinstance(result, ModelResponse):
return result
elif isinstance(result, TextCompletionResponse):
return result
elif isinstance(result, ModelResponseStream):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=streaming_chunks,
is_async=is_async,
)
return complete_streaming_response
return None


def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""
Expand Down
68 changes: 18 additions & 50 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, cast

import httpx
Expand All @@ -14,6 +13,7 @@
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.utils import Delta
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import (
Expand All @@ -29,11 +29,6 @@
from .llm_response_utils.get_api_base import get_api_base
from .rules import Rules

MAX_THREADS = 100

# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)


def is_async_iterable(obj: Any) -> bool:
"""
Expand Down Expand Up @@ -1568,21 +1563,6 @@ async def __anext__(self): # noqa: PLR0915
)
if processed_chunk is None:
continue
## LOGGING
## LOGGING
executor.submit(
self.logging_obj.success_handler,
result=processed_chunk,
start_time=None,
end_time=None,
cache_hit=cache_hit,
)

asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)

if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
Expand Down Expand Up @@ -1634,16 +1614,6 @@ async def __anext__(self): # noqa: PLR0915
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)

choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
Expand Down Expand Up @@ -1671,33 +1641,31 @@ async def __anext__(self): # noqa: PLR0915
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

asyncio.create_task(
self.logging_obj.async_success_handler(
response, cache_hit=cache_hit
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

executor.submit(
self.logging_obj.success_handler,
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)

raise StopAsyncIteration # Re-raise StopIteration
else:
self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)
return processed_chunk
except httpx.TimeoutException as e: # if httpx read timeout error occues
traceback_exception = traceback.format_exc()
Expand Down
5 changes: 5 additions & 0 deletions litellm/litellm_core_utils/thread_pool_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from concurrent.futures import ThreadPoolExecutor

MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
25 changes: 21 additions & 4 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
cast,
)
from urllib.parse import urlparse

import httpx
import openai
Expand Down Expand Up @@ -833,8 +834,9 @@ def streaming(
stream_options: Optional[dict] = None,
):
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
data.update(
self.get_stream_options(stream_options=stream_options, api_base=api_base)
)

openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
Expand Down Expand Up @@ -893,8 +895,9 @@ async def async_streaming(
):
response = None
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
data.update(
self.get_stream_options(stream_options=stream_options, api_base=api_base)
)
for _ in range(2):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
Expand Down Expand Up @@ -977,6 +980,20 @@ async def async_streaming(
status_code=500, message=f"{str(e)}", headers=error_headers
)

def get_stream_options(
self, stream_options: Optional[dict], api_base: Optional[str]
) -> dict:
"""
Pass `stream_options` to the data dict for OpenAI requests
"""
if stream_options is not None:
return {"stream_options": stream_options}
else:
# by default litellm will include usage for openai endpoints
if api_base is None or urlparse(api_base).hostname == "api.openai.com":
return {"stream_options": {"include_usage": True}}
return {}

# Embedding
@track_llm_api_timing()
async def make_openai_embedding_request(
Expand Down
6 changes: 1 addition & 5 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@
# Convert to str (if necessary)
claude_json_str = json.dumps(json_data)
import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -185,6 +184,7 @@

from openai import OpenAIError as OriginalError

from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
Expand Down Expand Up @@ -235,10 +235,6 @@

####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100

# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
sentry_sdk_instance = None
capture_exception = None
add_breadcrumb = None
Expand Down
20 changes: 20 additions & 0 deletions tests/local_testing/test_custom_callback_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ async def test_async_chat_openai_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -428,6 +430,7 @@ async def test_async_chat_openai_stream():
)
async for chunk in response:
continue
await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down Expand Up @@ -499,6 +502,8 @@ async def test_async_chat_azure_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
# test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -509,6 +514,7 @@ async def test_async_chat_azure_stream():
)
async for chunk in response:
continue
await asyncio.sleep(1)
except Exception:
pass
await asyncio.sleep(1)
Expand Down Expand Up @@ -540,6 +546,8 @@ async def test_async_chat_openai_stream_options():

async for chunk in response:
continue

await asyncio.sleep(1)
print("mock client args list=", mock_client.await_args_list)
mock_client.assert_awaited_once()
except Exception as e:
Expand Down Expand Up @@ -607,6 +615,8 @@ async def test_async_chat_bedrock_stream():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -617,6 +627,8 @@ async def test_async_chat_bedrock_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
await asyncio.sleep(1)
Expand Down Expand Up @@ -770,6 +782,8 @@ async def test_async_text_completion_bedrock():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.atext_completion(
Expand All @@ -780,6 +794,8 @@ async def test_async_text_completion_bedrock():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down Expand Up @@ -809,6 +825,8 @@ async def test_async_text_completion_openai_stream():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.atext_completion(
Expand All @@ -819,6 +837,8 @@ async def test_async_text_completion_openai_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down
Loading