From 5ce28324971e974ae24dc9a229b2160793140fb2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Sat, 4 Jan 2025 09:42:05 -0800 Subject: [PATCH] Support saving chat history in `gr.ChatInterface` (#10191) * save history prototype * add changeset * Declare exports in __all__ for type checking (#10238) * Declare exports * add changeset * type fixes * more type fixes * add changeset * notebooks * changes --------- Co-authored-by: gradio-pr-bot Co-authored-by: Freddy Boulton Co-authored-by: Abubakar Abid * Add `gr.BrowserState` change event (#10245) * changes * changes * add changeset * format * changes --------- Co-authored-by: gradio-pr-bot * history * changes * changes * changes * history * changes * changes * changes * format * add changeset * changes * changes * more changes * changes * dataset changes * changes * add changeset * add md variant for button * add changeset * changes * changes * format * format * add changeset * changes * changes * more changes * changes * changes * add changeset * changes * docs * changes * changes * changes * changes * fix * fix tests * change * add changeset * fix logo issue * changes * version * add changeset * fix typecheck * remove redundant * pkg version * add changeset * changes * Revert "changes" This reverts commit 13bfe8c485d049f7d8c6f1e5c13e2bc04ab71dd5. * reorganize code * format * changes * add to deployed demos * fix icons * fix icon * lint * changes * example * changes * fix buttons * add changeset * format * add changeset * update icon --------- Co-authored-by: gradio-pr-bot Co-authored-by: Dawood Co-authored-by: Dmitry Ustalov Co-authored-by: Freddy Boulton --- .changeset/tiny-areas-train.md | 11 + demo/chatinterface_save_history/run.ipynb | 1 + demo/chatinterface_save_history/run.py | 16 + demo/chatinterface_streaming_echo/run.ipynb | 2 +- demo/chatinterface_streaming_echo/run.py | 8 +- gradio/__init__.py | 2 +- gradio/chat_interface.py | 364 ++++++++++++------ gradio/components/button.py | 7 +- gradio/components/clear_button.py | 5 +- gradio/components/dataset.py | 6 + gradio/components/download_button.py | 6 +- gradio/components/duplicate_button.py | 9 +- gradio/components/login_button.py | 7 +- gradio/components/logout_button.py | 2 +- gradio/components/upload_button.py | 4 +- gradio/icons/README.md | 2 + gradio/icons/huggingface-logo.svg | 37 ++ gradio/icons/plus.svg | 12 + gradio/themes/base.py | 20 + gradio/utils.py | 25 +- .../05_chatbots/01_creating-a-chatbot-fast.md | 14 +- js/button/shared/Button.svelte | 11 +- js/chatbot/shared/FlagActive.svelte | 7 + js/chatbot/shared/LikeDislike.svelte | 9 +- js/dataset/Index.svelte | 45 ++- .../shared/DownloadButton.svelte | 2 +- .../test_chatinterface_streaming_echo.spec.ts | 3 +- js/textbox/Example.svelte | 2 - js/uploadbutton/shared/UploadButton.svelte | 2 +- test/test_utils.py | 6 + 30 files changed, 484 insertions(+), 163 deletions(-) create mode 100644 .changeset/tiny-areas-train.md create mode 100644 demo/chatinterface_save_history/run.ipynb create mode 100644 demo/chatinterface_save_history/run.py create mode 100644 gradio/icons/README.md create mode 100644 gradio/icons/huggingface-logo.svg create mode 100644 gradio/icons/plus.svg create mode 100644 js/chatbot/shared/FlagActive.svelte diff --git a/.changeset/tiny-areas-train.md b/.changeset/tiny-areas-train.md new file mode 100644 index 0000000000000..59a85af3812a0 --- /dev/null +++ b/.changeset/tiny-areas-train.md @@ -0,0 +1,11 @@ +--- +"@gradio/button": minor +"@gradio/chatbot": minor +"@gradio/dataset": minor +"@gradio/downloadbutton": minor +"@gradio/textbox": minor +"@gradio/uploadbutton": minor +"gradio": minor +--- + +feat:Support saving chat history in `gr.ChatInterface` diff --git a/demo/chatinterface_save_history/run.ipynb b/demo/chatinterface_save_history/run.ipynb new file mode 100644 index 0000000000000..8880af469b52c --- /dev/null +++ b/demo/chatinterface_save_history/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_save_history"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo_multimodal(message, history):\n", " response = \"You wrote: '\" + message[\"text\"] + \"' and uploaded: \" + str(len(message[\"files\"])) + \" files\"\n", " return response\n", "\n", "demo = gr.ChatInterface(\n", " echo_multimodal,\n", " type=\"messages\",\n", " multimodal=True,\n", " textbox=gr.MultimodalTextbox(file_count=\"multiple\"),\n", " save_history=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_save_history/run.py b/demo/chatinterface_save_history/run.py new file mode 100644 index 0000000000000..0870f35e35e34 --- /dev/null +++ b/demo/chatinterface_save_history/run.py @@ -0,0 +1,16 @@ +import gradio as gr + +def echo_multimodal(message, history): + response = "You wrote: '" + message["text"] + "' and uploaded: " + str(len(message["files"])) + " files" + return response + +demo = gr.ChatInterface( + echo_multimodal, + type="messages", + multimodal=True, + textbox=gr.MultimodalTextbox(file_count="multiple"), + save_history=True, +) + +if __name__ == "__main__": + demo.launch() diff --git a/demo/chatinterface_streaming_echo/run.ipynb b/demo/chatinterface_streaming_echo/run.ipynb index 222c338b1cc6c..9b82faac936e2 100644 --- a/demo/chatinterface_streaming_echo/run.ipynb +++ b/demo/chatinterface_streaming_echo/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "demo = gr.ChatInterface(slow_echo, type=\"messages\", flagging_mode=\"manual\", flagging_options=[\"Like\", \"Spam\", \"Inappropriate\", \"Other\"])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "demo = gr.ChatInterface(\n", " slow_echo,\n", " type=\"messages\",\n", " flagging_mode=\"manual\",\n", " flagging_options=[\"Like\", \"Spam\", \"Inappropriate\", \"Other\"], \n", " save_history=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_streaming_echo/run.py b/demo/chatinterface_streaming_echo/run.py index 7ef7bcfc570e7..089c6711dc5eb 100644 --- a/demo/chatinterface_streaming_echo/run.py +++ b/demo/chatinterface_streaming_echo/run.py @@ -6,7 +6,13 @@ def slow_echo(message, history): time.sleep(0.05) yield "You typed: " + message[: i + 1] -demo = gr.ChatInterface(slow_echo, type="messages", flagging_mode="manual", flagging_options=["Like", "Spam", "Inappropriate", "Other"]) +demo = gr.ChatInterface( + slow_echo, + type="messages", + flagging_mode="manual", + flagging_options=["Like", "Spam", "Inappropriate", "Other"], + save_history=True, +) if __name__ == "__main__": demo.launch() diff --git a/gradio/__init__.py b/gradio/__init__.py index f621285f0039f..f0feeac8b727e 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -167,7 +167,6 @@ "ImageEditor", "ImageMask", "Info", - "Success", "Interface", "JSON", "Json", @@ -204,6 +203,7 @@ "Sketchpad", "Slider", "State", + "Success", "Tab", "TabItem", "TabbedInterface", diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 3d5d3b209fd56..d8bdaa39af7a2 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -20,9 +20,11 @@ from gradio.blocks import Blocks from gradio.components import ( JSON, + BrowserState, Button, Chatbot, Component, + Dataset, Markdown, MultimodalTextbox, State, @@ -106,10 +108,11 @@ def __init__( fill_height: bool = True, fill_width: bool = False, api_name: str | Literal[False] = "chat", + save_history: bool = False, ): """ Parameters: - fn: the function to wrap the chat interface around. In the default case (assuming `type` is set to "messages"), the function should accept two parameters: a `str` input message and `list` of openai-style dictionary {"role": "user" | "assistant", "content": `str` | {"path": `str`} | `gr.Component`} representing the chat history, and the function should return/yield a `str` (if a simple message), a supported Gradio component (to return a file), a `dict` (for a complete openai-style message response), or a `list` of such messages. + fn: the function to wrap the chat interface around. Normally (assuming `type` is set to "messages"), the function should accept two parameters: a `str` representing the input message and `list` of openai-style dictionaries: {"role": "user" | "assistant", "content": `str` | {"path": `str`} | `gr.Component`} representing the chat history. The function should return/yield a `str` (for a simple message), a supported Gradio component (e.g. gr.Image to return an image), a `dict` (for a complete openai-style message response), or a `list` of such messages. multimodal: if True, the chat interface will use a `gr.MultimodalTextbox` component for the input, which allows for the uploading of multimedia files. If False, the chat interface will use a gr.Textbox component for the input. If this is True, the first argument of `fn` should accept not a `str` message but a `dict` message with keys "text" and "files" type: The format of the messages passed into the chat history parameter of `fn`. If "messages", passes the history as a list of dictionaries with openai-style "role" and "content" keys. The "content" key's value should be one of the following - (1) strings in valid Markdown (2) a dictionary with a "path" key and value corresponding to the file to display or (3) an instance of a Gradio component: at the moment gr.Image, gr.Plot, gr.Video, gr.Gallery, gr.Audio, and gr.HTML are supported. The "role" key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output. If this parameter is 'tuples' (deprecated), passes the chat history as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created. @@ -146,6 +149,7 @@ def __init__( fill_height: if True, the chat interface will expand to the height of window. fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width. api_name: the name of the API endpoint to use for the chat interface. Defaults to "chat". Set to False to disable the API endpoint. + save_history: if True, will save the chat history to the browser's local storage and display previous conversations in a side panel. """ super().__init__( analytics_enabled=analytics_enabled, @@ -184,6 +188,15 @@ def __init__( self.cache_examples = cache_examples self.cache_mode = cache_mode self.editable = editable + self.fill_height = fill_height + self.autoscroll = autoscroll + self.autofocus = autofocus + self.title = title + self.description = description + self.show_progress = show_progress + if save_history and not type == "messages": + raise ValueError("save_history is only supported for type='messages'") + self.save_history = save_history self.additional_inputs = [ get_component_instance(i) for i in utils.none_or_singleton_to_list(additional_inputs) @@ -235,115 +248,147 @@ def __init__( self.flagging_dir = flagging_dir with self: + self.saved_conversations = BrowserState( + [], storage_key="_saved_conversations" + ) + self.conversation_id = State(None) + self.saved_input = State() # Stores the most recent user message + self.null_component = State() # Used to discard unneeded values + with Column(): - if title: - Markdown( - f"

{self.title}

" - ) - if description: - Markdown(description) - if chatbot: - if self.type: - if self.type != chatbot.type: - warnings.warn( - "The type of the gr.Chatbot does not match the type of the gr.ChatInterface." - f"The type of the gr.ChatInterface, '{self.type}', will be used." - ) - chatbot.type = self.type - chatbot._setup_data_model() - else: - warnings.warn( - f"The gr.ChatInterface was not provided with a type, so the type of the gr.Chatbot, '{chatbot.type}', will be used." + self._render_header() + with Row(): + self._render_history_area() + with Column(scale=6): + self._render_chatbot_area( + chatbot, textbox, submit_btn, stop_btn ) - self.type = chatbot.type - self.chatbot = cast( - Chatbot, get_component_instance(chatbot, render=True) + self._render_footer() + self._setup_events() + + def _render_header(self): + if self.title: + Markdown( + f"

{self.title}

" + ) + if self.description: + Markdown(self.description) + + def _render_history_area(self): + if self.save_history: + with Column(scale=1, min_width=100): + self.new_chat_button = Button( + "New chat", + variant="secondary", + size="md", + icon=utils.get_icon_path("plus.svg"), + ) + self.chat_history_dataset = Dataset( + components=[Textbox(visible=False)], + show_label=False, + layout="table", + type="index", + ) + + def _render_chatbot_area( + self, + chatbot: Chatbot | None, + textbox: Textbox | MultimodalTextbox | None, + submit_btn: str | bool | None, + stop_btn: str | bool | None, + ): + if chatbot: + if self.type: + if self.type != chatbot.type: + warnings.warn( + "The type of the gr.Chatbot does not match the type of the gr.ChatInterface." + f"The type of the gr.ChatInterface, '{self.type}', will be used." ) - if self.chatbot.examples and self.examples_messages: - warnings.warn( - "The ChatInterface already has examples set. The examples provided in the chatbot will be ignored." + chatbot.type = cast(Literal["messages", "tuples"], self.type) + chatbot._setup_data_model() + else: + warnings.warn( + f"The gr.ChatInterface was not provided with a type, so the type of the gr.Chatbot, '{chatbot.type}', will be used." + ) + self.type = chatbot.type + self.chatbot = cast(Chatbot, get_component_instance(chatbot, render=True)) + if self.chatbot.examples and self.examples_messages: + warnings.warn( + "The ChatInterface already has examples set. The examples provided in the chatbot will be ignored." + ) + self.chatbot.examples = ( + self.examples_messages + if not self._additional_inputs_in_examples + else None + ) + self.chatbot._setup_examples() + else: + self.type = self.type or "tuples" + self.chatbot = Chatbot( + label="Chatbot", + scale=1, + height=400 if self.fill_height else None, + type=cast(Literal["messages", "tuples"], self.type), + autoscroll=self.autoscroll, + examples=self.examples_messages + if not self._additional_inputs_in_examples + else None, + ) + with Group(): + with Row(): + if textbox: + textbox.show_label = False + textbox_ = get_component_instance(textbox, render=True) + if not isinstance(textbox_, (Textbox, MultimodalTextbox)): + raise TypeError( + f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {builtins.type(textbox_)}" ) - self.chatbot.examples = ( - self.examples_messages - if not self._additional_inputs_in_examples - else None - ) - self.chatbot._setup_examples() + self.textbox = textbox_ else: - self.type = self.type or "tuples" - self.chatbot = Chatbot( - label="Chatbot", - scale=1, - height=200 if fill_height else None, - type=self.type, - autoscroll=autoscroll, - examples=self.examples_messages - if not self._additional_inputs_in_examples - else None, + textbox_component = ( + MultimodalTextbox if self.multimodal else Textbox ) - with Group(): - with Row(): - if textbox: - textbox.show_label = False - textbox_ = get_component_instance(textbox, render=True) - if not isinstance(textbox_, (Textbox, MultimodalTextbox)): - raise TypeError( - f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {builtins.type(textbox_)}" - ) - self.textbox = textbox_ - else: - textbox_component = ( - MultimodalTextbox if self.multimodal else Textbox - ) - self.textbox = textbox_component( - show_label=False, - label="Message", - placeholder="Type a message...", - scale=7, - autofocus=autofocus, - submit_btn=submit_btn, - stop_btn=stop_btn, - ) - - # Hide the stop button at the beginning, and show it with the given value during the generator execution. - self.original_stop_btn = self.textbox.stop_btn - self.textbox.stop_btn = False - - self.fake_api_btn = Button("Fake API", visible=False) - self.api_response = JSON( - label="Response", visible=False - ) # Used to store the response from the API call - - if self.examples: - self.examples_handler = Examples( - examples=self.examples, - inputs=[self.textbox] + self.additional_inputs, - outputs=self.chatbot, - fn=self._examples_stream_fn - if self.is_generator - else self._examples_fn, - cache_examples=self.cache_examples, - cache_mode=self.cache_mode, - visible=self._additional_inputs_in_examples, - preprocess=self._additional_inputs_in_examples, + self.textbox = textbox_component( + show_label=False, + label="Message", + placeholder="Type a message...", + scale=7, + autofocus=self.autofocus, + submit_btn=submit_btn, + stop_btn=stop_btn, ) - any_unrendered_inputs = any( - not inp.is_rendered for inp in self.additional_inputs - ) - if self.additional_inputs and any_unrendered_inputs: - with Accordion(**self.additional_inputs_accordion_params): # type: ignore - for input_component in self.additional_inputs: - if not input_component.is_rendered: - input_component.render() - - self.saved_input = State() # Stores the most recent user message - self.null_component = State() # Used to discard unneeded values - self.chatbot_state = ( - State(self.chatbot.value) if self.chatbot.value else State([]) - ) - self.show_progress = show_progress - self._setup_events() + # Hide the stop button at the beginning, and show it with the given value during the generator execution. + self.original_stop_btn = self.textbox.stop_btn + self.textbox.stop_btn = False + + self.chatbot_state = State(self.chatbot.value if self.chatbot.value else []) + self.fake_api_btn = Button("Fake API", visible=False) + self.api_response = JSON( + label="Response", visible=False + ) # Used to store the response from the API call + + def _render_footer(self): + if self.examples: + self.examples_handler = Examples( + examples=self.examples, + inputs=[self.textbox] + self.additional_inputs, + outputs=self.chatbot, + fn=self._examples_stream_fn if self.is_generator else self._examples_fn, + cache_examples=self.cache_examples, + cache_mode=cast(Literal["eager", "lazy"], self.cache_mode), + visible=self._additional_inputs_in_examples, + preprocess=self._additional_inputs_in_examples, + ) + + any_unrendered_inputs = any( + not inp.is_rendered for inp in self.additional_inputs + ) + if self.additional_inputs and any_unrendered_inputs: + with Accordion(**self.additional_inputs_accordion_params): # type: ignore + for input_component in self.additional_inputs: + if not input_component.is_rendered: + input_component.render() def _setup_example_messages( self, @@ -382,7 +427,49 @@ def _setup_example_messages( examples_messages.append(example_message) return examples_messages + def _generate_chat_title(self, conversation: list[MessageDict]) -> str: + """ + Generate a title for a conversation by taking the first user message that is a string + and truncating it to 40 characters. If files are present, add a 📎 to the title. + """ + title = "" + for message in conversation: + if message["role"] == "user": + if isinstance(message["content"], str): + title += message["content"] + break + else: + title += "📎 " + if len(title) > 40: + title = title[:40] + "..." + return title or "Conversation" + + def _save_conversation( + self, + index: int | None, + conversation: list[MessageDict], + saved_conversations: list[list[MessageDict]], + ): + if self.save_history: + if index is not None: + saved_conversations[index] = conversation + else: + saved_conversations.append(conversation) + index = len(saved_conversations) - 1 + return index, saved_conversations + + def _delete_conversation( + self, + index: int | None, + saved_conversations: list[list[MessageDict]], + ): + if index is not None: + saved_conversations.pop(index) + return None, saved_conversations + def _setup_events(self) -> None: + from gradio import on + submit_triggers = [self.textbox.submit, self.chatbot.retry] submit_fn = self._stream_fn if self.is_generator else self._submit_fn if hasattr(self.fn, "zerogpu"): @@ -407,6 +494,17 @@ def _setup_events(self) -> None: Literal["full", "minimal", "hidden"], self.show_progress ), } + save_fn_kwargs = { + "fn": self._save_conversation, + "inputs": [ + self.conversation_id, + self.chatbot_state, + self.saved_conversations, + ], + "outputs": [self.conversation_id, self.saved_conversations], + "show_api": False, + "queue": False, + } submit_event = ( self.textbox.submit( @@ -430,7 +528,8 @@ def _setup_events(self) -> None: None, self.textbox, show_api=False, - ) + ).then(**save_fn_kwargs) + # Creates the "/chat" API endpoint self.fake_api_btn.click( submit_fn, @@ -492,7 +591,7 @@ def _setup_events(self) -> None: lambda: update(interactive=True), outputs=[self.textbox], show_api=False, - ) + ).then(**save_fn_kwargs) self._setup_stop_events(submit_triggers, [submit_event, retry_event]) @@ -502,16 +601,24 @@ def _setup_events(self) -> None: [self.chatbot, self.textbox], show_api=False, queue=False, - ).then(**synchronize_chat_state_kwargs) + ).then(**synchronize_chat_state_kwargs).then(**save_fn_kwargs) self.chatbot.option_select( self.option_clicked, [self.chatbot], [self.chatbot, self.saved_input], show_api=False, - ).then(**submit_fn_kwargs).then(**synchronize_chat_state_kwargs) + ).then(**submit_fn_kwargs).then(**synchronize_chat_state_kwargs).then( + **save_fn_kwargs + ) - self.chatbot.clear(**synchronize_chat_state_kwargs) + self.chatbot.clear(**synchronize_chat_state_kwargs).then( + self._delete_conversation, + [self.conversation_id, self.saved_conversations], + [self.conversation_id, self.saved_conversations], + show_api=False, + queue=False, + ) if self.editable: self.chatbot.edit( @@ -519,7 +626,48 @@ def _setup_events(self) -> None: [self.chatbot], [self.chatbot, self.chatbot_state, self.saved_input], show_api=False, - ).success(**submit_fn_kwargs).success(**synchronize_chat_state_kwargs) + ).success(**submit_fn_kwargs).success(**synchronize_chat_state_kwargs).then( + **save_fn_kwargs + ) + + if self.save_history: + self.new_chat_button.click( + lambda: (None, []), + None, + [self.conversation_id, self.chatbot], + show_api=False, + queue=False, + ).then( + lambda x: x, + [self.chatbot], + [self.chatbot_state], + show_api=False, + queue=False, + ) + + @on( + [self.load, self.saved_conversations.change], + inputs=[self.saved_conversations], + outputs=[self.chat_history_dataset], + show_api=False, + queue=False, + ) + def load_chat_history(conversations): + return Dataset( + samples=[ + [self._generate_chat_title(conv)] + for conv in conversations or [] + if conv + ] + ) + + self.chat_history_dataset.click( + lambda index, conversations: (index, conversations[index]), + [self.chat_history_dataset, self.saved_conversations], + [self.conversation_id, self.chatbot], + show_api=False, + queue=False, + ).then(**synchronize_chat_state_kwargs) if self.flagging_mode != "never": flagging_callback = ChatCSVLogger() diff --git a/gradio/components/button.py b/gradio/components/button.py index 8f420df8f7950..6bce6e75fe31c 100644 --- a/gradio/components/button.py +++ b/gradio/components/button.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from gradio_client.documentation import document @@ -29,8 +30,8 @@ def __init__( every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop", "huggingface"] = "secondary", - size: Literal["sm", "lg"] | None = None, - icon: str | None = None, + size: Literal["sm", "md", "lg"] = "lg", + icon: str | Path | None = None, link: str | None = None, visible: bool = True, interactive: bool = True, @@ -47,7 +48,7 @@ def __init__( every: continuously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer. inputs: components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. variant: sets the background and text color of the button. Use 'primary' for main call-to-action buttons, 'secondary' for a more subdued style, 'stop' for a stop button, 'huggingface' for a black background with white text, consistent with Hugging Face's button styles. - size: size of the button. Can be "sm" or "lg". + size: size of the button. Can be "sm", "md", or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. link: URL to open when the button is clicked. If None, no link will be used. visible: if False, component will be hidden. diff --git a/gradio/components/clear_button.py b/gradio/components/clear_button.py index f8ba9c404872d..6056990e87ba0 100644 --- a/gradio/components/clear_button.py +++ b/gradio/components/clear_button.py @@ -5,6 +5,7 @@ import copy import json from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from gradio_client.documentation import document @@ -36,8 +37,8 @@ def __init__( every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop"] = "secondary", - size: Literal["sm", "lg"] | None = None, - icon: str | None = None, + size: Literal["sm", "md", "lg"] = "lg", + icon: str | Path | None = None, link: str | None = None, visible: bool = True, interactive: bool = True, diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 731a63eada5e1..db67b6a5a31fa 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -29,11 +29,13 @@ def __init__( self, *, label: str | None = None, + show_label: bool = True, components: Sequence[Component] | list[str] | None = None, component_props: list[dict[str, Any]] | None = None, samples: list[list[Any]] | None = None, headers: list[str] | None = None, type: Literal["values", "index", "tuple"] = "values", + layout: Literal["gallery", "table"] | None = None, samples_per_page: int = 10, visible: bool = True, elem_id: str | None = None, @@ -49,10 +51,12 @@ def __init__( """ Parameters: label: the label for this component, appears above the component. + show_label: If True, the label will be shown above the component. components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample. + layout: "gallery" if the dataset should be displayed as a gallery with each sample in a clickable card, or "table" if it should be displayed as a table with each sample in a row. By default, "gallery" is used if there is a single component, and "table" is used if there are more than one component. If there are more than one component, the layout can only be "table". samples_per_page: how many examples to show per page. visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -75,6 +79,8 @@ def __init__( self.container = container self.scale = scale self.min_width = min_width + self.layout = layout + self.show_label = show_label self._components = [get_component_instance(c) for c in components or []] if component_props is None: self.component_props = [ diff --git a/gradio/components/download_button.py b/gradio/components/download_button.py index 6c34718c09745..ac7bdbf5491bc 100644 --- a/gradio/components/download_button.py +++ b/gradio/components/download_button.py @@ -1,4 +1,4 @@ -"""gr.UploadButton() component.""" +"""gr.DownloadButton() component.""" from __future__ import annotations @@ -37,7 +37,7 @@ def __init__( inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop"] = "secondary", visible: bool = True, - size: Literal["sm", "lg"] | None = None, + size: Literal["sm", "md", "lg"] = "lg", icon: str | None = None, scale: int | None = None, min_width: int | None = None, @@ -55,7 +55,7 @@ def __init__( inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button. visible: If False, component will be hidden. - size: Size of the button. Can be "sm" or "lg". + size: size of the button. Can be "sm", "md", or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. diff --git a/gradio/components/duplicate_button.py b/gradio/components/duplicate_button.py index 30433d2d6462d..f7a9a6a1f13f7 100644 --- a/gradio/components/duplicate_button.py +++ b/gradio/components/duplicate_button.py @@ -1,8 +1,9 @@ -"""Predefined buttons with bound events that can be included in a gr.Blocks for convenience.""" +"""gr.DuplicateButton() component""" from __future__ import annotations from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING, Literal from gradio_client.documentation import document @@ -30,8 +31,8 @@ def __init__( every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop", "huggingface"] = "huggingface", - size: Literal["sm", "lg"] | None = "sm", - icon: str | None = None, + size: Literal["sm", "md", "lg"] = "sm", + icon: str | Path | None = None, link: str | None = None, visible: bool = True, interactive: bool = True, @@ -50,7 +51,7 @@ def __init__( every: continuously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer. inputs: components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. variant: sets the background and text color of the button. Use 'primary' for main call-to-action buttons, 'secondary' for a more subdued style, 'stop' for a stop button, 'huggingface' for a black background with white text, consistent with Hugging Face's button styles. - size: size of the button. Can be "sm" or "lg". + size: size of the button. Can be "sm", "md", or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. link: URL to open when the button is clicked. If None, no link will be used. visible: if False, component will be hidden. diff --git a/gradio/components/login_button.py b/gradio/components/login_button.py index ca34d2ec34a57..8190011a44988 100644 --- a/gradio/components/login_button.py +++ b/gradio/components/login_button.py @@ -6,10 +6,12 @@ import time import warnings from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING, Literal from gradio_client.documentation import document +from gradio import utils from gradio.components import Button, Component from gradio.context import get_blocks_context from gradio.routes import Request @@ -34,9 +36,8 @@ def __init__( every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop", "huggingface"] = "huggingface", - size: Literal["sm", "lg"] | None = None, - icon: str - | None = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg", + size: Literal["sm", "md", "lg"] = "lg", + icon: str | Path | None = utils.get_icon_path("huggingface-logo.svg"), link: str | None = None, visible: bool = True, interactive: bool = True, diff --git a/gradio/components/logout_button.py b/gradio/components/logout_button.py index 2a1e44412ee6d..3393f19ba33e0 100644 --- a/gradio/components/logout_button.py +++ b/gradio/components/logout_button.py @@ -32,7 +32,7 @@ def __init__( every: Timer | float | None = None, inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop"] = "secondary", - size: Literal["sm", "lg"] | None = None, + size: Literal["sm", "lg"] = "lg", icon: str | None = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg", # Link to logout page (which will delete the session cookie and redirect to landing page). diff --git a/gradio/components/upload_button.py b/gradio/components/upload_button.py index 5a1fad11c2289..fcfae1c7bdde7 100644 --- a/gradio/components/upload_button.py +++ b/gradio/components/upload_button.py @@ -42,7 +42,7 @@ def __init__( inputs: Component | Sequence[Component] | set[Component] | None = None, variant: Literal["primary", "secondary", "stop"] = "secondary", visible: bool = True, - size: Literal["sm", "lg"] | None = None, + size: Literal["sm", "md", "lg"] = "lg", icon: str | None = None, scale: int | None = None, min_width: int | None = None, @@ -63,7 +63,7 @@ def __init__( inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button. visible: If False, component will be hidden. - size: Size of the button. Can be "sm" or "lg". + size: size of the button. Can be "sm", "md", or "lg". icon: URL or path to the icon file to display within the button. If None, no icon will be displayed. scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. diff --git a/gradio/icons/README.md b/gradio/icons/README.md new file mode 100644 index 0000000000000..f38df88f11fe1 --- /dev/null +++ b/gradio/icons/README.md @@ -0,0 +1,2 @@ +The icons in this directory are loaded via `gradio.utils.get_icon_path` and +can be used directly in backend code (e.g. to populate icons in components). \ No newline at end of file diff --git a/gradio/icons/huggingface-logo.svg b/gradio/icons/huggingface-logo.svg new file mode 100644 index 0000000000000..43c5d3c0c97a9 --- /dev/null +++ b/gradio/icons/huggingface-logo.svg @@ -0,0 +1,37 @@ + + + + + + + + + + + diff --git a/gradio/icons/plus.svg b/gradio/icons/plus.svg new file mode 100644 index 0000000000000..7f9fa48d4b523 --- /dev/null +++ b/gradio/icons/plus.svg @@ -0,0 +1,12 @@ + + + diff --git a/gradio/themes/base.py b/gradio/themes/base.py index 31d571c49d959..279a45f31b2a4 100644 --- a/gradio/themes/base.py +++ b/gradio/themes/base.py @@ -708,6 +708,10 @@ def set( button_small_radius=None, button_small_text_size=None, button_small_text_weight=None, + button_medium_padding=None, + button_medium_radius=None, + button_medium_text_size=None, + button_medium_text_weight=None, button_primary_background_fill=None, button_primary_background_fill_dark=None, button_primary_background_fill_hover=None, @@ -1006,6 +1010,10 @@ def set( button_small_radius: The corner radius of a button set to "small" size. button_small_text_size: The text size of a button set to "small" size. button_small_text_weight: The text weight of a button set to "small" size. + button_medium_padding: The padding of a button set to "medium" size. + button_medium_radius: The corner radius of a button set to "medium" size. + button_medium_text_size: The text size of a button set to "medium" size. + button_medium_text_weight: The text weight of a button set to "medium" size. button_transition: The transition animation duration of a button between regular, hover, and focused states. button_transform_hover: The transform animation of a button on hover. button_transform_active: The transform animation of a button when pressed. @@ -1956,5 +1964,17 @@ def set( self.button_small_text_weight = button_small_text_weight or getattr( self, "button_small_text_weight", "400" ) + self.button_medium_padding = button_medium_padding or getattr( + self, "button_medium_padding", "*spacing_md calc(2 * *spacing_md)" + ) + self.button_medium_radius = button_medium_radius or getattr( + self, "button_medium_radius", "*radius_md" + ) + self.button_medium_text_size = button_medium_text_size or getattr( + self, "button_medium_text_size", "*text_md" + ) + self.button_medium_text_weight = button_medium_text_weight or getattr( + self, "button_medium_text_weight", "600" + ) return self diff --git a/gradio/utils.py b/gradio/utils.py index c6d5ba5a6aace..080a3d1181a4e 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -8,6 +8,7 @@ import functools import hashlib import importlib +import importlib.resources import importlib.util import inspect import json @@ -1098,7 +1099,7 @@ def is_in_or_equal(path_1: str | Path, path_2: str | Path) -> bool: @document() -def set_static_paths(paths: list[str | Path]) -> None: +def set_static_paths(paths: str | Path | list[str | Path]) -> None: """ Set the static paths to be served by the gradio app. @@ -1109,7 +1110,7 @@ def set_static_paths(paths: list[str | Path]) -> None: Calling this function will set the static paths for all gradio applications defined in the same interpreter session until it is called again or the session ends. Parameters: - paths: List of filepaths or directory names to be served by the gradio app. If it is a directory name, ALL files located within that directory will be considered static and not moved to the gradio cache. This also means that ALL files in that directory will be accessible over the network. + paths: filepath or list of filepaths or directory names to be served by the gradio app. If it is a directory name, ALL files located within that directory will be considered static and not moved to the gradio cache. This also means that ALL files in that directory will be accessible over the network. Example: import gradio as gr @@ -1130,6 +1131,8 @@ def set_static_paths(paths: list[str | Path]) -> None: """ from gradio.data_classes import _StaticFiles + if isinstance(paths, (str, Path)): + paths = [Path(paths)] _StaticFiles.all_paths.extend([Path(p).resolve() for p in paths]) @@ -1587,3 +1590,21 @@ def none_or_singleton_to_list(value: Any) -> list: if isinstance(value, (list, tuple)): return list(value) return [value] + + +def get_icon_path(icon_name: str) -> str: + """Get the path to an icon file in the "gradio/icons/" directory + and return it as a static file path so that it can be used by components. + + Parameters: + icon_name: Name of the icon file (e.g. "plus.svg") + Returns: + str: Full path to the icon file served as a static file + """ + icon_path = str( + importlib.resources.files("gradio").joinpath(str(Path("icons") / icon_name)) + ) + if Path(icon_path).exists(): + set_static_paths(icon_path) + return icon_path + raise ValueError(f"Icon file not found: {icon_name}") diff --git a/guides/05_chatbots/01_creating-a-chatbot-fast.md b/guides/05_chatbots/01_creating-a-chatbot-fast.md index 36ab09ff1597e..27ef1b7692003 100644 --- a/guides/05_chatbots/01_creating-a-chatbot-fast.md +++ b/guides/05_chatbots/01_creating-a-chatbot-fast.md @@ -341,14 +341,22 @@ To use the endpoint, you should use either the [Gradio Python Client](/guides/ge * Slack bot [[tutorial]](../guides/creating-a-slack-bot-from-a-gradio-app) * Website widget [[tutorial]](../guides/creating-a-website-widget-from-a-gradio-chatbot) -## Collecting Feedback +## Chat History -To gather feedback on your generations, set `gr.ChatInterface(flagging_mode="manual")` and users can thumbs-up and down assistant responses. Each flagged response, along with the entire chat history, will get saved in a CSV file in the app folder (or wherever `flagging_dir` specifies). +You can enable persistent chat history for your ChatInterface, allowing users to maintain multiple conversations and easily switch between them. When enabled, conversations are stored locally and privately in the user's browser using local storage. So if you deploy a ChatInterface e.g. on [Hugging Face Spaces](https://hf.space), each user will have their own separate chat history that won't interfere with other users' conversations. This means multiple users can interact with the same ChatInterface simultaneously while maintaining their own private conversation histories. -You can also specify more feedback options via `flagging_options`, which will appear under a dedicated flag button. Here's an example that shows several flagging options. Because the case-sensitive string "Like" is one of the flagging options, the user will see a "thumbs up" icon next to each assistant message. The three other flagging options will appear under a dedicated "flag" icon. +To enable this feature, simply set `gr.ChatInterface(save_history=True)` (as shown in the example in the next section). Users will then see their previous conversations in a side panel and can continue any previous chat or start a new one. + +## Collecting User Feedback + +To gather feedback on your chat model, set `gr.ChatInterface(flagging_mode="manual")` and users will be able to thumbs-up or thumbs-down assistant responses. Each flagged response, along with the entire chat history, will get saved in a CSV file in the app working directory (this can be configured via the `flagging_dir` parameter). + +You can also change the feedback options via `flagging_options` parameter. The default options are "Like" and "Dislike", which appear as the thumbs-up and thumbs-down icons. Any other options appear under a dedicated flag icon. This example shows a ChatInterface that has both chat history (mentioned in the previous section) and user feedback enabled: $code_chatinterface_streaming_echo +Note that in this example, we set several flagging options: "Like", "Spam", "Inappropriate", "Other". Because the case-sensitive string "Like" is one of the flagging options, the user will see a thumbs-up icon next to each assistant message. The three other flagging options will appear in a dropdown under the flag icon. + ## What's Next? Now that you've learned about the `gr.ChatInterface` class and how it can be used to create chatbot UIs quickly, we recommend reading one of the following: diff --git a/js/button/shared/Button.svelte b/js/button/shared/Button.svelte index 42c3a5f85ccbd..1cd4ca7e84405 100644 --- a/js/button/shared/Button.svelte +++ b/js/button/shared/Button.svelte @@ -6,7 +6,7 @@ export let visible = true; export let variant: "primary" | "secondary" | "stop" | "huggingface" = "secondary"; - export let size: "sm" | "lg" = "lg"; + export let size: "sm" | "md" | "lg" = "lg"; export let value: string | null = null; export let link: string | null = null; export let icon: FileData | null = null; @@ -174,6 +174,13 @@ font-size: var(--button-small-text-size); } + .md { + border-radius: var(--button-medium-radius); + padding: var(--button-medium-padding); + font-weight: var(--button-medium-text-weight); + font-size: var(--button-medium-text-size); + } + .lg { border-radius: var(--button-large-radius); padding: var(--button-large-padding); @@ -186,7 +193,7 @@ height: var(--text-xl); } .button-icon.right-padded { - margin-right: var(--spacing-xl); + margin-right: var(--spacing-md); } .huggingface { diff --git a/js/chatbot/shared/FlagActive.svelte b/js/chatbot/shared/FlagActive.svelte new file mode 100644 index 0000000000000..804dce5979a7d --- /dev/null +++ b/js/chatbot/shared/FlagActive.svelte @@ -0,0 +1,7 @@ + diff --git a/js/chatbot/shared/LikeDislike.svelte b/js/chatbot/shared/LikeDislike.svelte index 1589f7afe7c5c..b4ac9c70f9563 100644 --- a/js/chatbot/shared/LikeDislike.svelte +++ b/js/chatbot/shared/LikeDislike.svelte @@ -5,6 +5,7 @@ import ThumbUpActive from "./ThumbUpActive.svelte"; import ThumbUpDefault from "./ThumbUpDefault.svelte"; import Flag from "./Flag.svelte"; + import FlagActive from "./FlagActive.svelte"; export let handle_action: (selected: string | null) => void; export let feedback_options: string[]; @@ -46,7 +47,13 @@ {#if extra_feedback.length > 0}
- +
{#each extra_feedback as option}