Skip to content

Commit

Permalink
Merge branch 'main' into ud
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 authored Nov 27, 2023
2 parents 4910f14 + 7ab0fce commit 6768878
Show file tree
Hide file tree
Showing 12 changed files with 456 additions and 285 deletions.
4 changes: 2 additions & 2 deletions .github/.OwlBot.lock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
digest: sha256:3e3800bb100af5d7f9e810d48212b37812c1856d20ffeafb99ebe66461b61fc7
# created: 2023-08-02T10:53:29.114535628Z
digest: sha256:caffe0a9277daeccc4d1de5c9b55ebba0901b57c2f713ec9c876b0d4ec064f61
# created: 2023-11-08T19:46:45.022803742Z
533 changes: 276 additions & 257 deletions .kokoro/requirements.txt

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions google/auth/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@

from google.auth import exceptions

# Token server doesn't provide a new a token when doing refresh unless the
# token is expiring within 30 seconds, so refresh threshold should not be
# more than 30 seconds. Otherwise auth lib will send tons of refresh requests
# until 30 seconds before the expiration, and cause a spike of CPU usage.
REFRESH_THRESHOLD = datetime.timedelta(seconds=20)
# The smallest MDS cache used by this library stores tokens until 4 minutes from
# expiry.
REFRESH_THRESHOLD = datetime.timedelta(minutes=3, seconds=45)


def copy_docstring(source_class):
Expand Down
53 changes: 44 additions & 9 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get(
recursive=False,
retry_count=5,
headers=None,
return_none_for_not_found_error=False,
):
"""Fetch a resource from the metadata server.
Expand All @@ -173,6 +174,8 @@ def get(
retry_count (int): How many times to attempt connecting to metadata
server using above timeout.
headers (Optional[Mapping[str, str]]): Headers for the request.
return_none_for_not_found_error (Optional[bool]): If True, returns None
for 404 error instead of throwing an exception.
Returns:
Union[Mapping, str]: If the metadata server returns JSON, a mapping of
Expand Down Expand Up @@ -216,8 +219,17 @@ def get(
"metadata service. Compute Engine Metadata server unavailable".format(url)
)

content = _helpers.from_bytes(response.data)

if response.status == http_client.NOT_FOUND and return_none_for_not_found_error:
_LOGGER.info(
"Compute Engine Metadata server call to %s returned 404, reason: %s",
path,
content,
)
return None

if response.status == http_client.OK:
content = _helpers.from_bytes(response.data)
if (
_helpers.parse_content_type(response.headers["content-type"])
== "application/json"
Expand All @@ -232,14 +244,14 @@ def get(
raise new_exc from caught_exc
else:
return content
else:
raise exceptions.TransportError(
"Failed to retrieve {} from the Google Compute Engine "
"metadata service. Status: {} Response:\n{}".format(
url, response.status, response.data
),
response,
)

raise exceptions.TransportError(
"Failed to retrieve {} from the Google Compute Engine "
"metadata service. Status: {} Response:\n{}".format(
url, response.status, response.data
),
response,
)


def get_project_id(request):
Expand All @@ -259,6 +271,29 @@ def get_project_id(request):
return get(request, "project/project-id")


def get_universe_domain(request):
"""Get the universe domain value from the metadata server.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
Returns:
str: The universe domain value. If the universe domain endpoint is not
not found, return the default value, which is googleapis.com
Raises:
google.auth.exceptions.TransportError: if an error other than
404 occurs while retrieving metadata.
"""
universe_domain = get(
request, "universe/universe_domain", return_none_for_not_found_error=True
)
if not universe_domain:
return "googleapis.com"
return universe_domain


def get_service_account_info(request, service_account="default"):
"""Get information about a service account from the metadata server.
Expand Down
9 changes: 9 additions & 0 deletions google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self._quota_project_id = quota_project_id
self._scopes = scopes
self._default_scopes = default_scopes
self._universe_domain_cached = False

def _retrieve_info(self, request):
"""Retrieve information about the service account.
Expand Down Expand Up @@ -131,6 +132,14 @@ def service_account_email(self):
def requires_scopes(self):
return not self._scopes

@property
def universe_domain(self):
if self._universe_domain_cached:
return self._universe_domain
self._universe_domain = _metadata.get_universe_domain()
self._universe_domain_cached = True
return self._universe_domain

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
return self.__class__(
Expand Down
6 changes: 6 additions & 0 deletions google/oauth2/_credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ async def refresh(self, request):
)
)

@_helpers.copy_docstring(credentials.Credentials)
async def before_request(self, request, method, url, headers):
if not self.valid:
await self.refresh(request)
self.apply(headers)


class UserAccessTokenCredentials(oauth2_credentials.UserAccessTokenCredentials):
"""Access token credentials for user account.
Expand Down
12 changes: 5 additions & 7 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,11 @@ def _metric_header_for_usage(self):

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
if (
self._universe_domain != _DEFAULT_UNIVERSE_DOMAIN
and not self._jwt_credentials
):
raise exceptions.RefreshError(
"self._jwt_credentials is missing for non-default universe domain"
)
if self._always_use_jwt_access and not self._jwt_credentials:
# If self signed jwt should be used but jwt credential is not
# created, try to create one with scopes
self._create_self_signed_jwt(None)

if self._universe_domain != _DEFAULT_UNIVERSE_DOMAIN and self._subject:
raise exceptions.RefreshError(
"domain wide delegation is not supported for non-default universe domain"
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
59 changes: 59 additions & 0 deletions tests/compute_engine/test__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ def test_get_failure():
)


def test_get_return_none_for_not_found_error():
request = make_request("Metadata error", status=http_client.NOT_FOUND)

assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + PATH,
headers=_metadata._METADATA_HEADERS,
)


def test_get_failure_connection_failed():
request = make_request("")
request.side_effect = exceptions.TransportError()
Expand Down Expand Up @@ -371,6 +383,53 @@ def test_get_project_id():
assert project_id == project


def test_get_universe_domain_success():
request = make_request(
"fake_universe_domain", headers={"content-type": "text/plain"}
)

universe_domain = _metadata.get_universe_domain(request)

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)
assert universe_domain == "fake_universe_domain"


def test_get_universe_domain_not_found():
# Test that if the universe domain endpoint returns 404 error, we should
# use googleapis.com as the universe domain
request = make_request("not found", status=http_client.NOT_FOUND)

universe_domain = _metadata.get_universe_domain(request)

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)
assert universe_domain == "googleapis.com"


def test_get_universe_domain_other_error():
# Test that if the universe domain endpoint returns an error other than 404
# we should throw the error
request = make_request("unauthorized", status=http_client.UNAUTHORIZED)

with pytest.raises(exceptions.TransportError) as excinfo:
_metadata.get_universe_domain(request)

assert excinfo.match(r"unauthorized")

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)


@mock.patch(
"google.auth.metrics.token_request_access_token_mds",
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
Expand Down
20 changes: 20 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def test_token_usage_metrics(self):
assert headers["authorization"] == "Bearer token"
assert headers["x-goog-api-client"] == "cred-type/mds"

@mock.patch(
"google.auth.compute_engine._metadata.get_universe_domain",
return_value="fake_universe_domain",
)
def test_universe_domain(self, get_universe_domain):
self.credentials._universe_domain_cached = False
self.credentials._universe_domain = "googleapis.com"

# calling the universe_domain property should trigger a call to
# get_universe_domain to fetch the value. The value should be cached.
assert self.credentials.universe_domain == "fake_universe_domain"
assert self.credentials._universe_domain == "fake_universe_domain"
assert self.credentials._universe_domain_cached
get_universe_domain.assert_called_once()

# calling the universe_domain property the second time should use the
# cached value instead of calling get_universe_domain
assert self.credentials.universe_domain == "fake_universe_domain"
get_universe_domain.assert_called_once()


class TestIDTokenCredentials(object):
credentials = None
Expand Down
14 changes: 9 additions & 5 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,16 @@ def test_refresh_jwt_not_used_for_domain_wide_delegation(
assert jwt_grant.called
assert not self_signed_jwt_refresh.called

def test_refresh_non_gdu_missing_jwt_credentials(self):
credentials = self.make_credentials(universe_domain="foo")
def test_refresh_missing_jwt_credentials(self):
credentials = self.make_credentials()
credentials = credentials.with_scopes(["foo", "bar"])
credentials = credentials.with_always_use_jwt_access(True)
assert not credentials._jwt_credentials

with pytest.raises(exceptions.RefreshError) as excinfo:
credentials.refresh(None)
assert excinfo.match("self._jwt_credentials is missing")
credentials.refresh(mock.Mock())

# jwt credentials should have been automatically created with scopes
assert credentials._jwt_credentials is not None

def test_refresh_non_gdu_domain_wide_delegation_not_supported(self):
credentials = self.make_credentials(universe_domain="foo")
Expand Down
23 changes: 23 additions & 0 deletions tests_async/oauth2/test_credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,29 @@ def test_unpickle_old_credentials_pickle(self):
credentials = pickle.load(f)
assert credentials.quota_project_id is None

@mock.patch("google.oauth2._credentials_async.Credentials.apply", autospec=True)
@mock.patch("google.oauth2._credentials_async.Credentials.refresh", autospec=True)
@pytest.mark.asyncio
async def test_before_request(self, refresh, apply):
cred = self.make_credentials()
assert not cred.valid
await cred.before_request(mock.Mock(), "GET", "https://example.com", {})
refresh.assert_called()
apply.assert_called()

@mock.patch("google.oauth2._credentials_async.Credentials.apply", autospec=True)
@mock.patch("google.oauth2._credentials_async.Credentials.refresh", autospec=True)
@pytest.mark.asyncio
async def test_before_request_no_refresh(self, refresh, apply):
cred = self.make_credentials()
cred.token = refresh
cred.expiry = None

assert cred.valid
await cred.before_request(mock.Mock(), "GET", "https://example.com", {})
refresh.assert_not_called()
apply.assert_called()


class TestUserAccessTokenCredentials(object):
def test_instance(self):
Expand Down

0 comments on commit 6768878

Please sign in to comment.