From 44feef7429fe5ce70a0b346cf2120bfe994a1731 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:50:04 -0800 Subject: [PATCH] feat: Allow users to explicitly configure universe domain (#737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Allow users to explicitly configure universe domain chore: Update gapic-generator-python to v1.14.0 PiperOrigin-RevId: 603108274 Source-Link: https://github.com/googleapis/googleapis/commit/3d83e3652f689ab51c3f95f876458c6faef619bf Source-Link: https://github.com/googleapis/googleapis-gen/commit/baf5e9bbb14a768b2b4c9eae9feb78f18f1757fa Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiYmFmNWU5YmJiMTRhNzY4YjJiNGM5ZWFlOWZlYjc4ZjE4ZjE3NTdmYSJ9 * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Update test_client_v1.py * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot Co-authored-by: Lingqing Gan --- .../services/big_query_read/async_client.py | 75 ++- .../services/big_query_read/client.py | 293 +++++++- .../big_query_read/transports/base.py | 6 +- .../big_query_read/transports/grpc.py | 2 +- .../big_query_read/transports/grpc_asyncio.py | 2 +- .../services/big_query_write/async_client.py | 84 ++- .../services/big_query_write/client.py | 305 ++++++++- .../big_query_write/transports/base.py | 6 +- .../big_query_write/transports/grpc.py | 2 +- .../transports/grpc_asyncio.py | 2 +- .../services/big_query_read/async_client.py | 75 ++- .../services/big_query_read/client.py | 293 +++++++- .../big_query_read/transports/base.py | 6 +- .../big_query_read/transports/grpc.py | 2 +- .../big_query_read/transports/grpc_asyncio.py | 2 +- .../services/big_query_write/async_client.py | 84 ++- .../services/big_query_write/client.py | 305 ++++++++- .../big_query_write/transports/base.py | 6 +- .../big_query_write/transports/grpc.py | 2 +- .../transports/grpc_asyncio.py | 2 +- .../test_big_query_read.py | 567 +++++++++++++--- .../test_big_query_write.py | 625 +++++++++++++++--- .../test_big_query_read.py | 567 +++++++++++++--- .../test_big_query_write.py | 625 +++++++++++++++--- tests/unit/test_client_v1.py | 3 + 25 files changed, 3385 insertions(+), 556 deletions(-) diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/async_client.py b/google/cloud/bigquery_storage_v1/services/big_query_read/async_client.py index d98411e1..2a35cb2c 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/async_client.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/async_client.py @@ -40,9 +40,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore from google.cloud.bigquery_storage_v1.types import arrow from google.cloud.bigquery_storage_v1.types import avro @@ -62,8 +62,12 @@ class BigQueryReadAsyncClient: _client: BigQueryReadClient + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = BigQueryReadClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = BigQueryReadClient._DEFAULT_UNIVERSE read_session_path = staticmethod(BigQueryReadClient.read_session_path) parse_read_session_path = staticmethod(BigQueryReadClient.parse_read_session_path) @@ -170,6 +174,25 @@ def transport(self) -> BigQueryReadTransport: """ return self._client.transport + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + get_transport_class = functools.partial( type(BigQueryReadClient).get_transport_class, type(BigQueryReadClient) ) @@ -182,7 +205,7 @@ def __init__( client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiates the big query read client. + """Instantiates the big query read async client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -193,23 +216,38 @@ def __init__( transport (Union[str, ~.BigQueryReadTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. @@ -372,6 +410,9 @@ async def sample_create_read_session(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -503,6 +544,9 @@ async def sample_read_rows(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = rpc( request, @@ -602,6 +646,9 @@ async def sample_split_read_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/client.py b/google/cloud/bigquery_storage_v1/services/big_query_read/client.py index fa171256..1d8373b8 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/client.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/client.py @@ -29,6 +29,7 @@ Union, cast, ) +import warnings from google.cloud.bigquery_storage_v1 import gapic_version as package_version @@ -43,9 +44,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore + OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.bigquery_storage_v1.types import arrow from google.cloud.bigquery_storage_v1.types import avro @@ -127,11 +128,15 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = "bigquerystorage.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) + _DEFAULT_ENDPOINT_TEMPLATE = "bigquerystorage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -328,7 +333,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def get_mtls_endpoint_and_cert_source( cls, client_options: Optional[client_options_lib.ClientOptions] = None ): - """Return the API endpoint and client cert source for mutual TLS. + """Deprecated. Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the @@ -358,6 +363,11 @@ def get_mtls_endpoint_and_cert_source( Raises: google.auth.exceptions.MutualTLSChannelError: If any errors happen. """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) if client_options is None: client_options = client_options_lib.ClientOptions() use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") @@ -391,6 +401,175 @@ def get_mtls_endpoint_and_cert_source( return api_endpoint, client_cert_source + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = BigQueryReadClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + @staticmethod + def _compare_universes( + client_universe: str, credentials: ga_credentials.Credentials + ) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials (ga_credentials.Credentials): The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + if credentials: + credentials_universe = credentials.universe_domain + if client_universe != credentials_universe: + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + raise ValueError( + "The configured universe domain " + f"({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{default_universe}` is the default." + ) + return True + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + self._is_universe_domain_valid = ( + self._is_universe_domain_valid + or BigQueryReadClient._compare_universes( + self.universe_domain, self.transport._credentials + ) + ) + return self._is_universe_domain_valid + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + def __init__( self, *, @@ -410,22 +589,32 @@ def __init__( transport (Union[str, BigQueryReadTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): The client info used to send a user-agent string along with API requests. If ``None``, then default info will be used. @@ -436,17 +625,34 @@ def __init__( google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( - client_options + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = BigQueryReadClient._read_environment_variables() + self._client_cert_source = BigQueryReadClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert ) + self._universe_domain = BigQueryReadClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False - api_key_value = getattr(client_options, "api_key", None) + api_key_value = getattr(self._client_options, "api_key", None) if api_key_value and credentials: raise ValueError( "client_options.api_key and credentials are mutually exclusive" @@ -455,20 +661,30 @@ def __init__( # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. - if isinstance(transport, BigQueryReadTransport): + transport_provided = isinstance(transport, BigQueryReadTransport) + if transport_provided: # transport is a BigQueryReadTransport instance. - if credentials or client_options.credentials_file or api_key_value: + if credentials or self._client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: + if self._client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " "directly." ) - self._transport = transport - else: + self._transport = cast(BigQueryReadTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or BigQueryReadClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: import google.auth._default # type: ignore if api_key_value and hasattr( @@ -478,17 +694,17 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(transport) + Transport = type(self).get_transport_class(cast(str, transport)) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, client_info=client_info, always_use_jwt_access=True, - api_audience=client_options.api_audience, + api_audience=self._client_options.api_audience, ) def create_read_session( @@ -632,6 +848,9 @@ def sample_create_read_session(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -754,6 +973,9 @@ def sample_read_rows(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -844,6 +1066,9 @@ def sample_split_read_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py index c10c765a..ee4576df 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py @@ -61,7 +61,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -124,6 +124,10 @@ def __init__( host += ":443" self._host = host + @property + def host(self): + return self._host + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py index 50fe40bc..f7569bbe 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py @@ -67,7 +67,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py index e4adf208..3fa6aed7 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py @@ -112,7 +112,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1/services/big_query_write/async_client.py b/google/cloud/bigquery_storage_v1/services/big_query_write/async_client.py index 0738971b..0aec5966 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_write/async_client.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_write/async_client.py @@ -41,9 +41,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore from google.cloud.bigquery_storage_v1.types import storage from google.cloud.bigquery_storage_v1.types import stream @@ -67,8 +67,12 @@ class BigQueryWriteAsyncClient: _client: BigQueryWriteClient + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = BigQueryWriteClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = BigQueryWriteClient._DEFAULT_UNIVERSE table_path = staticmethod(BigQueryWriteClient.table_path) parse_table_path = staticmethod(BigQueryWriteClient.parse_table_path) @@ -177,6 +181,25 @@ def transport(self) -> BigQueryWriteTransport: """ return self._client.transport + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + get_transport_class = functools.partial( type(BigQueryWriteClient).get_transport_class, type(BigQueryWriteClient) ) @@ -189,7 +212,7 @@ def __init__( client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiates the big query write client. + """Instantiates the big query write async client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -200,23 +223,38 @@ def __init__( transport (Union[str, ~.BigQueryWriteTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. @@ -345,6 +383,9 @@ async def sample_create_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -481,6 +522,9 @@ def request_generator(): # add these here. metadata = tuple(metadata) + (gapic_v1.routing_header.to_grpc_metadata(()),) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = rpc( requests, @@ -594,6 +638,9 @@ async def sample_get_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -705,6 +752,9 @@ async def sample_finalize_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -822,6 +872,9 @@ async def sample_batch_commit_write_streams(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -943,6 +996,9 @@ async def sample_flush_rows(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, diff --git a/google/cloud/bigquery_storage_v1/services/big_query_write/client.py b/google/cloud/bigquery_storage_v1/services/big_query_write/client.py index 9de8cd2c..e2a714cb 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_write/client.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_write/client.py @@ -30,6 +30,7 @@ Union, cast, ) +import warnings from google.cloud.bigquery_storage_v1 import gapic_version as package_version @@ -44,9 +45,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore + OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.bigquery_storage_v1.types import storage from google.cloud.bigquery_storage_v1.types import stream @@ -132,11 +133,15 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = "bigquerystorage.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) + _DEFAULT_ENDPOINT_TEMPLATE = "bigquerystorage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -311,7 +316,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def get_mtls_endpoint_and_cert_source( cls, client_options: Optional[client_options_lib.ClientOptions] = None ): - """Return the API endpoint and client cert source for mutual TLS. + """Deprecated. Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the @@ -341,6 +346,11 @@ def get_mtls_endpoint_and_cert_source( Raises: google.auth.exceptions.MutualTLSChannelError: If any errors happen. """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) if client_options is None: client_options = client_options_lib.ClientOptions() use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") @@ -374,6 +384,175 @@ def get_mtls_endpoint_and_cert_source( return api_endpoint, client_cert_source + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = BigQueryWriteClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + @staticmethod + def _compare_universes( + client_universe: str, credentials: ga_credentials.Credentials + ) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials (ga_credentials.Credentials): The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + if credentials: + credentials_universe = credentials.universe_domain + if client_universe != credentials_universe: + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + raise ValueError( + "The configured universe domain " + f"({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{default_universe}` is the default." + ) + return True + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + self._is_universe_domain_valid = ( + self._is_universe_domain_valid + or BigQueryWriteClient._compare_universes( + self.universe_domain, self.transport._credentials + ) + ) + return self._is_universe_domain_valid + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + def __init__( self, *, @@ -393,22 +572,32 @@ def __init__( transport (Union[str, BigQueryWriteTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): The client info used to send a user-agent string along with API requests. If ``None``, then default info will be used. @@ -419,17 +608,34 @@ def __init__( google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( - client_options + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = BigQueryWriteClient._read_environment_variables() + self._client_cert_source = BigQueryWriteClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert ) + self._universe_domain = BigQueryWriteClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False - api_key_value = getattr(client_options, "api_key", None) + api_key_value = getattr(self._client_options, "api_key", None) if api_key_value and credentials: raise ValueError( "client_options.api_key and credentials are mutually exclusive" @@ -438,20 +644,33 @@ def __init__( # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. - if isinstance(transport, BigQueryWriteTransport): + transport_provided = isinstance(transport, BigQueryWriteTransport) + if transport_provided: # transport is a BigQueryWriteTransport instance. - if credentials or client_options.credentials_file or api_key_value: + if credentials or self._client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: + if self._client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " "directly." ) - self._transport = transport - else: + self._transport = cast(BigQueryWriteTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or BigQueryWriteClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: import google.auth._default # type: ignore if api_key_value and hasattr( @@ -461,17 +680,17 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(transport) + Transport = type(self).get_transport_class(cast(str, transport)) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, client_info=client_info, always_use_jwt_access=True, - api_audience=client_options.api_audience, + api_audience=self._client_options.api_audience, ) def create_write_stream( @@ -580,6 +799,9 @@ def sample_create_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -703,6 +925,9 @@ def request_generator(): # add these here. metadata = tuple(metadata) + (gapic_v1.routing_header.to_grpc_metadata(()),) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( requests, @@ -805,6 +1030,9 @@ def sample_get_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -905,6 +1133,9 @@ def sample_finalize_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -1013,6 +1244,9 @@ def sample_batch_commit_write_streams(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -1123,6 +1357,9 @@ def sample_flush_rows(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, diff --git a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/base.py b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/base.py index fecc6e7f..d9ae63d2 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/base.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/base.py @@ -62,7 +62,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -125,6 +125,10 @@ def __init__( host += ":443" self._host = host + @property + def host(self): + return self._host + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc.py b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc.py index 9605ba1c..ce651165 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc.py @@ -71,7 +71,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc_asyncio.py index 86e65c05..8121d6b8 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_write/transports/grpc_asyncio.py @@ -116,7 +116,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/async_client.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/async_client.py index ee862804..20ece340 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/async_client.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/async_client.py @@ -40,9 +40,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore from google.cloud.bigquery_storage_v1beta2.types import arrow from google.cloud.bigquery_storage_v1beta2.types import avro @@ -65,8 +65,12 @@ class BigQueryReadAsyncClient: _client: BigQueryReadClient + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = BigQueryReadClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = BigQueryReadClient._DEFAULT_UNIVERSE read_session_path = staticmethod(BigQueryReadClient.read_session_path) parse_read_session_path = staticmethod(BigQueryReadClient.parse_read_session_path) @@ -173,6 +177,25 @@ def transport(self) -> BigQueryReadTransport: """ return self._client.transport + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + get_transport_class = functools.partial( type(BigQueryReadClient).get_transport_class, type(BigQueryReadClient) ) @@ -185,7 +208,7 @@ def __init__( client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiates the big query read client. + """Instantiates the big query read async client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -196,23 +219,38 @@ def __init__( transport (Union[str, ~.BigQueryReadTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. @@ -376,6 +414,9 @@ async def sample_create_read_session(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -507,6 +548,9 @@ async def sample_read_rows(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = rpc( request, @@ -606,6 +650,9 @@ async def sample_split_read_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/client.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/client.py index f84bcddd..948aa552 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/client.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/client.py @@ -29,6 +29,7 @@ Union, cast, ) +import warnings from google.cloud.bigquery_storage_v1beta2 import gapic_version as package_version @@ -43,9 +44,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore + OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.bigquery_storage_v1beta2.types import arrow from google.cloud.bigquery_storage_v1beta2.types import avro @@ -130,11 +131,15 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = "bigquerystorage.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) + _DEFAULT_ENDPOINT_TEMPLATE = "bigquerystorage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -331,7 +336,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def get_mtls_endpoint_and_cert_source( cls, client_options: Optional[client_options_lib.ClientOptions] = None ): - """Return the API endpoint and client cert source for mutual TLS. + """Deprecated. Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the @@ -361,6 +366,11 @@ def get_mtls_endpoint_and_cert_source( Raises: google.auth.exceptions.MutualTLSChannelError: If any errors happen. """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) if client_options is None: client_options = client_options_lib.ClientOptions() use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") @@ -394,6 +404,175 @@ def get_mtls_endpoint_and_cert_source( return api_endpoint, client_cert_source + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = BigQueryReadClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + @staticmethod + def _compare_universes( + client_universe: str, credentials: ga_credentials.Credentials + ) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials (ga_credentials.Credentials): The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + if credentials: + credentials_universe = credentials.universe_domain + if client_universe != credentials_universe: + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + raise ValueError( + "The configured universe domain " + f"({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{default_universe}` is the default." + ) + return True + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + self._is_universe_domain_valid = ( + self._is_universe_domain_valid + or BigQueryReadClient._compare_universes( + self.universe_domain, self.transport._credentials + ) + ) + return self._is_universe_domain_valid + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + def __init__( self, *, @@ -413,22 +592,32 @@ def __init__( transport (Union[str, BigQueryReadTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): The client info used to send a user-agent string along with API requests. If ``None``, then default info will be used. @@ -439,17 +628,34 @@ def __init__( google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( - client_options + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = BigQueryReadClient._read_environment_variables() + self._client_cert_source = BigQueryReadClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert ) + self._universe_domain = BigQueryReadClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False - api_key_value = getattr(client_options, "api_key", None) + api_key_value = getattr(self._client_options, "api_key", None) if api_key_value and credentials: raise ValueError( "client_options.api_key and credentials are mutually exclusive" @@ -458,20 +664,30 @@ def __init__( # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. - if isinstance(transport, BigQueryReadTransport): + transport_provided = isinstance(transport, BigQueryReadTransport) + if transport_provided: # transport is a BigQueryReadTransport instance. - if credentials or client_options.credentials_file or api_key_value: + if credentials or self._client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: + if self._client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " "directly." ) - self._transport = transport - else: + self._transport = cast(BigQueryReadTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or BigQueryReadClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: import google.auth._default # type: ignore if api_key_value and hasattr( @@ -481,17 +697,17 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(transport) + Transport = type(self).get_transport_class(cast(str, transport)) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, client_info=client_info, always_use_jwt_access=True, - api_audience=client_options.api_audience, + api_audience=self._client_options.api_audience, ) def create_read_session( @@ -636,6 +852,9 @@ def sample_create_read_session(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -758,6 +977,9 @@ def sample_read_rows(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -848,6 +1070,9 @@ def sample_split_read_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py index 9acf1c15..96ca1839 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py @@ -61,7 +61,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -124,6 +124,10 @@ def __init__( host += ":443" self._host = host + @property + def host(self): + return self._host + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py index e466e8ae..696e550c 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py @@ -70,7 +70,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py index 5655c5c2..8fb8937f 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py @@ -115,7 +115,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/async_client.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/async_client.py index d1f41788..0821b58f 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/async_client.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/async_client.py @@ -41,9 +41,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore from google.cloud.bigquery_storage_v1beta2.types import storage from google.cloud.bigquery_storage_v1beta2.types import stream @@ -68,8 +68,12 @@ class BigQueryWriteAsyncClient: _client: BigQueryWriteClient + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = BigQueryWriteClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = BigQueryWriteClient._DEFAULT_UNIVERSE table_path = staticmethod(BigQueryWriteClient.table_path) parse_table_path = staticmethod(BigQueryWriteClient.parse_table_path) @@ -178,6 +182,25 @@ def transport(self) -> BigQueryWriteTransport: """ return self._client.transport + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + get_transport_class = functools.partial( type(BigQueryWriteClient).get_transport_class, type(BigQueryWriteClient) ) @@ -190,7 +213,7 @@ def __init__( client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiates the big query write client. + """Instantiates the big query write async client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -201,23 +224,38 @@ def __init__( transport (Union[str, ~.BigQueryWriteTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. @@ -346,6 +384,9 @@ async def sample_create_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -460,6 +501,9 @@ def request_generator(): # add these here. metadata = tuple(metadata) + (gapic_v1.routing_header.to_grpc_metadata(()),) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = rpc( requests, @@ -572,6 +616,9 @@ async def sample_get_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -682,6 +729,9 @@ async def sample_finalize_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -797,6 +847,9 @@ async def sample_batch_commit_write_streams(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, @@ -913,6 +966,9 @@ async def sample_flush_rows(): ), ) + # Validate the universe domain. + self._client._validate_universe_domain() + # Send the request. response = await rpc( request, diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/client.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/client.py index 021a729f..ca3b6d78 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/client.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/client.py @@ -30,6 +30,7 @@ Union, cast, ) +import warnings from google.cloud.bigquery_storage_v1beta2 import gapic_version as package_version @@ -44,9 +45,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore + OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.bigquery_storage_v1beta2.types import storage from google.cloud.bigquery_storage_v1beta2.types import stream @@ -133,11 +134,15 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. DEFAULT_ENDPOINT = "bigquerystorage.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) + _DEFAULT_ENDPOINT_TEMPLATE = "bigquerystorage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -312,7 +317,7 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def get_mtls_endpoint_and_cert_source( cls, client_options: Optional[client_options_lib.ClientOptions] = None ): - """Return the API endpoint and client cert source for mutual TLS. + """Deprecated. Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the @@ -342,6 +347,11 @@ def get_mtls_endpoint_and_cert_source( Raises: google.auth.exceptions.MutualTLSChannelError: If any errors happen. """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) if client_options is None: client_options = client_options_lib.ClientOptions() use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") @@ -375,6 +385,175 @@ def get_mtls_endpoint_and_cert_source( return api_endpoint, client_cert_source + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = BigQueryWriteClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + @staticmethod + def _compare_universes( + client_universe: str, credentials: ga_credentials.Credentials + ) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials (ga_credentials.Credentials): The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + if credentials: + credentials_universe = credentials.universe_domain + if client_universe != credentials_universe: + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + raise ValueError( + "The configured universe domain " + f"({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{default_universe}` is the default." + ) + return True + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + self._is_universe_domain_valid = ( + self._is_universe_domain_valid + or BigQueryWriteClient._compare_universes( + self.universe_domain, self.transport._credentials + ) + ) + return self._is_universe_domain_valid + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + def __init__( self, *, @@ -394,22 +573,32 @@ def __init__( transport (Union[str, BigQueryWriteTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If + to provide a client certificate for mTLS transport. If not provided, the default SSL client certificate will be used if present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): The client info used to send a user-agent string along with API requests. If ``None``, then default info will be used. @@ -420,17 +609,34 @@ def __init__( google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( - client_options + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = BigQueryWriteClient._read_environment_variables() + self._client_cert_source = BigQueryWriteClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert ) + self._universe_domain = BigQueryWriteClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False - api_key_value = getattr(client_options, "api_key", None) + api_key_value = getattr(self._client_options, "api_key", None) if api_key_value and credentials: raise ValueError( "client_options.api_key and credentials are mutually exclusive" @@ -439,20 +645,33 @@ def __init__( # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. - if isinstance(transport, BigQueryWriteTransport): + transport_provided = isinstance(transport, BigQueryWriteTransport) + if transport_provided: # transport is a BigQueryWriteTransport instance. - if credentials or client_options.credentials_file or api_key_value: + if credentials or self._client_options.credentials_file or api_key_value: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: + if self._client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " "directly." ) - self._transport = transport - else: + self._transport = cast(BigQueryWriteTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or BigQueryWriteClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: import google.auth._default # type: ignore if api_key_value and hasattr( @@ -462,17 +681,17 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(transport) + Transport = type(self).get_transport_class(cast(str, transport)) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, client_info=client_info, always_use_jwt_access=True, - api_audience=client_options.api_audience, + api_audience=self._client_options.api_audience, ) def create_write_stream( @@ -581,6 +800,9 @@ def sample_create_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -681,6 +903,9 @@ def request_generator(): # add these here. metadata = tuple(metadata) + (gapic_v1.routing_header.to_grpc_metadata(()),) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( requests, @@ -783,6 +1008,9 @@ def sample_get_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -883,6 +1111,9 @@ def sample_finalize_write_stream(): gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -990,6 +1221,9 @@ def sample_batch_commit_write_streams(): gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -1096,6 +1330,9 @@ def sample_flush_rows(): ), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py index fac40db9..4530d3bb 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py @@ -62,7 +62,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -125,6 +125,10 @@ def __init__( host += ":443" self._host = host + @property + def host(self): + return self._host + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py index f9108d87..64b15515 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py @@ -72,7 +72,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py index 3c23a87e..98ba2cd2 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py @@ -117,7 +117,7 @@ def __init__( Args: host (Optional[str]): - The hostname to connect to. + The hostname to connect to (default: 'bigquerystorage.googleapis.com'). credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none diff --git a/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py b/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py index 7cda213a..7f67f517 100644 --- a/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py +++ b/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py @@ -26,6 +26,7 @@ from grpc.experimental import aio import math import pytest +from google.api_core import api_core_version from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers @@ -66,6 +67,29 @@ def modify_default_endpoint(client): ) +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +# Anonymous Credentials with universe domain property. If no universe domain is provided, then +# the default universe domain is "googleapis.com". +class _AnonymousCredentialsWithUniverseDomain(ga_credentials.AnonymousCredentials): + def __init__(self, universe_domain="googleapis.com"): + super(_AnonymousCredentialsWithUniverseDomain, self).__init__() + self._universe_domain = universe_domain + + @property + def universe_domain(self): + return self._universe_domain + + def test__get_default_mtls_endpoint(): api_endpoint = "example.googleapis.com" api_mtls_endpoint = "example.mtls.googleapis.com" @@ -92,6 +116,254 @@ def test__get_default_mtls_endpoint(): assert BigQueryReadClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi +def test__read_environment_variables(): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert BigQueryReadClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + BigQueryReadClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryReadClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert BigQueryReadClient._get_client_cert_source(None, False) is None + assert ( + BigQueryReadClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + BigQueryReadClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + BigQueryReadClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + BigQueryReadClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), +) +@mock.patch.object( + BigQueryReadAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + BigQueryReadClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "always") + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + BigQueryReadClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + BigQueryReadClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + BigQueryReadClient._get_universe_domain(None, None) + == BigQueryReadClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + BigQueryReadClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (BigQueryReadClient, transports.BigQueryReadGrpcTransport, "grpc"), + ], +) +def test__validate_universe_domain(client_class, transport_class, transport_name): + client = client_class( + transport=transport_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + ) + assert client._validate_universe_domain() == True + + # Test the case when universe is already validated. + assert client._validate_universe_domain() == True + + if transport_name == "grpc": + # Test the case where credentials are provided by the + # `local_channel_credentials`. The default universes in both match. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + client = client_class(transport=transport_class(channel=channel)) + assert client._validate_universe_domain() == True + + # Test the case where credentials do not exist: e.g. a transport is provided + # with no credentials. Validation should still succeed because there is no + # mismatch with non-existent credentials. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + transport = transport_class(channel=channel) + transport._credentials = None + client = client_class(transport=transport) + assert client._validate_universe_domain() == True + + # Test the case when there is a universe mismatch from the credentials. + client = client_class( + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain( + universe_domain="foo.com" + ) + ) + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test the case when there is a universe mismatch from the client. + # + # TODO: Make this test unconditional once the minimum supported version of + # google-api-core becomes 2.15.0 or higher. + api_core_major, api_core_minor, _ = [ + int(part) for part in api_core_version.__version__.split(".") + ] + if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15): + client = client_class( + client_options={"universe_domain": "bar.com"}, + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain(), + ), + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + @pytest.mark.parametrize( "client_class,transport_name", [ @@ -100,7 +372,7 @@ def test__get_default_mtls_endpoint(): ], ) def test_big_query_read_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: @@ -146,7 +418,7 @@ def test_big_query_read_client_service_account_always_use_jwt( ], ) def test_big_query_read_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: @@ -189,19 +461,23 @@ def test_big_query_read_client_get_transport_class(): ], ) @mock.patch.object( - BigQueryReadClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigQueryReadClient) + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), ) @mock.patch.object( BigQueryReadAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryReadAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), ) def test_big_query_read_client_client_options( client_class, transport_class, transport_name ): # Check that if channel is provided we won't create a new one. with mock.patch.object(BigQueryReadClient, "get_transport_class") as gtc: - transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + transport = transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain() + ) client = client_class(transport=transport) gtc.assert_not_called() @@ -236,7 +512,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -266,15 +544,23 @@ def test_big_query_read_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): + with pytest.raises(MutualTLSChannelError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") @@ -284,7 +570,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id="octopus", @@ -302,7 +590,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -332,12 +622,14 @@ def test_big_query_read_client_client_options( ], ) @mock.patch.object( - BigQueryReadClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigQueryReadClient) + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), ) @mock.patch.object( BigQueryReadAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryReadAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), ) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) def test_big_query_read_client_mtls_env_auto( @@ -360,7 +652,9 @@ def test_big_query_read_client_mtls_env_auto( if use_client_cert_env == "false": expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) else: expected_client_cert_source = client_cert_source_callback expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -392,7 +686,9 @@ def test_big_query_read_client_mtls_env_auto( return_value=client_cert_source_callback, ): if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) expected_client_cert_source = None else: expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -426,7 +722,9 @@ def test_big_query_read_client_mtls_env_auto( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -512,6 +810,116 @@ def test_big_query_read_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient]) +@mock.patch.object( + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), +) +@mock.patch.object( + BigQueryReadAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), +) +def test_big_query_read_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + else: + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == default_endpoint + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -537,7 +945,9 @@ def test_big_query_read_client_client_options_scopes( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=["1", "2"], client_cert_source_for_mtls=None, quota_project_id=None, @@ -576,7 +986,9 @@ def test_big_query_read_client_client_options_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -634,7 +1046,9 @@ def test_big_query_read_client_create_channel_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -651,8 +1065,8 @@ def test_big_query_read_client_create_channel_credentials_file( ) as adc, mock.patch.object( grpc_helpers, "create_channel" ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() + file_creds = _AnonymousCredentialsWithUniverseDomain() load_creds.return_value = (file_creds, None) adc.return_value = (creds, None) client = client_class(client_options=options, transport=transport_name) @@ -684,7 +1098,7 @@ def test_big_query_read_client_create_channel_credentials_file( ) def test_create_read_session(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -728,7 +1142,7 @@ def test_create_read_session_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -747,7 +1161,7 @@ async def test_create_read_session_async( transport: str = "grpc_asyncio", request_type=storage.CreateReadSessionRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -796,7 +1210,7 @@ async def test_create_read_session_async_from_dict(): def test_create_read_session_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -828,7 +1242,7 @@ def test_create_read_session_field_headers(): @pytest.mark.asyncio async def test_create_read_session_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -859,7 +1273,7 @@ async def test_create_read_session_field_headers_async(): def test_create_read_session_flattened(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -893,7 +1307,7 @@ def test_create_read_session_flattened(): def test_create_read_session_flattened_error(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -910,7 +1324,7 @@ def test_create_read_session_flattened_error(): @pytest.mark.asyncio async def test_create_read_session_flattened_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -947,7 +1361,7 @@ async def test_create_read_session_flattened_async(): @pytest.mark.asyncio async def test_create_read_session_flattened_error_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -970,7 +1384,7 @@ async def test_create_read_session_flattened_error_async(): ) def test_read_rows(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -998,7 +1412,7 @@ def test_read_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1015,7 +1429,7 @@ async def test_read_rows_async( transport: str = "grpc_asyncio", request_type=storage.ReadRowsRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1049,7 +1463,7 @@ async def test_read_rows_async_from_dict(): def test_read_rows_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1079,7 +1493,7 @@ def test_read_rows_field_headers(): @pytest.mark.asyncio async def test_read_rows_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1111,7 +1525,7 @@ async def test_read_rows_field_headers_async(): def test_read_rows_flattened(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1139,7 +1553,7 @@ def test_read_rows_flattened(): def test_read_rows_flattened_error(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1155,7 +1569,7 @@ def test_read_rows_flattened_error(): @pytest.mark.asyncio async def test_read_rows_flattened_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1186,7 +1600,7 @@ async def test_read_rows_flattened_async(): @pytest.mark.asyncio async def test_read_rows_flattened_error_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1208,7 +1622,7 @@ async def test_read_rows_flattened_error_async(): ) def test_split_read_stream(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1237,7 +1651,7 @@ def test_split_read_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1256,7 +1670,7 @@ async def test_split_read_stream_async( transport: str = "grpc_asyncio", request_type=storage.SplitReadStreamRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1290,7 +1704,7 @@ async def test_split_read_stream_async_from_dict(): def test_split_read_stream_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1322,7 +1736,7 @@ def test_split_read_stream_field_headers(): @pytest.mark.asyncio async def test_split_read_stream_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1356,17 +1770,17 @@ async def test_split_read_stream_field_headers_async(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( @@ -1376,7 +1790,7 @@ def test_credentials_transport_error(): # It is an error to provide an api_key and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) options = client_options.ClientOptions() options.api_key = "api_key" @@ -1387,16 +1801,17 @@ def test_credentials_transport_error(): ) # It is an error to provide an api_key and a credential. - options = mock.Mock() + options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): client = BigQueryReadClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # It is an error to provide scopes and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( @@ -1408,7 +1823,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) client = BigQueryReadClient(transport=transport) assert client.transport is transport @@ -1417,13 +1832,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel transport = transports.BigQueryReadGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel @@ -1439,7 +1854,7 @@ def test_transport_get_channel(): def test_transport_adc(transport_class): # Test default credentials are used if not provided. with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class() adc.assert_called_once() @@ -1452,7 +1867,7 @@ def test_transport_adc(transport_class): ) def test_transport_kind(transport_name): transport = BigQueryReadClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert transport.kind == transport_name @@ -1460,7 +1875,7 @@ def test_transport_kind(transport_name): def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert isinstance( client.transport, @@ -1472,7 +1887,7 @@ def test_big_query_read_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.BigQueryReadTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), credentials_file="credentials.json", ) @@ -1484,7 +1899,7 @@ def test_big_query_read_base_transport(): ) as Transport: Transport.return_value = None transport = transports.BigQueryReadTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Every method on the transport should just blindly @@ -1518,7 +1933,7 @@ def test_big_query_read_base_transport_with_credentials_file(): "google.cloud.bigquery_storage_v1.services.big_query_read.transports.BigQueryReadTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + load_creds.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryReadTransport( credentials_file="credentials.json", quota_project_id="octopus", @@ -1540,7 +1955,7 @@ def test_big_query_read_base_transport_with_adc(): "google.cloud.bigquery_storage_v1.services.big_query_read.transports.BigQueryReadTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryReadTransport() adc.assert_called_once() @@ -1548,7 +1963,7 @@ def test_big_query_read_base_transport_with_adc(): def test_big_query_read_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) BigQueryReadClient() adc.assert_called_once_with( scopes=None, @@ -1571,7 +1986,7 @@ def test_big_query_read_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( scopes=["1", "2"], @@ -1620,7 +2035,7 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() adc.return_value = (creds, None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) @@ -1648,7 +2063,7 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): [transports.BigQueryReadGrpcTransport, transports.BigQueryReadGrpcAsyncIOTransport], ) def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_class): - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() # Check ssl_channel_credentials is used if provided. with mock.patch.object(transport_class, "create_channel") as mock_create_channel: @@ -1694,7 +2109,7 @@ def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_cla ) def test_big_query_read_host_no_port(transport_name): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com" ), @@ -1712,7 +2127,7 @@ def test_big_query_read_host_no_port(transport_name): ) def test_big_query_read_host_with_port(transport_name): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com:8000" ), @@ -1766,7 +2181,7 @@ def test_big_query_read_transport_channel_mtls_with_client_cert_source(transport mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() with pytest.warns(DeprecationWarning): with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) @@ -2030,7 +2445,7 @@ def test_client_with_default_client_info(): transports.BigQueryReadTransport, "_prep_wrapped_messages" ) as prep: client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2040,7 +2455,7 @@ def test_client_with_default_client_info(): ) as prep: transport_class = BigQueryReadClient.get_transport_class() transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2049,7 +2464,7 @@ def test_client_with_default_client_info(): @pytest.mark.asyncio async def test_transport_close_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc_asyncio", ) with mock.patch.object( @@ -2067,7 +2482,7 @@ def test_transport_close(): for transport, close_name in transports.items(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) with mock.patch.object( type(getattr(client.transport, close_name)), "close" @@ -2083,7 +2498,7 @@ def test_client_ctx(): ] for transport in transports: client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -2114,7 +2529,9 @@ def test_api_key_credentials(client_class, transport_class): patched.assert_called_once_with( credentials=mock_cred, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, diff --git a/tests/unit/gapic/bigquery_storage_v1/test_big_query_write.py b/tests/unit/gapic/bigquery_storage_v1/test_big_query_write.py index f343e494..090151dd 100644 --- a/tests/unit/gapic/bigquery_storage_v1/test_big_query_write.py +++ b/tests/unit/gapic/bigquery_storage_v1/test_big_query_write.py @@ -26,6 +26,7 @@ from grpc.experimental import aio import math import pytest +from google.api_core import api_core_version from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers @@ -71,6 +72,29 @@ def modify_default_endpoint(client): ) +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +# Anonymous Credentials with universe domain property. If no universe domain is provided, then +# the default universe domain is "googleapis.com". +class _AnonymousCredentialsWithUniverseDomain(ga_credentials.AnonymousCredentials): + def __init__(self, universe_domain="googleapis.com"): + super(_AnonymousCredentialsWithUniverseDomain, self).__init__() + self._universe_domain = universe_domain + + @property + def universe_domain(self): + return self._universe_domain + + def test__get_default_mtls_endpoint(): api_endpoint = "example.googleapis.com" api_mtls_endpoint = "example.mtls.googleapis.com" @@ -100,6 +124,262 @@ def test__get_default_mtls_endpoint(): ) +def test__read_environment_variables(): + assert BigQueryWriteClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert BigQueryWriteClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + BigQueryWriteClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryWriteClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert BigQueryWriteClient._get_client_cert_source(None, False) is None + assert ( + BigQueryWriteClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + BigQueryWriteClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + BigQueryWriteClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + BigQueryWriteClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + BigQueryWriteClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), +) +@mock.patch.object( + BigQueryWriteAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + BigQueryWriteClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "always") + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + BigQueryWriteClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + BigQueryWriteClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + BigQueryWriteClient._get_universe_domain(None, None) + == BigQueryWriteClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + BigQueryWriteClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (BigQueryWriteClient, transports.BigQueryWriteGrpcTransport, "grpc"), + ], +) +def test__validate_universe_domain(client_class, transport_class, transport_name): + client = client_class( + transport=transport_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + ) + assert client._validate_universe_domain() == True + + # Test the case when universe is already validated. + assert client._validate_universe_domain() == True + + if transport_name == "grpc": + # Test the case where credentials are provided by the + # `local_channel_credentials`. The default universes in both match. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + client = client_class(transport=transport_class(channel=channel)) + assert client._validate_universe_domain() == True + + # Test the case where credentials do not exist: e.g. a transport is provided + # with no credentials. Validation should still succeed because there is no + # mismatch with non-existent credentials. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + transport = transport_class(channel=channel) + transport._credentials = None + client = client_class(transport=transport) + assert client._validate_universe_domain() == True + + # Test the case when there is a universe mismatch from the credentials. + client = client_class( + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain( + universe_domain="foo.com" + ) + ) + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test the case when there is a universe mismatch from the client. + # + # TODO: Make this test unconditional once the minimum supported version of + # google-api-core becomes 2.15.0 or higher. + api_core_major, api_core_minor, _ = [ + int(part) for part in api_core_version.__version__.split(".") + ] + if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15): + client = client_class( + client_options={"universe_domain": "bar.com"}, + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain(), + ), + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + @pytest.mark.parametrize( "client_class,transport_name", [ @@ -108,7 +388,7 @@ def test__get_default_mtls_endpoint(): ], ) def test_big_query_write_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: @@ -154,7 +434,7 @@ def test_big_query_write_client_service_account_always_use_jwt( ], ) def test_big_query_write_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: @@ -198,20 +478,22 @@ def test_big_query_write_client_get_transport_class(): ) @mock.patch.object( BigQueryWriteClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), ) @mock.patch.object( BigQueryWriteAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), ) def test_big_query_write_client_client_options( client_class, transport_class, transport_name ): # Check that if channel is provided we won't create a new one. with mock.patch.object(BigQueryWriteClient, "get_transport_class") as gtc: - transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + transport = transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain() + ) client = client_class(transport=transport) gtc.assert_not_called() @@ -246,7 +528,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -276,15 +560,23 @@ def test_big_query_write_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): + with pytest.raises(MutualTLSChannelError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") @@ -294,7 +586,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id="octopus", @@ -312,7 +606,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -343,13 +639,13 @@ def test_big_query_write_client_client_options( ) @mock.patch.object( BigQueryWriteClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), ) @mock.patch.object( BigQueryWriteAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), ) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) def test_big_query_write_client_mtls_env_auto( @@ -372,7 +668,9 @@ def test_big_query_write_client_mtls_env_auto( if use_client_cert_env == "false": expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) else: expected_client_cert_source = client_cert_source_callback expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -404,7 +702,9 @@ def test_big_query_write_client_mtls_env_auto( return_value=client_cert_source_callback, ): if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) expected_client_cert_source = None else: expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -438,7 +738,9 @@ def test_big_query_write_client_mtls_env_auto( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -528,6 +830,118 @@ def test_big_query_write_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [BigQueryWriteClient, BigQueryWriteAsyncClient] +) +@mock.patch.object( + BigQueryWriteClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), +) +@mock.patch.object( + BigQueryWriteAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), +) +def test_big_query_write_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + else: + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == default_endpoint + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -553,7 +967,9 @@ def test_big_query_write_client_client_options_scopes( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=["1", "2"], client_cert_source_for_mtls=None, quota_project_id=None, @@ -592,7 +1008,9 @@ def test_big_query_write_client_client_options_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -652,7 +1070,9 @@ def test_big_query_write_client_create_channel_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -669,8 +1089,8 @@ def test_big_query_write_client_create_channel_credentials_file( ) as adc, mock.patch.object( grpc_helpers, "create_channel" ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() + file_creds = _AnonymousCredentialsWithUniverseDomain() load_creds.return_value = (file_creds, None) adc.return_value = (creds, None) client = client_class(client_options=options, transport=transport_name) @@ -703,7 +1123,7 @@ def test_big_query_write_client_create_channel_credentials_file( ) def test_create_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -741,7 +1161,7 @@ def test_create_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -760,7 +1180,7 @@ async def test_create_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.CreateWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -803,7 +1223,7 @@ async def test_create_write_stream_async_from_dict(): def test_create_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -835,7 +1255,7 @@ def test_create_write_stream_field_headers(): @pytest.mark.asyncio async def test_create_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -866,7 +1286,7 @@ async def test_create_write_stream_field_headers_async(): def test_create_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -896,7 +1316,7 @@ def test_create_write_stream_flattened(): def test_create_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -912,7 +1332,7 @@ def test_create_write_stream_flattened_error(): @pytest.mark.asyncio async def test_create_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -945,7 +1365,7 @@ async def test_create_write_stream_flattened_async(): @pytest.mark.asyncio async def test_create_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -967,7 +1387,7 @@ async def test_create_write_stream_flattened_error_async(): ) def test_append_rows(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -997,7 +1417,7 @@ async def test_append_rows_async( transport: str = "grpc_asyncio", request_type=storage.AppendRowsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1039,7 +1459,7 @@ async def test_append_rows_async_from_dict(): ) def test_get_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1075,7 +1495,7 @@ def test_get_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1092,7 +1512,7 @@ async def test_get_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.GetWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1133,7 +1553,7 @@ async def test_get_write_stream_async_from_dict(): def test_get_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1163,7 +1583,7 @@ def test_get_write_stream_field_headers(): @pytest.mark.asyncio async def test_get_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1192,7 +1612,7 @@ async def test_get_write_stream_field_headers_async(): def test_get_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1216,7 +1636,7 @@ def test_get_write_stream_flattened(): def test_get_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1231,7 +1651,7 @@ def test_get_write_stream_flattened_error(): @pytest.mark.asyncio async def test_get_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1258,7 +1678,7 @@ async def test_get_write_stream_flattened_async(): @pytest.mark.asyncio async def test_get_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1279,7 +1699,7 @@ async def test_get_write_stream_flattened_error_async(): ) def test_finalize_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1311,7 +1731,7 @@ def test_finalize_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1330,7 +1750,7 @@ async def test_finalize_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.FinalizeWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1367,7 +1787,7 @@ async def test_finalize_write_stream_async_from_dict(): def test_finalize_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1399,7 +1819,7 @@ def test_finalize_write_stream_field_headers(): @pytest.mark.asyncio async def test_finalize_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1432,7 +1852,7 @@ async def test_finalize_write_stream_field_headers_async(): def test_finalize_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1458,7 +1878,7 @@ def test_finalize_write_stream_flattened(): def test_finalize_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1473,7 +1893,7 @@ def test_finalize_write_stream_flattened_error(): @pytest.mark.asyncio async def test_finalize_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1504,7 +1924,7 @@ async def test_finalize_write_stream_flattened_async(): @pytest.mark.asyncio async def test_finalize_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1525,7 +1945,7 @@ async def test_finalize_write_stream_flattened_error_async(): ) def test_batch_commit_write_streams(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1554,7 +1974,7 @@ def test_batch_commit_write_streams_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1573,7 +1993,7 @@ async def test_batch_commit_write_streams_async( transport: str = "grpc_asyncio", request_type=storage.BatchCommitWriteStreamsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1607,7 +2027,7 @@ async def test_batch_commit_write_streams_async_from_dict(): def test_batch_commit_write_streams_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1639,7 +2059,7 @@ def test_batch_commit_write_streams_field_headers(): @pytest.mark.asyncio async def test_batch_commit_write_streams_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1672,7 +2092,7 @@ async def test_batch_commit_write_streams_field_headers_async(): def test_batch_commit_write_streams_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1698,7 +2118,7 @@ def test_batch_commit_write_streams_flattened(): def test_batch_commit_write_streams_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1713,7 +2133,7 @@ def test_batch_commit_write_streams_flattened_error(): @pytest.mark.asyncio async def test_batch_commit_write_streams_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1744,7 +2164,7 @@ async def test_batch_commit_write_streams_flattened_async(): @pytest.mark.asyncio async def test_batch_commit_write_streams_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1765,7 +2185,7 @@ async def test_batch_commit_write_streams_flattened_error_async(): ) def test_flush_rows(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1795,7 +2215,7 @@ def test_flush_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1812,7 +2232,7 @@ async def test_flush_rows_async( transport: str = "grpc_asyncio", request_type=storage.FlushRowsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1847,7 +2267,7 @@ async def test_flush_rows_async_from_dict(): def test_flush_rows_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1877,7 +2297,7 @@ def test_flush_rows_field_headers(): @pytest.mark.asyncio async def test_flush_rows_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1908,7 +2328,7 @@ async def test_flush_rows_field_headers_async(): def test_flush_rows_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1932,7 +2352,7 @@ def test_flush_rows_flattened(): def test_flush_rows_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1947,7 +2367,7 @@ def test_flush_rows_flattened_error(): @pytest.mark.asyncio async def test_flush_rows_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1976,7 +2396,7 @@ async def test_flush_rows_flattened_async(): @pytest.mark.asyncio async def test_flush_rows_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1991,17 +2411,17 @@ async def test_flush_rows_flattened_error_async(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( @@ -2011,7 +2431,7 @@ def test_credentials_transport_error(): # It is an error to provide an api_key and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) options = client_options.ClientOptions() options.api_key = "api_key" @@ -2022,16 +2442,17 @@ def test_credentials_transport_error(): ) # It is an error to provide an api_key and a credential. - options = mock.Mock() + options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): client = BigQueryWriteClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # It is an error to provide scopes and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( @@ -2043,7 +2464,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) client = BigQueryWriteClient(transport=transport) assert client.transport is transport @@ -2052,13 +2473,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel transport = transports.BigQueryWriteGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel @@ -2074,7 +2495,7 @@ def test_transport_get_channel(): def test_transport_adc(transport_class): # Test default credentials are used if not provided. with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class() adc.assert_called_once() @@ -2087,7 +2508,7 @@ def test_transport_adc(transport_class): ) def test_transport_kind(transport_name): transport = BigQueryWriteClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert transport.kind == transport_name @@ -2095,7 +2516,7 @@ def test_transport_kind(transport_name): def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert isinstance( client.transport, @@ -2107,7 +2528,7 @@ def test_big_query_write_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.BigQueryWriteTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), credentials_file="credentials.json", ) @@ -2119,7 +2540,7 @@ def test_big_query_write_base_transport(): ) as Transport: Transport.return_value = None transport = transports.BigQueryWriteTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Every method on the transport should just blindly @@ -2156,7 +2577,7 @@ def test_big_query_write_base_transport_with_credentials_file(): "google.cloud.bigquery_storage_v1.services.big_query_write.transports.BigQueryWriteTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + load_creds.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryWriteTransport( credentials_file="credentials.json", quota_project_id="octopus", @@ -2179,7 +2600,7 @@ def test_big_query_write_base_transport_with_adc(): "google.cloud.bigquery_storage_v1.services.big_query_write.transports.BigQueryWriteTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryWriteTransport() adc.assert_called_once() @@ -2187,7 +2608,7 @@ def test_big_query_write_base_transport_with_adc(): def test_big_query_write_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) BigQueryWriteClient() adc.assert_called_once_with( scopes=None, @@ -2211,7 +2632,7 @@ def test_big_query_write_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( scopes=["1", "2"], @@ -2261,7 +2682,7 @@ def test_big_query_write_transport_create_channel(transport_class, grpc_helpers) ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() adc.return_value = (creds, None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) @@ -2293,7 +2714,7 @@ def test_big_query_write_transport_create_channel(transport_class, grpc_helpers) ], ) def test_big_query_write_grpc_transport_client_cert_source_for_mtls(transport_class): - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() # Check ssl_channel_credentials is used if provided. with mock.patch.object(transport_class, "create_channel") as mock_create_channel: @@ -2339,7 +2760,7 @@ def test_big_query_write_grpc_transport_client_cert_source_for_mtls(transport_cl ) def test_big_query_write_host_no_port(transport_name): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com" ), @@ -2357,7 +2778,7 @@ def test_big_query_write_host_no_port(transport_name): ) def test_big_query_write_host_with_port(transport_name): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com:8000" ), @@ -2416,7 +2837,7 @@ def test_big_query_write_transport_channel_mtls_with_client_cert_source( mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() with pytest.warns(DeprecationWarning): with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) @@ -2659,7 +3080,7 @@ def test_client_with_default_client_info(): transports.BigQueryWriteTransport, "_prep_wrapped_messages" ) as prep: client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2669,7 +3090,7 @@ def test_client_with_default_client_info(): ) as prep: transport_class = BigQueryWriteClient.get_transport_class() transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2678,7 +3099,7 @@ def test_client_with_default_client_info(): @pytest.mark.asyncio async def test_transport_close_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc_asyncio", ) with mock.patch.object( @@ -2696,7 +3117,7 @@ def test_transport_close(): for transport, close_name in transports.items(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) with mock.patch.object( type(getattr(client.transport, close_name)), "close" @@ -2712,7 +3133,7 @@ def test_client_ctx(): ] for transport in transports: client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -2743,7 +3164,9 @@ def test_api_key_credentials(client_class, transport_class): patched.assert_called_once_with( credentials=mock_cred, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, diff --git a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py index 630437b5..5a6d55ff 100644 --- a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py +++ b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py @@ -26,6 +26,7 @@ from grpc.experimental import aio import math import pytest +from google.api_core import api_core_version from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers @@ -68,6 +69,29 @@ def modify_default_endpoint(client): ) +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +# Anonymous Credentials with universe domain property. If no universe domain is provided, then +# the default universe domain is "googleapis.com". +class _AnonymousCredentialsWithUniverseDomain(ga_credentials.AnonymousCredentials): + def __init__(self, universe_domain="googleapis.com"): + super(_AnonymousCredentialsWithUniverseDomain, self).__init__() + self._universe_domain = universe_domain + + @property + def universe_domain(self): + return self._universe_domain + + def test__get_default_mtls_endpoint(): api_endpoint = "example.googleapis.com" api_mtls_endpoint = "example.mtls.googleapis.com" @@ -94,6 +118,254 @@ def test__get_default_mtls_endpoint(): assert BigQueryReadClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi +def test__read_environment_variables(): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert BigQueryReadClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + BigQueryReadClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert BigQueryReadClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryReadClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert BigQueryReadClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert BigQueryReadClient._get_client_cert_source(None, False) is None + assert ( + BigQueryReadClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + BigQueryReadClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + BigQueryReadClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + BigQueryReadClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), +) +@mock.patch.object( + BigQueryReadAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + BigQueryReadClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "always") + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == BigQueryReadClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + BigQueryReadClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryReadClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + BigQueryReadClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + BigQueryReadClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + BigQueryReadClient._get_universe_domain(None, None) + == BigQueryReadClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + BigQueryReadClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (BigQueryReadClient, transports.BigQueryReadGrpcTransport, "grpc"), + ], +) +def test__validate_universe_domain(client_class, transport_class, transport_name): + client = client_class( + transport=transport_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + ) + assert client._validate_universe_domain() == True + + # Test the case when universe is already validated. + assert client._validate_universe_domain() == True + + if transport_name == "grpc": + # Test the case where credentials are provided by the + # `local_channel_credentials`. The default universes in both match. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + client = client_class(transport=transport_class(channel=channel)) + assert client._validate_universe_domain() == True + + # Test the case where credentials do not exist: e.g. a transport is provided + # with no credentials. Validation should still succeed because there is no + # mismatch with non-existent credentials. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + transport = transport_class(channel=channel) + transport._credentials = None + client = client_class(transport=transport) + assert client._validate_universe_domain() == True + + # Test the case when there is a universe mismatch from the credentials. + client = client_class( + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain( + universe_domain="foo.com" + ) + ) + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test the case when there is a universe mismatch from the client. + # + # TODO: Make this test unconditional once the minimum supported version of + # google-api-core becomes 2.15.0 or higher. + api_core_major, api_core_minor, _ = [ + int(part) for part in api_core_version.__version__.split(".") + ] + if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15): + client = client_class( + client_options={"universe_domain": "bar.com"}, + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain(), + ), + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + @pytest.mark.parametrize( "client_class,transport_name", [ @@ -102,7 +374,7 @@ def test__get_default_mtls_endpoint(): ], ) def test_big_query_read_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: @@ -148,7 +420,7 @@ def test_big_query_read_client_service_account_always_use_jwt( ], ) def test_big_query_read_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: @@ -191,19 +463,23 @@ def test_big_query_read_client_get_transport_class(): ], ) @mock.patch.object( - BigQueryReadClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigQueryReadClient) + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), ) @mock.patch.object( BigQueryReadAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryReadAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), ) def test_big_query_read_client_client_options( client_class, transport_class, transport_name ): # Check that if channel is provided we won't create a new one. with mock.patch.object(BigQueryReadClient, "get_transport_class") as gtc: - transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + transport = transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain() + ) client = client_class(transport=transport) gtc.assert_not_called() @@ -238,7 +514,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -268,15 +546,23 @@ def test_big_query_read_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): + with pytest.raises(MutualTLSChannelError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") @@ -286,7 +572,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id="octopus", @@ -304,7 +592,9 @@ def test_big_query_read_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -334,12 +624,14 @@ def test_big_query_read_client_client_options( ], ) @mock.patch.object( - BigQueryReadClient, "DEFAULT_ENDPOINT", modify_default_endpoint(BigQueryReadClient) + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), ) @mock.patch.object( BigQueryReadAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryReadAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), ) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) def test_big_query_read_client_mtls_env_auto( @@ -362,7 +654,9 @@ def test_big_query_read_client_mtls_env_auto( if use_client_cert_env == "false": expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) else: expected_client_cert_source = client_cert_source_callback expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -394,7 +688,9 @@ def test_big_query_read_client_mtls_env_auto( return_value=client_cert_source_callback, ): if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) expected_client_cert_source = None else: expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -428,7 +724,9 @@ def test_big_query_read_client_mtls_env_auto( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -514,6 +812,116 @@ def test_big_query_read_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient]) +@mock.patch.object( + BigQueryReadClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadClient), +) +@mock.patch.object( + BigQueryReadAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryReadAsyncClient), +) +def test_big_query_read_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = BigQueryReadClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryReadClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + else: + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == default_endpoint + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -539,7 +947,9 @@ def test_big_query_read_client_client_options_scopes( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=["1", "2"], client_cert_source_for_mtls=None, quota_project_id=None, @@ -578,7 +988,9 @@ def test_big_query_read_client_client_options_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -636,7 +1048,9 @@ def test_big_query_read_client_create_channel_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -653,8 +1067,8 @@ def test_big_query_read_client_create_channel_credentials_file( ) as adc, mock.patch.object( grpc_helpers, "create_channel" ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() + file_creds = _AnonymousCredentialsWithUniverseDomain() load_creds.return_value = (file_creds, None) adc.return_value = (creds, None) client = client_class(client_options=options, transport=transport_name) @@ -686,7 +1100,7 @@ def test_big_query_read_client_create_channel_credentials_file( ) def test_create_read_session(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -722,7 +1136,7 @@ def test_create_read_session_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -741,7 +1155,7 @@ async def test_create_read_session_async( transport: str = "grpc_asyncio", request_type=storage.CreateReadSessionRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -782,7 +1196,7 @@ async def test_create_read_session_async_from_dict(): def test_create_read_session_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -814,7 +1228,7 @@ def test_create_read_session_field_headers(): @pytest.mark.asyncio async def test_create_read_session_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -845,7 +1259,7 @@ async def test_create_read_session_field_headers_async(): def test_create_read_session_flattened(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -879,7 +1293,7 @@ def test_create_read_session_flattened(): def test_create_read_session_flattened_error(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -896,7 +1310,7 @@ def test_create_read_session_flattened_error(): @pytest.mark.asyncio async def test_create_read_session_flattened_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -933,7 +1347,7 @@ async def test_create_read_session_flattened_async(): @pytest.mark.asyncio async def test_create_read_session_flattened_error_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -956,7 +1370,7 @@ async def test_create_read_session_flattened_error_async(): ) def test_read_rows(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -984,7 +1398,7 @@ def test_read_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1001,7 +1415,7 @@ async def test_read_rows_async( transport: str = "grpc_asyncio", request_type=storage.ReadRowsRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1035,7 +1449,7 @@ async def test_read_rows_async_from_dict(): def test_read_rows_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1065,7 +1479,7 @@ def test_read_rows_field_headers(): @pytest.mark.asyncio async def test_read_rows_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1097,7 +1511,7 @@ async def test_read_rows_field_headers_async(): def test_read_rows_flattened(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1125,7 +1539,7 @@ def test_read_rows_flattened(): def test_read_rows_flattened_error(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1141,7 +1555,7 @@ def test_read_rows_flattened_error(): @pytest.mark.asyncio async def test_read_rows_flattened_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1172,7 +1586,7 @@ async def test_read_rows_flattened_async(): @pytest.mark.asyncio async def test_read_rows_flattened_error_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1194,7 +1608,7 @@ async def test_read_rows_flattened_error_async(): ) def test_split_read_stream(request_type, transport: str = "grpc"): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1223,7 +1637,7 @@ def test_split_read_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1242,7 +1656,7 @@ async def test_split_read_stream_async( transport: str = "grpc_asyncio", request_type=storage.SplitReadStreamRequest ): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1276,7 +1690,7 @@ async def test_split_read_stream_async_from_dict(): def test_split_read_stream_field_headers(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1308,7 +1722,7 @@ def test_split_read_stream_field_headers(): @pytest.mark.asyncio async def test_split_read_stream_field_headers_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1342,17 +1756,17 @@ async def test_split_read_stream_field_headers_async(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( @@ -1362,7 +1776,7 @@ def test_credentials_transport_error(): # It is an error to provide an api_key and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) options = client_options.ClientOptions() options.api_key = "api_key" @@ -1373,16 +1787,17 @@ def test_credentials_transport_error(): ) # It is an error to provide an api_key and a credential. - options = mock.Mock() + options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): client = BigQueryReadClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # It is an error to provide scopes and a transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryReadClient( @@ -1394,7 +1809,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) client = BigQueryReadClient(transport=transport) assert client.transport is transport @@ -1403,13 +1818,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryReadGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel transport = transports.BigQueryReadGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel @@ -1425,7 +1840,7 @@ def test_transport_get_channel(): def test_transport_adc(transport_class): # Test default credentials are used if not provided. with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class() adc.assert_called_once() @@ -1438,7 +1853,7 @@ def test_transport_adc(transport_class): ) def test_transport_kind(transport_name): transport = BigQueryReadClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert transport.kind == transport_name @@ -1446,7 +1861,7 @@ def test_transport_kind(transport_name): def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert isinstance( client.transport, @@ -1458,7 +1873,7 @@ def test_big_query_read_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.BigQueryReadTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), credentials_file="credentials.json", ) @@ -1470,7 +1885,7 @@ def test_big_query_read_base_transport(): ) as Transport: Transport.return_value = None transport = transports.BigQueryReadTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Every method on the transport should just blindly @@ -1504,7 +1919,7 @@ def test_big_query_read_base_transport_with_credentials_file(): "google.cloud.bigquery_storage_v1beta2.services.big_query_read.transports.BigQueryReadTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + load_creds.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryReadTransport( credentials_file="credentials.json", quota_project_id="octopus", @@ -1526,7 +1941,7 @@ def test_big_query_read_base_transport_with_adc(): "google.cloud.bigquery_storage_v1beta2.services.big_query_read.transports.BigQueryReadTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryReadTransport() adc.assert_called_once() @@ -1534,7 +1949,7 @@ def test_big_query_read_base_transport_with_adc(): def test_big_query_read_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) BigQueryReadClient() adc.assert_called_once_with( scopes=None, @@ -1557,7 +1972,7 @@ def test_big_query_read_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( scopes=["1", "2"], @@ -1606,7 +2021,7 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() adc.return_value = (creds, None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) @@ -1634,7 +2049,7 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): [transports.BigQueryReadGrpcTransport, transports.BigQueryReadGrpcAsyncIOTransport], ) def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_class): - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() # Check ssl_channel_credentials is used if provided. with mock.patch.object(transport_class, "create_channel") as mock_create_channel: @@ -1680,7 +2095,7 @@ def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_cla ) def test_big_query_read_host_no_port(transport_name): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com" ), @@ -1698,7 +2113,7 @@ def test_big_query_read_host_no_port(transport_name): ) def test_big_query_read_host_with_port(transport_name): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com:8000" ), @@ -1752,7 +2167,7 @@ def test_big_query_read_transport_channel_mtls_with_client_cert_source(transport mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() with pytest.warns(DeprecationWarning): with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) @@ -2016,7 +2431,7 @@ def test_client_with_default_client_info(): transports.BigQueryReadTransport, "_prep_wrapped_messages" ) as prep: client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2026,7 +2441,7 @@ def test_client_with_default_client_info(): ) as prep: transport_class = BigQueryReadClient.get_transport_class() transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2035,7 +2450,7 @@ def test_client_with_default_client_info(): @pytest.mark.asyncio async def test_transport_close_async(): client = BigQueryReadAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc_asyncio", ) with mock.patch.object( @@ -2053,7 +2468,7 @@ def test_transport_close(): for transport, close_name in transports.items(): client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) with mock.patch.object( type(getattr(client.transport, close_name)), "close" @@ -2069,7 +2484,7 @@ def test_client_ctx(): ] for transport in transports: client = BigQueryReadClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -2100,7 +2515,9 @@ def test_api_key_credentials(client_class, transport_class): patched.assert_called_once_with( credentials=mock_cred, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, diff --git a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py index a3af181a..d729506c 100644 --- a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py +++ b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py @@ -26,6 +26,7 @@ from grpc.experimental import aio import math import pytest +from google.api_core import api_core_version from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers @@ -71,6 +72,29 @@ def modify_default_endpoint(client): ) +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +# Anonymous Credentials with universe domain property. If no universe domain is provided, then +# the default universe domain is "googleapis.com". +class _AnonymousCredentialsWithUniverseDomain(ga_credentials.AnonymousCredentials): + def __init__(self, universe_domain="googleapis.com"): + super(_AnonymousCredentialsWithUniverseDomain, self).__init__() + self._universe_domain = universe_domain + + @property + def universe_domain(self): + return self._universe_domain + + def test__get_default_mtls_endpoint(): api_endpoint = "example.googleapis.com" api_mtls_endpoint = "example.mtls.googleapis.com" @@ -100,6 +124,262 @@ def test__get_default_mtls_endpoint(): ) +def test__read_environment_variables(): + assert BigQueryWriteClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert BigQueryWriteClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + BigQueryWriteClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryWriteClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert BigQueryWriteClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert BigQueryWriteClient._get_client_cert_source(None, False) is None + assert ( + BigQueryWriteClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + BigQueryWriteClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + BigQueryWriteClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + BigQueryWriteClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + BigQueryWriteClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), +) +@mock.patch.object( + BigQueryWriteAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + BigQueryWriteClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "always") + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == BigQueryWriteClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + BigQueryWriteClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + BigQueryWriteClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + BigQueryWriteClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + BigQueryWriteClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + BigQueryWriteClient._get_universe_domain(None, None) + == BigQueryWriteClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + BigQueryWriteClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (BigQueryWriteClient, transports.BigQueryWriteGrpcTransport, "grpc"), + ], +) +def test__validate_universe_domain(client_class, transport_class, transport_name): + client = client_class( + transport=transport_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + ) + assert client._validate_universe_domain() == True + + # Test the case when universe is already validated. + assert client._validate_universe_domain() == True + + if transport_name == "grpc": + # Test the case where credentials are provided by the + # `local_channel_credentials`. The default universes in both match. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + client = client_class(transport=transport_class(channel=channel)) + assert client._validate_universe_domain() == True + + # Test the case where credentials do not exist: e.g. a transport is provided + # with no credentials. Validation should still succeed because there is no + # mismatch with non-existent credentials. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + transport = transport_class(channel=channel) + transport._credentials = None + client = client_class(transport=transport) + assert client._validate_universe_domain() == True + + # Test the case when there is a universe mismatch from the credentials. + client = client_class( + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain( + universe_domain="foo.com" + ) + ) + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test the case when there is a universe mismatch from the client. + # + # TODO: Make this test unconditional once the minimum supported version of + # google-api-core becomes 2.15.0 or higher. + api_core_major, api_core_minor, _ = [ + int(part) for part in api_core_version.__version__.split(".") + ] + if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15): + client = client_class( + client_options={"universe_domain": "bar.com"}, + transport=transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain(), + ), + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + @pytest.mark.parametrize( "client_class,transport_name", [ @@ -108,7 +388,7 @@ def test__get_default_mtls_endpoint(): ], ) def test_big_query_write_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: @@ -154,7 +434,7 @@ def test_big_query_write_client_service_account_always_use_jwt( ], ) def test_big_query_write_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: @@ -198,20 +478,22 @@ def test_big_query_write_client_get_transport_class(): ) @mock.patch.object( BigQueryWriteClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), ) @mock.patch.object( BigQueryWriteAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), ) def test_big_query_write_client_client_options( client_class, transport_class, transport_name ): # Check that if channel is provided we won't create a new one. with mock.patch.object(BigQueryWriteClient, "get_transport_class") as gtc: - transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + transport = transport_class( + credentials=_AnonymousCredentialsWithUniverseDomain() + ) client = client_class(transport=transport) gtc.assert_not_called() @@ -246,7 +528,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -276,15 +560,23 @@ def test_big_query_write_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has # unsupported value. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): + with pytest.raises(MutualTLSChannelError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as excinfo: client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") @@ -294,7 +586,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id="octopus", @@ -312,7 +606,9 @@ def test_big_query_write_client_client_options( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -343,13 +639,13 @@ def test_big_query_write_client_client_options( ) @mock.patch.object( BigQueryWriteClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), ) @mock.patch.object( BigQueryWriteAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(BigQueryWriteAsyncClient), + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), ) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) def test_big_query_write_client_mtls_env_auto( @@ -372,7 +668,9 @@ def test_big_query_write_client_mtls_env_auto( if use_client_cert_env == "false": expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) else: expected_client_cert_source = client_cert_source_callback expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -404,7 +702,9 @@ def test_big_query_write_client_mtls_env_auto( return_value=client_cert_source_callback, ): if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) expected_client_cert_source = None else: expected_host = client.DEFAULT_MTLS_ENDPOINT @@ -438,7 +738,9 @@ def test_big_query_write_client_mtls_env_auto( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -528,6 +830,118 @@ def test_big_query_write_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [BigQueryWriteClient, BigQueryWriteAsyncClient] +) +@mock.patch.object( + BigQueryWriteClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteClient), +) +@mock.patch.object( + BigQueryWriteAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(BigQueryWriteAsyncClient), +) +def test_big_query_write_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = BigQueryWriteClient._DEFAULT_UNIVERSE + default_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = BigQueryWriteClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=_AnonymousCredentialsWithUniverseDomain()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + else: + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), + ) + assert client.api_endpoint == default_endpoint + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -553,7 +967,9 @@ def test_big_query_write_client_client_options_scopes( patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=["1", "2"], client_cert_source_for_mtls=None, quota_project_id=None, @@ -592,7 +1008,9 @@ def test_big_query_write_client_client_options_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -652,7 +1070,9 @@ def test_big_query_write_client_create_channel_credentials_file( patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, @@ -669,8 +1089,8 @@ def test_big_query_write_client_create_channel_credentials_file( ) as adc, mock.patch.object( grpc_helpers, "create_channel" ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() + file_creds = _AnonymousCredentialsWithUniverseDomain() load_creds.return_value = (file_creds, None) adc.return_value = (creds, None) client = client_class(client_options=options, transport=transport_name) @@ -703,7 +1123,7 @@ def test_big_query_write_client_create_channel_credentials_file( ) def test_create_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -737,7 +1157,7 @@ def test_create_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -756,7 +1176,7 @@ async def test_create_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.CreateWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -795,7 +1215,7 @@ async def test_create_write_stream_async_from_dict(): def test_create_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -827,7 +1247,7 @@ def test_create_write_stream_field_headers(): @pytest.mark.asyncio async def test_create_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -858,7 +1278,7 @@ async def test_create_write_stream_field_headers_async(): def test_create_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -888,7 +1308,7 @@ def test_create_write_stream_flattened(): def test_create_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -904,7 +1324,7 @@ def test_create_write_stream_flattened_error(): @pytest.mark.asyncio async def test_create_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -937,7 +1357,7 @@ async def test_create_write_stream_flattened_async(): @pytest.mark.asyncio async def test_create_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -959,7 +1379,7 @@ async def test_create_write_stream_flattened_error_async(): ) def test_append_rows(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -989,7 +1409,7 @@ async def test_append_rows_async( transport: str = "grpc_asyncio", request_type=storage.AppendRowsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1031,7 +1451,7 @@ async def test_append_rows_async_from_dict(): ) def test_get_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1063,7 +1483,7 @@ def test_get_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1080,7 +1500,7 @@ async def test_get_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.GetWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1117,7 +1537,7 @@ async def test_get_write_stream_async_from_dict(): def test_get_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1147,7 +1567,7 @@ def test_get_write_stream_field_headers(): @pytest.mark.asyncio async def test_get_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1176,7 +1596,7 @@ async def test_get_write_stream_field_headers_async(): def test_get_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1200,7 +1620,7 @@ def test_get_write_stream_flattened(): def test_get_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1215,7 +1635,7 @@ def test_get_write_stream_flattened_error(): @pytest.mark.asyncio async def test_get_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1242,7 +1662,7 @@ async def test_get_write_stream_flattened_async(): @pytest.mark.asyncio async def test_get_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1263,7 +1683,7 @@ async def test_get_write_stream_flattened_error_async(): ) def test_finalize_write_stream(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1295,7 +1715,7 @@ def test_finalize_write_stream_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1314,7 +1734,7 @@ async def test_finalize_write_stream_async( transport: str = "grpc_asyncio", request_type=storage.FinalizeWriteStreamRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1351,7 +1771,7 @@ async def test_finalize_write_stream_async_from_dict(): def test_finalize_write_stream_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1383,7 +1803,7 @@ def test_finalize_write_stream_field_headers(): @pytest.mark.asyncio async def test_finalize_write_stream_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1416,7 +1836,7 @@ async def test_finalize_write_stream_field_headers_async(): def test_finalize_write_stream_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1442,7 +1862,7 @@ def test_finalize_write_stream_flattened(): def test_finalize_write_stream_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1457,7 +1877,7 @@ def test_finalize_write_stream_flattened_error(): @pytest.mark.asyncio async def test_finalize_write_stream_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1488,7 +1908,7 @@ async def test_finalize_write_stream_flattened_async(): @pytest.mark.asyncio async def test_finalize_write_stream_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1509,7 +1929,7 @@ async def test_finalize_write_stream_flattened_error_async(): ) def test_batch_commit_write_streams(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1538,7 +1958,7 @@ def test_batch_commit_write_streams_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1557,7 +1977,7 @@ async def test_batch_commit_write_streams_async( transport: str = "grpc_asyncio", request_type=storage.BatchCommitWriteStreamsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1591,7 +2011,7 @@ async def test_batch_commit_write_streams_async_from_dict(): def test_batch_commit_write_streams_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1623,7 +2043,7 @@ def test_batch_commit_write_streams_field_headers(): @pytest.mark.asyncio async def test_batch_commit_write_streams_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1656,7 +2076,7 @@ async def test_batch_commit_write_streams_field_headers_async(): def test_batch_commit_write_streams_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1682,7 +2102,7 @@ def test_batch_commit_write_streams_flattened(): def test_batch_commit_write_streams_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1697,7 +2117,7 @@ def test_batch_commit_write_streams_flattened_error(): @pytest.mark.asyncio async def test_batch_commit_write_streams_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1728,7 +2148,7 @@ async def test_batch_commit_write_streams_flattened_async(): @pytest.mark.asyncio async def test_batch_commit_write_streams_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1749,7 +2169,7 @@ async def test_batch_commit_write_streams_flattened_error_async(): ) def test_flush_rows(request_type, transport: str = "grpc"): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1779,7 +2199,7 @@ def test_flush_rows_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc", ) @@ -1796,7 +2216,7 @@ async def test_flush_rows_async( transport: str = "grpc_asyncio", request_type=storage.FlushRowsRequest ): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) @@ -1831,7 +2251,7 @@ async def test_flush_rows_async_from_dict(): def test_flush_rows_field_headers(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1861,7 +2281,7 @@ def test_flush_rows_field_headers(): @pytest.mark.asyncio async def test_flush_rows_field_headers_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Any value that is part of the HTTP/1.1 URI should be sent as @@ -1892,7 +2312,7 @@ async def test_flush_rows_field_headers_async(): def test_flush_rows_flattened(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1916,7 +2336,7 @@ def test_flush_rows_flattened(): def test_flush_rows_flattened_error(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1931,7 +2351,7 @@ def test_flush_rows_flattened_error(): @pytest.mark.asyncio async def test_flush_rows_flattened_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Mock the actual call within the gRPC stub, and fake the request. @@ -1960,7 +2380,7 @@ async def test_flush_rows_flattened_async(): @pytest.mark.asyncio async def test_flush_rows_flattened_error_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Attempting to call a method with both a request object and flattened @@ -1975,17 +2395,17 @@ async def test_flush_rows_flattened_error_async(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( @@ -1995,7 +2415,7 @@ def test_credentials_transport_error(): # It is an error to provide an api_key and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) options = client_options.ClientOptions() options.api_key = "api_key" @@ -2006,16 +2426,17 @@ def test_credentials_transport_error(): ) # It is an error to provide an api_key and a credential. - options = mock.Mock() + options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): client = BigQueryWriteClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() + client_options=options, + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # It is an error to provide scopes and a transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) with pytest.raises(ValueError): client = BigQueryWriteClient( @@ -2027,7 +2448,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) client = BigQueryWriteClient(transport=transport) assert client.transport is transport @@ -2036,13 +2457,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.BigQueryWriteGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel transport = transports.BigQueryWriteGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) channel = transport.grpc_channel assert channel @@ -2058,7 +2479,7 @@ def test_transport_get_channel(): def test_transport_adc(transport_class): # Test default credentials are used if not provided. with mock.patch.object(google.auth, "default") as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class() adc.assert_called_once() @@ -2071,7 +2492,7 @@ def test_transport_adc(transport_class): ) def test_transport_kind(transport_name): transport = BigQueryWriteClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert transport.kind == transport_name @@ -2079,7 +2500,7 @@ def test_transport_kind(transport_name): def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) assert isinstance( client.transport, @@ -2091,7 +2512,7 @@ def test_big_query_write_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.BigQueryWriteTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), credentials_file="credentials.json", ) @@ -2103,7 +2524,7 @@ def test_big_query_write_base_transport(): ) as Transport: Transport.return_value = None transport = transports.BigQueryWriteTransport( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), ) # Every method on the transport should just blindly @@ -2140,7 +2561,7 @@ def test_big_query_write_base_transport_with_credentials_file(): "google.cloud.bigquery_storage_v1beta2.services.big_query_write.transports.BigQueryWriteTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + load_creds.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryWriteTransport( credentials_file="credentials.json", quota_project_id="octopus", @@ -2163,7 +2584,7 @@ def test_big_query_write_base_transport_with_adc(): "google.cloud.bigquery_storage_v1beta2.services.big_query_write.transports.BigQueryWriteTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport = transports.BigQueryWriteTransport() adc.assert_called_once() @@ -2171,7 +2592,7 @@ def test_big_query_write_base_transport_with_adc(): def test_big_query_write_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) BigQueryWriteClient() adc.assert_called_once_with( scopes=None, @@ -2195,7 +2616,7 @@ def test_big_query_write_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) + adc.return_value = (_AnonymousCredentialsWithUniverseDomain(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( scopes=["1", "2"], @@ -2245,7 +2666,7 @@ def test_big_query_write_transport_create_channel(transport_class, grpc_helpers) ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: - creds = ga_credentials.AnonymousCredentials() + creds = _AnonymousCredentialsWithUniverseDomain() adc.return_value = (creds, None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) @@ -2277,7 +2698,7 @@ def test_big_query_write_transport_create_channel(transport_class, grpc_helpers) ], ) def test_big_query_write_grpc_transport_client_cert_source_for_mtls(transport_class): - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() # Check ssl_channel_credentials is used if provided. with mock.patch.object(transport_class, "create_channel") as mock_create_channel: @@ -2323,7 +2744,7 @@ def test_big_query_write_grpc_transport_client_cert_source_for_mtls(transport_cl ) def test_big_query_write_host_no_port(transport_name): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com" ), @@ -2341,7 +2762,7 @@ def test_big_query_write_host_no_port(transport_name): ) def test_big_query_write_host_with_port(transport_name): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_options=client_options.ClientOptions( api_endpoint="bigquerystorage.googleapis.com:8000" ), @@ -2400,7 +2821,7 @@ def test_big_query_write_transport_channel_mtls_with_client_cert_source( mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = ga_credentials.AnonymousCredentials() + cred = _AnonymousCredentialsWithUniverseDomain() with pytest.warns(DeprecationWarning): with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) @@ -2643,7 +3064,7 @@ def test_client_with_default_client_info(): transports.BigQueryWriteTransport, "_prep_wrapped_messages" ) as prep: client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2653,7 +3074,7 @@ def test_client_with_default_client_info(): ) as prep: transport_class = BigQueryWriteClient.get_transport_class() transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -2662,7 +3083,7 @@ def test_client_with_default_client_info(): @pytest.mark.asyncio async def test_transport_close_async(): client = BigQueryWriteAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=_AnonymousCredentialsWithUniverseDomain(), transport="grpc_asyncio", ) with mock.patch.object( @@ -2680,7 +3101,7 @@ def test_transport_close(): for transport, close_name in transports.items(): client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) with mock.patch.object( type(getattr(client.transport, close_name)), "close" @@ -2696,7 +3117,7 @@ def test_client_ctx(): ] for transport in transports: client = BigQueryWriteClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport + credentials=_AnonymousCredentialsWithUniverseDomain(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -2727,7 +3148,9 @@ def test_api_key_credentials(client_class, transport_class): patched.assert_called_once_with( credentials=mock_cred, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), scopes=None, client_cert_source_for_mtls=None, quota_project_id=None, diff --git a/tests/unit/test_client_v1.py b/tests/unit/test_client_v1.py index 63f9f086..8ee701c5 100644 --- a/tests/unit/test_client_v1.py +++ b/tests/unit/test_client_v1.py @@ -43,6 +43,9 @@ def mock_transport(monkeypatch): transport.read_rows: fake_read_rows_rpc, } + # _credentials property for TPC support + transport._credentials = "" + return transport