From cb83fa80e1a137ce588d7b65c6145eb09837484d Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Sun, 9 Feb 2020 11:54:00 -0800 Subject: [PATCH] Replace hardcoded session with http_client param Remove timeout parameter, requests and httpx behaviors are incompatible anyway --- oauth2cli/oauth2.py | 77 +++++++++++++++++++++++++++++++++----------- tests/http_client.py | 30 +++++++++++++++++ tests/test_client.py | 4 +++ 3 files changed, 92 insertions(+), 19 deletions(-) create mode 100644 tests/http_client.py diff --git a/oauth2cli/oauth2.py b/oauth2cli/oauth2.py index 9a947390..fac35f1b 100644 --- a/oauth2cli/oauth2.py +++ b/oauth2cli/oauth2.py @@ -1,6 +1,7 @@ """This OAuth2 client implementation aims to be spec-compliant, and generic.""" # OAuth2 spec https://tools.ietf.org/html/rfc6749 +import json try: from urllib.parse import urlencode, parse_qs except ImportError: @@ -11,6 +12,7 @@ import time import base64 import sys +import functools import requests @@ -35,6 +37,7 @@ def __init__( self, server_configuration, # type: dict client_id, # type: str + http_client=None, # We insert it here to match the upcoming async API client_secret=None, # type: Optional[str] client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] @@ -57,6 +60,9 @@ def __init__( or https://example.com/.../.well-known/openid-configuration client_id (str): The client's id, issued by the authorization server + http_client (http.HttpClient): + Your implementation of abstract class :class:`http.HttpClient`. + Defaults to a requests session instance. client_secret (str): Triggers HTTP AUTH for Confidential Client client_assertion (bytes, callable): The client assertion to authenticate this client, per RFC 7521. @@ -76,20 +82,51 @@ def __init__( you could choose to set this as {"client_secret": "your secret"} if your authorization server wants it to be in the request body (rather than in the request header). + + verify (boolean): + It will be passed to the + `verify parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + proxies (dict): + It will be passed to the + `proxies parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + timeout (object): + It will be passed to the + `timeout parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + + There is no session-wide `timeout` parameter defined here. + The timeout behavior is determined by the actual http client you use. + If you happen to use Requests, it chose to not support session-wide timeout + (https://github.com/psf/requests/issues/3341), but you can patch that by: + + s = requests.Session() + s.request = functools.partial(s.request, timeout=3) + + and then feed that patched session instance to this class. """ self.configuration = server_configuration self.client_id = client_id self.client_secret = client_secret self.client_assertion = client_assertion + self.default_headers = default_headers or {} self.default_body = default_body or {} if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - self.session = s = requests.Session() - s.headers.update(default_headers or {}) - s.verify = verify - s.proxies = proxies or {} - self.timeout = timeout + if http_client: + self.http_client = http_client + else: + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies + self.http_client.request = functools.partial( + # A workaround for requests not supporting session-wide timeout + self.http_client.request, timeout=timeout) def _build_auth_request_params(self, response_type, **kwargs): # response_type is a string defined in @@ -110,7 +147,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 params=None, # a dict to be sent as query string to the endpoint data=None, # All relevant data, which will go into the http body headers=None, # a dict to be sent as request headers - timeout=None, post=None, # A callable to replace requests.post(), for testing. # Such as: lambda url, **kwargs: # Mock(status_code=200, json=Mock(return_value={})) @@ -128,11 +164,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails - # We don't have to clean up None values here, because requests lib will. + _data = {k: v for k, v in _data.items() if v} # Clean up None values if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) + _headers = {'Accept': 'application/json'} + _headers.update(self.default_headers) + _headers.update(headers or {}) + # Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1 # Clients in possession of a client password MAY use the HTTP Basic # authentication. @@ -140,18 +180,16 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # the authorization server MAY support including the # client credentials in the request-body using the following # parameters: client_id, client_secret. - auth = None if self.client_secret and self.client_id: - auth = (self.client_id, self.client_secret) # for HTTP Basic Auth + _headers["Authorization"] = "Basic " + base64.b64encode( + "{}:{}".format(self.client_id, self.client_secret) + .encode("ascii")).decode("ascii") if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") - _headers = {'Accept': 'application/json'} - _headers.update(headers or {}) - resp = (post or self.session.post)( + resp = (post or self.http_client.post)( self.configuration["token_endpoint"], - headers=_headers, params=params, data=_data, auth=auth, - timeout=timeout or self.timeout, + headers=_headers, params=params, data=_data, **kwargs) if resp.status_code >= 500: resp.raise_for_status() # TODO: Will probably retry here @@ -159,7 +197,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says # even an error response will be a valid json structure, # so we simply return it here, without needing to invent an exception. - return resp.json() + return json.loads(resp.text) except ValueError: self.logger.exception( "Token response is not in json format: %s", resp.text) @@ -200,7 +238,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion} - def initiate_device_flow(self, scope=None, timeout=None, **kwargs): + def initiate_device_flow(self, scope=None, **kwargs): # type: (list, **dict) -> dict # The naming of this method is following the wording of this specs # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1 @@ -218,10 +256,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): DAE = "device_authorization_endpoint" if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") - flow = self.session.post(self.configuration[DAE], + resp = self.http_client.post(self.configuration[DAE], data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, - **kwargs).json() + headers=dict(self.default_headers, **kwargs.pop("headers", {})), + **kwargs) + flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this diff --git a/tests/http_client.py b/tests/http_client.py new file mode 100644 index 00000000..4bff9b45 --- /dev/null +++ b/tests/http_client.py @@ -0,0 +1,30 @@ +import requests + + +class MinimalHttpClient: + + def __init__(self, verify=True, proxies=None, timeout=None): + self.session = requests.Session() + self.session.verify = verify + self.session.proxies = proxies + self.timeout = timeout + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return MinimalResponse(requests_resp=self.session.post( + url, params=params, data=data, headers=headers, + timeout=self.timeout)) + + def get(self, url, params=None, headers=None, **kwargs): + return MinimalResponse(requests_resp=self.session.get( + url, params=params, headers=headers, timeout=self.timeout)) + + +class MinimalResponse(object): # Not for production use + def __init__(self, requests_resp=None, status_code=None, text=None): + self.status_code = status_code or requests_resp.status_code + self.text = text or requests_resp.text + self._raw_resp = requests_resp + + def raise_for_status(self): + if self._raw_resp: + self._raw_resp.raise_for_status() diff --git a/tests/test_client.py b/tests/test_client.py index 051e3a0d..601c05a6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,6 +13,7 @@ from oauth2cli.authcode import obtain_auth_code from oauth2cli.assertion import JwtSigner from tests import unittest, Oauth2TestCase +from tests.http_client import MinimalHttpClient logging.basicConfig(level=logging.DEBUG) @@ -83,6 +84,7 @@ class TestClient(Oauth2TestCase): @classmethod def setUpClass(cls): + http_client = MinimalHttpClient() if "client_certificate" in CONFIG: private_key_path = CONFIG["client_certificate"]["private_key_path"] with open(os.path.join(THIS_FOLDER, private_key_path)) as f: @@ -90,6 +92,7 @@ def setUpClass(cls): cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], + http_client=http_client, client_assertion=JwtSigner( private_key, algorithm="RS256", @@ -103,6 +106,7 @@ def setUpClass(cls): else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], + http_client=http_client, client_secret=CONFIG.get('client_secret')) @unittest.skipIf(