From 92f2b4e32035a35f5f2a4956fee443fe3061bc32 Mon Sep 17 00:00:00 2001 From: Sasha Sobran Date: Thu, 16 Nov 2023 18:07:11 -0800 Subject: [PATCH] feat: support user provided api endpoint. PiperOrigin-RevId: 583223550 --- google/cloud/aiplatform/initializer.py | 52 +++++++++++++++-------- tests/unit/aiplatform/test_initializer.py | 50 ++++++++++++++++++---- 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 33491277ac..53d5a9d759 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -103,6 +103,7 @@ def __init__(self): self._encryption_spec_key_name = None self._network = None self._service_account = None + self._api_endpoint = None def init( self, @@ -119,6 +120,7 @@ def init( encryption_spec_key_name: Optional[str] = None, network: Optional[str] = None, service_account: Optional[str] = None, + api_endpoint: Optional[str] = None, ): """Updates common initialization parameters with provided options. @@ -174,11 +176,17 @@ def init( PipelineJob, HyperparameterTuningJob, CustomTrainingJob, CustomPythonPackageTrainingJob, CustomContainerTrainingJob, ModelEvaluationJob. + api_endpoint (str): + Optional. The desired API endpoint, + e.g., us-central1-aiplatform.googleapis.com Raises: ValueError: If experiment_description is provided but experiment is not. """ + if api_endpoint is not None: + self._api_endpoint = api_endpoint + if experiment_description and experiment is None: raise ValueError( "Experiment needs to be set in `init` in order to add experiment descriptions." @@ -252,6 +260,11 @@ def get_encryption_spec( ) return encryption_spec + @property + def api_endpoint(self) -> Optional[str]: + """Default API endpoint, if provided.""" + return self._api_endpoint + @property def project(self) -> str: """Default project.""" @@ -351,27 +364,32 @@ def get_client_options( { "api_endpoint": "us-central1-aiplatform.googleapis.com" } or { "api_endpoint": "asia-east1-aiplatform.googleapis.com" } """ - if not (self.location or location_override): - raise ValueError( - "No location found. Provide or initialize SDK with a location." - ) - region = location_override or self.location - region = region.lower() + api_endpoint = self.api_endpoint - utils.validate_region(region) + if api_endpoint is None: + if not (self.location or location_override): + raise ValueError( + "No location found. Provide or initialize SDK with a location." + ) - service_base_path = api_base_path_override or ( - constants.PREDICTION_API_BASE_PATH - if prediction_client - else constants.API_BASE_PATH - ) + region = location_override or self.location + region = region.lower() + + utils.validate_region(region) + + service_base_path = api_base_path_override or ( + constants.PREDICTION_API_BASE_PATH + if prediction_client + else constants.API_BASE_PATH + ) + + api_endpoint = ( + f"{region}-{service_base_path}" + if not api_path_override + else api_path_override + ) - api_endpoint = ( - f"{region}-{service_base_path}" - if not api_path_override - else api_path_override - ) return client_options.ClientOptions(api_endpoint=api_endpoint) def common_location_path( diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 6995aae170..021a23c16b 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -17,19 +17,19 @@ import importlib import os -import pytest +from typing import Optional from unittest import mock from unittest.mock import patch +import pytest + import google.auth from google.auth import credentials - from google.cloud.aiplatform import initializer from google.cloud.aiplatform.metadata.metadata import _experiment_tracker from google.cloud.aiplatform.constants import base as constants from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import resource_manager_utils - from google.cloud.aiplatform.compat.services import ( model_service_client, ) @@ -307,30 +307,64 @@ def test_create_client_appended_user_agent(self): assert " " + appended_user_agent[0] in user_agent assert " " + appended_user_agent[1] in user_agent + def test_set_api_endpoint(self): + initializer.global_config.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + api_endpoint="test.googleapis.com", + ) + + assert initializer.global_config.api_endpoint == "test.googleapis.com" + + def test_not_set_api_endpoint(self): + initializer.global_config.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + assert initializer.global_config.api_endpoint is None + @pytest.mark.parametrize( - "init_location, location_override, expected_endpoint", + "init_location, location_override, api_endpoint, expected_endpoint", [ - ("us-central1", None, "us-central1-aiplatform.googleapis.com"), + ("us-central1", None, None, "us-central1-aiplatform.googleapis.com"), ( "us-central1", "europe-west4", + None, "europe-west4-aiplatform.googleapis.com", ), - ("asia-east1", None, "asia-east1-aiplatform.googleapis.com"), + ("asia-east1", None, None, "asia-east1-aiplatform.googleapis.com"), ( "asia-southeast1", "australia-southeast1", + None, "australia-southeast1-aiplatform.googleapis.com", ), + ( + "asia-east1", + None, + "us-central1-aiplatform.googleapis.com", + "us-central1-aiplatform.googleapis.com", + ), + ( + "us-central1", + None, + "test.aiplatform.googleapis.com", + "test.aiplatform.googleapis.com", + ), ], ) def test_get_client_options( self, init_location: str, - location_override: str, + location_override: Optional[str], + api_endpoint: Optional[str], expected_endpoint: str, ): - initializer.global_config.init(location=init_location) + initializer.global_config.init( + location=init_location, api_endpoint=api_endpoint + ) assert ( initializer.global_config.get_client_options(