-
Notifications
You must be signed in to change notification settings - Fork 202
/
Copy pathoauth2.py
551 lines (475 loc) · 26.3 KB
/
oauth2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
"""This OAuth2 client implementation aims to be spec-compliant, and generic."""
# OAuth2 spec https://tools.ietf.org/html/rfc6749
import json
try:
from urllib.parse import urlencode, parse_qs
except ImportError:
from urlparse import parse_qs
from urllib import urlencode
import logging
import warnings
import time
import base64
import sys
import functools
import requests
string_types = (str,) if sys.version_info[0] >= 3 else (basestring, )
class BaseClient(object):
# This low-level interface works. Yet you'll find its sub-class
# more friendly to remind you what parameters are needed in each scenario.
# More on Client Types at https://tools.ietf.org/html/rfc6749#section-2.1
@staticmethod
def encode_saml_assertion(assertion):
return base64.urlsafe_b64encode(assertion).rstrip(b'=') # Per RFC 7522
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
CLIENT_ASSERTION_TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
client_assertion_encoders = {CLIENT_ASSERTION_TYPE_SAML2: encode_saml_assertion}
@property
def session(self):
warnings.warn("Will be gone in next major release", DeprecationWarning)
return self._http_client
@session.setter
def session(self, value):
warnings.warn("Will be gone in next major release", DeprecationWarning)
self._http_client = value
def __init__(
self,
server_configuration, # type: dict
client_id, # type: str
http_client=None, # We insert it here to match the upcoming async API
client_secret=None, # type: Optional[str]
client_assertion=None, # type: Union[bytes, callable, None]
client_assertion_type=None, # type: Optional[str]
default_headers=None, # type: Optional[dict]
default_body=None, # type: Optional[dict]
verify=None, # type: Union[str, True, False, None]
proxies=None, # type: Optional[dict]
timeout=None, # type: Union[tuple, float, None]
):
"""Initialize a client object to talk all the OAuth2 grants to the server.
Args:
server_configuration (dict):
It contains the configuration (i.e. metadata) of the auth server.
The actual content typically contains keys like
"authorization_endpoint", "token_endpoint", etc..
Based on RFC 8414 (https://tools.ietf.org/html/rfc8414),
you can probably fetch it online from either
https://example.com/.../.well-known/oauth-authorization-server
or
https://example.com/.../.well-known/openid-configuration
client_id (str): The client's id, issued by the authorization server
http_client (http.HttpClient):
Your implementation of abstract class :class:`http.HttpClient`.
Defaults to a requests session instance.
There is no session-wide `timeout` parameter defined here.
Timeout behavior is determined by the actual http client you use.
If you happen to use Requests, it disallows session-wide timeout
(https://github.com/psf/requests/issues/3341). The workaround is:
s = requests.Session()
s.request = functools.partial(s.request, timeout=3)
and then feed that patched session instance to this class.
client_secret (str): Triggers HTTP AUTH for Confidential Client
client_assertion (bytes, callable):
The client assertion to authenticate this client, per RFC 7521.
It can be a raw SAML2 assertion (this method will encode it for you),
or a raw JWT assertion.
It can also be a callable (recommended),
so that we will do lazy creation of an assertion.
client_assertion_type (str):
The type of your :attr:`client_assertion` parameter.
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
:attr:`CLIENT_ASSERTION_TYPE_JWT`, the only two defined in RFC 7521.
default_headers (dict):
A dict to be sent in each request header.
It is not required by OAuth2 specs, but you may use it for telemetry.
default_body (dict):
A dict to be sent in each token request body. For example,
you could choose to set this as {"client_secret": "your secret"}
if your authorization server wants it to be in the request body
(rather than in the request header).
verify (boolean):
It will be passed to the
`verify parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_.
When leaving it with default value (None), we will use True instead.
This does not apply if you have chosen to pass your own Http client.
proxies (dict):
It will be passed to the
`proxies parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_.
This does not apply if you have chosen to pass your own Http client.
timeout (object):
It will be passed to the
`timeout parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_.
This does not apply if you have chosen to pass your own Http client.
"""
self.configuration = server_configuration
self.client_id = client_id
self.client_secret = client_secret
self.client_assertion = client_assertion
self.default_headers = default_headers or {}
self.default_body = default_body or {}
if client_assertion_type is not None:
self.default_body["client_assertion_type"] = client_assertion_type
self.logger = logging.getLogger(__name__)
if http_client:
if verify is not None or proxies is not None or timeout is not None:
raise ValueError(
"verify, proxies, or timeout is not allowed "
"when http_client is in use")
self._http_client = http_client
else:
self._http_client = requests.Session()
self._http_client.verify = True if verify is None else verify
self._http_client.proxies = proxies
self._http_client.request = functools.partial(
# A workaround for requests not supporting session-wide timeout
self._http_client.request, timeout=timeout)
def _build_auth_request_params(self, response_type, **kwargs):
# response_type is a string defined in
# https://tools.ietf.org/html/rfc6749#section-3.1.1
# or it can be a space-delimited string as defined in
# https://tools.ietf.org/html/rfc6749#section-8.4
response_type = self._stringify(response_type)
params = {'client_id': self.client_id, 'response_type': response_type}
params.update(kwargs) # Note: None values will override params
params = {k: v for k, v in params.items() if v is not None} # clean up
if params.get('scope'):
params['scope'] = self._stringify(params['scope'])
return params # A dict suitable to be used in http request
def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
self, grant_type,
params=None, # a dict to be sent as query string to the endpoint
data=None, # All relevant data, which will go into the http body
headers=None, # a dict to be sent as request headers
post=None, # A callable to replace requests.post(), for testing.
# Such as: lambda url, **kwargs:
# Mock(status_code=200, text='{}')
**kwargs # Relay all extra parameters to underlying requests
): # Returns the json object came from the OAUTH2 response
_data = {'client_id': self.client_id, 'grant_type': grant_type}
if self.default_body.get("client_assertion_type") and self.client_assertion:
# See https://tools.ietf.org/html/rfc7521#section-4.2
encoder = self.client_assertion_encoders.get(
self.default_body["client_assertion_type"], lambda a: a)
_data["client_assertion"] = encoder(
self.client_assertion() # Do lazy on-the-fly computation
if callable(self.client_assertion) else self.client_assertion)
_data.update(self.default_body) # It may contain authen parameters
_data.update(data or {}) # So the content in data param prevails
_data = {k: v for k, v in _data.items() if v} # Clean up None values
if _data.get('scope'):
_data['scope'] = self._stringify(_data['scope'])
_headers = {'Accept': 'application/json'}
_headers.update(self.default_headers)
_headers.update(headers or {})
# Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1
# Clients in possession of a client password MAY use the HTTP Basic
# authentication.
# Alternatively, (but NOT RECOMMENDED,)
# the authorization server MAY support including the
# client credentials in the request-body using the following
# parameters: client_id, client_secret.
if self.client_secret and self.client_id:
_headers["Authorization"] = "Basic " + base64.b64encode(
"{}:{}".format(self.client_id, self.client_secret)
.encode("ascii")).decode("ascii")
if "token_endpoint" not in self.configuration:
raise ValueError("token_endpoint not found in configuration")
resp = (post or self._http_client.post)(
self.configuration["token_endpoint"],
headers=_headers, params=params, data=_data,
**kwargs)
if resp.status_code >= 500:
resp.raise_for_status() # TODO: Will probably retry here
try:
# The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says
# even an error response will be a valid json structure,
# so we simply return it here, without needing to invent an exception.
return json.loads(resp.text)
except ValueError:
self.logger.exception(
"Token response is not in json format: %s", resp.text)
raise
def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs):
# type: (str, Union[str, list, set, tuple]) -> dict
"""Obtain an access token via a refresh token.
:param refresh_token: The refresh token issued to the client
:param scope: If omitted, is treated as equal to the scope originally
granted by the resource owner,
according to https://tools.ietf.org/html/rfc6749#section-6
"""
assert isinstance(refresh_token, string_types)
data = kwargs.pop('data', {})
data.update(refresh_token=refresh_token, scope=scope)
return self._obtain_token("refresh_token", data=data, **kwargs)
def _stringify(self, sequence):
if isinstance(sequence, (list, set, tuple)):
return ' '.join(sorted(sequence)) # normalizing it, ascendingly
return sequence # as-is
class Client(BaseClient): # We choose to implement all 4 grants in 1 class
"""This is the main API for oauth2 client.
Its methods define and document parameters mentioned in OAUTH2 RFC 6749.
"""
DEVICE_FLOW = { # consts for device flow, that can be customized by sub-class
"GRANT_TYPE": "urn:ietf:params:oauth:grant-type:device_code",
"DEVICE_CODE": "device_code",
}
DEVICE_FLOW_RETRIABLE_ERRORS = ("authorization_pending", "slow_down")
GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}
def initiate_device_flow(self, scope=None, **kwargs):
# type: (list, **dict) -> dict
# The naming of this method is following the wording of this specs
# https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1
"""Initiate a device flow.
Returns the data defined in Device Flow specs.
https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.2
You should then orchestrate the User Interaction as defined in here
https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.3
And possibly here
https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.3.1
"""
DAE = "device_authorization_endpoint"
if not self.configuration.get(DAE):
raise ValueError("You need to provide device authorization endpoint")
resp = self._http_client.post(self.configuration[DAE],
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
headers=dict(self.default_headers, **kwargs.pop("headers", {})),
**kwargs)
flow = json.loads(resp.text)
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
flow["expires_in"] = int(flow.get("expires_in", 1800))
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this
return flow
def _obtain_token_by_device_flow(self, flow, **kwargs):
# type: (dict, **dict) -> dict
# This method updates flow during each run. And it is non-blocking.
now = time.time()
skew = 1
if flow.get("latest_attempt_at", 0) + flow.get("interval", 5) - skew > now:
warnings.warn('Attempted too soon. Please do time.sleep(flow["interval"])')
data = kwargs.pop("data", {})
data.update({
"client_id": self.client_id,
self.DEVICE_FLOW["DEVICE_CODE"]: flow["device_code"],
})
result = self._obtain_token(
self.DEVICE_FLOW["GRANT_TYPE"], data=data, **kwargs)
if result.get("error") == "slow_down":
# Respecting https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.5
flow["interval"] = flow.get("interval", 5) + 5
flow["latest_attempt_at"] = now
return result
def obtain_token_by_device_flow(self,
flow,
exit_condition=lambda flow: flow.get("expires_at", 0) < time.time(),
**kwargs):
# type: (dict, Callable) -> dict
"""Obtain token by a device flow object, with customizable polling effect.
Args:
flow (dict):
An object previously generated by initiate_device_flow(...).
Its content WILL BE CHANGED by this method during each run.
We share this object with you, so that you could implement
your own loop, should you choose to do so.
exit_condition (Callable):
This method implements a loop to provide polling effect.
The loop's exit condition is calculated by this callback.
The default callback makes the loop run until the flow expires.
Therefore, one of the ways to exit the polling early,
is to change the flow["expires_at"] to a small number such as 0.
In case you are doing async programming, you may want to
completely turn off the loop. You can do so by using a callback as:
exit_condition = lambda flow: True
to make the loop run only once, i.e. no polling, hence non-block.
"""
while True:
result = self._obtain_token_by_device_flow(flow, **kwargs)
if result.get("error") not in self.DEVICE_FLOW_RETRIABLE_ERRORS:
return result
for i in range(flow.get("interval", 5)): # Wait interval seconds
if exit_condition(flow):
return result
time.sleep(1) # Shorten each round, to make exit more responsive
def build_auth_request_uri(
self,
response_type, redirect_uri=None, scope=None, state=None, **kwargs):
"""Generate an authorization uri to be visited by resource owner.
Later when the response reaches your redirect_uri,
you can use parse_auth_response() to check the returned state.
This method could be named build_authorization_request_uri() instead,
but then there would be a build_authentication_request_uri() in the OIDC
subclass doing almost the same thing. So we use a loose term "auth" here.
:param response_type:
Must be "code" when you are using Authorization Code Grant,
"token" when you are using Implicit Grant, or other
(possibly space-delimited) strings as registered extension value.
See https://tools.ietf.org/html/rfc6749#section-3.1.1
:param redirect_uri: Optional. Server will use the pre-registered one.
:param scope: It is a space-delimited, case-sensitive string.
Some ID provider can accept empty string to represent default scope.
:param state: Recommended. An opaque value used by the client to
maintain state between the request and callback.
:param kwargs: Other parameters, typically defined in OpenID Connect.
"""
if "authorization_endpoint" not in self.configuration:
raise ValueError("authorization_endpoint not found in configuration")
authorization_endpoint = self.configuration["authorization_endpoint"]
params = self._build_auth_request_params(
response_type, redirect_uri=redirect_uri, scope=scope, state=state,
**kwargs)
sep = '&' if '?' in authorization_endpoint else '?'
return "%s%s%s" % (authorization_endpoint, sep, urlencode(params))
@staticmethod
def parse_auth_response(params, state=None):
"""Parse the authorization response being redirected back.
:param params: A string or dict of the query string
:param state: REQUIRED if the state parameter was present in the client
authorization request. This function will compare it with response.
"""
if not isinstance(params, dict):
params = parse_qs(params)
if params.get('state') != state:
raise ValueError('state mismatch')
return params
def obtain_token_by_authorization_code(
self, code, redirect_uri=None, scope=None, **kwargs):
"""Get a token via authorization code. a.k.a. Authorization Code Grant.
This is typically used by a server-side app (Confidential Client),
but it can also be used by a device-side native app (Public Client).
See more detail at https://tools.ietf.org/html/rfc6749#section-4.1.3
:param code: The authorization code received from authorization server.
:param redirect_uri:
Required, if the "redirect_uri" parameter was included in the
authorization request, and their values MUST be identical.
:param scope:
It is both unnecessary and harmless to use scope here, per RFC 6749.
We suggest to use the same scope already used in auth request uri,
so that this library can link the obtained tokens with their scope.
"""
data = kwargs.pop("data", {})
data.update(code=code, redirect_uri=redirect_uri)
if scope:
data["scope"] = scope
if not self.client_secret:
# client_id is required, if the client is not authenticating itself.
# See https://tools.ietf.org/html/rfc6749#section-4.1.3
data["client_id"] = self.client_id
return self._obtain_token("authorization_code", data=data, **kwargs)
def obtain_token_by_username_password(
self, username, password, scope=None, **kwargs):
"""The Resource Owner Password Credentials Grant, used by legacy app."""
data = kwargs.pop("data", {})
data.update(username=username, password=password, scope=scope)
return self._obtain_token("password", data=data, **kwargs)
def obtain_token_for_client(self, scope=None, **kwargs):
"""Obtain token for this client (rather than for an end user),
a.k.a. the Client Credentials Grant, used by Backend Applications.
We don't name it obtain_token_by_client_credentials(...) because those
credentials are typically already provided in class constructor, not here.
You can still explicitly provide an optional client_secret parameter,
or you can provide such extra parameters as `default_body` during the
class initialization.
"""
data = kwargs.pop("data", {})
data.update(scope=scope)
return self._obtain_token("client_credentials", data=data, **kwargs)
def __init__(self,
server_configuration, client_id,
on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...)
on_removing_rt=lambda token_item: None,
on_updating_rt=lambda token_item, new_rt: None,
**kwargs):
super(Client, self).__init__(server_configuration, client_id, **kwargs)
self.on_obtaining_tokens = on_obtaining_tokens
self.on_removing_rt = on_removing_rt
self.on_updating_rt = on_updating_rt
def _obtain_token(
self, grant_type, params=None, data=None,
also_save_rt=False,
*args, **kwargs):
_data = data.copy() # to prevent side effect
resp = super(Client, self)._obtain_token(
grant_type, params, _data, *args, **kwargs)
if "error" not in resp:
_resp = resp.copy()
RT = "refresh_token"
if grant_type == RT and RT in _resp and not also_save_rt:
# Then we skip it from on_obtaining_tokens();
# Leave it to self.obtain_token_by_refresh_token()
_resp.pop(RT, None)
if "scope" in _resp:
scope = _resp["scope"].split() # It is conceptually a set,
# but we represent it as a list which can be persisted to JSON
else:
# Note: The scope will generally be absent in authorization grant,
# but our obtain_token_by_authorization_code(...) encourages
# app developer to still explicitly provide a scope here.
scope = _data.get("scope")
self.on_obtaining_tokens({
"client_id": self.client_id,
"scope": scope,
"token_endpoint": self.configuration["token_endpoint"],
"grant_type": grant_type, # can be used to know an IdToken-less
# response is for an app or for a user
"response": _resp, "params": params, "data": _data,
})
return resp
def obtain_token_by_refresh_token(self, token_item, scope=None,
rt_getter=lambda token_item: token_item["refresh_token"],
on_removing_rt=None,
on_updating_rt=None,
**kwargs):
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
"""This is an overload which will trigger token storage callbacks.
:param token_item:
A refresh token (RT) item, in flexible format. It can be a string,
or a whatever data structure containing RT string and its metadata,
in such case the `rt_getter` callable must be able to
extract the RT string out from the token item data structure.
Either way, this token_item will be passed into other callbacks as-is.
:param scope: If omitted, is treated as equal to the scope originally
granted by the resource owner,
according to https://tools.ietf.org/html/rfc6749#section-6
:param rt_getter: A callable to translate the token_item to a raw RT string
:param on_removing_rt: If absent, fall back to the one defined in initialization
:param on_updating_rt:
Default to None, it will fall back to the one defined in initialization.
This is the most common case.
As a special case, you can pass in a False,
then this function will NOT trigger on_updating_rt() for RT UPDATE,
instead it will allow the RT to be added by on_obtaining_tokens().
This behavior is useful when you are migrating RTs from elsewhere
into a token storage managed by this library.
"""
resp = super(Client, self).obtain_token_by_refresh_token(
rt_getter(token_item)
if not isinstance(token_item, string_types) else token_item,
scope=scope,
also_save_rt=on_updating_rt is False,
**kwargs)
if resp.get('error') == 'invalid_grant':
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
RT = "refresh_token"
if on_updating_rt is not False and RT in resp:
(on_updating_rt or self.on_updating_rt)(token_item, resp[RT])
return resp
def obtain_token_by_assertion(
self, assertion, grant_type, scope=None, **kwargs):
# type: (bytes, Union[str, None], Union[str, list, set, tuple]) -> dict
"""This method implements Assertion Framework for OAuth2 (RFC 7521).
See details at https://tools.ietf.org/html/rfc7521#section-4.1
:param assertion:
The assertion bytes can be a raw SAML2 assertion, or a JWT assertion.
:param grant_type:
It is typically either the value of :attr:`GRANT_TYPE_SAML2`,
or :attr:`GRANT_TYPE_JWT`, the only two profiles defined in RFC 7521.
:param scope: Optional. It must be a subset of previously granted scopes.
"""
encoder = self.grant_assertion_encoders.get(grant_type, lambda a: a)
data = kwargs.pop("data", {})
data.update(scope=scope, assertion=encoder(assertion))
return self._obtain_token(grant_type, data=data, **kwargs)