Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: Make client's handling of error responses more robust and user-friendly #418

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,15 @@ def generate(
timeout=self.timeout,
)

# TODO: expose better error messages for 422 and similar errors
payload = resp.json()
try:
payload = resp.json()
except requests.JSONDecodeError as e:
# If the status code is success-like, reset it to 500 since the server is sending an invalid response.
if 200 <= resp.status_code < 400:
resp.status_code = 500

payload = {"message": e.msg}

if resp.status_code != 200:
raise parse_error(resp.status_code, payload)

Expand Down
35 changes: 21 additions & 14 deletions clients/python/lorax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ def __init__(self, model_id: str):
super(NotSupportedError, self).__init__(message)


# Unknown error
class UnknownError(Exception):
class UnprocessableEntityError(Exception):
def __init__(self, message: str):
super().__init__(message)


# Unknown error
class UnknownError(Exception):
def __init__(self, message: str, code: int):
super().__init__(f"Error status {code}: {message}")


def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Expand All @@ -75,17 +80,17 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:

"""
# Try to parse a LoRAX error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
message = payload.get("error", "")

error_type = payload.get("error_type", "")
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)

# Try to parse a APIInference error
if status_code == 400:
Expand All @@ -98,6 +103,8 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
if status_code == 422:
return UnprocessableEntityError(message)

# Fallback to an unknown error
return UnknownError(message)
return UnknownError(message, status_code)
7 changes: 6 additions & 1 deletion clients/python/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
UnknownError, UnprocessableEntityError,
)


Expand Down Expand Up @@ -59,6 +59,11 @@ def test_rate_limit_exceeded_error():
assert isinstance(parse_error(429, payload), RateLimitExceededError)


def test_unprocessable_entity_error():
payload = {"error": "test"}
assert isinstance(parse_error(422, payload), UnprocessableEntityError)


def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)
Loading