Skip to content

Commit

Permalink
Allow accepting user-provided-tokens in gr.load (#9807)
Browse files Browse the repository at this point in the history
* load

* add changeset

* lint

* external

* lint

* changes

* format

* changes

* renamed

* external

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Oct 23, 2024
1 parent 90d9d14 commit 5e89b6d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .changeset/shaggy-mirrors-nail.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Allow accepting user-provided-tokens in `gr.load`
43 changes: 33 additions & 10 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def load(
| None = None,
token: str | None = None,
hf_token: str | None = None,
accept_token: bool = False,
**kwargs,
) -> Blocks:
"""
Expand All @@ -49,6 +50,7 @@ def load(
name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided.
src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`.
token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
Returns:
a Gradio Blocks app for the given model
Expand All @@ -64,23 +66,44 @@ def load(
)
if src is None:
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
tokens = name.split("/")
if len(tokens) <= 1:
parts = name.split("/")
if len(parts) <= 1:
raise ValueError(
"Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
)
src = tokens[0] # type: ignore
name = "/".join(tokens[1:])
if src in ["huggingface", "models", "spaces"]:
src = parts[0] # type: ignore
name = "/".join(parts[1:])
assert src is not None # noqa: S101
if not isinstance(src, Callable) and src not in ["models", "spaces", "huggingface"]:
raise ValueError(
"The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app."
)

if not accept_token:
if isinstance(src, Callable):
return src(name, token, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token, **kwargs
)
elif isinstance(src, Callable):
return src(name, token, **kwargs)
else:
raise ValueError(
"The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app."
)
import gradio as gr

with gr.Blocks(fill_height=True) as demo:
textbox = gr.Textbox(
type="password",
label="Token",
info="Enter your token and press enter.",
)

@gr.render(inputs=[textbox], triggers=[textbox.submit])
def create(token_value):
if isinstance(src, Callable):
return src(name, token_value, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token_value, **kwargs
)

return demo


def load_blocks_from_huggingface(
Expand Down

0 comments on commit 5e89b6d

Please sign in to comment.