Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LiteLLM support to Chat() #1674

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions shiny/templates/chat/hello-providers/litellm/_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "app",
"id": "chat-ai-litellm",
"title": "Chat AI using LiteLLM"
}
22 changes: 22 additions & 0 deletions shiny/templates/chat/hello-providers/litellm/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ------------------------------------------------------------------------------------
# A basic Shiny Chat example powered by OpenAI's GPT-4o model using the `litellm` library.
# To run it, you'll need OpenAI API key.
# To get setup, follow the instructions at https://platform.openai.com/docs/quickstart
# ------------------------------------------------------------------------------------
import litellm
from app_utils import load_dotenv

from shiny.express import ui

# Load a .env file (if it exists) to get the OpenAI API key
load_dotenv()

chat = ui.Chat(id="chat")
chat.ui()


@chat.on_user_submit
async def _():
messages = chat.messages()
response = await litellm.acompletion(model="gpt-4o", messages=messages, stream=True)
await chat.append_message_stream(response)
26 changes: 26 additions & 0 deletions shiny/templates/chat/hello-providers/litellm/app_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from pathlib import Path
from typing import Any

app_dir = Path(__file__).parent
env_file = app_dir / ".env"


def load_dotenv(dotenv_path: os.PathLike[str] = env_file, **kwargs: Any) -> None:
"""
A convenience wrapper around `dotenv.load_dotenv` that warns if `dotenv` is not installed.
It also returns `None` to make it easier to ignore the return value.
"""
try:
import dotenv

dotenv.load_dotenv(dotenv_path=dotenv_path, **kwargs)
except ImportError:
import warnings

warnings.warn(
"Could not import `dotenv`. If you want to use `.env` files to "
"load environment variables, please install it using "
"`pip install python-dotenv`.",
stacklevel=2,
)
5 changes: 5 additions & 0 deletions shiny/templates/chat/hello-providers/litellm/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
shiny
python-dotenv
tokenizers
openai
litellm
34 changes: 34 additions & 0 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class GenerateContentResponse:
text: str

from langchain_core.messages import BaseMessage, BaseMessageChunk
from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs]
ModelResponse,
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk


Expand Down Expand Up @@ -137,6 +140,36 @@ def can_normalize_chunk(self, chunk: Any) -> bool:
return False


class LiteLlmNormalizer(OpenAINormalizer):
def normalize(self, message: Any) -> ChatMessage:
x = cast("ModelResponse", message)
return super().normalize(x)

def normalize_chunk(self, chunk: Any) -> ChatMessage:
x = cast("ModelResponse", chunk)
return super().normalize_chunk(x)

def can_normalize(self, message: Any) -> bool:
try:
from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs]
ModelResponse,
)

return isinstance(message, ModelResponse)
except Exception:
return False

def can_normalize_chunk(self, chunk: Any) -> bool:
try:
from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs]
ModelResponse,
)

return isinstance(chunk, ModelResponse)
except Exception:
return False


class AnthropicNormalizer(BaseMessageNormalizer):
def normalize(self, message: Any) -> ChatMessage:
x = cast("AnthropicMessage", message)
Expand Down Expand Up @@ -248,6 +281,7 @@ def __init__(self) -> None:
"anthropic": AnthropicNormalizer(),
"google": GoogleNormalizer(),
"langchain": LangChainNormalizer(),
"litellm": LiteLlmNormalizer(),
"ollama": OllamaNormalizer(),
"dict": DictNormalizer(),
"string": StringNormalizer(),
Expand Down
Loading