Skip to content

Commit

Permalink
prefix api routes (#9200)
Browse files Browse the repository at this point in the history
  • Loading branch information
pngwn authored Sep 2, 2024
1 parent 38cf712 commit 2e179d3
Show file tree
Hide file tree
Showing 30 changed files with 432 additions and 297 deletions.
8 changes: 8 additions & 0 deletions .changeset/witty-rice-fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@gradio/client": minor
"@gradio/core": minor
"gradio": minor
"gradio_client": minor
---

feat:prefix api routes
30 changes: 22 additions & 8 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@ import {
} from "./helpers/init_helpers";
import { check_and_wake_space, check_space_status } from "./helpers/spaces";
import { open_stream, readable_stream, close_stream } from "./utils/stream";
import { API_INFO_ERROR_MSG, CONFIG_ERROR_MSG } from "./constants";
import {
API_INFO_ERROR_MSG,
CONFIG_ERROR_MSG,
HEARTBEAT_URL,
COMPONENT_SERVER_URL
} from "./constants";

export class Client {
app_reference: string;
options: ClientOptions;

config: Config | undefined;
api_prefix = "";
api_info: ApiInfo<JsApiData> | undefined;
api_map: Record<string, number> = {};
session_hash: string = Math.random().toString(36).substring(2);
Expand Down Expand Up @@ -175,6 +181,8 @@ export class Client {
async _resolve_hearbeat(_config: Config): Promise<void> {
if (_config) {
this.config = _config;
this.api_prefix = _config.api_prefix || "";

if (this.config && this.config.connect_heartbeat) {
if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(
Expand All @@ -193,7 +201,7 @@ export class Client {
if (this.config && this.config.connect_heartbeat) {
// connect to the heartbeat endpoint via GET request
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
`${this.config.root}${this.api_prefix}/${HEARTBEAT_URL}/${this.session_hash}`
);

// if the jwt is available, add it to the query params
Expand Down Expand Up @@ -282,6 +290,7 @@ export class Client {
_config: Config
): Promise<Config | client_return> {
this.config = _config;
this.api_prefix = _config.api_prefix || "";

if (typeof window !== "undefined" && typeof document !== "undefined") {
if (window.location.protocol === "https:") {
Expand Down Expand Up @@ -311,6 +320,8 @@ export class Client {
if (status.status === "running") {
try {
this.config = await this._resolve_config();
this.api_prefix = this?.config?.api_prefix || "";

if (!this.config) {
throw new Error(CONFIG_ERROR_MSG);
}
Expand Down Expand Up @@ -390,12 +401,15 @@ export class Client {
}

try {
const response = await this.fetch(`${root_url}/component_server/`, {
method: "POST",
body: body,
headers,
credentials: "include"
});
const response = await this.fetch(
`${root_url}${this.api_prefix}/${COMPONENT_SERVER_URL}/`,
{
method: "POST",
body: body,
headers,
credentials: "include"
}
);

if (!response.ok) {
throw new Error(
Expand Down
32 changes: 18 additions & 14 deletions client/js/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
// endpoints
export const HOST_URL = "host";
export const API_URL = "api/predict/";
export const SSE_URL_V0 = "queue/join";
export const SSE_DATA_URL_V0 = "queue/data";
export const SSE_URL = "queue/data";
export const SSE_DATA_URL = "queue/join";
export const UPLOAD_URL = "upload";
export const LOGIN_URL = "login";
export const CONFIG_URL = "config";
export const API_INFO_URL = "info";
export const RUNTIME_URL = "runtime";
export const SLEEPTIME_URL = "sleeptime";
export const RAW_API_INFO_URL = "info?serialize=False";
export const HOST_URL = `host`;
export const API_URL = `predict/`;
export const SSE_URL_V0 = `queue/join`;
export const SSE_DATA_URL_V0 = `queue/data`;
export const SSE_URL = `queue/data`;
export const SSE_DATA_URL = `queue/join`;
export const UPLOAD_URL = `upload`;
export const LOGIN_URL = `login`;
export const CONFIG_URL = `config`;
export const API_INFO_URL = `info`;
export const RUNTIME_URL = `runtime`;
export const SLEEPTIME_URL = `sleeptime`;
export const HEARTBEAT_URL = `heartbeat`;
export const COMPONENT_SERVER_URL = `component_server`;
export const RESET_URL = `reset`;
export const CANCEL_URL = `cancel`;

export const RAW_API_INFO_URL = `info?serialize=False`;
export const SPACE_FETCHER_URL =
"https://gradio-space-api-fetcher-v2.hf.space/api";
export const RESET_URL = "reset";
export const SPACE_URL = "https://hf.space/{}";

// messages
Expand Down
1 change: 1 addition & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ export interface Config {
max_file_size?: number;
theme_hash?: number;
username: string | null;
api_prefix?: string;
}

// todo: DRY up types
Expand Down
3 changes: 1 addition & 2 deletions client/js/src/upload.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { UploadResponse } from "./types";
import type { Client } from "./client";

export async function upload(
Expand Down Expand Up @@ -34,7 +33,7 @@ export async function upload(
const file = new FileData({
...file_data[i],
path: f,
url: root_url + "/file=" + f
url: `${root_url}${this.api_prefix}/file=${f}`
});
return file;
});
Expand Down
4 changes: 2 additions & 2 deletions client/js/src/utils/stream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BROKEN_CONNECTION_MSG } from "../constants";
import { BROKEN_CONNECTION_MSG, SSE_URL } from "../constants";
import type { Client } from "../client";
import { stream } from "fetch-event-stream";

Expand All @@ -25,7 +25,7 @@ export async function open_stream(this: Client): Promise<void> {
session_hash: this.session_hash
}).toString();

let url = new URL(`${config.root}/queue/data?${params}`);
let url = new URL(`${config.root}${this.api_prefix}/${SSE_URL}?${params}`);

if (jwt) {
url.searchParams.set("__sign", jwt);
Expand Down
35 changes: 23 additions & 12 deletions client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ import {
process_endpoint
} from "../helpers/api_info";
import semiver from "semiver";
import { BROKEN_CONNECTION_MSG, QUEUE_FULL_MSG } from "../constants";
import {
BROKEN_CONNECTION_MSG,
QUEUE_FULL_MSG,
SSE_URL,
SSE_DATA_URL,
RESET_URL,
CANCEL_URL
} from "../constants";
import { apply_diff_stream, close_stream } from "./stream";
import { Client } from "../client";

Expand All @@ -46,7 +53,8 @@ export function submit(
event_callbacks,
unclosed_events,
post_data,
options
options,
api_prefix
} = this;

const that = this;
Expand Down Expand Up @@ -133,14 +141,14 @@ export function submit(
}

if ("event_id" in cancel_request) {
await fetch(`${config.root}/cancel`, {
await fetch(`${config.root}${api_prefix}/${CANCEL_URL}`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(cancel_request)
});
}

await fetch(`${config.root}/reset`, {
await fetch(`${config.root}${api_prefix}/${RESET_URL}`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(reset_request)
Expand Down Expand Up @@ -207,7 +215,7 @@ export function submit(
});

post_data(
`${config.root}/run${
`${config.root}${api_prefix}/run${
_endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
}${url_params ? "?" + url_params : ""}`,
{
Expand Down Expand Up @@ -413,7 +421,7 @@ export function submit(
session_hash: session_hash
}).toString();
let url = new URL(
`${config.root}/queue/join?${
`${config.root}${api_prefix}/${SSE_URL}?${
url_params ? url_params + "&" : ""
}${params}`
);
Expand Down Expand Up @@ -451,11 +459,14 @@ export function submit(
close();
}
} else if (type === "data") {
let [_, status] = await post_data(`${config.root}/queue/data`, {
...payload,
session_hash,
event_id
});
let [_, status] = await post_data(
`${config.root}${api_prefix}/queue/data`,
{
...payload,
session_hash,
event_id
}
);
if (status !== 200) {
fire_event({
type: "status",
Expand Down Expand Up @@ -564,7 +575,7 @@ export function submit(
: Promise.resolve(null);
const post_data_promise = zerogpu_auth_promise.then((headers) => {
return post_data(
`${config.root}/queue/join?${url_params}`,
`${config.root}${api_prefix}/${SSE_DATA_URL}?${url_params}`,
{
...payload,
session_hash
Expand Down
4 changes: 2 additions & 2 deletions client/js/src/utils/upload_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ export async function upload_files(
});
try {
const upload_url = upload_id
? `${root_url}/${UPLOAD_URL}?upload_id=${upload_id}`
: `${root_url}/${UPLOAD_URL}`;
? `${root_url}${this.api_prefix}/${UPLOAD_URL}?upload_id=${upload_id}`
: `${root_url}${this.api_prefix}/${UPLOAD_URL}`;

response = await this.fetch(upload_url, {
method: "POST",
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/utils/view_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export async function view_api(this: Client): Promise<any> {
credentials: "include"
});
} else {
const url = join_urls(config.root, API_INFO_URL);
const url = join_urls(config.root, this.api_prefix, API_INFO_URL);
response = await this.fetch(url, {
headers,
credentials: "include"
Expand Down
30 changes: 19 additions & 11 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,26 @@ def __init__(
self.protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1"] = (
self.config.get("protocol", "ws")
)
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
self.api_prefix: str = self.config.get("api_prefix", "").lstrip("/") + "/"
self.src_prefixed = urllib.parse.urljoin(self.src, self.api_prefix) + "/"

self.api_url = urllib.parse.urljoin(self.src_prefixed, utils.API_URL)
self.sse_url = urllib.parse.urljoin(
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
self.src_prefixed,
utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL,
)
self.heartbeat_url = urllib.parse.urljoin(
self.src_prefixed, utils.HEARTBEAT_URL
)
self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL)
self.sse_data_url = urllib.parse.urljoin(
self.src,
self.src_prefixed,
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
)
self.ws_url = urllib.parse.urljoin(
self.src.replace("http", "ws", 1), utils.WS_URL
self.src_prefixed.replace("http", "ws", 1), utils.WS_URL
)
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
self.upload_url = urllib.parse.urljoin(self.src_prefixed, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src_prefixed, utils.RESET_URL)
self.app_version = version.parse(self.config.get("version", "2.0"))
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())
Expand Down Expand Up @@ -552,7 +558,9 @@ def fn(future):
return job

def _get_api_info(self):
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
print("SRC PREFIXED", self.src_prefixed, utils.RAW_API_INFO_URL)
api_info_url = urllib.parse.urljoin(self.src_prefixed, utils.RAW_API_INFO_URL)
print("API INFO URL", api_info_url)
if self.app_version > version.Version("3.36.1"):
r = httpx.get(
api_info_url,
Expand Down Expand Up @@ -864,7 +872,7 @@ def _get_config(self) -> dict:
)
else: # to support older versions of Gradio
r = httpx.get(
self.src,
self.src_prefixed,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
Expand Down Expand Up @@ -1071,7 +1079,7 @@ def __init__(
]
self.parameters_info = self._get_parameters_info()

self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.root_url = client.src.rstrip("/") + "/" + client.api_prefix

# Disallow hitting endpoints that the Gradio app has disabled
self.is_valid = self.api_name is not False
Expand Down Expand Up @@ -1139,7 +1147,7 @@ def make_cancel(
if helper is None:
return
if self.client.app_version > version.Version("4.29.0"):
url = urllib.parse.urljoin(self.client.src, utils.CANCEL_URL)
url = urllib.parse.urljoin(self.client.src_prefixed, utils.CANCEL_URL)

# The event_id won't be set on the helper until later
# so need to create the data in a function that's run at cancel time
Expand Down
11 changes: 9 additions & 2 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
InvalidComponentError,
)
from gradio.helpers import create_tracker, skip, special_args
from gradio.route_utils import MediaStream
from gradio.route_utils import API_PREFIX, MediaStream
from gradio.state_holder import SessionState, StateHolder
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass as Theme
Expand Down Expand Up @@ -2049,6 +2049,7 @@ def get_config(self):
def get_config_file(self) -> BlocksConfigDict:
config: BlocksConfigDict = {
"version": routes.VERSION,
"api_prefix": API_PREFIX,
"mode": self.mode,
"app_id": self.app_id,
"dev_mode": self.dev_mode,
Expand Down Expand Up @@ -2403,6 +2404,7 @@ def reverse(text):
)
self.server_name = server_name
self.local_url = local_url
self.local_api_url = f"{self.local_url.rstrip('/')}{API_PREFIX}/"
self.server_port = server_port
self.server = server
self.is_running = True
Expand Down Expand Up @@ -2431,7 +2433,9 @@ def reverse(text):
# Cannot run async functions in background other than app's scope.
# Workaround by triggering the app endpoint
httpx.get(
f"{self.local_url}startup-events", verify=ssl_verify, timeout=None
f"{self.local_api_url}startup-events",
verify=ssl_verify,
timeout=None,
)
else:
# NOTE: One benefit of the code above dispatching `startup_events()` via a self HTTP request is
Expand Down Expand Up @@ -2486,6 +2490,9 @@ def reverse(text):
and not networking.url_ok(self.local_url)
and not self.share
):
print(self.local_url)
print(networking.url_ok(self.local_url))

raise ValueError(
"When localhost is not accessible, a shareable link must be created. Please set share=True or check your proxy settings to allow access to localhost."
)
Expand Down
2 changes: 1 addition & 1 deletion gradio/components/logout_button.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
icon: str
| None = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
# Link to logout page (which will delete the session cookie and redirect to landing page).
link: str | None = "/logout",
link: str | None = "/gradio_api/logout",
visible: bool = True,
interactive: bool = True,
elem_id: str | None = None,
Expand Down
Loading

0 comments on commit 2e179d3

Please sign in to comment.