Skip to content

Commit

Permalink
Custom component in rerender (#10170)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* Support gr.Examples in gr.render (#10173)

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <[email protected]>

* add changeset

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2024
1 parent e525680 commit 5e6e234
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 19 deletions.
6 changes: 6 additions & 0 deletions .changeset/modern-views-say.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/core": patch
"gradio": patch
---

fix:Custom component in rerender
2 changes: 1 addition & 1 deletion demo/render_tests/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: render_tests"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from datetime import datetime\n", "\n", "import gradio as gr\n", "\n", "def update_log():\n", " return datetime.now().timestamp()\n", "\n", "def get_target(evt: gr.EventData):\n", " return evt.target\n", "\n", "def get_select_index(evt: gr.SelectData):\n", " return evt.index\n", "\n", "with gr.Blocks() as demo:\n", " gr.Textbox(value=update_log, every=0.2, label=\"Time\")\n", " \n", " slider = gr.Slider(1, 10, step=1)\n", " @gr.render(inputs=[slider])\n", " def show_log(s):\n", " with gr.Row():\n", " for i in range(s):\n", " gr.Textbox(value=update_log, every=0.2, label=f\"Render {i + 1}\")\n", "\n", " with gr.Row():\n", " selected_btn = gr.Textbox(label=\"Selected Button\")\n", " selected_chat = gr.Textbox(label=\"Selected Chat\")\n", " @gr.render(inputs=[slider])\n", " def show_buttons(s):\n", " with gr.Row():\n", " with gr.Column():\n", " for i in range(s):\n", " btn = gr.Button(f\"Button {i + 1}\")\n", " btn.click(get_target, None, selected_btn)\n", " chatbot = gr.Chatbot([[\"Hello\", \"Hi\"], [\"How are you?\", \"I'm good.\"]])\n", " chatbot.select(get_select_index, None, selected_chat)\n", "\n", "if __name__ == '__main__':\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: render_tests"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from datetime import datetime\n", "\n", "import gradio as gr\n", "\n", "def update_log():\n", " return datetime.now().timestamp()\n", "\n", "def get_target(evt: gr.EventData):\n", " return evt.target\n", "\n", "def get_select_index(evt: gr.SelectData):\n", " return evt.index\n", "\n", "with gr.Blocks() as demo:\n", " gr.Textbox(value=update_log, every=0.2, label=\"Time\")\n", " \n", " slider = gr.Slider(1, 10, step=1)\n", " @gr.render(inputs=[slider])\n", " def show_log(s):\n", " with gr.Row():\n", " for i in range(s):\n", " gr.Textbox(value=update_log, every=0.2, label=f\"Render {i + 1}\")\n", "\n", " with gr.Row():\n", " selected_btn = gr.Textbox(label=\"Selected Button\")\n", " selected_chat = gr.Textbox(label=\"Selected Chat\")\n", " @gr.render(inputs=[slider])\n", " def show_buttons(s):\n", " with gr.Row():\n", " with gr.Column():\n", " for i in range(s):\n", " btn = gr.Button(f\"Button {i + 1}\")\n", " btn.click(get_target, None, selected_btn)\n", " chatbot = gr.Chatbot([[\"Hello\", \"Hi\"], [\"How are you?\", \"I'm good.\"]])\n", " chatbot.select(get_select_index, None, selected_chat)\n", "\n", " @gr.render()\n", " def examples_in_interface():\n", " gr.Interface(lambda x:x, gr.Textbox(label=\"input\"), gr.Textbox(), examples=[[\"test\"]])\n", "\n", " @gr.render()\n", " def examples_in_blocks():\n", " a = gr.Textbox(label=\"little textbox\")\n", " gr.Examples([[\"abc\"], [\"def\"]], [a])\n", "\n", "\n", "if __name__ == '__main__':\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
10 changes: 10 additions & 0 deletions demo/render_tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,15 @@ def show_buttons(s):
chatbot = gr.Chatbot([["Hello", "Hi"], ["How are you?", "I'm good."]])
chatbot.select(get_select_index, None, selected_chat)

@gr.render()
def examples_in_interface():
gr.Interface(lambda x:x, gr.Textbox(label="input"), gr.Textbox(), examples=[["test"]])

@gr.render()
def examples_in_blocks():
a = gr.Textbox(label="little textbox")
gr.Examples([["abc"], ["def"]], [a])


if __name__ == '__main__':
demo.launch()
13 changes: 8 additions & 5 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from gradio_client.documentation import document

from gradio import components, oauth, processing_utils, routes, utils, wasm_utils
from gradio.context import Context, LocalContext
from gradio.context import Context, LocalContext, get_blocks_context
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
Expand Down Expand Up @@ -338,10 +338,13 @@ def _get_processed_example(self, example):

def create(self) -> None:
"""Creates the Dataset component to hold the examples"""

self.root_block = Context.root_block
if self.root_block:
self.root_block.extra_startup_events.append(self._start_caching)
blocks_config = get_blocks_context()
self.root_block = Context.root_block or (
blocks_config.root_block if blocks_config else None
)
if blocks_config:
if self.root_block:
self.root_block.extra_startup_events.append(self._start_caching)

if self.cache_examples:

Expand Down
14 changes: 6 additions & 8 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,25 +613,23 @@ def custom_component_path(
raise HTTPException(
status_code=404, detail="Environment not supported."
)
config = app.get_blocks().config
components = config["components"]
components = utils.get_all_components()
location = next(
(item for item in components if item["component_class_id"] == id), None
(item for item in components if item.get_component_class_id() == id),
None,
)
if location is None:
raise HTTPException(status_code=404, detail="Component not found.")

component_instance = app.get_blocks().get_component(location["id"])

module_name = component_instance.__class__.__module__
module_name = location.__module__
module_path = sys.modules[module_name].__file__

if module_path is None or component_instance is None:
if module_path is None:
raise HTTPException(status_code=404, detail="Component not found.")

try:
requested_path = utils.safe_join(
component_instance.__class__.TEMPLATE_DIR,
location.TEMPLATE_DIR,
UserProvidedPath(f"{type}/{file_name}"),
)
except InvalidPathError:
Expand Down
2 changes: 1 addition & 1 deletion js/core/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@
rerender_layout({
components: _components,
layout: render_layout,
root: root,
root: root + api_prefix,
dependencies: dependencies,
render_id: render_id
});
Expand Down
14 changes: 10 additions & 4 deletions js/core/src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
{} as { [id: number]: ComponentMeta }
);

await walk_layout(layout, root);
await walk_layout(layout, root, _components);

layout_store.set(_rootNode);
set_event_specific_args(dependencies);
Expand Down Expand Up @@ -230,7 +230,12 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
] = instance_map[layout.id];
}

walk_layout(layout, root, current_element.parent).then(() => {
walk_layout(
layout,
root,
_components.concat(components),
current_element.parent
).then(() => {
layout_store.set(_rootNode);
});

Expand All @@ -240,6 +245,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
async function walk_layout(
node: LayoutNode,
root: string,
components: ComponentMeta[],
parent?: ComponentMeta
): Promise<ComponentMeta> {
const instance = instance_map[node.id];
Expand All @@ -254,7 +260,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {
instance.type,
instance.component_class_id,
root,
_components,
components,
instance.props.components
).example_components;
}
Expand Down Expand Up @@ -288,7 +294,7 @@ export function create_components(initial_layout: ComponentMeta | undefined): {

if (node.children) {
instance.children = await Promise.all(
node.children.map((v) => walk_layout(v, root, instance))
node.children.map((v) => walk_layout(v, root, components, instance))
);
}

Expand Down
9 changes: 9 additions & 0 deletions js/spa/test/render_tests.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ test("Test event/selection data works in render", async ({ page }) => {
await page.getByText("Hi").click();
await expect(selected_chat).toHaveValue("[0, 1]");
});

test("Test examples work in render", async ({ page }) => {
await page.getByRole("button", { name: "test" }).click();
await expect(page.getByLabel("input", { exact: true })).toHaveValue("test");
await page.getByRole("button", { name: "def", exact: true }).click();
await expect(page.getByLabel("little textbox", { exact: true })).toHaveValue(
"def"
);
});

0 comments on commit 5e6e234

Please sign in to comment.