From d19d5f0fdcebbf14967e18536c3af5121249df90 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 16 Apr 2024 11:05:25 -0500 Subject: [PATCH 1/2] enh: Make client's handling of error responses more robust and user-friendly --- clients/python/lorax/client.py | 11 +++++++++-- clients/python/lorax/errors.py | 35 ++++++++++++++++++++-------------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 46627c3d4..0efe746dd 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -186,8 +186,15 @@ def generate( timeout=self.timeout, ) - # TODO: expose better error messages for 422 and similar errors - payload = resp.json() + try: + payload = resp.json() + except requests.JSONDecodeError as e: + # If the status code is success-like, reset it to 500 since the server is sending an invalid response. + if 200 <= resp.status_code < 400: + resp.status_code = 500 + + payload = {"message": e.msg} + if resp.status_code != 200: raise parse_error(resp.status_code, payload) diff --git a/clients/python/lorax/errors.py b/clients/python/lorax/errors.py index 108b3159d..172e8a47a 100644 --- a/clients/python/lorax/errors.py +++ b/clients/python/lorax/errors.py @@ -54,12 +54,17 @@ def __init__(self, model_id: str): super(NotSupportedError, self).__init__(message) -# Unknown error -class UnknownError(Exception): +class UnprocessableEntityError(Exception): def __init__(self, message: str): super().__init__(message) +# Unknown error +class UnknownError(Exception): + def __init__(self, message: str, code: int): + super().__init__(f"Error status {code}: {message}") + + def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: """ Parse error given an HTTP status code and a json payload @@ -75,17 +80,17 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: """ # Try to parse a LoRAX error - message = payload["error"] - if "error_type" in payload: - error_type = payload["error_type"] - if error_type == "generation": - return GenerationError(message) - if error_type == "incomplete_generation": - return IncompleteGenerationError(message) - if error_type == "overloaded": - return OverloadedError(message) - if error_type == "validation": - return ValidationError(message) + message = payload.get("error", "") + + error_type = payload.get("error_type", "") + if error_type == "generation": + return GenerationError(message) + if error_type == "incomplete_generation": + return IncompleteGenerationError(message) + if error_type == "overloaded": + return OverloadedError(message) + if error_type == "validation": + return ValidationError(message) # Try to parse a APIInference error if status_code == 400: @@ -98,6 +103,8 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: return NotFoundError(message) if status_code == 429: return RateLimitExceededError(message) + if status_code == 422: + return UnprocessableEntityError(message) # Fallback to an unknown error - return UnknownError(message) + return UnknownError(message, status_code) From 95a01f6ba2a8f26926c56ec3fc688465d72495a5 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 16 Apr 2024 11:06:36 -0500 Subject: [PATCH 2/2] add test --- clients/python/tests/test_errors.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/clients/python/tests/test_errors.py b/clients/python/tests/test_errors.py index 8ad80e594..7eb773984 100644 --- a/clients/python/tests/test_errors.py +++ b/clients/python/tests/test_errors.py @@ -9,7 +9,7 @@ ShardTimeoutError, NotFoundError, RateLimitExceededError, - UnknownError, + UnknownError, UnprocessableEntityError, ) @@ -59,6 +59,11 @@ def test_rate_limit_exceeded_error(): assert isinstance(parse_error(429, payload), RateLimitExceededError) +def test_unprocessable_entity_error(): + payload = {"error": "test"} + assert isinstance(parse_error(422, payload), UnprocessableEntityError) + + def test_unknown_error(): payload = {"error": "test"} assert isinstance(parse_error(500, payload), UnknownError)