Skip to content

Commit

Permalink
feat: support self-signed JWT flow for service accounts (#774)
Browse files Browse the repository at this point in the history
See [RFC (internal only)](https://docs.google.com/document/d/1SNCVTmW6Rtr__u-_V7nsT9PhSzjj1z0P9fAD3YUgRoc/edit#) and https://aip.dev/auth/4111.

Support the self-signed JWT flow for service accounts by passing `default_scopes` and `default_host` in calls to the auth library and `create_channel`. This depends on features exposed in the following PRs: googleapis/python-api-core#134, googleapis/google-auth-library-python#665.

It may be easier to look at https://github.com/googleapis/python-translate/pull/107/files for a diff on a real library.

This change is written so that the library is (temporarily) compatible with older `google-api-core` and `google-auth` versions. Because of this it not possible to reach 100% coverage on a single unit test run. `pytest` runs twice in two of the `nox` sessions.

Miscellaneous changes:
- sprinkled in `__init__.py` files in subdirs of the `test/` directory, as otherwise pytest-cov seems to fail to collect coverage properly in some instances.
- new dependency on `packaging` for Version comparison https://pypi.org/project/packaging/

Co-authored-by: Brent Shaffer <[email protected]>
  • Loading branch information
busunkim96 and bshaffer authored Apr 21, 2021
1 parent 7ca9222 commit 89d6f35
Show file tree
Hide file tree
Showing 12 changed files with 458 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

{% block content %}
import abc
import typing
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
import packaging.version
import pkg_resources

from google import auth # type: ignore
import google.api_core # type: ignore
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
Expand Down Expand Up @@ -34,6 +36,18 @@ try:
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class {{ service.name }}Transport(abc.ABC):
"""Abstract transport class for {{ service.name }}."""

Expand All @@ -43,13 +57,15 @@ class {{ service.name }}Transport(abc.ABC):
{%- endfor %}
)

DEFAULT_HOST: str = {% if service.host %}'{{ service.host }}'{% else %}{{ '' }}{% endif %}

def __init__(
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
host: str = DEFAULT_HOST,
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
quota_project_id: typing.Optional[str] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
**kwargs,
) -> None:
Expand All @@ -66,7 +82,7 @@ class {{ service.name }}Transport(abc.ABC):
credentials_file (Optional[str]): A file with credentials that can
be loaded with :func:`google.auth.load_credentials_from_file`.
This argument is mutually exclusive with credentials.
scope (Optional[Sequence[str]]): A list of scopes.
scopes (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -80,6 +96,8 @@ class {{ service.name }}Transport(abc.ABC):
host += ':443'
self._host = host

scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)

# Save the scopes.
self._scopes = scopes or self.AUTH_SCOPES

Expand All @@ -91,17 +109,59 @@ class {{ service.name }}Transport(abc.ABC):
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
credentials_file,
scopes=self._scopes,
**scopes_kwargs,
quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# Save the credentials.
self._credentials = credentials


# TODO(busunkim): These two class methods are in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-api-core
# and google-auth are increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
def _get_scopes_kwargs(cls, host: str, scopes: Optional[Sequence[str]]) -> Dict[str, Optional[Sequence[str]]]:
"""Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""

scopes_kwargs = {}

if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}

return scopes_kwargs

# TODO: Remove this function once google-api-core >= 1.26.0 is required
@classmethod
def _get_self_signed_jwt_kwargs(cls, host: str, scopes: Optional[Sequence[str]]) -> Dict[str, Union[Optional[Sequence[str]], str]]:
"""Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""

self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}

if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return self_signed_jwt_kwargs


def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down Expand Up @@ -138,11 +198,11 @@ class {{ service.name }}Transport(abc.ABC):
{%- for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> typing.Callable[
def {{ method.name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
typing.Union[
Union[
{{ method.output.ident }},
typing.Awaitable[{{ method.output.ident }}]
Awaitable[{{ method.output.ident }}]
]]:
raise NotImplementedError()
{%- endfor %}
Expand All @@ -152,29 +212,29 @@ class {{ service.name }}Transport(abc.ABC):
@property
def set_iam_policy(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.SetIamPolicyRequest],
typing.Union[policy.Policy, typing.Awaitable[policy.Policy]],
Union[policy.Policy, Awaitable[policy.Policy]],
]:
raise NotImplementedError()

@property
def get_iam_policy(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.GetIamPolicyRequest],
typing.Union[policy.Policy, typing.Awaitable[policy.Policy]],
Union[policy.Policy, Awaitable[policy.Policy]],
]:
raise NotImplementedError()

@property
def test_iam_permissions(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.TestIamPermissionsRequest],
typing.Union[
Union[
iam_policy.TestIamPermissionsResponse,
typing.Awaitable[iam_policy.TestIamPermissionsResponse],
Awaitable[iam_policy.TestIamPermissionsResponse],
],
]:
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

{% block content %}
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
Expand Down Expand Up @@ -202,13 +202,15 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
scopes = scopes or cls.AUTH_SCOPES

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

{% block content %}
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers_async # type: ignore
Expand All @@ -12,6 +12,7 @@ from google.api_core import operations_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version

import grpc # type: ignore
from grpc.experimental import aio # type: ignore
Expand Down Expand Up @@ -75,13 +76,15 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
Returns:
aio.Channel: A gRPC AsyncIO channel object.
"""
scopes = scopes or cls.AUTH_SCOPES

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
**kwargs
)

Expand Down Expand Up @@ -163,7 +166,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None

else:
if api_mtls_endpoint:
host = api_mtls_endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
"""
# Run the base constructor
# TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
# TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the
# credentials object
super().__init__(
host=host,
credentials=credentials,
client_info=client_info,
)
self._session = AuthorizedSession(self._credentials)
self._session = AuthorizedSession(self._credentials, default_host=self.DEFAULT_HOST)
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
Expand All @@ -106,11 +108,14 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
# Sanity check: Only create a new client if we do not already have one.
if self._operations_client is None:
from google.api_core import grpc_helpers

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(self._host, self._scopes)

self._operations_client = operations_v1.OperationsClient(
grpc_helpers.create_channel(
self._host,
credentials=self._credentials,
scopes=self.AUTH_SCOPES,
**self_signed_jwt_kwargs,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
Expand Down
1 change: 0 additions & 1 deletion gapic/templates/.coveragerc.j2
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
branch = True

[report]
fail_under = 100
show_missing = True
omit =
{{ api.naming.module_namespace|join("/") }}/{{ api.naming.module_name }}/__init__.py
Expand Down
62 changes: 62 additions & 0 deletions gapic/templates/noxfile.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,28 @@

{% block content %}
import os
import pathlib
import shutil
import subprocess
import sys


import nox # type: ignore

CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()

LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt"
PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8")


nox.sessions = [
"unit",
"cover",
"mypy",
"check_lower_bounds"
# exclude update_lower_bounds from default
"docs",
]

@nox.session(python=['3.6', '3.7', '3.8', '3.9'])
def unit(session):
Expand All @@ -25,6 +43,18 @@ def unit(session):
)


@nox.session(python='3.7')
def cover(session):
"""Run the final coverage report.
This outputs the coverage report aggregating coverage from the unit
test runs (not system test runs), and then erases coverage data.
"""
session.install("coverage", "pytest-cov")
session.run("coverage", "report", "--show-missing", "--fail-under=100")

session.run("coverage", "erase")


@nox.session(python=['3.6', '3.7'])
def mypy(session):
"""Run the type checker."""
Expand All @@ -40,6 +70,38 @@ def mypy(session):
{%- endif %}
)


@nox.session
def update_lower_bounds(session):
"""Update lower bounds in constraints.txt to match setup.py"""
session.install('google-cloud-testutils')
session.install('.')

session.run(
'lower-bound-checker',
'update',
'--package-name',
PACKAGE_NAME,
'--constraints-file',
str(LOWER_BOUND_CONSTRAINTS_FILE),
)


@nox.session
def check_lower_bounds(session):
"""Check lower bounds in setup.py are reflected in constraints file"""
session.install('google-cloud-testutils')
session.install('.')

session.run(
'lower-bound-checker',
'check',
'--package-name',
PACKAGE_NAME,
'--constraints-file',
str(LOWER_BOUND_CONSTRAINTS_FILE),
)

@nox.session(python='3.6')
def docs(session):
"""Build the docs for this library."""
Expand Down
3 changes: 2 additions & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ setuptools.setup(
'google-api-core[grpc] >= 1.22.2, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.15.0',
'packaging >= 14.3',
{%- if api.requires_package(('google', 'iam', 'v1')) or opts.add_iam_methods %}
'grpc-google-iam-v1',
'grpc-google-iam-v1 >= 0.12.3, < 0.13dev',
{%- endif %}
),
python_requires='>=3.6',
Expand Down
2 changes: 2 additions & 0 deletions gapic/templates/tests/__init__.py.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

{% extends '_base.py.j2' %}
Loading

0 comments on commit 89d6f35

Please sign in to comment.