Skip to content

Commit

Permalink
aiohttp_openai/ fixes - allow using aiohttp_openai/gpt-4o (BerriA…
Browse files Browse the repository at this point in the history
…I#7598)

* fixes for get_complete_url

* update aiohttp tests

* fix event loop for aiohtto

* ci/cd run again

* test_aiohttp_openai
  • Loading branch information
ishaan-jaff authored and rajatvig committed Jan 15, 2025
1 parent 10d21be commit 107ba94
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 61 deletions.
27 changes: 22 additions & 5 deletions litellm/llms/aiohttp_openai/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import TYPE_CHECKING, Any, List, Optional

import httpx
from aiohttp import ClientResponse

from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.types.llms.openai import AllMessageValues
Expand All @@ -24,6 +24,22 @@


class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Ensure - /v1/chat/completions is at the end of the url
"""

if not api_base.endswith("/chat/completions"):
api_base += "/chat/completions"
return api_base

def validate_environment(
self,
headers: dict,
Expand All @@ -33,12 +49,12 @@ def validate_environment(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}
return {"Authorization": f"Bearer {api_key}"}

def transform_response(
async def transform_response( # type: ignore
self,
model: str,
raw_response: httpx.Response,
raw_response: ClientResponse,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
Expand All @@ -49,4 +65,5 @@ def transform_response(
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return ModelResponse(**raw_response.json())
_json_response = await raw_response.json()
return ModelResponse(**_json_response)
16 changes: 13 additions & 3 deletions litellm/llms/custom_httpx/aiohttp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,19 @@ async def async_completion(
litellm_params=litellm_params,
stream=False,
)
_json_response = await _response.json()

return _json_response
_transformed_response = await provider_config.transform_response( # type: ignore
model=model,
raw_response=_response, # type: ignore
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
return _transformed_response

def completion(
self,
Expand Down
73 changes: 37 additions & 36 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,43 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
encoding=encoding,
)
elif custom_llm_provider == "aiohttp_openai":
# NEW aiohttp provider for 10-100x higher RPS
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)

headers = headers or litellm.headers

if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
response = base_llm_aiohttp_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
elif (
model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
Expand Down Expand Up @@ -2802,42 +2839,6 @@ def completion( # type: ignore # noqa: PLR0915
)
return response
response = model_response
elif custom_llm_provider == "aiohttp_openai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)

headers = headers or litellm.headers

if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
response = base_llm_aiohttp_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
elif custom_llm_provider == "custom":
url = litellm.api_base or api_base or ""
if url is None or url == "":
Expand Down
39 changes: 23 additions & 16 deletions tests/llm_translation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import os
import sys

import asyncio
import pytest

sys.path.insert(
Expand All @@ -12,31 +12,38 @@
import litellm


@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
def setup_and_teardown(event_loop): # Add event_loop as a dependency
curr_dir = os.getcwd()
sys.path.insert(0, os.path.abspath("../.."))

import litellm
from litellm import Router

importlib.reload(litellm)
import asyncio

loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
# Set the event loop from the fixture
asyncio.set_event_loop(event_loop)

print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield

# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
# Clean up any pending tasks
pending = asyncio.all_tasks(event_loop)
for task in pending:
task.cancel()

# Run the event loop until all tasks are cancelled
if pending:
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))


def pytest_collection_modifyitems(config, items):
Expand Down
12 changes: 11 additions & 1 deletion tests/llm_translation/test_aiohttp_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import litellm


@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_aiohttp_openai():
litellm.set_verbose = True
response = await litellm.acompletion(
Expand All @@ -21,3 +21,13 @@ async def test_aiohttp_openai():
api_key="fake-key",
)
print(response)


@pytest.mark.asyncio()
async def test_aiohttp_openai_gpt_4o():
litellm.set_verbose = True
response = await litellm.acompletion(
model="aiohttp_openai/gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
)
print(response)

0 comments on commit 107ba94

Please sign in to comment.