Skip to content

Commit

Permalink
refactor: remove duplication in grpc transport init (#808)
Browse files Browse the repository at this point in the history
* refactor: remove duplication in grpc transport init

* refactor: let base transport set final host, creds, scopes

* refactor: remove scopes or self.AUTH_SCOPES in grpc

* docs: explain self._prep_wrapped_messages
  • Loading branch information
busunkim96 authored Mar 10, 2021
1 parent 203cdf1 commit 6a2622a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class {{ service.name }}Transport(abc.ABC):
host += ':443'
self._host = host

# Save the scopes.
self._scopes = scopes or self.AUTH_SCOPES

# If no credentials are provided, then determine the appropriate
# defaults.
if credentials and credentials_file:
Expand All @@ -88,19 +91,16 @@ class {{ service.name }}Transport(abc.ABC):
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
credentials_file,
scopes=scopes,
scopes=self._scopes,
quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id)
credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id)

# Save the credentials.
self._credentials = credentials

# Lifted into its own function so it can be stubbed out during tests.
self._prep_wrapped_messages(client_info)


def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,91 +101,73 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
# Ignore credentials if a channel was passed.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
ssl_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials

else:
host = host if ":" in host else host + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
if api_mtls_endpoint:
host = api_mtls_endpoint

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
self._ssl_channel_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
else:
if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# The base transport sets the host, credentials and scopes
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
)

if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
self._host,
credentials=self._credentials,
credentials_file=credentials_file,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

self._stubs = {} # type: Dict[str, Callable]
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
# Wrap messages. This must be done after self._grpc_channel exists
self._prep_wrapped_messages(client_info)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
client_info=client_info,
)

@classmethod
def create_channel(cls,
Expand All @@ -197,7 +179,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
**kwargs) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
address (Optional[str]): The host for the channel to use.
host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
**kwargs) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
address (Optional[str]): The host for the channel to use.
host (Optional[str]): The host for the channel to use.
credentials (Optional[~.Credentials]): The
authorization credentials to attach to requests. These
credentials identify this application to the service. If
Expand Down Expand Up @@ -145,91 +145,72 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
# Ignore credentials if a channel was passed.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
ssl_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
self._ssl_channel_credentials = ssl_credentials
else:
host = host if ":" in host else host + ":443"
if api_mtls_endpoint:
host = api_mtls_endpoint

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
self._ssl_channel_credentials = SslCredentials().ssl_credentials

else:
if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
# The base transport sets the host, credentials and scopes
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
self._host,
credentials=self._credentials,
credentials_file=credentials_file,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
client_info=client_info,
)

self._stubs = {}
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
# Wrap messages. This must be done after self._grpc_channel exists
self._prep_wrapped_messages(client_info)

@property
def grpc_channel(self) -> aio.Channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{%- endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._prep_wrapped_messages(client_info)

{%- if service.has_lro %}

Expand Down

0 comments on commit 6a2622a

Please sign in to comment.