Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure base_model cost tracking works across all endpoints #7989

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def completion_cost( # noqa: PLR0915
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
"""
try:

call_type = _infer_call_type(call_type, completion_response) or "completion"

if (
Expand Down
101 changes: 101 additions & 0 deletions litellm/litellm_core_utils/get_litellm_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Optional


def _get_base_model_from_litellm_call_metadata(
metadata: Optional[dict],
) -> Optional[str]:
if metadata is None:
return None

if metadata is not None:
model_info = metadata.get("model_info", {})

if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None


def get_litellm_params(
api_key: Optional[str] = None,
force_timeout=600,
azure=False,
logger_fn=None,
verbose=False,
hugging_face=False,
replicate=False,
together_ai=False,
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
litellm_call_id=None,
model_alias_map=None,
completion_call_id=None,
metadata: Optional[dict] = None,
model_info=None,
proxy_server_request=None,
acompletion=None,
aembedding=None,
preset_cache_key=None,
no_log=None,
input_cost_per_second=None,
input_cost_per_token=None,
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model: Optional[str] = None,
litellm_trace_id: Optional[str] = None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
**kwargs,
) -> dict:
litellm_params = {
"acompletion": acompletion,
"api_key": api_key,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
"verbose": verbose,
"custom_llm_provider": custom_llm_provider,
"api_base": api_base,
"litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map,
"completion_call_id": completion_call_id,
"aembedding": aembedding,
"metadata": metadata,
"model_info": model_info,
"proxy_server_request": proxy_server_request,
"preset_cache_key": preset_cache_key,
"no-log": no_log,
"stream_response": {}, # litellm_call_id: ModelResponse Dict
"input_cost_per_token": input_cost_per_token,
"input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params,
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
"async_call": async_call,
"ssl_verify": ssl_verify,
}
return litellm_params
16 changes: 15 additions & 1 deletion litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.pagerduty.pagerduty import PagerDutyAlerting
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
Expand Down Expand Up @@ -256,10 +257,19 @@ def __init__(
self.completion_start_time: Optional[datetime.datetime] = None
self._llm_caching_handler: Optional[LLMCachingHandler] = None

# INITIAL LITELLM_PARAMS
litellm_params = {}
if kwargs is not None:
litellm_params = get_litellm_params(**kwargs)
litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)

self.litellm_params = litellm_params

self.model_call_details: Dict[str, Any] = {
"litellm_trace_id": litellm_trace_id,
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
}

def process_dynamic_callbacks(self):
Expand Down Expand Up @@ -358,7 +368,10 @@ def update_environment_variables(
if model is not None:
self.model = model
self.user = user
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
self.litellm_params = {
**self.litellm_params,
**scrub_sensitive_keys_in_metadata(litellm_params),
}
self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}")

Expand Down Expand Up @@ -784,6 +797,7 @@ def _response_cost_calculator(

used for consistent cost calculation across response headers + logging integrations.
"""

## RESPONSE COST ##
custom_pricing = use_custom_pricing_for_model(
litellm_params=(
Expand Down
9 changes: 8 additions & 1 deletion litellm/litellm_core_utils/mock_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import List, Optional

from ..types.utils import Embedding, EmbeddingResponse, ImageObject, ImageResponse
from ..types.utils import (
Embedding,
EmbeddingResponse,
ImageObject,
ImageResponse,
Usage,
)


def mock_embedding(model: str, mock_response: Optional[List[float]]):
Expand All @@ -9,6 +15,7 @@ def mock_embedding(model: str, mock_response: Optional[List[float]]):
return EmbeddingResponse(
model=model,
data=[Embedding(embedding=mock_response, index=0, object="embedding")],
usage=Usage(prompt_tokens=10, completion_tokens=0),
)


Expand Down
36 changes: 14 additions & 22 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3224,8 +3224,6 @@ def embedding( # noqa: PLR0915
**non_default_params,
)

if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
Expand All @@ -3248,28 +3246,22 @@ def embedding( # noqa: PLR0915
}
}
)
litellm_params_dict = get_litellm_params(**kwargs)

logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params=litellm_params_dict,
custom_llm_provider=custom_llm_provider,
)

if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
try:
response: Optional[EmbeddingResponse] = None
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params={
"timeout": timeout,
"azure": azure,
"litellm_call_id": litellm_call_id,
"logger_fn": logger_fn,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"aembedding": aembedding,
"preset_cache_key": None,
"stream_response": {},
"cooldown_time": cooldown_time,
},
custom_llm_provider=custom_llm_provider,
)

if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
Expand Down
102 changes: 4 additions & 98 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
exception_type,
get_error_message,
)
from litellm.litellm_core_utils.get_litellm_params import (
_get_base_model_from_litellm_call_metadata,
get_litellm_params,
)
from litellm.litellm_core_utils.get_llm_provider_logic import (
_is_non_openai_azure_model,
get_llm_provider,
Expand Down Expand Up @@ -2094,88 +2098,6 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
return model_cost


def get_litellm_params(
api_key: Optional[str] = None,
force_timeout=600,
azure=False,
logger_fn=None,
verbose=False,
hugging_face=False,
replicate=False,
together_ai=False,
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
litellm_call_id=None,
model_alias_map=None,
completion_call_id=None,
metadata: Optional[dict] = None,
model_info=None,
proxy_server_request=None,
acompletion=None,
preset_cache_key=None,
no_log=None,
input_cost_per_second=None,
input_cost_per_token=None,
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model: Optional[str] = None,
litellm_trace_id: Optional[str] = None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
**kwargs,
) -> dict:
litellm_params = {
"acompletion": acompletion,
"api_key": api_key,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
"verbose": verbose,
"custom_llm_provider": custom_llm_provider,
"api_base": api_base,
"litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map,
"completion_call_id": completion_call_id,
"metadata": metadata,
"model_info": model_info,
"proxy_server_request": proxy_server_request,
"preset_cache_key": preset_cache_key,
"no-log": no_log,
"stream_response": {}, # litellm_call_id: ModelResponse Dict
"input_cost_per_token": input_cost_per_token,
"input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params,
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
"async_call": async_call,
"ssl_verify": ssl_verify,
}
return litellm_params


def _should_drop_param(k, additional_drop_params) -> bool:
if (
additional_drop_params is not None
Expand Down Expand Up @@ -5666,22 +5588,6 @@ def get_logging_id(start_time, response_obj):
return None


def _get_base_model_from_litellm_call_metadata(
metadata: Optional[dict],
) -> Optional[str]:
if metadata is None:
return None

if metadata is not None:
model_info = metadata.get("model_info", {})

if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None


def _get_base_model_from_metadata(model_call_details=None):
if model_call_details is None:
return None
Expand Down
Loading