diff --git a/.changeset/warm-dragons-carry.md b/.changeset/warm-dragons-carry.md new file mode 100644 index 0000000000000..fd0f7b23c8626 --- /dev/null +++ b/.changeset/warm-dragons-carry.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface` diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index d9f6ed5e90f55..1d5b150f66a26 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str: type_ = get_type(schema) if type_ == {}: if "json" in schema.get("description", {}): - return "Dict[Any, Any]" + return "str | float | bool | list | dict" else: return "Any" elif type_ == "$ref": diff --git a/client/python/test/test_utils.py b/client/python/test/test_utils.py index ee89e16bd754e..2144be8477109 100644 --- a/client/python/test/test_utils.py +++ b/client/python/test/test_utils.py @@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema): elif schema == "FileSerializable": answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]" elif schema == "JSONSerializable": - answer = "Dict[Any, Any]" + answer = "str | float | bool | list | dict" elif schema == "GallerySerializable": answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]" elif schema == "SingleFileSerializable": diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 84eef5e5f99a4..825c54f489ac6 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -18,6 +18,7 @@ from gradio import utils from gradio.blocks import Blocks from gradio.components import ( + JSON, Button, Chatbot, Component, @@ -283,7 +284,7 @@ def __init__( self.textbox.stop_btn = False self.fake_api_btn = Button("Fake API", visible=False) - self.fake_response_textbox = Textbox( + self.api_response = JSON( label="Response", visible=False ) # Used to store the response from the API call @@ -311,6 +312,7 @@ def __init__( input_component.render() self.saved_input = State() # Stores the most recent user message + self.null_component = State() # Used to discard unneeded values self.chatbot_state = ( State(self.chatbot.value) if self.chatbot.value else State([]) ) @@ -357,8 +359,7 @@ def _setup_events(self) -> None: submit_fn_kwargs = { "fn": submit_fn, "inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs, - "outputs": [self.fake_response_textbox, self.chatbot] - + self.additional_outputs, + "outputs": [self.null_component, self.chatbot] + self.additional_outputs, "show_api": False, "concurrency_limit": cast( Union[int, Literal["default"], None], self.concurrency_limit @@ -395,11 +396,12 @@ def _setup_events(self) -> None: self.fake_api_btn.click( submit_fn, [self.textbox, self.chatbot_state] + self.additional_inputs, - [self.fake_response_textbox, self.chatbot_state] + self.additional_outputs, + [self.api_response, self.chatbot_state] + self.additional_outputs, api_name=cast(Union[str, Literal[False]], self.api_name), concurrency_limit=cast( Union[int, Literal["default"], None], self.concurrency_limit ), + postprocess=False, ) if (