Skip to content

Commit

Permalink
feat: add mtls support to rest transport
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Jan 17, 2021
1 parent e5b7648 commit 644a085
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,21 +239,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
# Create SSL credentials for mutual TLS if needed.
use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")))

ssl_credentials = None
client_cert_source = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
import grpc # type: ignore

cert, key = client_options.client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
is_mtls = True
client_cert_source = client_options.client_cert_source
else:
creds = SslCredentials()
is_mtls = creds.is_mtls
ssl_credentials = creds.ssl_credentials if is_mtls else None
is_mtls = mtls.has_default_client_cert_source()
client_cert_source = mtls.default_client_cert_source() if is_mtls else None

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
Expand Down Expand Up @@ -292,7 +286,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
ssl_channel_credentials=ssl_credentials,
client_cert_source_for_mtls=client_cert_source,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -80,8 +81,12 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
ssl_channel_credentials (grpc.ChannelCredentials): Deprecated. SSL
credentials for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -98,6 +103,13 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if ssl_channel_credentials:
warnings.warn("ssl_channel_credentials is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -107,8 +119,6 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
Expand Down Expand Up @@ -144,12 +154,18 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -124,8 +125,12 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
ssl_channel_credentials (grpc.ChannelCredentials): Deprecated. SSL
credentials for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -141,6 +146,13 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
if ssl_channel_credentials:
warnings.warn("ssl_channel_credentials is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
Expand All @@ -151,8 +163,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
Expand Down Expand Up @@ -188,12 +198,18 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand All @@ -68,8 +68,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
This argument is ignored if ``channel`` is provided.
scopes (Optional(Sequence[str])): A list of scopes. This argument is
ignored if ``channel`` is provided.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client
certificate to configure mutual TLS HTTP channel. It is ignored
if ``channel`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -89,6 +90,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)

{%- if service.has_lro %}

Expand Down
Loading

0 comments on commit 644a085

Please sign in to comment.