Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support quota project override via client options #496

Merged
merged 4 commits into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
scopes=client_options.scopes,
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
quota_project_id=client_options.quota_project_id,
)

{% for method in service.methods.values() -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class {{ service.name }}Transport(abc.ABC):
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
quota_project_id: typing.Optional[str] = None,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -49,6 +50,8 @@ class {{ service.name }}Transport(abc.ABC):
be loaded with :func:`google.auth.load_credentials_from_file`.
This argument is mutually exclusive with credentials.
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -61,9 +64,14 @@ class {{ service.name }}Transport(abc.ABC):
raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive")

if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(credentials_file, scopes=scopes)
credentials, _ = auth.load_credentials_from_file(
credentials_file,
scopes=scopes,
quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(scopes=scopes)
credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id)

# Save the credentials.
self._credentials = credentials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes: Sequence[str] = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None) -> None:
"""Instantiate the transport.

Args:
Expand All @@ -71,6 +72,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand All @@ -89,7 +92,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES)
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
Expand All @@ -108,14 +111,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

self._stubs = {} # type: Dict[str, Callable]
Expand All @@ -126,6 +131,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: str = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
Expand All @@ -141,6 +147,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
service. These are only used when credentials are not specified and
are passed to :func:`google.auth.default`.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
kwargs (Optional[dict]): Keyword arguments, which are passed to the
channel creation.
Returns:
Expand All @@ -156,6 +164,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
Expand All @@ -60,6 +61,8 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
service. These are only used when credentials are not specified and
are passed to :func:`google.auth.default`.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
kwargs (Optional[dict]): Keyword arguments, which are passed to the
channel creation.
Returns:
Expand All @@ -71,6 +74,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**kwargs
)

Expand All @@ -81,7 +85,9 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes: Optional[Sequence[str]] = None,
channel: aio.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
) -> None:
"""Instantiate the transport.

Args:
Expand Down Expand Up @@ -109,6 +115,8 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand Down Expand Up @@ -143,14 +151,16 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

self._stubs = {}
Expand Down
2 changes: 1 addition & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ setuptools.setup(
platforms='Posix; MacOS X; Windows',
include_package_data=True,
install_requires=(
'google-api-core[grpc] >= 1.21.0, < 2.0.0dev',
'google-api-core[grpc] >= 1.22.0, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.1.0',
{%- if api.requires_package(('google', 'iam', 'v1')) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -127,6 +128,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -142,6 +144,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -158,6 +161,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=client_cert_source_callback,
quota_project_id=None,

)

Expand All @@ -175,6 +179,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -191,15 +196,29 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
# unsupported value.
os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
with pytest.raises(MutualTLSChannelError):
client = client_class()
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}):
with pytest.raises(MutualTLSChannelError):
client = client_class()

del os.environ["GOOGLE_API_USE_MTLS"]
# Check the case quota_project_id is provided
options = client_options.ClientOptions(quota_project_id="octopus")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id="octopus",
)


@pytest.mark.parametrize("client_class,transport_class,transport_name", [
Expand All @@ -221,6 +240,7 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
scopes=["1", "2"],
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)


Expand All @@ -243,6 +263,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)


Expand All @@ -259,6 +280,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
scopes=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
quota_project_id=None,
)


Expand Down Expand Up @@ -1001,12 +1023,15 @@ def test_{{ service.name|snake_case }}_base_transport_with_credentials_file():
load_creds.return_value = (credentials.AnonymousCredentials(), None)
transport = transports.{{ service.name }}Transport(
credentials_file="credentials.json",
quota_project_id="octopus",
)
load_creds.assert_called_once_with("credentials.json", scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))
),
quota_project_id="octopus",
)


def test_{{ service.name|snake_case }}_auth_adc():
Expand All @@ -1017,22 +1042,23 @@ def test_{{ service.name|snake_case }}_auth_adc():
adc.assert_called_once_with(scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))
{%- endfor %}),
quota_project_id=None,
)


def test_{{ service.name|snake_case }}_transport_auth_adc():
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(auth, 'default') as adc:
adc.return_value = (credentials.AnonymousCredentials(), None)
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk", quota_project_id="octopus")
adc.assert_called_once_with(scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))

{%- endfor %}),
quota_project_id="octopus",
)

def test_{{ service.name|snake_case }}_host_no_port():
{% with host = (service.host|default('localhost', true)).split(':')[0] -%}
Expand Down Expand Up @@ -1122,6 +1148,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_c
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1160,6 +1187,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1200,6 +1228,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1240,6 +1269,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
click==7.1.2
google-api-core==1.21.0
google-api-core==1.22.0
googleapis-common-protos==1.52.0
jinja2==2.11.2
MarkupSafe==1.1.1
Expand Down