diff --git a/.circleci/config.yml b/.circleci/config.yml index f4b0051e..3f6ca2aa 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -30,7 +30,7 @@ jobs: - python/install-packages: pkg-manager: pip-dist path-args: ".[test]" - - run: coverage run -m unittest discover + - run: coverage run -m unittest discover -s auth0/v3/test -t . - run: bash <(curl -s https://codecov.io/bash) workflows: diff --git a/.flake8 b/.flake8 index eb68c91f..7981a1c5 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] -ignore = E501 +ignore = E501 F401 max-line-length = 88 \ No newline at end of file diff --git a/README.rst b/README.rst index 1989f159..637064b0 100644 --- a/README.rst +++ b/README.rst @@ -316,6 +316,54 @@ When consuming methods from the API clients, the requests could fail for a numbe resets is exposed in the ``reset_at`` property. When the header is unset, this value will be ``-1``. - Network timeouts: Adjustable by passing a ``timeout`` argument to the client. See the `rate limit docs `__ for details. +========================= +Asynchronous Environments +========================= + +This SDK provides async methods built on top of `asyncio `__. To make them available you must have Python >=3.6 and the `aiohttp `__ module installed. + +Then additional methods with the ``_async`` suffix will be added to modules created by the ``management.Auth0`` class or to classes that are passed to the ``asyncify`` method. For example: + +.. code-block:: python + + import asyncio + import aiohttp + from auth0.v3.asyncify import asyncify + from auth0.v3.management import Auth0, Users, Connections + from auth0.v3.authentication import Users as AuthUsers + + auth0 = Auth0('domain', 'mgmt_api_token') + + async def main(): + # users = auth0.users.all() <= sync + users = await auth0.users.all_async() # <= async + + # To share a session amongst multiple calls to the same service + async with auth0.users as users: + data = await users.get_async(id) + users.update_async(id, data) + + # Use asyncify directly on services + Users = asyncify(Users) + Connections = asyncify(Connections) + users = Users(domain, mgmt_api_token) + connections = Connections(domain, mgmt_api_token) + + # Create a session and share it among the services + session = aiohttp.ClientSession() + users.set_session(session) + connections.set_session(session) + u = await auth0.users.all_async() + c = await auth0.connections.all_async() + session.close() + + # Use auth api + U = asyncify(AuthUsers) + u = U(domain=domain) + await u.userinfo_async(access_token) + + + asyncio.run(main()) ============== Supported APIs diff --git a/auth0/v3/asyncify.py b/auth0/v3/asyncify.py new file mode 100644 index 00000000..18cf7d43 --- /dev/null +++ b/auth0/v3/asyncify.py @@ -0,0 +1,84 @@ +import aiohttp + +from auth0.v3.rest_async import AsyncRestClient + + +def _gen_async(client, method): + m = getattr(client, method) + + async def closure(*args, **kwargs): + return await m(*args, **kwargs) + + return closure + + +def asyncify(cls): + methods = [ + func + for func in dir(cls) + if callable(getattr(cls, func)) and not func.startswith("_") + ] + + class AsyncClient(cls): + def __init__( + self, + domain, + token, + telemetry=True, + timeout=5.0, + protocol="https", + rest_options=None, + ): + if token is None: + # Wrap the auth client + super(AsyncClient, self).__init__(domain, telemetry, timeout, protocol) + else: + # Wrap the mngtmt client + super(AsyncClient, self).__init__( + domain, token, telemetry, timeout, protocol, rest_options + ) + self.client = AsyncRestClient( + jwt=token, telemetry=telemetry, timeout=timeout, options=rest_options + ) + + class Wrapper(cls): + def __init__( + self, + domain, + token=None, + telemetry=True, + timeout=5.0, + protocol="https", + rest_options=None, + ): + if token is None: + # Wrap the auth client + super(Wrapper, self).__init__(domain, telemetry, timeout, protocol) + else: + # Wrap the mngtmt client + super(Wrapper, self).__init__( + domain, token, telemetry, timeout, protocol, rest_options + ) + + self._async_client = AsyncClient( + domain, token, telemetry, timeout, protocol, rest_options + ) + for method in methods: + setattr( + self, + "{}_async".format(method), + _gen_async(self._async_client, method), + ) + + async def __aenter__(self): + """Automatically create and set session within context manager.""" + async_rest_client = self._async_client.client + self._session = aiohttp.ClientSession() + async_rest_client.set_session(self._session) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Automatically close session within context manager.""" + await self._session.close() + + return Wrapper diff --git a/auth0/v3/authentication/base.py b/auth0/v3/authentication/base.py index afd534dc..2c8819ac 100644 --- a/auth0/v3/authentication/base.py +++ b/auth0/v3/authentication/base.py @@ -5,6 +5,8 @@ import requests +from auth0.v3.rest import RestClient, RestClientOptions + from ..exceptions import Auth0Error, RateLimitError UNKNOWN_ERROR = "a0.sdk.internal.unknown" @@ -24,132 +26,14 @@ class AuthenticationBase(object): def __init__(self, domain, telemetry=True, timeout=5.0, protocol="https"): self.domain = domain - self.timeout = timeout self.protocol = protocol - self.base_headers = {"Content-Type": "application/json"} - - if telemetry: - py_version = platform.python_version() - version = sys.modules["auth0"].__version__ - - auth0_client = json.dumps( - { - "name": "auth0-python", - "version": version, - "env": { - "python": py_version, - }, - } - ).encode("utf-8") - - self.base_headers.update( - { - "User-Agent": "Python/{}".format(py_version), - "Auth0-Client": base64.b64encode(auth0_client), - } - ) + self.client = RestClient( + None, + options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0), + ) def post(self, url, data=None, headers=None): - request_headers = self.base_headers.copy() - request_headers.update(headers or {}) - response = requests.post( - url=url, json=data, headers=request_headers, timeout=self.timeout - ) - return self._process_response(response) + return self.client.post(url, data, headers) def get(self, url, params=None, headers=None): - request_headers = self.base_headers.copy() - request_headers.update(headers or {}) - response = requests.get( - url=url, params=params, headers=request_headers, timeout=self.timeout - ) - return self._process_response(response) - - def _process_response(self, response): - return self._parse(response).content() - - def _parse(self, response): - if not response.text: - return EmptyResponse(response.status_code) - try: - return JsonResponse(response) - except ValueError: - return PlainResponse(response) - - -class Response(object): - def __init__(self, status_code, content, headers): - self._status_code = status_code - self._content = content - self._headers = headers - - def content(self): - if not self._is_error(): - return self._content - - if self._status_code == 429: - reset_at = int(self._headers.get("x-ratelimit-reset", "-1")) - raise RateLimitError( - error_code=self._error_code(), - message=self._error_message(), - reset_at=reset_at, - ) - - raise Auth0Error( - status_code=self._status_code, - error_code=self._error_code(), - message=self._error_message(), - ) - - def _is_error(self): - return self._status_code is None or self._status_code >= 400 - - # Adding these methods to force implementation in subclasses because they are references in this parent class - def _error_code(self): - raise NotImplementedError - - def _error_message(self): - raise NotImplementedError - - -class JsonResponse(Response): - def __init__(self, response): - content = json.loads(response.text) - super(JsonResponse, self).__init__( - response.status_code, content, response.headers - ) - - def _error_code(self): - if "error" in self._content: - return self._content.get("error") - elif "code" in self._content: - return self._content.get("code") - else: - return UNKNOWN_ERROR - - def _error_message(self): - return self._content.get("error_description", "") - - -class PlainResponse(Response): - def __init__(self, response): - super(PlainResponse, self).__init__( - response.status_code, response.text, response.headers - ) - - def _error_code(self): - return UNKNOWN_ERROR - - def _error_message(self): - return self._content - - -class EmptyResponse(Response): - def __init__(self, status_code): - super(EmptyResponse, self).__init__(status_code, "", {}) - - def _error_code(self): - return UNKNOWN_ERROR - - def _error_message(self): - return "" + return self.client.get(url, params, headers) diff --git a/auth0/v3/management/actions.py b/auth0/v3/management/actions.py index 0c3023cf..c9884f31 100644 --- a/auth0/v3/management/actions.py +++ b/auth0/v3/management/actions.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Actions(object): diff --git a/auth0/v3/management/attack_protection.py b/auth0/v3/management/attack_protection.py index 6fc6f376..1455f088 100644 --- a/auth0/v3/management/attack_protection.py +++ b/auth0/v3/management/attack_protection.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class AttackProtection(object): diff --git a/auth0/v3/management/auth0.py b/auth0/v3/management/auth0.py index 26d7693d..fb6bc905 100644 --- a/auth0/v3/management/auth0.py +++ b/auth0/v3/management/auth0.py @@ -1,3 +1,4 @@ +from ..utils import is_async_available from .actions import Actions from .attack_protection import AttackProtection from .blacklists import Blacklists @@ -27,6 +28,37 @@ from .users import Users from .users_by_email import UsersByEmail +modules = { + "actions": Actions, + "attack_protection": AttackProtection, + "blacklists": Blacklists, + "client_grants": ClientGrants, + "clients": Clients, + "connections": Connections, + "custom_domains": CustomDomains, + "device_credentials": DeviceCredentials, + "email_templates": EmailTemplates, + "emails": Emails, + "grants": Grants, + "guardian": Guardian, + "hooks": Hooks, + "jobs": Jobs, + "log_streams": LogStreams, + "logs": Logs, + "organizations": Organizations, + "prompts": Prompts, + "resource_servers": ResourceServers, + "roles": Roles, + "rules_configs": RulesConfigs, + "rules": Rules, + "stats": Stats, + "tenants": Tenants, + "tickets": Tickets, + "user_blocks": UserBlocks, + "users_by_email": UsersByEmail, + "users": Users, +} + class Auth0(object): """Provides easy access to all endpoint classes @@ -43,57 +75,12 @@ class Auth0(object): """ def __init__(self, domain, token, rest_options=None): - self.actions = Actions(domain=domain, token=token, rest_options=rest_options) - self.attack_protection = AttackProtection( - domain=domain, token=token, rest_options=rest_options - ) - self.blacklists = Blacklists( - domain=domain, token=token, rest_options=rest_options - ) - self.client_grants = ClientGrants( - domain=domain, token=token, rest_options=rest_options - ) - self.clients = Clients(domain=domain, token=token, rest_options=rest_options) - self.connections = Connections( - domain=domain, token=token, rest_options=rest_options - ) - self.custom_domains = CustomDomains( - domain=domain, token=token, rest_options=rest_options - ) - self.device_credentials = DeviceCredentials( - domain=domain, token=token, rest_options=rest_options - ) - self.email_templates = EmailTemplates( - domain=domain, token=token, rest_options=rest_options - ) - self.emails = Emails(domain=domain, token=token, rest_options=rest_options) - self.grants = Grants(domain=domain, token=token, rest_options=rest_options) - self.guardian = Guardian(domain=domain, token=token, rest_options=rest_options) - self.hooks = Hooks(domain=domain, token=token, rest_options=rest_options) - self.jobs = Jobs(domain=domain, token=token, rest_options=rest_options) - self.log_streams = LogStreams( - domain=domain, token=token, rest_options=rest_options - ) - self.logs = Logs(domain=domain, token=token, rest_options=rest_options) - self.organizations = Organizations( - domain=domain, token=token, rest_options=rest_options - ) - self.prompts = Prompts(domain=domain, token=token, rest_options=rest_options) - self.resource_servers = ResourceServers( - domain=domain, token=token, rest_options=rest_options - ) - self.roles = Roles(domain=domain, token=token, rest_options=rest_options) - self.rules_configs = RulesConfigs( - domain=domain, token=token, rest_options=rest_options - ) - self.rules = Rules(domain=domain, token=token, rest_options=rest_options) - self.stats = Stats(domain=domain, token=token, rest_options=rest_options) - self.tenants = Tenants(domain=domain, token=token, rest_options=rest_options) - self.tickets = Tickets(domain=domain, token=token, rest_options=rest_options) - self.user_blocks = UserBlocks( - domain=domain, token=token, rest_options=rest_options - ) - self.users_by_email = UsersByEmail( - domain=domain, token=token, rest_options=rest_options - ) - self.users = Users(domain=domain, token=token, rest_options=rest_options) + if is_async_available(): + from ..asyncify import asyncify + + for name, cls in modules.items(): + cls = asyncify(cls) + setattr(self, name, cls(domain=domain, token=token, rest_options=None)) + else: + for name, cls in modules.items(): + setattr(self, name, cls(domain=domain, token=token, rest_options=None)) diff --git a/auth0/v3/management/blacklists.py b/auth0/v3/management/blacklists.py index 9e1320a2..b1d23c9d 100644 --- a/auth0/v3/management/blacklists.py +++ b/auth0/v3/management/blacklists.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Blacklists(object): diff --git a/auth0/v3/management/branding.py b/auth0/v3/management/branding.py index 2064d545..644e4410 100644 --- a/auth0/v3/management/branding.py +++ b/auth0/v3/management/branding.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Branding(object): diff --git a/auth0/v3/management/client_grants.py b/auth0/v3/management/client_grants.py index 5fd5f767..35cd3808 100644 --- a/auth0/v3/management/client_grants.py +++ b/auth0/v3/management/client_grants.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class ClientGrants(object): diff --git a/auth0/v3/management/clients.py b/auth0/v3/management/clients.py index d9ec4953..24e22603 100644 --- a/auth0/v3/management/clients.py +++ b/auth0/v3/management/clients.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Clients(object): diff --git a/auth0/v3/management/connections.py b/auth0/v3/management/connections.py index e11d632c..d9ea5fdc 100644 --- a/auth0/v3/management/connections.py +++ b/auth0/v3/management/connections.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Connections(object): diff --git a/auth0/v3/management/custom_domains.py b/auth0/v3/management/custom_domains.py index 2f276aa2..fb9f69dd 100644 --- a/auth0/v3/management/custom_domains.py +++ b/auth0/v3/management/custom_domains.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class CustomDomains(object): diff --git a/auth0/v3/management/device_credentials.py b/auth0/v3/management/device_credentials.py index 3be2b662..88fca78b 100644 --- a/auth0/v3/management/device_credentials.py +++ b/auth0/v3/management/device_credentials.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class DeviceCredentials(object): diff --git a/auth0/v3/management/email_templates.py b/auth0/v3/management/email_templates.py index c0ff8ecd..dbd3c76c 100644 --- a/auth0/v3/management/email_templates.py +++ b/auth0/v3/management/email_templates.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class EmailTemplates(object): diff --git a/auth0/v3/management/emails.py b/auth0/v3/management/emails.py index 9b1b5940..08995d08 100644 --- a/auth0/v3/management/emails.py +++ b/auth0/v3/management/emails.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Emails(object): diff --git a/auth0/v3/management/grants.py b/auth0/v3/management/grants.py index c10646fb..b1db5de5 100644 --- a/auth0/v3/management/grants.py +++ b/auth0/v3/management/grants.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Grants(object): diff --git a/auth0/v3/management/guardian.py b/auth0/v3/management/guardian.py index d54d397b..3118fe59 100644 --- a/auth0/v3/management/guardian.py +++ b/auth0/v3/management/guardian.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Guardian(object): diff --git a/auth0/v3/management/hooks.py b/auth0/v3/management/hooks.py index e108f621..4a50fc40 100644 --- a/auth0/v3/management/hooks.py +++ b/auth0/v3/management/hooks.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Hooks(object): diff --git a/auth0/v3/management/jobs.py b/auth0/v3/management/jobs.py index 38f7da9a..8f56aa72 100644 --- a/auth0/v3/management/jobs.py +++ b/auth0/v3/management/jobs.py @@ -1,6 +1,6 @@ import warnings -from .rest import RestClient +from ..rest import RestClient class Jobs(object): diff --git a/auth0/v3/management/log_streams.py b/auth0/v3/management/log_streams.py index 26326b2e..ad45e709 100644 --- a/auth0/v3/management/log_streams.py +++ b/auth0/v3/management/log_streams.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class LogStreams(object): diff --git a/auth0/v3/management/logs.py b/auth0/v3/management/logs.py index efabe6fb..70b0a0bc 100644 --- a/auth0/v3/management/logs.py +++ b/auth0/v3/management/logs.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Logs(object): diff --git a/auth0/v3/management/organizations.py b/auth0/v3/management/organizations.py index 8c5b8a8d..f9f2afed 100644 --- a/auth0/v3/management/organizations.py +++ b/auth0/v3/management/organizations.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Organizations(object): diff --git a/auth0/v3/management/prompts.py b/auth0/v3/management/prompts.py index 80d0d6ae..1e08c516 100644 --- a/auth0/v3/management/prompts.py +++ b/auth0/v3/management/prompts.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Prompts(object): diff --git a/auth0/v3/management/resource_servers.py b/auth0/v3/management/resource_servers.py index 732b0411..c4e0a102 100644 --- a/auth0/v3/management/resource_servers.py +++ b/auth0/v3/management/resource_servers.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class ResourceServers(object): diff --git a/auth0/v3/management/roles.py b/auth0/v3/management/roles.py index 4b3d9088..0c6327b5 100644 --- a/auth0/v3/management/roles.py +++ b/auth0/v3/management/roles.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Roles(object): diff --git a/auth0/v3/management/rules.py b/auth0/v3/management/rules.py index 1ed10d23..c98480ef 100644 --- a/auth0/v3/management/rules.py +++ b/auth0/v3/management/rules.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Rules(object): diff --git a/auth0/v3/management/rules_configs.py b/auth0/v3/management/rules_configs.py index 3b2b89a8..e0e4b13a 100644 --- a/auth0/v3/management/rules_configs.py +++ b/auth0/v3/management/rules_configs.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class RulesConfigs(object): diff --git a/auth0/v3/management/stats.py b/auth0/v3/management/stats.py index a711b4f8..c9d9c584 100644 --- a/auth0/v3/management/stats.py +++ b/auth0/v3/management/stats.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Stats(object): diff --git a/auth0/v3/management/tenants.py b/auth0/v3/management/tenants.py index a589593e..7b1cbedf 100644 --- a/auth0/v3/management/tenants.py +++ b/auth0/v3/management/tenants.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Tenants(object): diff --git a/auth0/v3/management/tickets.py b/auth0/v3/management/tickets.py index 63334d11..a9207d6a 100644 --- a/auth0/v3/management/tickets.py +++ b/auth0/v3/management/tickets.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class Tickets(object): diff --git a/auth0/v3/management/user_blocks.py b/auth0/v3/management/user_blocks.py index e85fde92..03d0c58d 100644 --- a/auth0/v3/management/user_blocks.py +++ b/auth0/v3/management/user_blocks.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class UserBlocks(object): diff --git a/auth0/v3/management/users.py b/auth0/v3/management/users.py index 14a48cb0..74be1149 100644 --- a/auth0/v3/management/users.py +++ b/auth0/v3/management/users.py @@ -1,6 +1,6 @@ import warnings -from .rest import RestClient +from ..rest import RestClient class Users(object): diff --git a/auth0/v3/management/users_by_email.py b/auth0/v3/management/users_by_email.py index 440d130c..8a23506e 100644 --- a/auth0/v3/management/users_by_email.py +++ b/auth0/v3/management/users_by_email.py @@ -1,4 +1,4 @@ -from .rest import RestClient +from ..rest import RestClient class UsersByEmail(object): diff --git a/auth0/v3/management/rest.py b/auth0/v3/rest.py similarity index 81% rename from auth0/v3/management/rest.py rename to auth0/v3/rest.py index 21c11475..41be5eaa 100644 --- a/auth0/v3/management/rest.py +++ b/auth0/v3/rest.py @@ -7,7 +7,7 @@ import requests -from ..exceptions import Auth0Error, RateLimitError +from auth0.v3.exceptions import Auth0Error, RateLimitError UNKNOWN_ERROR = "a0.sdk.internal.unknown" @@ -75,10 +75,12 @@ def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): self._skip_sleep = False self.base_headers = { - "Authorization": "Bearer {}".format(self.jwt), "Content-Type": "application/json", } + if jwt is not None: + self.base_headers["Authorization"] = "Bearer {}".format(self.jwt) + if options.telemetry: py_version = platform.python_version() version = sys.modules["auth0"].__version__ @@ -96,10 +98,13 @@ def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): self.base_headers.update( { "User-Agent": "Python/{}".format(py_version), - "Auth0-Client": base64.b64encode(auth0_client), + "Auth0-Client": base64.b64encode(auth0_client).decode(), } ) + # Cap the maximum number of retries to 10 or fewer. Floor the retries at 0. + self._retries = min(self.MAX_REQUEST_RETRIES(), max(0, options.retries)) + # For backwards compatibility reasons only # TODO: Deprecate in the next major so we can prune these arguments. Guidance should be to use RestClient.options.* self.telemetry = options.telemetry @@ -121,8 +126,9 @@ def MAX_REQUEST_RETRY_DELAY(self): def MIN_REQUEST_RETRY_DELAY(self): return 100 - def get(self, url, params=None): - headers = self.base_headers.copy() + def get(self, url, params=None, headers=None): + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) # Track the API request attempt number attempt = 0 @@ -130,39 +136,23 @@ def get(self, url, params=None): # Reset the metrics tracker self._metrics = {"retries": 0, "backoff": []} - # Cap the maximum number of retries to 10 or fewer. Floor the retries at 0. - retries = min(self.MAX_REQUEST_RETRIES(), max(0, self.options.retries)) - while True: # Increment attempt number attempt += 1 # Issue the request response = requests.get( - url, params=params, headers=headers, timeout=self.options.timeout + url, + params=params, + headers=request_headers, + timeout=self.options.timeout, ) - # If the response did not have a 429 header, or the retries were configured at 0, or the attempt number is equal to or greater than the configured retries, break - if response.status_code != 429 or retries <= 0 or attempt > retries: + # If the response did not have a 429 header, or the attempt number is greater than the configured retries, break + if response.status_code != 429 or attempt > self._retries: break - # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: - # max(MIN_REQUEST_RETRY_DELAY, min(MAX_REQUEST_RETRY_DELAY, (100ms * (2 ** attempt - 1)) + random_between(1, MAX_REQUEST_RETRY_JITTER))) - - # Increases base delay by (100ms * (2 ** attempt - 1)) - wait = 100 * 2 ** (attempt - 1) - - # Introduces jitter to the base delay; increases delay between 1ms to MAX_REQUEST_RETRY_JITTER (100ms) - wait += randint(1, self.MAX_REQUEST_RETRY_JITTER()) - - # Is never more than MAX_REQUEST_RETRY_DELAY (1s) - wait = min(self.MAX_REQUEST_RETRY_DELAY(), wait) - - # Is never less than MIN_REQUEST_RETRY_DELAY (100ms) - wait = max(self.MIN_REQUEST_RETRY_DELAY(), wait) - - self._metrics["retries"] = attempt - self._metrics["backoff"].append(wait) + wait = self._calculate_wait(attempt) # Skip calling sleep() when running unit tests if self._skip_sleep is False: @@ -172,11 +162,12 @@ def get(self, url, params=None): # Return the final Response return self._process_response(response) - def post(self, url, data=None): - headers = self.base_headers.copy() + def post(self, url, data=None, headers=None): + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) response = requests.post( - url, json=data, headers=headers, timeout=self.options.timeout + url, json=data, headers=request_headers, timeout=self.options.timeout ) return self._process_response(response) @@ -217,6 +208,27 @@ def delete(self, url, params=None, data=None): ) return self._process_response(response) + def _calculate_wait(self, attempt): + # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: + # max(MIN_REQUEST_RETRY_DELAY, min(MAX_REQUEST_RETRY_DELAY, (100ms * (2 ** attempt - 1)) + random_between(1, MAX_REQUEST_RETRY_JITTER))) + + # Increases base delay by (100ms * (2 ** attempt - 1)) + wait = 100 * 2 ** (attempt - 1) + + # Introduces jitter to the base delay; increases delay between 1ms to MAX_REQUEST_RETRY_JITTER (100ms) + wait += randint(1, self.MAX_REQUEST_RETRY_JITTER()) + + # Is never more than MAX_REQUEST_RETRY_DELAY (1s) + wait = min(self.MAX_REQUEST_RETRY_DELAY(), wait) + + # Is never less than MIN_REQUEST_RETRY_DELAY (100ms) + wait = max(self.MIN_REQUEST_RETRY_DELAY(), wait) + + self._metrics["retries"] = attempt + self._metrics["backoff"].append(wait) + + return wait + def _process_response(self, response): return self._parse(response).content() @@ -276,10 +288,14 @@ def _error_code(self): return self._content.get("errorCode") elif "error" in self._content: return self._content.get("error") + elif "code" in self._content: + return self._content.get("code") else: return UNKNOWN_ERROR def _error_message(self): + if "error_description" in self._content: + return self._content.get("error_description") message = self._content.get("message", "") if message is not None and message != "": return message diff --git a/auth0/v3/rest_async.py b/auth0/v3/rest_async.py new file mode 100644 index 00000000..40493930 --- /dev/null +++ b/auth0/v3/rest_async.py @@ -0,0 +1,138 @@ +import asyncio +import json + +import aiohttp + +from auth0.v3.exceptions import RateLimitError + +from .rest import EmptyResponse, JsonResponse, PlainResponse +from .rest import Response as _Response +from .rest import RestClient + + +def _clean_params(params): + if params is None: + return params + return {k: v for k, v in params.items() if v is not None} + + +class AsyncRestClient(RestClient): + """Provides simple methods for handling all RESTful api endpoints. + + Args: + telemetry (bool, optional): Enable or disable Telemetry + (defaults to True) + timeout (float or tuple, optional): Change the requests + connect and read timeout. Pass a tuple to specify + both values separately or a float to set both to it. + (defaults to 5.0 for both) + options (RestClientOptions): Pass an instance of + RestClientOptions to configure additional RestClient + options, such as rate-limit retries. Overrides matching + options passed to the constructor. + (defaults to 3) + """ + + def __init__(self, *args, **kwargs): + super(AsyncRestClient, self).__init__(*args, **kwargs) + self._session = None + sock_connect, sock_read = ( + self.timeout + if isinstance(self.timeout, tuple) + else (self.timeout, self.timeout) + ) + self.timeout = aiohttp.ClientTimeout( + sock_connect=sock_connect, sock_read=sock_read + ) + + def set_session(self, session): + """Set Client Session to improve performance by reusing session. + Session should be closed manually or within context manager. + """ + self._session = session + + async def _request(self, *args, **kwargs): + kwargs["headers"] = kwargs.get("headers", self.base_headers) + kwargs["timeout"] = self.timeout + if self._session is not None: + # Request with re-usable session + async with self._session.request(*args, **kwargs) as response: + return await self._process_response(response) + else: + # Request without re-usable session + async with aiohttp.ClientSession() as session: + async with session.request(*args, **kwargs) as response: + return await self._process_response(response) + + async def get(self, url, params=None, headers=None): + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) + # Track the API request attempt number + attempt = 0 + + # Reset the metrics tracker + self._metrics = {"retries": 0, "backoff": []} + + params = _clean_params(params) + while True: + # Increment attempt number + attempt += 1 + + try: + response = await self._request( + "get", url, params=params, headers=request_headers + ) + return response + except RateLimitError as e: + # If the attempt number is greater than the configured retries, raise RateLimitError + if attempt > self._retries: + raise e + + wait = self._calculate_wait(attempt) + + # Skip calling sleep() when running unit tests + if self._skip_sleep is False: + # sleep() functions in seconds, so convert the milliseconds formula above accordingly + await asyncio.sleep(wait / 1000) + + async def post(self, url, data=None, headers=None): + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) + return await self._request("post", url, json=data, headers=request_headers) + + async def file_post(self, url, data=None, files=None): + headers = self.base_headers.copy() + headers.pop("Content-Type", None) + return await self._request("post", url, data={**data, **files}, headers=headers) + + async def patch(self, url, data=None): + return await self._request("patch", url, json=data) + + async def put(self, url, data=None): + return await self._request("put", url, json=data) + + async def delete(self, url, params=None, data=None): + return await self._request( + "delete", url, json=data, params=_clean_params(params) or {} + ) + + async def _process_response(self, response): + parsed_response = await self._parse(response) + return parsed_response.content() + + async def _parse(self, response): + text = await response.text() + requests_response = RequestsResponse(response, text) + if not text: + return EmptyResponse(response.status) + try: + return JsonResponse(requests_response) + except ValueError: + return PlainResponse(requests_response) + + +class RequestsResponse(object): + def __init__(self, response, text): + self.status_code = response.status + self.headers = response.headers + self.text = text diff --git a/auth0/v3/test/authentication/test_base.py b/auth0/v3/test/authentication/test_base.py index 9207d3fe..dcb5ce2b 100644 --- a/auth0/v3/test/authentication/test_base.py +++ b/auth0/v3/test/authentication/test_base.py @@ -13,12 +13,13 @@ class TestBase(unittest.TestCase): def test_telemetry_enabled_by_default(self): ab = AuthenticationBase("auth0.com") + base_headers = ab.client.base_headers - user_agent = ab.base_headers["User-Agent"] - auth0_client_bytes = base64.b64decode(ab.base_headers["Auth0-Client"]) + user_agent = base_headers["User-Agent"] + auth0_client_bytes = base64.b64decode(base_headers["Auth0-Client"]) auth0_client_json = auth0_client_bytes.decode("utf-8") auth0_client = json.loads(auth0_client_json) - content_type = ab.base_headers["Content-Type"] + content_type = base_headers["Content-Type"] from auth0 import __version__ as auth0_version @@ -39,7 +40,7 @@ def test_telemetry_enabled_by_default(self): def test_telemetry_disabled(self): ab = AuthenticationBase("auth0.com", telemetry=False) - self.assertEqual(ab.base_headers, {"Content-Type": "application/json"}) + self.assertEqual(ab.client.base_headers, {"Content-Type": "application/json"}) @mock.patch("requests.post") def test_post(self, mock_post): @@ -51,7 +52,7 @@ def test_post(self, mock_post): data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) mock_post.assert_called_with( - url="the-url", + "the-url", json={"a": "b"}, headers={"c": "d", "Content-Type": "application/json"}, timeout=(10, 2), @@ -70,7 +71,7 @@ def test_post_with_defaults(self, mock_post): data = ab.post("the-url") mock_post.assert_called_with( - url="the-url", + "the-url", json=None, headers={"Content-Type": "application/json"}, timeout=5.0, @@ -88,8 +89,8 @@ def test_post_includes_telemetry(self, mock_post): data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) self.assertEqual(mock_post.call_count, 1) - call_kwargs = mock_post.call_args[1] - self.assertEqual(call_kwargs["url"], "the-url") + call_args, call_kwargs = mock_post.call_args + self.assertEqual(call_args[0], "the-url") self.assertEqual(call_kwargs["json"], {"a": "b"}) headers = call_kwargs["headers"] self.assertEqual(headers["c"], "d") @@ -228,7 +229,7 @@ def test_get(self, mock_get): data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"}) mock_get.assert_called_with( - url="the-url", + "the-url", params={"a": "b"}, headers={"c": "d", "Content-Type": "application/json"}, timeout=(10, 2), @@ -247,7 +248,7 @@ def test_get_with_defaults(self, mock_get): data = ab.get("the-url") mock_get.assert_called_with( - url="the-url", + "the-url", params=None, headers={"Content-Type": "application/json"}, timeout=5.0, @@ -265,8 +266,8 @@ def test_get_includes_telemetry(self, mock_get): data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"}) self.assertEqual(mock_get.call_count, 1) - call_kwargs = mock_get.call_args[1] - self.assertEqual(call_kwargs["url"], "the-url") + call_args, call_kwargs = mock_get.call_args + self.assertEqual(call_args[0], "the-url") self.assertEqual(call_kwargs["params"], {"a": "b"}) headers = call_kwargs["headers"] self.assertEqual(headers["c"], "d") diff --git a/auth0/v3/test/management/test_rest.py b/auth0/v3/test/management/test_rest.py index dabc8355..b929b7a9 100644 --- a/auth0/v3/test/management/test_rest.py +++ b/auth0/v3/test/management/test_rest.py @@ -6,8 +6,9 @@ import mock import requests +from auth0.v3.rest import RestClient, RestClientOptions + from ...exceptions import Auth0Error, RateLimitError -from ...management.rest import RestClient, RestClientOptions class TestRest(unittest.TestCase): diff --git a/auth0/v3/test_async/__init__.py b/auth0/v3/test_async/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/auth0/v3/test_async/test_asyncify.py b/auth0/v3/test_async/test_asyncify.py new file mode 100644 index 00000000..f8a7a0c5 --- /dev/null +++ b/auth0/v3/test_async/test_asyncify.py @@ -0,0 +1,191 @@ +import base64 +import json +import platform +import re +import sys +from tempfile import TemporaryFile +from unittest import IsolatedAsyncioTestCase + +import aiohttp +from aioresponses import CallbackResult, aioresponses +from callee import Attrs +from mock import ANY, MagicMock + +from auth0.v3.asyncify import asyncify +from auth0.v3.management import Clients, Guardian, Jobs + +clients = re.compile(r"^https://example\.com/api/v2/clients.*") +factors = re.compile(r"^https://example\.com/api/v2/guardian/factors.*") +users_imports = re.compile(r"^https://example\.com/api/v2/jobs/users-imports.*") +payload = {"foo": "bar"} + +telemetry = base64.b64encode( + json.dumps( + { + "name": "auth0-python", + "version": sys.modules["auth0"].__version__, + "env": { + "python": platform.python_version(), + }, + } + ).encode("utf-8") +).decode() + +headers = { + "User-Agent": "Python/{}".format(platform.python_version()), + "Authorization": "Bearer jwt", + "Content-Type": "application/json", + "Auth0-Client": telemetry, +} + + +def get_callback(status=200): + mock = MagicMock(return_value=CallbackResult(status=status, payload=payload)) + + def callback(url, **kwargs): + return mock(url, **kwargs) + + return callback, mock + + +class TestAsyncify(IsolatedAsyncioTestCase): + @aioresponses() + async def test_get(self, mocked): + callback, mock = get_callback() + mocked.get(clients, callback=callback) + c = asyncify(Clients)(domain="example.com", token="jwt") + self.assertEqual(await c.all_async(), payload) + mock.assert_called_with( + Attrs(path="/api/v2/clients"), + allow_redirects=True, + params={"include_fields": "true"}, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_post(self, mocked): + callback, mock = get_callback() + mocked.post(clients, callback=callback) + c = asyncify(Clients)(domain="example.com", token="jwt") + data = {"client": 1} + self.assertEqual(await c.create_async(data), payload) + mock.assert_called_with( + Attrs(path="/api/v2/clients"), + allow_redirects=True, + json=data, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_file_post(self, mocked): + callback, mock = get_callback() + mocked.post(users_imports, callback=callback) + j = asyncify(Jobs)(domain="example.com", token="jwt") + users = TemporaryFile() + self.assertEqual(await j.import_users_async("connection-1", users), payload) + file_port_headers = headers.copy() + file_port_headers.pop("Content-Type") + mock.assert_called_with( + Attrs(path="/api/v2/jobs/users-imports"), + allow_redirects=True, + data={ + "connection_id": "connection-1", + "upsert": "false", + "send_completion_email": "true", + "external_id": None, + "users": users, + }, + headers=file_port_headers, + timeout=ANY, + ) + users.close() + + @aioresponses() + async def test_patch(self, mocked): + callback, mock = get_callback() + mocked.patch(clients, callback=callback) + c = asyncify(Clients)(domain="example.com", token="jwt") + data = {"client": 1} + self.assertEqual(await c.update_async("client-1", data), payload) + mock.assert_called_with( + Attrs(path="/api/v2/clients/client-1"), + allow_redirects=True, + json=data, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_put(self, mocked): + callback, mock = get_callback() + mocked.put(factors, callback=callback) + g = asyncify(Guardian)(domain="example.com", token="jwt") + data = {"factor": 1} + self.assertEqual(await g.update_factor_async("factor-1", data), payload) + mock.assert_called_with( + Attrs(path="/api/v2/guardian/factors/factor-1"), + allow_redirects=True, + json=data, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_delete(self, mocked): + callback, mock = get_callback() + mocked.delete(clients, callback=callback) + c = asyncify(Clients)(domain="example.com", token="jwt") + self.assertEqual(await c.delete_async("client-1"), payload) + mock.assert_called_with( + Attrs(path="/api/v2/clients/client-1"), + allow_redirects=True, + params={}, + json=None, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_shared_session(self, mocked): + callback, mock = get_callback() + mocked.get(clients, callback=callback) + async with asyncify(Clients)(domain="example.com", token="jwt") as c: + self.assertEqual(await c.all_async(), payload) + mock.assert_called_with( + Attrs(path="/api/v2/clients"), + allow_redirects=True, + params={"include_fields": "true"}, + headers=headers, + timeout=ANY, + ) + + @aioresponses() + async def test_rate_limit(self, mocked): + callback, mock = get_callback(status=429) + mocked.get(clients, callback=callback) + mocked.get(clients, callback=callback) + mocked.get(clients, callback=callback) + mocked.get(clients, payload=payload) + c = asyncify(Clients)(domain="example.com", token="jwt") + rest_client = c._async_client.client + rest_client._skip_sleep = True + self.assertEqual(await c.all_async(), payload) + self.assertEqual(3, mock.call_count) + (a, b, c) = rest_client._metrics["backoff"] + self.assertTrue(100 <= a < b < c <= 1000) + + @aioresponses() + async def test_timeout(self, mocked): + callback, mock = get_callback() + mocked.get(clients, callback=callback) + c = asyncify(Clients)(domain="example.com", token="jwt", timeout=(8.8, 9.9)) + self.assertEqual(await c.all_async(), payload) + mock.assert_called_with( + ANY, + allow_redirects=ANY, + params=ANY, + headers=ANY, + timeout=aiohttp.ClientTimeout(sock_connect=8.8, sock_read=9.9), + ) diff --git a/auth0/v3/utils.py b/auth0/v3/utils.py new file mode 100644 index 00000000..07eade8c --- /dev/null +++ b/auth0/v3/utils.py @@ -0,0 +1,15 @@ +import sys + + +def is_async_available(): + if sys.version_info >= (3, 6): + try: + import asyncio + + import aiohttp + + return True + except ImportError: + pass + + return False diff --git a/docs/source/v3.management.rst b/docs/source/v3.management.rst index 8bc17b22..e7fc0138 100644 --- a/docs/source/v3.management.rst +++ b/docs/source/v3.management.rst @@ -145,14 +145,6 @@ management.resource\_servers module :undoc-members: :show-inheritance: -management.rest module -------------------------- - -.. automodule:: auth0.v3.management.rest - :members: - :undoc-members: - :show-inheritance: - management.roles module -------------------------- diff --git a/requirements.txt b/requirements.txt index eb5e55cc..de51b1df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,14 @@ -e . +aiohttp==3.8.1 +aioresponses==0.7.3 +aiosignal==1.2.0 +alabaster==0.7.12 +async-timeout==4.0.2 +attrs==21.4.0 +Authlib==1.0.0 Babel==2.9.1 black==22.3.0 +callee==0.3.1 certifi==2021.10.8 cffi==1.15.0 cfgv==3.3.1 @@ -8,30 +16,43 @@ charset-normalizer==2.0.12 click==8.0.4 coverage==6.3.2 cryptography==36.0.2 +Deprecated==1.2.13 distlib==0.3.4 docutils==0.17.1 filelock==3.6.0 flake8==4.0.1 +Flask==2.0.3 +Flask-Cors==3.0.10 +frozenlist==1.3.0 identify==2.4.12 idna==3.3 imagesize==1.3.0 +iniconfig==1.1.1 isort==5.10.1 +itsdangerous==2.1.1 Jinja2==3.1.1 +jwcrypto==1.0 MarkupSafe==2.1.1 mccabe==0.6.1 mock==4.0.3 +multidict==6.0.2 mypy-extensions==0.4.3 nodeenv==1.6.0 packaging==21.3 pathspec==0.9.0 platformdirs==2.5.1 +pluggy==1.0.0 pre-commit==2.17.0 +py==1.11.0 pycodestyle==2.8.0 pycparser==2.21 pyflakes==2.4.0 Pygments==2.11.2 PyJWT==2.3.0 pyparsing==3.0.7 +pytest==7.1.0 +pytest-mock==3.7.0 +python-dotenv==0.19.2 pytz==2022.1 pyupgrade==2.31.1 PyYAML==6.0 @@ -51,3 +72,6 @@ toml==0.10.2 tomli==2.0.1 urllib3==1.26.9 virtualenv==20.13.4 +Werkzeug==2.0.3 +wrapt==1.14.0 +yarl==1.7.2