From ab142ee13d19070b75b5eb03efcda7193b8993c2 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Tue, 27 Aug 2024 18:49:36 -0400 Subject: [PATCH] Fix serialization error in curl api (#9189) * Fix code * add changeset * Fix code * fix * add another test --------- Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/moody-dogs-search.md | 5 +++++ gradio/routes.py | 5 +++-- test/test_routes.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 .changeset/moody-dogs-search.md diff --git a/.changeset/moody-dogs-search.md b/.changeset/moody-dogs-search.md new file mode 100644 index 0000000000000..1245e6b02c841 --- /dev/null +++ b/.changeset/moody-dogs-search.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix serialization error in curl api diff --git a/gradio/routes.py b/gradio/routes.py index 04afb5438b2fc..cc560dc20d047 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -891,12 +891,13 @@ async def simple_predict_get( event_id: str, ): def process_msg(message: EventMessage) -> str | None: + msg = message.model_dump() if isinstance(message, ProcessCompletedMessage): event = "complete" if message.success else "error" - data = message.output.get("data") + data = msg["output"].get("data") elif isinstance(message, ProcessGeneratingMessage): event = "generating" if message.success else "error" - data = message.output.get("data") + data = msg["output"].get("data") elif isinstance(message, HeartbeatMessage): event = "heartbeat" data = None diff --git a/test/test_routes.py b/test/test_routes.py index 7834f7a89c2f6..3aaa7eb70d6b1 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -1,6 +1,7 @@ """Contains tests for networking.py and app.py""" import functools +import json import os import pickle import tempfile @@ -1454,3 +1455,35 @@ def test_file_access(): demo.close() not_allowed_file.unlink() allowed_file.unlink() + + +def test_bash_api_serialization(): + demo = gr.Interface(lambda x: x, "json", "json") + + app, _, _ = demo.launch(prevent_thread_lock=True) + test_client = TestClient(app) + + with test_client: + submit = test_client.post("/call/predict", json={"data": [{"a": 1}]}) + event_id = submit.json()["event_id"] + response = test_client.get(f"/call/predict/{event_id}") + assert response.status_code == 200 + assert "event: complete\ndata:" in response.text + assert json.dumps({"a": 1}) in response.text + + +def test_bash_api_multiple_inputs_outputs(): + demo = gr.Interface( + lambda x, y: (y, x), ["textbox", "number"], ["number", "textbox"] + ) + + app, _, _ = demo.launch(prevent_thread_lock=True) + test_client = TestClient(app) + + with test_client: + submit = test_client.post("/call/predict", json={"data": ["abc", 123]}) + event_id = submit.json()["event_id"] + response = test_client.get(f"/call/predict/{event_id}") + assert response.status_code == 200 + assert "event: complete\ndata:" in response.text + assert json.dumps([123, "abc"]) in response.text