diff --git a/google/auth/_default.py b/google/auth/_default.py index d038438d5..3a4190389 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -36,11 +36,13 @@ _SERVICE_ACCOUNT_TYPE = "service_account" _EXTERNAL_ACCOUNT_TYPE = "external_account" _IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account" +_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account" _VALID_TYPES = ( _AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE, _EXTERNAL_ACCOUNT_TYPE, _IMPERSONATED_SERVICE_ACCOUNT_TYPE, + _GDCH_SERVICE_ACCOUNT_TYPE, ) # Help message when no credentials can be found. @@ -134,6 +136,8 @@ def load_credentials_from_file( def _load_credentials_from_info( filename, info, scopes, default_scopes, quota_project_id, request ): + from google.auth.credentials import CredentialsWithQuotaProject + credential_type = info.get("type") if credential_type == _AUTHORIZED_USER_TYPE: @@ -158,6 +162,8 @@ def _load_credentials_from_info( credentials, project_id = _get_impersonated_service_account_credentials( filename, info, scopes ) + elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_gdch_service_account_credentials(filename, info) else: raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " @@ -165,7 +171,8 @@ def _load_credentials_from_info( file=filename, type=credential_type, valid_types=_VALID_TYPES ) ) - credentials = _apply_quota_project_id(credentials, quota_project_id) + if isinstance(credentials, CredentialsWithQuotaProject): + credentials = _apply_quota_project_id(credentials, quota_project_id) return credentials, project_id @@ -421,6 +428,20 @@ def _get_impersonated_service_account_credentials(filename, info, scopes): return credentials, None +def _get_gdch_service_account_credentials(filename, info): + from google.oauth2 import gdch_credentials + + try: + credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info( + info + ) + except ValueError as caught_exc: + msg = "Failed to load GDCH service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + six.raise_from(new_exc, caught_exc) + return credentials, info.get("project") + + def _apply_quota_project_id(credentials, quota_project_id): if quota_project_id: credentials = credentials.with_quota_project(quota_project_id) @@ -456,6 +477,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non endpoint. The project ID returned in this case is the one corresponding to the underlying workload identity pool resource if determinable. + + If the environment variable is set to the path of a valid GDCH service + account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH + credential will be returned. The project ID returned is the project + specified in the JSON file. 2. If the `Google Cloud SDK`_ is installed and has application default credentials set they are loaded and returned. @@ -490,6 +516,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non .. _Metadata Service: https://cloud.google.com/compute/docs\ /storing-retrieving-metadata .. _Cloud Run: https://cloud.google.com/run + .. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\ + /hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted Example:: diff --git a/google/auth/_service_account_info.py b/google/auth/_service_account_info.py index 3d340c78d..157099273 100644 --- a/google/auth/_service_account_info.py +++ b/google/auth/_service_account_info.py @@ -22,7 +22,7 @@ from google.auth import crypt -def from_dict(data, require=None): +def from_dict(data, require=None, use_rsa_signer=True): """Validates a dictionary containing Google service account data. Creates and returns a :class:`google.auth.crypt.Signer` instance from the @@ -32,6 +32,8 @@ def from_dict(data, require=None): data (Mapping[str, str]): The service account data require (Sequence[str]): List of keys required to be present in the info. + use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer. + We use RSA signer by default. Returns: google.auth.crypt.Signer: A signer created from the private key in the @@ -52,18 +54,23 @@ def from_dict(data, require=None): ) # Create a signer. - signer = crypt.RSASigner.from_service_account_info(data) + if use_rsa_signer: + signer = crypt.RSASigner.from_service_account_info(data) + else: + signer = crypt.ES256Signer.from_service_account_info(data) return signer -def from_filename(filename, require=None): +def from_filename(filename, require=None, use_rsa_signer=True): """Reads a Google service account JSON file and returns its parsed info. Args: filename (str): The path to the service account .json file. require (Sequence[str]): List of keys required to be present in the info. + use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer. + We use RSA signer by default. Returns: Tuple[ Mapping[str, str], google.auth.crypt.Signer ]: The verified @@ -71,4 +78,4 @@ def from_filename(filename, require=None): """ with io.open(filename, "r", encoding="utf-8") as json_file: data = json.load(json_file) - return data, from_dict(data, require=require) + return data, from_dict(data, require=require, use_rsa_signer=use_rsa_signer) diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 2f4e8474b..847c5db8a 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -44,11 +44,13 @@ def _handle_error_response(response_data): """Translates an error response into an exception. Args: - response_data (Mapping): The decoded response data. + response_data (Mapping | str): The decoded response data. Raises: google.auth.exceptions.RefreshError: The errors contained in response_data. """ + if isinstance(response_data, six.string_types): + raise exceptions.RefreshError(response_data) try: error_details = "{}: {}".format( response_data["error"], response_data.get("error_description") @@ -79,7 +81,7 @@ def _parse_expiry(response_data): def _token_endpoint_request_no_throw( - request, token_uri, body, access_token=None, use_json=False + request, token_uri, body, access_token=None, use_json=False, **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. This function doesn't throw on response errors. @@ -93,6 +95,13 @@ def _token_endpoint_request_no_throw( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Tuple(bool, Mapping[str, str]): A boolean indicating if the request is @@ -112,32 +121,40 @@ def _token_endpoint_request_no_throw( # retry to fetch token for maximum of two times if any internal failure # occurs. while True: - response = request(method="POST", url=token_uri, headers=headers, body=body) + response = request( + method="POST", url=token_uri, headers=headers, body=body, **kwargs + ) response_body = ( response.data.decode("utf-8") if hasattr(response.data, "decode") else response.data ) - response_data = json.loads(response_body) if response.status == http_client.OK: + # response_body should be a JSON + response_data = json.loads(response_body) break else: - error_desc = response_data.get("error_description") or "" - error_code = response_data.get("error") or "" - if ( - any(e == "internal_failure" for e in (error_code, error_desc)) - and retry < 1 - ): - retry += 1 - continue - return response.status == http_client.OK, response_data - - return response.status == http_client.OK, response_data + # For a failed response, response_body could be a string + try: + response_data = json.loads(response_body) + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + if ( + any(e == "internal_failure" for e in (error_code, error_desc)) + and retry < 1 + ): + retry += 1 + continue + except ValueError: + response_data = response_body + return False, response_data + + return True, response_data def _token_endpoint_request( - request, token_uri, body, access_token=None, use_json=False + request, token_uri, body, access_token=None, use_json=False, **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. @@ -150,6 +167,13 @@ def _token_endpoint_request( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Mapping[str, str]: The JSON-decoded response data. @@ -159,7 +183,7 @@ def _token_endpoint_request( an error. """ response_status_ok, response_data = _token_endpoint_request_no_throw( - request, token_uri, body, access_token=access_token, use_json=use_json + request, token_uri, body, access_token=access_token, use_json=use_json, **kwargs ) if not response_status_ok: _handle_error_response(response_data) diff --git a/google/oauth2/gdch_credentials.py b/google/oauth2/gdch_credentials.py new file mode 100644 index 000000000..7410cfc2e --- /dev/null +++ b/google/oauth2/gdch_credentials.py @@ -0,0 +1,251 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental GDCH credentials support. +""" + +import datetime + +from google.auth import _helpers +from google.auth import _service_account_info +from google.auth import credentials +from google.auth import exceptions +from google.auth import jwt +from google.oauth2 import _client + + +TOKEN_EXCHANGE_TYPE = "urn:ietf:params:oauth:token-type:token-exchange" +ACCESS_TOKEN_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" +JWT_LIFETIME = datetime.timedelta(seconds=3600) # 1 hour + + +class ServiceAccountCredentials(credentials.Credentials): + """Credentials for GDCH (`Google Distributed Cloud Hosted`_) for service + account users. + + .. _Google Distributed Cloud Hosted: + https://cloud.google.com/blog/topics/hybrid-cloud/\ + announcing-google-distributed-cloud-edge-and-hosted + + To create a GDCH service account credential, first create a JSON file of + the following format:: + + { + "type": "gdch_service_account", + "format_version": "1", + "project": "", + "private_key_id": "", + "private_key": "-----BEGIN EC PRIVATE KEY-----\n\n-----END EC PRIVATE KEY-----\n", + "name": "", + "ca_cert_path": "", + "token_uri": "https://service-identity./authenticate" + } + + The "format_version" field stands for the format of the JSON file. For now + it is always "1". The `private_key_id` and `private_key` is used for signing. + The `ca_cert_path` is used for token server TLS certificate verification. + + After the JSON file is created, set `GOOGLE_APPLICATION_CREDENTIALS` environment + variable to the JSON file path, then use the following code to create the + credential:: + + import google.auth + + credential, _ = google.auth.default() + credential = credential.with_gdch_audience("") + + We can also create the credential directly:: + + from google.oauth import gdch_credentials + + credential = gdch_credentials.ServiceAccountCredentials.from_service_account_file("") + credential = credential.with_gdch_audience("") + + The token is obtained in the following way. This class first creates a + self signed JWT. It uses the `name` value as the `iss` and `sub` claim, and + the `token_uri` as the `aud` claim, and signs the JWT with the `private_key`. + It then sends the JWT to the `token_uri` to exchange a final token for + `audience`. + """ + + def __init__( + self, signer, service_identity_name, project, audience, token_uri, ca_cert_path + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + service_identity_name (str): The service identity name. It will be + used as the `iss` and `sub` claim in the self signed JWT. + project (str): The project. + audience (str): The audience for the final token. + token_uri (str): The token server uri. + ca_cert_path (str): The CA cert path for token server side TLS + certificate verification. If the token server uses well known + CA, then this parameter can be `None`. + """ + super(ServiceAccountCredentials, self).__init__() + self._signer = signer + self._service_identity_name = service_identity_name + self._project = project + self._audience = audience + self._token_uri = token_uri + self._ca_cert_path = ca_cert_path + + def _create_jwt(self): + now = _helpers.utcnow() + expiry = now + JWT_LIFETIME + iss_sub_value = "system:serviceaccount:{}:{}".format( + self._project, self._service_identity_name + ) + + payload = { + "iss": iss_sub_value, + "sub": iss_sub_value, + "aud": self._token_uri, + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + } + + return _helpers.from_bytes(jwt.encode(self._signer, payload)) + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + import google.auth.transport.requests + + if not isinstance(request, google.auth.transport.requests.Request): + raise exceptions.RefreshError( + "For GDCH service account credentials, request must be a google.auth.transport.requests.Request object" + ) + + # Create a self signed JWT, and do token exchange. + jwt_token = self._create_jwt() + request_body = { + "grant_type": TOKEN_EXCHANGE_TYPE, + "audience": self._audience, + "requested_token_type": ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": jwt_token, + "subject_token_type": SERVICE_ACCOUNT_TOKEN_TYPE, + } + response_data = _client._token_endpoint_request( + request, + self._token_uri, + request_body, + access_token=None, + use_json=True, + verify=self._ca_cert_path, + ) + + self.token, _, self.expiry, _ = _client._handle_refresh_grant_response( + response_data, None + ) + + def with_gdch_audience(self, audience): + """Create a copy of GDCH credentials with the specified audience. + + Args: + audience (str): The intended audience for GDCH credentials. + """ + return self.__class__( + self._signer, + self._service_identity_name, + self._project, + audience, + self._token_uri, + self._ca_cert_path, + ) + + @classmethod + def _from_signer_and_info(cls, signer, info): + """Creates a Credentials instance from a signer and service account + info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + if info["format_version"] != "1": + raise ValueError("Only format version 1 is supported") + + return cls( + signer, + info["name"], # service_identity_name + info["project"], + None, # audience + info["token_uri"], + info.get("ca_cert_path", None), + ) + + @classmethod + def from_service_account_info(cls, info): + """Creates a Credentials instance from parsed service account info. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict( + info, + require=[ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ], + use_rsa_signer=False, + ) + return cls._from_signer_and_info(signer, info) + + @classmethod + def from_service_account_file(cls, filename): + """Creates a Credentials instance from a service account json file. + + Args: + filename (str): The path to the service account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + """ + info, signer = _service_account_info.from_filename( + filename, + require=[ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ], + use_rsa_signer=False, + ) + return cls._from_signer_and_info(signer, info) diff --git a/tests/data/gdch_service_account.json b/tests/data/gdch_service_account.json new file mode 100644 index 000000000..172164e9f --- /dev/null +++ b/tests/data/gdch_service_account.json @@ -0,0 +1,11 @@ +{ + "type": "gdch_service_account", + "format_version": "1", + "project": "project_foo", + "private_key_id": "key_foo", + "private_key": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIIGb2np7v54Hs6++NiLE7CQtQg7rzm4znstHvrOUlcMMoAoGCCqGSM49\nAwEHoUQDQgAECvv0VyZS9nYOa8tdwKCbkNxlWgrAZVClhJXqrvOZHlH4N3d8Rplk\n2DEJvzp04eMxlHw1jm6JCs3iJR6KAokG+w==\n-----END EC PRIVATE KEY-----\n", + "name": "service_identity_name", + "ca_cert_path": "/path/to/ca/cert", + "token_uri": "https://service-identity./authenticate" +} + diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index 5485bed84..bd4cc5001 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -56,7 +56,7 @@ def test__handle_error_response(): assert excinfo.match(r"help: I\'m alive") -def test__handle_error_response_non_json(): +def test__handle_error_response_no_error(): response_data = {"foo": "bar"} with pytest.raises(exceptions.RefreshError) as excinfo: @@ -65,6 +65,15 @@ def test__handle_error_response_non_json(): assert excinfo.match(r"{\"foo\": \"bar\"}") +def test__handle_error_response_not_json(): + response_data = "this is an error message" + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(response_data) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test__parse_expiry(unused_utcnow): result = _client._parse_expiry({"expires_in": 500}) @@ -145,6 +154,8 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error_description": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 request = make_request( {"error": "internal_failure"}, status=http_client.BAD_REQUEST @@ -154,6 +165,20 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 + + +def test__token_endpoint_request_string_error(): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.BAD_REQUEST + response.data = "this is an error message" + request = mock.create_autospec(transport.Request) + request.return_value = response + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._token_endpoint_request(request, "http://example.com", {}) + assert excinfo.match("this is an error message") def verify_request_params(request, params): diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py new file mode 100644 index 000000000..60944ed41 --- /dev/null +++ b/tests/oauth2/test_gdch_credentials.py @@ -0,0 +1,174 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import datetime +import json +import os + +import mock +import pytest # type: ignore +import requests +import six + +from google.auth import exceptions +from google.auth import jwt +import google.auth.transport.requests +from google.oauth2 import gdch_credentials +from google.oauth2.gdch_credentials import ServiceAccountCredentials + + +class TestServiceAccountCredentials(object): + AUDIENCE = "https://service-identity./authenticate" + PROJECT = "project_foo" + PRIVATE_KEY_ID = "key_foo" + NAME = "service_identity_name" + CA_CERT_PATH = "/path/to/ca/cert" + TOKEN_URI = "https://service-identity./authenticate" + + JSON_PATH = os.path.join( + os.path.dirname(__file__), "..", "data", "gdch_service_account.json" + ) + with open(JSON_PATH, "rb") as fh: + INFO = json.load(fh) + + def test_with_gdch_audience(self): + mock_signer = mock.Mock() + creds = ServiceAccountCredentials._from_signer_and_info(mock_signer, self.INFO) + assert creds._signer == mock_signer + assert creds._service_identity_name == self.NAME + assert creds._audience is None + assert creds._token_uri == self.TOKEN_URI + assert creds._ca_cert_path == self.CA_CERT_PATH + + new_creds = creds.with_gdch_audience(self.AUDIENCE) + assert new_creds._signer == mock_signer + assert new_creds._service_identity_name == self.NAME + assert new_creds._audience == self.AUDIENCE + assert new_creds._token_uri == self.TOKEN_URI + assert new_creds._ca_cert_path == self.CA_CERT_PATH + + def test__create_jwt(self): + creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH) + with mock.patch("google.auth._helpers.utcnow") as utcnow: + utcnow.return_value = datetime.datetime.now() + jwt_token = creds._create_jwt() + header, payload, _, _ = jwt._unverified_decode(jwt_token) + + expected_iss_sub_value = ( + "system:serviceaccount:project_foo:service_identity_name" + ) + assert isinstance(jwt_token, six.text_type) + assert header["alg"] == "ES256" + assert header["kid"] == self.PRIVATE_KEY_ID + assert payload["iss"] == expected_iss_sub_value + assert payload["sub"] == expected_iss_sub_value + assert payload["aud"] == self.AUDIENCE + assert payload["exp"] == (payload["iat"] + 3600) + + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._create_jwt", + autospec=True, + ) + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + def test_refresh(self, token_endpoint_request, create_jwt): + creds = ServiceAccountCredentials.from_service_account_info(self.INFO) + creds = creds.with_gdch_audience(self.AUDIENCE) + req = google.auth.transport.requests.Request() + + mock_jwt_token = "jwt token" + create_jwt.return_value = mock_jwt_token + sts_token = "STS token" + token_endpoint_request.return_value = { + "access_token": sts_token, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 3600, + } + + creds.refresh(req) + + token_endpoint_request.assert_called_with( + req, + self.TOKEN_URI, + { + "grant_type": gdch_credentials.TOKEN_EXCHANGE_TYPE, + "audience": self.AUDIENCE, + "requested_token_type": gdch_credentials.ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": mock_jwt_token, + "subject_token_type": gdch_credentials.SERVICE_ACCOUNT_TOKEN_TYPE, + }, + access_token=None, + use_json=True, + verify=self.CA_CERT_PATH, + ) + assert creds.token == sts_token + + def test_refresh_wrong_requests_object(self): + creds = ServiceAccountCredentials.from_service_account_info(self.INFO) + creds = creds.with_gdch_audience(self.AUDIENCE) + req = requests.Request() + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(req) + assert excinfo.match( + "request must be a google.auth.transport.requests.Request object" + ) + + def test__from_signer_and_info_wrong_format_version(self): + with pytest.raises(ValueError) as excinfo: + ServiceAccountCredentials._from_signer_and_info( + mock.Mock(), {"format_version": "2"} + ) + assert excinfo.match("Only format version 1 is supported") + + def test_from_service_account_info_miss_field(self): + for field in [ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ]: + info_with_missing_field = copy.deepcopy(self.INFO) + del info_with_missing_field[field] + with pytest.raises(ValueError) as excinfo: + ServiceAccountCredentials.from_service_account_info( + info_with_missing_field + ) + assert excinfo.match("missing fields") + + @mock.patch("google.auth._service_account_info.from_filename") + def test_from_service_account_file(self, from_filename): + mock_signer = mock.Mock() + from_filename.return_value = (self.INFO, mock_signer) + creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH) + from_filename.assert_called_with( + self.JSON_PATH, + require=[ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ], + use_rsa_signer=False, + ) + assert creds._signer == mock_signer + assert creds._service_identity_name == self.NAME + assert creds._audience is None + assert creds._token_uri == self.TOKEN_URI + assert creds._ca_cert_path == self.CA_CERT_PATH diff --git a/tests/test__default.py b/tests/test__default.py index ed64bc723..61772c2e3 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -28,6 +28,7 @@ from google.auth import external_account from google.auth import identity_pool from google.auth import impersonated_credentials +from google.oauth2 import gdch_credentials from google.oauth2 import service_account import google.oauth2.credentials @@ -50,6 +51,8 @@ CLIENT_SECRETS_FILE = os.path.join(DATA_DIR, "client_secrets.json") +GDCH_SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "gdch_service_account.json") + with open(SERVICE_ACCOUNT_FILE) as fh: SERVICE_ACCOUNT_FILE_DATA = json.load(fh) @@ -637,6 +640,14 @@ def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_proj assert get_project_id.called +def test__get_gdch_service_account_credentials_invalid_format_version(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default._get_gdch_service_account_credentials( + "file_name", {"format_version": "2"} + ) + assert excinfo.match("Failed to load GDCH service account credentials") + + class _AppIdentityModule(object): """The interface of the App Idenity app engine module. See https://cloud.google.com/appengine/docs/standard/python/refdocs\ @@ -1140,3 +1151,19 @@ def test_default_impersonated_service_account_set_both_scopes_and_default_scopes credentials, _ = _default.default(scopes=scopes, default_scopes=default_scopes) assert credentials._target_scopes == scopes + + +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test_default_gdch_service_account_credentials(get_adc_path): + get_adc_path.return_value = GDCH_SERVICE_ACCOUNT_FILE + + creds, project = _default.default(quota_project_id="project-foo") + + assert isinstance(creds, gdch_credentials.ServiceAccountCredentials) + assert creds._service_identity_name == "service_identity_name" + assert creds._audience is None + assert creds._token_uri == "https://service-identity./authenticate" + assert creds._ca_cert_path == "/path/to/ca/cert" + assert project == "project_foo" diff --git a/tests/test__service_account_info.py b/tests/test__service_account_info.py index d5529bcce..9ad9f0fc8 100644 --- a/tests/test__service_account_info.py +++ b/tests/test__service_account_info.py @@ -24,10 +24,14 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "data") SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") +GDCH_SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "gdch_service_account.json") with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: SERVICE_ACCOUNT_INFO = json.load(fh) +with open(GDCH_SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + GDCH_SERVICE_ACCOUNT_INFO = json.load(fh) + def test_from_dict(): signer = _service_account_info.from_dict(SERVICE_ACCOUNT_INFO) @@ -35,6 +39,14 @@ def test_from_dict(): assert signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] +def test_from_dict_es256_signer(): + signer = _service_account_info.from_dict( + GDCH_SERVICE_ACCOUNT_INFO, use_rsa_signer=False + ) + assert isinstance(signer, crypt.ES256Signer) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_INFO["private_key_id"] + + def test_from_dict_bad_private_key(): info = SERVICE_ACCOUNT_INFO.copy() info["private_key"] = "garbage" @@ -60,3 +72,12 @@ def test_from_filename(): assert isinstance(signer, crypt.RSASigner) assert signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"] + + +def test_from_filename_es256_signer(): + _, signer = _service_account_info.from_filename( + GDCH_SERVICE_ACCOUNT_JSON_FILE, use_rsa_signer=False + ) + + assert isinstance(signer, crypt.ES256Signer) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_INFO["private_key_id"]