Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK-3181] Asyncio Support #312

Merged
merged 15 commits into from
May 4, 2022
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- python/install-packages:
pkg-manager: pip-dist
path-args: ".[test]"
- run: coverage run -m unittest discover
- run: coverage run -m unittest discover -s auth0/v3/test -t .
- run: bash <(curl -s https://codecov.io/bash)

workflows:
Expand Down
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
ignore = E501
ignore = E501 F401
max-line-length = 88
84 changes: 84 additions & 0 deletions auth0/v3/asyncify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import aiohttp

from auth0.v3.rest_async import AsyncRestClient


def _gen_async(client, method):
m = getattr(client, method)

async def closure(*args, **kwargs):
return await m(*args, **kwargs)

return closure


def asyncify(cls):
methods = [
func
for func in dir(cls)
if callable(getattr(cls, func)) and not func.startswith("_")
]

class AsyncClient(cls):
def __init__(
self,
domain,
token,
telemetry=True,
timeout=5.0,
protocol="https",
rest_options=None,
):
if token is None:
# Wrap the auth client
super(AsyncClient, self).__init__(domain, telemetry, timeout, protocol)
else:
# Wrap the mngtmt client
super(AsyncClient, self).__init__(
domain, token, telemetry, timeout, protocol, rest_options
)
self.client = AsyncRestClient(
jwt=token, telemetry=telemetry, timeout=timeout, options=rest_options
)

class Wrapper(cls):
def __init__(
self,
domain,
token=None,
telemetry=True,
timeout=5.0,
protocol="https",
rest_options=None,
):
if token is None:
# Wrap the auth client
super(Wrapper, self).__init__(domain, telemetry, timeout, protocol)
else:
# Wrap the mngtmt client
super(Wrapper, self).__init__(
domain, token, telemetry, timeout, protocol, rest_options
)

self._async_client = AsyncClient(
domain, token, telemetry, timeout, protocol, rest_options
)
for method in methods:
setattr(
self,
"{}_async".format(method),
_gen_async(self._async_client, method),
)

async def __aenter__(self):
"""Automatically create and set session within context manager."""
async_rest_client = self._async_client.client
self._session = aiohttp.ClientSession()
async_rest_client.set_session(self._session)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Automatically close session within context manager."""
await self._session.close()

return Wrapper
132 changes: 8 additions & 124 deletions auth0/v3/authentication/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import requests

from auth0.v3.rest import RestClient, RestClientOptions

from ..exceptions import Auth0Error, RateLimitError

UNKNOWN_ERROR = "a0.sdk.internal.unknown"
Expand All @@ -24,132 +26,14 @@ class AuthenticationBase(object):

def __init__(self, domain, telemetry=True, timeout=5.0, protocol="https"):
self.domain = domain
self.timeout = timeout
self.protocol = protocol
self.base_headers = {"Content-Type": "application/json"}

if telemetry:
py_version = platform.python_version()
version = sys.modules["auth0"].__version__

auth0_client = json.dumps(
{
"name": "auth0-python",
"version": version,
"env": {
"python": py_version,
},
}
).encode("utf-8")

self.base_headers.update(
{
"User-Agent": "Python/{}".format(py_version),
"Auth0-Client": base64.b64encode(auth0_client),
}
)
self.client = RestClient(
None,
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
)

def post(self, url, data=None, headers=None):
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
response = requests.post(
url=url, json=data, headers=request_headers, timeout=self.timeout
)
return self._process_response(response)
return self.client.post(url, data, headers)

def get(self, url, params=None, headers=None):
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
response = requests.get(
url=url, params=params, headers=request_headers, timeout=self.timeout
)
return self._process_response(response)

def _process_response(self, response):
return self._parse(response).content()

def _parse(self, response):
if not response.text:
return EmptyResponse(response.status_code)
try:
return JsonResponse(response)
except ValueError:
return PlainResponse(response)


class Response(object):
def __init__(self, status_code, content, headers):
self._status_code = status_code
self._content = content
self._headers = headers

def content(self):
if not self._is_error():
return self._content

if self._status_code == 429:
reset_at = int(self._headers.get("x-ratelimit-reset", "-1"))
raise RateLimitError(
error_code=self._error_code(),
message=self._error_message(),
reset_at=reset_at,
)

raise Auth0Error(
status_code=self._status_code,
error_code=self._error_code(),
message=self._error_message(),
)

def _is_error(self):
return self._status_code is None or self._status_code >= 400

# Adding these methods to force implementation in subclasses because they are references in this parent class
def _error_code(self):
raise NotImplementedError

def _error_message(self):
raise NotImplementedError


class JsonResponse(Response):
def __init__(self, response):
content = json.loads(response.text)
super(JsonResponse, self).__init__(
response.status_code, content, response.headers
)

def _error_code(self):
if "error" in self._content:
return self._content.get("error")
elif "code" in self._content:
return self._content.get("code")
else:
return UNKNOWN_ERROR

def _error_message(self):
return self._content.get("error_description", "")


class PlainResponse(Response):
def __init__(self, response):
super(PlainResponse, self).__init__(
response.status_code, response.text, response.headers
)

def _error_code(self):
return UNKNOWN_ERROR

def _error_message(self):
return self._content


class EmptyResponse(Response):
def __init__(self, status_code):
super(EmptyResponse, self).__init__(status_code, "", {})

def _error_code(self):
return UNKNOWN_ERROR

def _error_message(self):
return ""
return self.client.get(url, params, headers)
2 changes: 1 addition & 1 deletion auth0/v3/management/actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class Actions(object):
Expand Down
2 changes: 1 addition & 1 deletion auth0/v3/management/attack_protection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class AttackProtection(object):
Expand Down
95 changes: 41 additions & 54 deletions auth0/v3/management/auth0.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..utils import is_async_available
from .actions import Actions
from .attack_protection import AttackProtection
from .blacklists import Blacklists
Expand Down Expand Up @@ -27,6 +28,37 @@
from .users import Users
from .users_by_email import UsersByEmail

modules = {
"actions": Actions,
"attack_protection": AttackProtection,
"blacklists": Blacklists,
"client_grants": ClientGrants,
"clients": Clients,
"connections": Connections,
"custom_domains": CustomDomains,
"device_credentials": DeviceCredentials,
"email_templates": EmailTemplates,
"emails": Emails,
"grants": Grants,
"guardian": Guardian,
"hooks": Hooks,
"jobs": Jobs,
"log_streams": LogStreams,
"logs": Logs,
"organizations": Organizations,
"prompts": Prompts,
"resource_servers": ResourceServers,
"roles": Roles,
"rules_configs": RulesConfigs,
"rules": Rules,
"stats": Stats,
"tenants": Tenants,
"tickets": Tickets,
"user_blocks": UserBlocks,
"users_by_email": UsersByEmail,
"users": Users,
}


class Auth0(object):
"""Provides easy access to all endpoint classes
Expand All @@ -43,57 +75,12 @@ class Auth0(object):
"""

def __init__(self, domain, token, rest_options=None):
self.actions = Actions(domain=domain, token=token, rest_options=rest_options)
self.attack_protection = AttackProtection(
domain=domain, token=token, rest_options=rest_options
)
self.blacklists = Blacklists(
domain=domain, token=token, rest_options=rest_options
)
self.client_grants = ClientGrants(
domain=domain, token=token, rest_options=rest_options
)
self.clients = Clients(domain=domain, token=token, rest_options=rest_options)
self.connections = Connections(
domain=domain, token=token, rest_options=rest_options
)
self.custom_domains = CustomDomains(
domain=domain, token=token, rest_options=rest_options
)
self.device_credentials = DeviceCredentials(
domain=domain, token=token, rest_options=rest_options
)
self.email_templates = EmailTemplates(
domain=domain, token=token, rest_options=rest_options
)
self.emails = Emails(domain=domain, token=token, rest_options=rest_options)
self.grants = Grants(domain=domain, token=token, rest_options=rest_options)
self.guardian = Guardian(domain=domain, token=token, rest_options=rest_options)
self.hooks = Hooks(domain=domain, token=token, rest_options=rest_options)
self.jobs = Jobs(domain=domain, token=token, rest_options=rest_options)
self.log_streams = LogStreams(
domain=domain, token=token, rest_options=rest_options
)
self.logs = Logs(domain=domain, token=token, rest_options=rest_options)
self.organizations = Organizations(
domain=domain, token=token, rest_options=rest_options
)
self.prompts = Prompts(domain=domain, token=token, rest_options=rest_options)
self.resource_servers = ResourceServers(
domain=domain, token=token, rest_options=rest_options
)
self.roles = Roles(domain=domain, token=token, rest_options=rest_options)
self.rules_configs = RulesConfigs(
domain=domain, token=token, rest_options=rest_options
)
self.rules = Rules(domain=domain, token=token, rest_options=rest_options)
self.stats = Stats(domain=domain, token=token, rest_options=rest_options)
self.tenants = Tenants(domain=domain, token=token, rest_options=rest_options)
self.tickets = Tickets(domain=domain, token=token, rest_options=rest_options)
self.user_blocks = UserBlocks(
domain=domain, token=token, rest_options=rest_options
)
self.users_by_email = UsersByEmail(
domain=domain, token=token, rest_options=rest_options
)
self.users = Users(domain=domain, token=token, rest_options=rest_options)
if is_async_available():
from ..asyncify import asyncify

for name, cls in modules.items():
cls = asyncify(cls)
setattr(self, name, cls(domain=domain, token=token, rest_options=None))
else:
for name, cls in modules.items():
setattr(self, name, cls(domain=domain, token=token, rest_options=None))
2 changes: 1 addition & 1 deletion auth0/v3/management/blacklists.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class Blacklists(object):
Expand Down
2 changes: 1 addition & 1 deletion auth0/v3/management/client_grants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class ClientGrants(object):
Expand Down
2 changes: 1 addition & 1 deletion auth0/v3/management/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class Clients(object):
Expand Down
2 changes: 1 addition & 1 deletion auth0/v3/management/connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class Connections(object):
Expand Down
2 changes: 1 addition & 1 deletion auth0/v3/management/custom_domains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .rest import RestClient
from ..rest import RestClient


class CustomDomains(object):
Expand Down
Loading