Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JSON type fix in Client and and typing fix for /chat endpoint in gr.ChatInterface #10193

Merged
merged 12 commits into from
Dec 13, 2024
6 changes: 6 additions & 0 deletions .changeset/warm-dragons-carry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface`
2 changes: 1 addition & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
type_ = get_type(schema)
if type_ == {}:
if "json" in schema.get("description", {}):
return "Dict[Any, Any]"
return "str | float | bool | list | dict"
else:
return "Any"
elif type_ == "$ref":
Expand Down
2 changes: 1 addition & 1 deletion client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema):
elif schema == "FileSerializable":
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
elif schema == "JSONSerializable":
answer = "Dict[Any, Any]"
answer = "str | float | bool | list | dict"
elif schema == "GallerySerializable":
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
elif schema == "SingleFileSerializable":
Expand Down
10 changes: 6 additions & 4 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gradio import utils
from gradio.blocks import Blocks
from gradio.components import (
JSON,
Button,
Chatbot,
Component,
Expand Down Expand Up @@ -283,7 +284,7 @@ def __init__(
self.textbox.stop_btn = False

self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
self.api_response = JSON(
label="Response", visible=False
) # Used to store the response from the API call

Expand Down Expand Up @@ -311,6 +312,7 @@ def __init__(
input_component.render()

self.saved_input = State() # Stores the most recent user message
self.null_component = State() # Used to discard unneeded values
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
Expand Down Expand Up @@ -357,8 +359,7 @@ def _setup_events(self) -> None:
submit_fn_kwargs = {
"fn": submit_fn,
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
"outputs": [self.fake_response_textbox, self.chatbot]
+ self.additional_outputs,
"outputs": [self.null_component, self.chatbot] + self.additional_outputs,
"show_api": False,
"concurrency_limit": cast(
Union[int, Literal["default"], None], self.concurrency_limit
Expand Down Expand Up @@ -395,11 +396,12 @@ def _setup_events(self) -> None:
self.fake_api_btn.click(
submit_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs,
[self.api_response, self.chatbot_state] + self.additional_outputs,
api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
postprocess=False,
)

if (
Expand Down
Loading