Skip to content

Commit

Permalink
gr.load_chat: Allow loading any openai-compatible server immediately …
Browse files Browse the repository at this point in the history
…as a ChatInterface (#10222)

* changes

* add changeset

* add changeset

* Update gradio/external.py

Co-authored-by: Abubakar Abid <[email protected]>

* changes

* changes

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Abubakar Abid <[email protected]>

* changes

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
  • Loading branch information
4 people authored and freddyaboulton committed Dec 23, 2024
1 parent 958c4f4 commit 2a4fd72
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .changeset/thick-dingos-help.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:gr.load_chat: Allow loading any openai-compatible server immediately as a ChatInterface
2 changes: 1 addition & 1 deletion gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
on,
)
from gradio.exceptions import Error
from gradio.external import load
from gradio.external import load, load_chat
from gradio.flagging import (
CSVLogger,
FlaggingCallback,
Expand Down
5 changes: 4 additions & 1 deletion gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def __init__(
self.type = type
self.multimodal = multimodal
self.concurrency_limit = concurrency_limit
self.fn = fn
if isinstance(fn, ChatInterface):
self.fn = fn.fn
else:
self.fn = fn
self.is_async = inspect.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
Expand Down
66 changes: 65 additions & 1 deletion gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import tempfile
import warnings
from collections.abc import Callable
from collections.abc import Callable, Generator
from pathlib import Path
from typing import TYPE_CHECKING, Literal

Expand All @@ -30,6 +30,7 @@

if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.chat_interface import ChatInterface
from gradio.interface import Interface


Expand Down Expand Up @@ -581,3 +582,66 @@ def fn(*data):
kwargs["_api_mode"] = True
interface = gradio.Interface(**kwargs)
return interface


@document()
def load_chat(
base_url: str,
model: str,
token: str | None = None,
*,
system_message: str | None = None,
streaming: bool = True,
) -> ChatInterface:
"""
Load a chat interface from an OpenAI API chat compatible endpoint.
Parameters:
base_url: The base URL of the endpoint.
model: The model name.
token: The API token.
system_message: The system message for the conversation, if any.
streaming: Whether the response should be streamed.
"""
try:
from openai import OpenAI
except ImportError as e:
raise ImportError(
"To use OpenAI API Client, you must install the `openai` package. You can install it with `pip install openai`."
) from e
from gradio.chat_interface import ChatInterface

client = OpenAI(api_key=token, base_url=base_url)
start_message = (
[{"role": "system", "content": system_message}] if system_message else []
)

def open_api(message: str, history: list | None) -> str:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
return (
client.chat.completions.create(
model=model,
messages=history + [{"role": "user", "content": message}],
)
.choices[0]
.message.content
)

def open_api_stream(
message: str, history: list | None
) -> Generator[str, None, None]:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
stream = client.chat.completions.create(
model=model,
messages=history + [{"role": "user", "content": message}],
stream=True,
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content
yield response

return ChatInterface(open_api_stream if streaming else open_api, type="messages")
10 changes: 10 additions & 0 deletions guides/05_chatbots/01_creating-a-chatbot-fast.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ This tutorial uses `gr.ChatInterface()`, which is a high-level abstraction that
$ pip install --upgrade gradio
```

## Quickly loading from Ollama or any OpenAI-API compatible endpoint

If you have a chat server serving an OpenAI API compatible endpoint (skip ahead if you don't), you can spin up a ChatInterface in a single line. First, also run `pip install openai`. Then, with your own URL, model, and optional token:

```python
import gradio as gr

gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token=None).launch()
```

## Defining a chat function

When working with `gr.ChatInterface()`, the first thing you should do is define your **chat function**. In the simplest case, your chat function should accept two arguments: `message` and `history` (the arguments can be named anything, but must be in this order).
Expand Down

0 comments on commit 2a4fd72

Please sign in to comment.