Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add experimental GDCH support #1044

Merged
merged 8 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion google/auth/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
_SERVICE_ACCOUNT_TYPE = "service_account"
_EXTERNAL_ACCOUNT_TYPE = "external_account"
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
_VALID_TYPES = (
_AUTHORIZED_USER_TYPE,
_SERVICE_ACCOUNT_TYPE,
_EXTERNAL_ACCOUNT_TYPE,
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
_GDCH_SERVICE_ACCOUNT_TYPE,
)

# Help message when no credentials can be found.
Expand Down Expand Up @@ -134,6 +136,8 @@ def load_credentials_from_file(
def _load_credentials_from_info(
filename, info, scopes, default_scopes, quota_project_id, request
):
from google.auth.credentials import CredentialsWithQuotaProject

credential_type = info.get("type")

if credential_type == _AUTHORIZED_USER_TYPE:
Expand All @@ -158,14 +162,17 @@ def _load_credentials_from_info(
credentials, project_id = _get_impersonated_service_account_credentials(
filename, info, scopes
)
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
credentials, project_id = _get_gdch_service_account_credentials(filename, info)
else:
raise exceptions.DefaultCredentialsError(
"The file {file} does not have a valid type. "
"Type is {type}, expected one of {valid_types}.".format(
file=filename, type=credential_type, valid_types=_VALID_TYPES
)
)
credentials = _apply_quota_project_id(credentials, quota_project_id)
if isinstance(credentials, CredentialsWithQuotaProject):
credentials = _apply_quota_project_id(credentials, quota_project_id)
return credentials, project_id


Expand Down Expand Up @@ -421,6 +428,20 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
return credentials, None


def _get_gdch_service_account_credentials(filename, info):
from google.oauth2 import gdch_credentials

try:
credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info(
info
)
except ValueError as caught_exc:
msg = "Failed to load GDCH service account credentials from {}".format(filename)
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
new_exc = exceptions.DefaultCredentialsError(msg, caught_exc)
six.raise_from(new_exc, caught_exc)
return credentials, info.get("project")


def _apply_quota_project_id(credentials, quota_project_id):
if quota_project_id:
credentials = credentials.with_quota_project(quota_project_id)
Expand Down Expand Up @@ -456,6 +477,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
endpoint.
The project ID returned in this case is the one corresponding to the
underlying workload identity pool resource if determinable.

If the environment variable is set to the path of a valid GDCH service
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
credential will be returned. The project ID returned is the project
specified in the JSON file.
2. If the `Google Cloud SDK`_ is installed and has application default
credentials set they are loaded and returned.

Expand Down Expand Up @@ -490,6 +516,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
.. _Metadata Service: https://cloud.google.com/compute/docs\
/storing-retrieving-metadata
.. _Cloud Run: https://cloud.google.com/run
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted

Example::

Expand Down
15 changes: 11 additions & 4 deletions google/auth/_service_account_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.auth import crypt


def from_dict(data, require=None):
def from_dict(data, require=None, use_rsa_signer=True):
"""Validates a dictionary containing Google service account data.

Creates and returns a :class:`google.auth.crypt.Signer` instance from the
Expand All @@ -32,6 +32,8 @@ def from_dict(data, require=None):
data (Mapping[str, str]): The service account data
require (Sequence[str]): List of keys required to be present in the
info.
use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer.
We use RSA signer by default.

Returns:
google.auth.crypt.Signer: A signer created from the private key in the
Expand All @@ -52,23 +54,28 @@ def from_dict(data, require=None):
)

# Create a signer.
signer = crypt.RSASigner.from_service_account_info(data)
if use_rsa_signer:
signer = crypt.RSASigner.from_service_account_info(data)
else:
signer = crypt.ES256Signer.from_service_account_info(data)

return signer


def from_filename(filename, require=None):
def from_filename(filename, require=None, use_rsa_signer=True):
"""Reads a Google service account JSON file and returns its parsed info.

Args:
filename (str): The path to the service account .json file.
require (Sequence[str]): List of keys required to be present in the
info.
use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer.
We use RSA signer by default.

Returns:
Tuple[ Mapping[str, str], google.auth.crypt.Signer ]: The verified
info and a signer instance.
"""
with io.open(filename, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
return data, from_dict(data, require=require)
return data, from_dict(data, require=require, use_rsa_signer=use_rsa_signer)
58 changes: 41 additions & 17 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ def _handle_error_response(response_data):
"""Translates an error response into an exception.

Args:
response_data (Mapping): The decoded response data.
response_data (Mapping | str): The decoded response data.

Raises:
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if isinstance(response_data, six.string_types):
raise exceptions.RefreshError(response_data)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -79,7 +81,7 @@ def _parse_expiry(response_data):


def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
request, token_uri, body, access_token=None, use_json=False, **kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Expand All @@ -93,6 +95,13 @@ def _token_endpoint_request_no_throw(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.

Returns:
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
Expand All @@ -112,32 +121,40 @@ def _token_endpoint_request_no_throw(
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, **kwargs
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)
response_data = json.loads(response_body)

if response.status == http_client.OK:
# response_body should be a JSON
response_data = json.loads(response_body)
break
else:
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data
# For a failed response, response_body could be a string
try:
response_data = json.loads(response_body)
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
except ValueError:
response_data = response_body
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
return False, response_data

return True, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
request, token_uri, body, access_token=None, use_json=False, **kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.

Expand All @@ -150,6 +167,13 @@ def _token_endpoint_request(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
kwargs: Additional arguments passed on to the request method. The
kwargs will be passed to `requests.request` method, see:
https://docs.python-requests.org/en/latest/api/#requests.request.
For example, you can use `cert=("cert_pem_path", "key_pem_path")`
to set up client side SSL certificate, and use
`verify="ca_bundle_path"` to set up the CA certificates for sever
side SSL certificate verification.

Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -159,7 +183,7 @@ def _token_endpoint_request(
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
request, token_uri, body, access_token=access_token, use_json=use_json, **kwargs
)
if not response_status_ok:
_handle_error_response(response_data)
Expand Down
Loading