diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py index 5a496eaf..499bf894 100644 --- a/tests/unit/oauth_test_utils.py +++ b/tests/unit/oauth_test_utils.py @@ -54,11 +54,14 @@ def __call__(self, request, uri, response_headers): if authorization and authorization.replace("Bearer ", "") in self.tokens: return [200, response_headers, json.dumps(self.sample_post_response_data)] elif self.redirect_server is None and self.token_server is not None: - return [401, {'Www-Authenticate': f'Bearer x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [401, + {'Www-Authenticate': f'Bearer realm="Trino", token_type="JWT", Bearer x_token_server="{self.token_server}"', + 'Basic realm': '"Trino"'}, + ""] + return [401, + {'Www-Authenticate': f'Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{self.redirect_server}", x_token_server="{self.token_server}"', + 'Basic realm': '"Trino"'}, + ""] class GetTokenCallback: diff --git a/trino/auth.py b/trino/auth.py index e5939403..8a24ecd7 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -459,7 +459,7 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: auth_info_headers = self._parse_authenticate_header(auth_info) auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) - token_server = auth_info_headers.get('x_token_server') + token_server = auth_info_headers.get('bearer x_token_server', auth_info_headers.get('x_token_server')) if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")