Skip to content

Commit

Permalink
JSON type fix in Client and and typing fix for /chat endpoint in `g…
Browse files Browse the repository at this point in the history
…r.ChatInterface` (#10193)

* fix

* add changeset

* add changeset

* fix

* chat interface fixes

* rename

* add changeset

* format

* changes

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Dec 13, 2024
1 parent 5e6e234 commit 424365b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .changeset/warm-dragons-carry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface`
2 changes: 1 addition & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
type_ = get_type(schema)
if type_ == {}:
if "json" in schema.get("description", {}):
return "Dict[Any, Any]"
return "str | float | bool | list | dict"
else:
return "Any"
elif type_ == "$ref":
Expand Down
2 changes: 1 addition & 1 deletion client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema):
elif schema == "FileSerializable":
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
elif schema == "JSONSerializable":
answer = "Dict[Any, Any]"
answer = "str | float | bool | list | dict"
elif schema == "GallerySerializable":
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
elif schema == "SingleFileSerializable":
Expand Down
10 changes: 6 additions & 4 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gradio import utils
from gradio.blocks import Blocks
from gradio.components import (
JSON,
Button,
Chatbot,
Component,
Expand Down Expand Up @@ -283,7 +284,7 @@ def __init__(
self.textbox.stop_btn = False

self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
self.api_response = JSON(
label="Response", visible=False
) # Used to store the response from the API call

Expand Down Expand Up @@ -311,6 +312,7 @@ def __init__(
input_component.render()

self.saved_input = State() # Stores the most recent user message
self.null_component = State() # Used to discard unneeded values
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
Expand Down Expand Up @@ -357,8 +359,7 @@ def _setup_events(self) -> None:
submit_fn_kwargs = {
"fn": submit_fn,
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
"outputs": [self.fake_response_textbox, self.chatbot]
+ self.additional_outputs,
"outputs": [self.null_component, self.chatbot] + self.additional_outputs,
"show_api": False,
"concurrency_limit": cast(
Union[int, Literal["default"], None], self.concurrency_limit
Expand Down Expand Up @@ -395,11 +396,12 @@ def _setup_events(self) -> None:
self.fake_api_btn.click(
submit_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs,
[self.api_response, self.chatbot_state] + self.additional_outputs,
api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
postprocess=False,
)

if (
Expand Down

0 comments on commit 424365b

Please sign in to comment.