Skip to content

Commit

Permalink
feat: add GOOGLE_API_USE_MTLS support (#420)
Browse files Browse the repository at this point in the history
Co-authored-by: Dov Shlachter <[email protected]>
  • Loading branch information
arithmetic1728 and software-dov authored May 27, 2020
1 parent 4957090 commit 41fa725
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

{% block content %}
from collections import OrderedDict
import os
import re
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources
Expand All @@ -11,6 +12,8 @@ from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport import mtls # type: ignore
from google.auth.exceptions import MutualTLSChannelError # type: ignore
from google.oauth2 import service_account # type: ignore

{% filter sort_lines -%}
Expand Down Expand Up @@ -144,21 +147,47 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
transport (Union[str, ~.{{ service.name }}Transport]): The
transport to use. If set to None, a transport is chosen
automatically.
client_options (ClientOptions): Custom options for the client.
client_options (ClientOptions): Custom options for the client. It
won't take effect unless ``transport`` is None.
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.
default endpoint provided by the client. GOOGLE_API_USE_MTLS
environment variable can also be used to override the endpoint:
"Always" (always use the default mTLS endpoint), "Never" (always
use the default regular endpoint, this is the default value for
the environment variable) and "Auto" (auto switch to the default
mTLS endpoint if client SSL credentials is present). However,
the ``api_endpoint`` property takes precedence if provided.
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)
if client_options is None:
client_options = ClientOptions.ClientOptions()

if transport is None and client_options.api_endpoint is None:
use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "Never")
if use_mtls_env == "Never":
client_options.api_endpoint = self.DEFAULT_ENDPOINT
elif use_mtls_env == "Always":
client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT
elif use_mtls_env == "Auto":
has_client_cert_source = (
client_options.client_cert_source is not None
or mtls.has_default_client_cert_source()
)
client_options.api_endpoint = (
self.DEFAULT_MTLS_ENDPOINT if has_client_cert_source else self.DEFAULT_ENDPOINT
)
else:
raise MutualTLSChannelError(
"Unsupported GOOGLE_API_USE_MTLS value. Accepted values: Never, Auto, Always"
)

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
Expand All @@ -169,38 +198,16 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
elif client_options is None or (
client_options.api_endpoint is None
and client_options.client_cert_source is None
):
# Don't trigger mTLS if we get an empty ClientOptions.
elif isinstance(transport, str):
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# We have a non-empty ClientOptions. If client_cert_source is
# provided, trigger mTLS with user provided endpoint or the default
# mTLS endpoint.
if client_options.client_cert_source:
api_mtls_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_MTLS_ENDPOINT
)
else:
api_mtls_endpoint = None

api_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_ENDPOINT
)

self._transport = {{ service.name }}GrpcTransport(
credentials=credentials,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
host=client_options.api_endpoint,
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore

Expand Down Expand Up @@ -63,7 +64,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
is None.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
"""
if channel:
Expand All @@ -76,6 +77,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
Expand All @@ -96,7 +100,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
self._stubs = {} # type: Dict[str, Callable]
self._stubs = {} # type: Dict[str, Callable]


@classmethod
Expand Down
118 changes: 90 additions & 28 deletions gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{% extends "_base.py.j2" %}

{% block content %}
import os
from unittest import mock

import grpc
Expand All @@ -11,6 +12,7 @@ import pytest
{% filter sort_lines -%}
from google import auth
from google.auth import credentials
from google.auth.exceptions import MutualTLSChannelError
from google.oauth2 import service_account
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
Expand Down Expand Up @@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
{% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}


def test_{{ service.client_name|snake_case }}_get_transport_class():
transport = {{ service.client_name }}.get_transport_class()
assert transport == transports.{{ service.name }}GrpcTransport

transport = {{ service.client_name }}.get_transport_class("grpc")
assert transport == transports.{{ service.name }}GrpcTransport


def test_{{ service.client_name|snake_case }}_client_options():
# Check that if channel is provided we won't create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
Expand All @@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options():
client = {{ service.client_name }}(transport=transport)
gtc.assert_not_called()

# Check mTLS is not triggered with empty client options.
options = client_options.ClientOptions()
# Check that if channel is provided via str we will create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
transport = gtc.return_value = mock.MagicMock()
client = {{ service.client_name }}(client_options=options)
transport.assert_called_once_with(
credentials=None,
host=client.DEFAULT_ENDPOINT,
)
client = {{ service.client_name }}(transport="grpc")
gtc.assert_called()

# Check mTLS is not triggered if api_endpoint is provided but
# client_cert_source is None.
# Check the case api_endpoint is provided.
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
)

# Check mTLS is triggered if client_cert_source is provided.
options = client_options.ClientOptions(
client_cert_source=client_cert_source_callback
)
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
# "Never".
os.environ["GOOGLE_API_USE_MTLS"] = "Never"
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
client = {{ service.client_name }}()
grpc_transport.assert_called_once_with(
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=client_cert_source_callback,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
)

# Check mTLS is triggered if api_endpoint and client_cert_source are provided.
options = client_options.ClientOptions(
api_endpoint="squid.clam.whelk",
client_cert_source=client_cert_source_callback
)
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
# "Always".
os.environ["GOOGLE_API_USE_MTLS"] = "Always"
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}()
grpc_transport.assert_called_once_with(
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
# "Auto", and client_cert_source is provided.
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
grpc_transport.return_value = None
client = {{ service.client_name }}(client_options=options)
grpc_transport.assert_called_once_with(
api_mtls_endpoint="squid.clam.whelk",
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=client_cert_source_callback,
credentials=None,
host="squid.clam.whelk",
host=client.DEFAULT_MTLS_ENDPOINT,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
# "Auto", and default_client_cert_source is provided.
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
grpc_transport.return_value = None
client = {{ service.client_name }}()
grpc_transport.assert_called_once_with(
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
# "Auto", but client_cert_source and default_client_cert_source are None.
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
grpc_transport.return_value = None
client = {{ service.client_name }}()
grpc_transport.assert_called_once_with(
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
# unsupported value.
os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
with pytest.raises(MutualTLSChannelError):
client = {{ service.client_name }}()

del os.environ["GOOGLE_API_USE_MTLS"]


def test_{{ service.client_name|snake_case }}_client_options_from_dict():
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
Expand All @@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
client_options={'api_endpoint': 'squid.clam.whelk'}
)
grpc_transport.assert_called_once_with(
api_mtls_endpoint=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
Expand Down Expand Up @@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc():
))


def test_{{ service.name|snake_case }}_transport_auth_adc():
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(auth, 'default') as adc:
adc.return_value = (credentials.AnonymousCredentials(), None)
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
adc.assert_called_once_with(scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))


def test_{{ service.name|snake_case }}_host_no_port():
{% with host = (service.host|default('localhost', true)).split(':')[0] -%}
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
transport='grpc',
)
assert client._transport._host == '{{ host }}:443'
{% endwith %}
Expand All @@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port():
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
transport='grpc',
)
assert client._transport._host == '{{ host }}:8000'
{% endwith %}
Expand Down
Loading

0 comments on commit 41fa725

Please sign in to comment.