Skip to content

Commit

Permalink
Fix serialization error in curl api (#9189)
Browse files Browse the repository at this point in the history
* Fix code

* add changeset

* Fix code

* fix

* add another test

---------

Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent 4a85559 commit ab142ee
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/moody-dogs-search.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix serialization error in curl api
5 changes: 3 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,13 @@ async def simple_predict_get(
event_id: str,
):
def process_msg(message: EventMessage) -> str | None:
msg = message.model_dump()
if isinstance(message, ProcessCompletedMessage):
event = "complete" if message.success else "error"
data = message.output.get("data")
data = msg["output"].get("data")
elif isinstance(message, ProcessGeneratingMessage):
event = "generating" if message.success else "error"
data = message.output.get("data")
data = msg["output"].get("data")
elif isinstance(message, HeartbeatMessage):
event = "heartbeat"
data = None
Expand Down
33 changes: 33 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains tests for networking.py and app.py"""

import functools
import json
import os
import pickle
import tempfile
Expand Down Expand Up @@ -1454,3 +1455,35 @@ def test_file_access():
demo.close()
not_allowed_file.unlink()
allowed_file.unlink()


def test_bash_api_serialization():
demo = gr.Interface(lambda x: x, "json", "json")

app, _, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)

with test_client:
submit = test_client.post("/call/predict", json={"data": [{"a": 1}]})
event_id = submit.json()["event_id"]
response = test_client.get(f"/call/predict/{event_id}")
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps({"a": 1}) in response.text


def test_bash_api_multiple_inputs_outputs():
demo = gr.Interface(
lambda x, y: (y, x), ["textbox", "number"], ["number", "textbox"]
)

app, _, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)

with test_client:
submit = test_client.post("/call/predict", json={"data": ["abc", 123]})
event_id = submit.json()["event_id"]
response = test_client.get(f"/call/predict/{event_id}")
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps([123, "abc"]) in response.text

0 comments on commit ab142ee

Please sign in to comment.