diff --git a/.changeset/thick-dingos-help.md b/.changeset/thick-dingos-help.md new file mode 100644 index 0000000000000..60135ca030828 --- /dev/null +++ b/.changeset/thick-dingos-help.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:gr.load_chat: Allow loading any openai-compatible server immediately as a ChatInterface diff --git a/gradio/__init__.py b/gradio/__init__.py index 5135c9fef720b..512745b35dd5e 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -77,7 +77,7 @@ on, ) from gradio.exceptions import Error -from gradio.external import load +from gradio.external import load, load_chat from gradio.flagging import ( CSVLogger, FlaggingCallback, diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 429301723e88b..35e36c5d89c89 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -155,7 +155,10 @@ def __init__( self.type = type self.multimodal = multimodal self.concurrency_limit = concurrency_limit - self.fn = fn + if isinstance(fn, ChatInterface): + self.fn = fn.fn + else: + self.fn = fn self.is_async = inspect.iscoroutinefunction( self.fn ) or inspect.isasyncgenfunction(self.fn) diff --git a/gradio/external.py b/gradio/external.py index 7cc01cc68fd15..d410dcdce304b 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -8,7 +8,7 @@ import re import tempfile import warnings -from collections.abc import Callable +from collections.abc import Callable, Generator from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -30,6 +30,7 @@ if TYPE_CHECKING: from gradio.blocks import Blocks + from gradio.chat_interface import ChatInterface from gradio.interface import Interface @@ -581,3 +582,66 @@ def fn(*data): kwargs["_api_mode"] = True interface = gradio.Interface(**kwargs) return interface + + +@document() +def load_chat( + base_url: str, + model: str, + token: str | None = None, + *, + system_message: str | None = None, + streaming: bool = True, +) -> ChatInterface: + """ + Load a chat interface from an OpenAI API chat compatible endpoint. + Parameters: + base_url: The base URL of the endpoint. + model: The model name. + token: The API token. + system_message: The system message for the conversation, if any. + streaming: Whether the response should be streamed. + """ + try: + from openai import OpenAI + except ImportError as e: + raise ImportError( + "To use OpenAI API Client, you must install the `openai` package. You can install it with `pip install openai`." + ) from e + from gradio.chat_interface import ChatInterface + + client = OpenAI(api_key=token, base_url=base_url) + start_message = ( + [{"role": "system", "content": system_message}] if system_message else [] + ) + + def open_api(message: str, history: list | None) -> str: + history = history or start_message + if len(history) > 0 and isinstance(history[0], (list, tuple)): + history = ChatInterface._tuples_to_messages(history) + return ( + client.chat.completions.create( + model=model, + messages=history + [{"role": "user", "content": message}], + ) + .choices[0] + .message.content + ) + + def open_api_stream( + message: str, history: list | None + ) -> Generator[str, None, None]: + history = history or start_message + if len(history) > 0 and isinstance(history[0], (list, tuple)): + history = ChatInterface._tuples_to_messages(history) + stream = client.chat.completions.create( + model=model, + messages=history + [{"role": "user", "content": message}], + stream=True, + ) + response = "" + for chunk in stream: + response += chunk.choices[0].delta.content + yield response + + return ChatInterface(open_api_stream if streaming else open_api, type="messages") diff --git a/guides/05_chatbots/01_creating-a-chatbot-fast.md b/guides/05_chatbots/01_creating-a-chatbot-fast.md index df663fe679a04..4ed8b95a40379 100644 --- a/guides/05_chatbots/01_creating-a-chatbot-fast.md +++ b/guides/05_chatbots/01_creating-a-chatbot-fast.md @@ -14,6 +14,16 @@ This tutorial uses `gr.ChatInterface()`, which is a high-level abstraction that $ pip install --upgrade gradio ``` +## Quickly loading from Ollama or any OpenAI-API compatible endpoint + +If you have a chat server serving an OpenAI API compatible endpoint (skip ahead if you don't), you can spin up a ChatInterface in a single line. First, also run `pip install openai`. Then, with your own URL, model, and optional token: + +```python +import gradio as gr + +gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token=None).launch() +``` + ## Defining a chat function When working with `gr.ChatInterface()`, the first thing you should do is define your **chat function**. In the simplest case, your chat function should accept two arguments: `message` and `history` (the arguments can be named anything, but must be in this order).