Skip to content

Commit

Permalink
Allow editable ChatInterface (#10229)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
3 people authored Dec 21, 2024
1 parent 506bd28 commit 1be31c1
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .changeset/slimy-pants-hang.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Allow editable ChatInterface
2 changes: 1 addition & 1 deletion demo/test_chatinterface_streaming_echo/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True, editable=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
2 changes: 1 addition & 1 deletion demo/test_chatinterface_streaming_echo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def slow_echo(message, history):
for i in range(len(message)):
yield f"Run {runs} - You typed: " + message[: i + 1]

chat = gr.ChatInterface(slow_echo, fill_height=True)
chat = gr.ChatInterface(slow_echo, fill_height=True, editable=True)

with gr.Blocks() as demo:
chat.render()
Expand Down
26 changes: 25 additions & 1 deletion gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from gradio.components.multimodal_textbox import MultimodalPostprocess, MultimodalValue
from gradio.context import get_blocks_context
from gradio.events import Dependency, SelectData
from gradio.events import Dependency, EditData, SelectData
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.helpers import special_args, update
from gradio.layouts import Accordion, Column, Group, Row
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
additional_inputs: str | Component | list[str | Component] | None = None,
additional_inputs_accordion: str | Accordion | None = None,
additional_outputs: Component | list[Component] | None = None,
editable: bool = False,
examples: list[str] | list[MultimodalValue] | list[list] | None = None,
example_labels: list[str] | None = None,
example_icons: list[str] | None = None,
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
type: The format of the messages passed into the chat history parameter of `fn`. If "messages", passes the history as a list of dictionaries with openai-style "role" and "content" keys. The "content" key's value should be one of the following - (1) strings in valid Markdown (2) a dictionary with a "path" key and value corresponding to the file to display or (3) an instance of a Gradio component: at the moment gr.Image, gr.Plot, gr.Video, gr.Gallery, gr.Audio, and gr.HTML are supported. The "role" key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output. If this parameter is 'tuples' (deprecated), passes the chat history as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message.
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
textbox: an instance of the gr.Textbox or gr.MultimodalTextbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox or gr.MultimodalTextbox component will be created.
editable: if True, users can edit past messages to regenerate responses.
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If the components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion. The values of these components will be passed into `fn` as arguments in order after the chat history.
additional_inputs_accordion: if a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
additional_outputs: an instance or list of instances of gradio components to use as additional outputs from the chat function. These must be components that are already defined in the same Blocks scope. If provided, the chat function should return additional values for these components. See $demo/chatinterface_artifacts.
Expand Down Expand Up @@ -173,6 +175,7 @@ def __init__(
self.run_examples_on_click = run_examples_on_click
self.cache_examples = cache_examples
self.cache_mode = cache_mode
self.editable = editable
self.additional_inputs = [
get_component_instance(i)
for i in utils.none_or_singleton_to_list(additional_inputs)
Expand Down Expand Up @@ -490,6 +493,14 @@ def _setup_events(self) -> None:

self.chatbot.clear(**synchronize_chat_state_kwargs)

if self.editable:
self.chatbot.edit(
self._edit_message,
[self.chatbot],
[self.chatbot, self.chatbot_state, self.saved_input],
show_api=False,
).success(**submit_fn_kwargs).success(**synchronize_chat_state_kwargs)

def _setup_stop_events(
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
) -> None:
Expand Down Expand Up @@ -712,6 +723,19 @@ def example_populated(self, example: SelectData):
else:
return example.value["text"]

def _edit_message(
self, history: list[MessageDict] | TupleFormat, edit_data: EditData
) -> tuple[
list[MessageDict] | TupleFormat,
list[MessageDict] | TupleFormat,
str | MultimodalPostprocess,
]:
if isinstance(edit_data.index, (list, tuple)):
history = history[: edit_data.index[0]]
else:
history = history[: edit_data.index]
return history, history, edit_data.value

def example_clicked(
self, example: SelectData
) -> Generator[
Expand Down
3 changes: 3 additions & 0 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,9 @@ class Events:
edit = EventListener(
"edit",
doc="This listener is triggered when the user edits the {{ component }} (e.g. image) using the built-in editor.",
callback=lambda block: setattr(block, "editable", "user")
if getattr(block, "editable", None) is None
else None,
)
clear = EventListener(
"clear",
Expand Down
25 changes: 25 additions & 0 deletions js/spa/test/test_chatinterface_streaming_echo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,28 @@ test("test stopping generation", async ({ page }) => {
await expect(current_content).toBe(new_content);
await expect(new_content!.length).toBeLessThan(3000);
});

test("editing messages", async ({ page }) => {
const submit_button = page.locator(".submit-button");
const textbox = page.locator(".input-container textarea");
const chatbot = page.getByLabel("chatbot conversation");

await textbox.fill("Lets");
await submit_button.click();
await expect(chatbot).toContainText("You typed: Lets");

await textbox.fill("Test");
await submit_button.click();
await expect(chatbot).toContainText("You typed: Test");

await textbox.fill("This");
await submit_button.click();
await expect(chatbot).toContainText("You typed: This");

await page.getByLabel("Edit").nth(1).click();
await page.getByLabel("chatbot conversation").getByRole("textbox").fill("Do");
await page.getByLabel("Submit").click();

await expect(chatbot).toContainText("You typed: Do");
await expect(chatbot).not.toContainText("You typed: This");
});

0 comments on commit 1be31c1

Please sign in to comment.