Skip to content

Commit

Permalink
Improve how server/js client handle unexpected errors (#6798)
Browse files Browse the repository at this point in the history
* Client fixes

* fix

* add changeset

* commented out code

* add changeset

* Log error give generic message

* Add client side catch

* remove exception 😂

* Add test

* Fix info and warning

* lint

* Use BaseException

* Use event_callbacks

* Use event_id not present

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Dec 15, 2023
1 parent 526fb6c commit 245d58e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 66 deletions.
6 changes: 6 additions & 0 deletions .changeset/dirty-experts-cry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---

feat:Improve how server/js client handle unexpected errors
181 changes: 120 additions & 61 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -821,84 +821,121 @@ export function api_factory(
});
} else {
event_id = response.event_id as string;
if (!stream_open) {
open_stream();
}
let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);

if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
} else if (type === "complete") {
complete = status;
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});
try {
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);

// TODO: Find out how to print this information
// only during testing
// console.info("data", type, status, data);

if (type == "heartbeat") {
return;
}

if (complete) {
if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
} else if (type === "complete") {
complete = status;
} else if (type == "unexpected_error") {
console.error("Unexpected error", status?.message);
fire_event({
type: "status",
stage: "error",
message: "An Unexpected Error Occurred!",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
return;
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...complete,
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});

if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
}

if (status.stage === "complete" || status.stage === "error") {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
if (Object.keys(event_callbacks).length === 0) {
close_stream();
if (
status.stage === "complete" ||
status.stage === "error"
) {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
if (Object.keys(event_callbacks).length === 0) {
close_stream();
}
}
}
} catch (e) {
console.error("Unexpected client exception", e);
fire_event({
type: "status",
stage: "error",
message: "An Unexpected Error Occurred!",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
close_stream();
}
};
event_callbacks[event_id] = callback;
if (!stream_open) {
open_stream();
}
}
});
}
Expand Down Expand Up @@ -1014,6 +1051,14 @@ export function api_factory(
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
if (!("event_id" in _data)) {
await Promise.all(
Object.keys(event_callbacks).map((event_id) =>
event_callbacks[event_id](_data)
)
);
return;
}
await event_callbacks[_data.event_id](_data);
};
}
Expand Down Expand Up @@ -1583,6 +1628,20 @@ function handle_message(
success: data.success
}
};
case "heartbeat":
return {
type: "heartbeat"
};
case "unexpected_error":
return {
type: "unexpected_error",
status: {
queue,
message: data.message,
stage: "error",
success: false
}
};
case "estimation":
return {
type: "update",
Expand Down
21 changes: 16 additions & 5 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,19 @@ async def sse_stream(request: fastapi.Request):
except EmptyQueue:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": ServerMessage.heartbeat}
# Fix this
message = {
"msg": ServerMessage.heartbeat,
}
# Need to reset last_heartbeat with perf_counter
# otherwise only a single hearbeat msg will be sent
# and then the stream will retry leading to infinite queue 😬
last_heartbeat = time.perf_counter()

if blocks._queue.stopped:
message = {
"msg": ServerMessage.server_stopped,
"msg": "unexpected_error",
"message": "Server stopped unexpectedly.",
"success": False,
}
if message:
Expand All @@ -644,9 +648,16 @@ async def sse_stream(request: fastapi.Request):
)
):
return
except asyncio.CancelledError as e:
del blocks._queue.pending_messages_per_session[session_hash]
await blocks._queue.clean_events(session_hash=session_hash)
except BaseException as e:
message = {
"msg": "unexpected_error",
"success": False,
"message": str(e),
}
yield f"data: {json.dumps(message)}\n\n"
if isinstance(e, asyncio.CancelledError):
del blocks._queue.pending_messages_per_session[session_hash]
await blocks._queue.clean_events(session_hash=session_hash)
raise e

return StreamingResponse(
Expand Down

0 comments on commit 245d58e

Please sign in to comment.