diff --git a/.changeset/shaggy-mirrors-nail.md b/.changeset/shaggy-mirrors-nail.md new file mode 100644 index 0000000000000..8a404299b6a2f --- /dev/null +++ b/.changeset/shaggy-mirrors-nail.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Allow accepting user-provided-tokens in `gr.load` diff --git a/gradio/external.py b/gradio/external.py index 9777ebd7f58b6..bbd8aa73f9861 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -41,6 +41,7 @@ def load( | None = None, token: str | None = None, hf_token: str | None = None, + accept_token: bool = False, **kwargs, ) -> Blocks: """ @@ -49,6 +50,7 @@ def load( name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided. src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`. token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens. + accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space. kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor. Returns: a Gradio Blocks app for the given model @@ -64,23 +66,44 @@ def load( ) if src is None: # Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224") - tokens = name.split("/") - if len(tokens) <= 1: + parts = name.split("/") + if len(parts) <= 1: raise ValueError( "Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}" ) - src = tokens[0] # type: ignore - name = "/".join(tokens[1:]) - if src in ["huggingface", "models", "spaces"]: + src = parts[0] # type: ignore + name = "/".join(parts[1:]) + assert src is not None # noqa: S101 + if not isinstance(src, Callable) and src not in ["models", "spaces", "huggingface"]: + raise ValueError( + "The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app." + ) + + if not accept_token: + if isinstance(src, Callable): + return src(name, token, **kwargs) return load_blocks_from_huggingface( name=name, src=src, hf_token=token, **kwargs ) - elif isinstance(src, Callable): - return src(name, token, **kwargs) else: - raise ValueError( - "The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app." - ) + import gradio as gr + + with gr.Blocks(fill_height=True) as demo: + textbox = gr.Textbox( + type="password", + label="Token", + info="Enter your token and press enter.", + ) + + @gr.render(inputs=[textbox], triggers=[textbox.submit]) + def create(token_value): + if isinstance(src, Callable): + return src(name, token_value, **kwargs) + return load_blocks_from_huggingface( + name=name, src=src, hf_token=token_value, **kwargs + ) + + return demo def load_blocks_from_huggingface(