Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to fix the generic response parsing and handling #817

Merged
merged 6 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions src/oic/extension/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def do_op(
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
http_args.update(ht_args)

resp = self.request_and_return(
url, response_cls, method, body, body_type, http_args=http_args
Expand Down Expand Up @@ -266,26 +266,53 @@ def do_token_introspection(

def do_token_revocation(
self,
body_type="",
body_type=None,
method="POST",
request_args=None,
extra_args=None,
http_args=None,
**kwargs,
):
request = self.message_factory.get_request_type("revocation_endpoint")
response_cls = self.message_factory.get_response_type("revocation_endpoint")
return self.do_op(
request=request,
body_type=body_type,
method=method,
request_args=request_args,
extra_args=extra_args,
http_args=http_args,
response_cls=response_cls,
**kwargs,
# There is no expected response, only the status code is important,
# so do not use do_op().

url, body, ht_args, _ = self.request_info(
request, method, request_args, extra_args, **kwargs
)

if http_args is None:
http_args = ht_args
else:
http_args.update(ht_args)

resp = self.http_request(url, method, data=body, **http_args)

if resp.status_code == 200:
return 200
if resp.status_code == 503:
# Revoke failed, should retry later
raise PyoidcError("Retry revocation later.")

if 400 <= resp.status_code < 500:
# check for error response
try:
err = ErrorResponse().deserialize(resp.text)
try:
err.verify()
except PyoidcError:
pass
else:
return err
except Exception:
logger.exception(
"Failed to decode error response (%d) %s",
resp.status_code,
sanitize(resp.text),
)

return resp

def add_code_challenge(self):
try:
cv_len = self.config["code_challenge"]["length"]
Expand Down
67 changes: 39 additions & 28 deletions src/oic/oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import cast
from urllib.parse import urlparse

import requests
from jwkest import b64e
from typing_extensions import Literal

from oic import CC_METHOD
from oic import OIDCONF_PATTERN
Expand Down Expand Up @@ -47,6 +47,7 @@
from oic.oauth2.message import ROPCAccessTokenRequest
from oic.oauth2.message import TokenErrorResponse
from oic.oauth2.message import sanitize
from oic.oauth2.util import ENCODINGS
from oic.oauth2.util import get_or_post
from oic.oauth2.util import verify_header
from oic.utils.http_util import BadRequest
Expand Down Expand Up @@ -88,8 +89,6 @@

ENDPOINTS = ["authorization_endpoint", "token_endpoint", "token_revocation_endpoint"]

ENCODINGS = Literal["json", "urlencoded", "dict"]


class ExpiredToken(PyoidcError):
pass
Expand Down Expand Up @@ -726,43 +725,55 @@ def init_authentication_method(
else:
return http_args

def parse_request_response(self, reqresp, response, body_type, state="", **kwargs):
def parse_request_response(
self,
reqresp: requests.Response,
response: Type[Message] = None,
body_type: ENCODINGS = None,
state="",
**kwargs,
) -> Union[Message, requests.Response]:

if reqresp.status_code in SUCCESSFUL:
body_type = verify_header(reqresp, body_type)
elif reqresp.status_code in [302, 303]: # redirect
# Handle the early return stuff
if reqresp.status_code in [302, 303]: # redirect
return reqresp
elif reqresp.status_code == 500:
logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
elif reqresp.status_code in [400, 401]:
# expecting an error response
if issubclass(response, ErrorResponse):
pass

if reqresp.status_code in SUCCESSFUL:
verified_body_type = verify_header(reqresp, body_type)
elif (
reqresp.status_code in [400, 401]
and response
and issubclass(response, ErrorResponse)
):
# This is okay if we are expecting an error response, do not log
verified_body_type = verify_header(reqresp, body_type)
else:
# Any other error
logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
raise HttpError(
"HTTP ERROR: %s [%s] on %s"
% (reqresp.text, reqresp.status_code, reqresp.url)
)

# we expect some specific response message type, try to parse it
if response:
if body_type is None:
# There is no content-type for zero content length. Return the status code.
return reqresp.status_code
elif body_type == "txt":
# no meaning trying to parse unstructured text
return reqresp.text
# Just let the parser throw if we cannot parse it.
if verified_body_type is None:
verified_body_type = "urlencoded"

return self.parse_response(
response, reqresp.text, body_type, state, **kwargs
response, reqresp.text, verified_body_type, state, **kwargs
)

# could be an error response
# No one told us what to expect, so try to decode an error response
if reqresp.status_code in [200, 400, 401]:
if body_type == "txt":
if verified_body_type == "txt":
body_type = "urlencoded"
try:
err = ErrorResponse().deserialize(reqresp.message, method=body_type)
err = ErrorResponse().deserialize(reqresp.text, method=body_type)
try:
err.verify()
except PyoidcError:
Expand Down Expand Up @@ -796,21 +807,21 @@ def request_and_return(
:param response: Response type
:param method: Which HTTP method to use
:param body: A message body if any
:param body_type: The format of the body of the return message
:param body_type: The expected format of the body of the return message.
:param http_args: Arguments for the HTTP client
:return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected.
"""
# FIXME: Cannot annotate return value as Message since it disrupts all other methods
if http_args is None:
http_args = {}

try:
resp = self.http_request(url, method, data=body, **http_args)
except Exception:
raise
resp = self.http_request(url, method, data=body, **http_args)

kwargs.setdefault("keyjar", self.keyjar)

if "keyjar" not in kwargs:
kwargs["keyjar"] = self.keyjar
# Handle older usage, does not match type annotations
if body_type == "":
body_type = None

return self.parse_request_response(resp, response, body_type, state, **kwargs)

Expand Down
11 changes: 7 additions & 4 deletions src/oic/oauth2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from http import cookiejar as cookielib
from http.cookies import CookieError
from http.cookies import SimpleCookie
from typing import cast

import requests

Expand Down Expand Up @@ -103,7 +104,7 @@ def _cookies(self):

return cookie_dict

def http_request(self, url, method="GET", **kwargs):
def http_request(self, url: str, method="GET", **kwargs) -> requests.Response:
"""
Run a HTTP request to fetch the given url.

Expand All @@ -117,8 +118,7 @@ def http_request(self, url, method="GET", **kwargs):

"""
_kwargs = copy.copy(self.request_args)
if kwargs:
_kwargs.update(kwargs)
_kwargs.update(kwargs)

if self.cookiejar:
_kwargs["cookies"] = self._cookies()
Expand All @@ -129,7 +129,10 @@ def http_request(self, url, method="GET", **kwargs):

try:
if getattr(self.settings, "requests_session", None) is not None:
r = self.settings.requests_session.request(method, url, **_kwargs) # type: ignore
r = cast(
requests.Response,
self.settings.requests_session.request(method, url, **_kwargs), # type: ignore
)
else:
r = requests.request(method, url, **_kwargs) # type: ignore
except Exception as err:
Expand Down
46 changes: 32 additions & 14 deletions src/oic/oauth2/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from http.cookiejar import http2time # type: ignore
from typing import Any
from typing import Dict
from typing import Optional
from urllib.parse import parse_qs
from urllib.parse import urlsplit
from urllib.parse import urlunsplit

from typing_extensions import Literal

from oic.exception import UnSupported
from oic.oauth2.exception import TimeFormatError
from oic.utils.sanitize import sanitize
Expand Down Expand Up @@ -46,6 +49,9 @@
"rfc2109": True,
}

# The encodings understood by our Message class
ENCODINGS = Literal["json", "urlencoded", "dict", "jwt", "jwe"]


def get_or_post(
uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, accept=None, **kwargs
Expand Down Expand Up @@ -180,23 +186,35 @@ def match_to_(val, vlist):
return False


def verify_header(reqresp, body_type):
def guess_body_type(reqresp):
"""Try to guess the body type of the message from a response object."""
# try to handle chunked responses.
if "chunked" not in reqresp.headers.get("transfer-encoding", ""):
if int(reqresp.headers["content-length"]) == 0:
return None

_ctype = reqresp.headers["content-type"]
if match_to_("application/json", _ctype):
body_type = "json"
elif match_to_("application/jwt", _ctype):
body_type = "jwt"
elif match_to_("application/jwe", _ctype):
body_type = "jwe"
elif match_to_(URL_ENCODED, _ctype):
body_type = "urlencoded"
else:
body_type = None
return body_type


def verify_header(reqresp, body_type: Optional[ENCODINGS]) -> Optional[ENCODINGS]:
logger.debug("resp.headers: %s" % (sanitize(reqresp.headers),))
logger.debug("resp.txt: %s" % (sanitize(reqresp.text),))

if body_type == "":
if int(reqresp.headers["content-length"]) == 0:
return None
_ctype = reqresp.headers["content-type"]
if match_to_("application/json", _ctype):
body_type = "json"
elif match_to_("application/jwt", _ctype):
body_type = "jwt"
elif match_to_(URL_ENCODED, _ctype):
body_type = "urlencoded"
else:
body_type = "txt" # reasonable default ??
elif body_type == "json":
if body_type is None:
return guess_body_type(reqresp)

if body_type == "json":
if not match_to_("application/json", reqresp.headers["content-type"]):
if match_to_("application/jwt", reqresp.headers["content-type"]):
body_type = "jwt"
Expand Down
4 changes: 2 additions & 2 deletions src/oic/oic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,7 +2320,7 @@ def do_verified_logout(
logger.info("logging out from {} at {}".format(_cid, _url))

try:
res = self.httpc.http_request(
resp = self.httpc.http_request(
_url,
"POST",
data="logout_token={}".format(sjwt),
Expand All @@ -2334,7 +2334,7 @@ def do_verified_logout(
failed.append(_cid)
continue

if res.status_code < 300:
if resp.status_code < 300:
logger.info("Logged out from {}".format(_cid))
else:
_errstr = "failed to logout from {}".format(_cid)
Expand Down
15 changes: 0 additions & 15 deletions tests/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from urllib.parse import urlparse

import pytest
import requests
import responses

from oic.oauth2 import Client
Expand All @@ -26,7 +25,6 @@
from oic.oauth2.message import ExtensionTokenRequest
from oic.oauth2.message import FormatError
from oic.oauth2.message import GrantExpired
from oic.oauth2.message import Message
from oic.oauth2.message import MessageTuple
from oic.oauth2.message import MissingRequiredAttribute
from oic.oauth2.message import OauthMessageFactory
Expand Down Expand Up @@ -628,19 +626,6 @@ class ExtensionMessageFactory(OauthMessageFactory):
assert isinstance(resp, AccessTokenResponse)
assert resp["access_token"] == "Token"

def test_parse_request_response_should_return_status_code_if_content_length_zero(
self,
):

resp = requests.Response()
resp.headers = requests.models.CaseInsensitiveDict(data={"content-length": "0"})
resp.status_code = 200
parsed_response = self.client.parse_request_response(
reqresp=resp, response=Message, body_type=""
)

assert parsed_response == 200


class TestServer(object):
@pytest.fixture(autouse=True)
Expand Down
Loading