diff --git a/shiny/templates/chat/hello-providers/litellm/_template.json b/shiny/templates/chat/hello-providers/litellm/_template.json new file mode 100644 index 000000000..52866eeb9 --- /dev/null +++ b/shiny/templates/chat/hello-providers/litellm/_template.json @@ -0,0 +1,5 @@ +{ + "type": "app", + "id": "chat-ai-litellm", + "title": "Chat AI using LiteLLM" +} diff --git a/shiny/templates/chat/hello-providers/litellm/app.py b/shiny/templates/chat/hello-providers/litellm/app.py new file mode 100644 index 000000000..7cc90a87d --- /dev/null +++ b/shiny/templates/chat/hello-providers/litellm/app.py @@ -0,0 +1,22 @@ +# ------------------------------------------------------------------------------------ +# A basic Shiny Chat example powered by OpenAI's GPT-4o model using the `litellm` library. +# To run it, you'll need OpenAI API key. +# To get setup, follow the instructions at https://platform.openai.com/docs/quickstart +# ------------------------------------------------------------------------------------ +import litellm +from app_utils import load_dotenv + +from shiny.express import ui + +# Load a .env file (if it exists) to get the OpenAI API key +load_dotenv() + +chat = ui.Chat(id="chat") +chat.ui() + + +@chat.on_user_submit +async def _(): + messages = chat.messages() + response = await litellm.acompletion(model="gpt-4o", messages=messages, stream=True) + await chat.append_message_stream(response) diff --git a/shiny/templates/chat/hello-providers/litellm/app_utils.py b/shiny/templates/chat/hello-providers/litellm/app_utils.py new file mode 100644 index 000000000..404a13730 --- /dev/null +++ b/shiny/templates/chat/hello-providers/litellm/app_utils.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path +from typing import Any + +app_dir = Path(__file__).parent +env_file = app_dir / ".env" + + +def load_dotenv(dotenv_path: os.PathLike[str] = env_file, **kwargs: Any) -> None: + """ + A convenience wrapper around `dotenv.load_dotenv` that warns if `dotenv` is not installed. + It also returns `None` to make it easier to ignore the return value. + """ + try: + import dotenv + + dotenv.load_dotenv(dotenv_path=dotenv_path, **kwargs) + except ImportError: + import warnings + + warnings.warn( + "Could not import `dotenv`. If you want to use `.env` files to " + "load environment variables, please install it using " + "`pip install python-dotenv`.", + stacklevel=2, + ) diff --git a/shiny/templates/chat/hello-providers/litellm/requirements.txt b/shiny/templates/chat/hello-providers/litellm/requirements.txt new file mode 100644 index 000000000..0519f0ea0 --- /dev/null +++ b/shiny/templates/chat/hello-providers/litellm/requirements.txt @@ -0,0 +1,5 @@ +shiny +python-dotenv +tokenizers +openai +litellm diff --git a/shiny/ui/_chat_normalize.py b/shiny/ui/_chat_normalize.py index 2d5063324..b778a9fbd 100644 --- a/shiny/ui/_chat_normalize.py +++ b/shiny/ui/_chat_normalize.py @@ -18,6 +18,9 @@ class GenerateContentResponse: text: str from langchain_core.messages import BaseMessage, BaseMessageChunk + from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs] + ModelResponse, + ) from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -137,6 +140,36 @@ def can_normalize_chunk(self, chunk: Any) -> bool: return False +class LiteLlmNormalizer(OpenAINormalizer): + def normalize(self, message: Any) -> ChatMessage: + x = cast("ModelResponse", message) + return super().normalize(x) + + def normalize_chunk(self, chunk: Any) -> ChatMessage: + x = cast("ModelResponse", chunk) + return super().normalize_chunk(x) + + def can_normalize(self, message: Any) -> bool: + try: + from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs] + ModelResponse, + ) + + return isinstance(message, ModelResponse) + except Exception: + return False + + def can_normalize_chunk(self, chunk: Any) -> bool: + try: + from litellm.types.utils import ( # pyright: ignore[reportMissingTypeStubs] + ModelResponse, + ) + + return isinstance(chunk, ModelResponse) + except Exception: + return False + + class AnthropicNormalizer(BaseMessageNormalizer): def normalize(self, message: Any) -> ChatMessage: x = cast("AnthropicMessage", message) @@ -248,6 +281,7 @@ def __init__(self) -> None: "anthropic": AnthropicNormalizer(), "google": GoogleNormalizer(), "langchain": LangChainNormalizer(), + "litellm": LiteLlmNormalizer(), "ollama": OllamaNormalizer(), "dict": DictNormalizer(), "string": StringNormalizer(),