diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 01aefe125e..16a8249558 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -69,7 +69,9 @@ def _set_project_as_env_var_or_google_auth_default(self): # See https://github.com/googleapis/google-auth-library-python/issues/924 # TODO: Remove when google.auth.default() learns the # CLOUD_ML_PROJECT_ID env variable or Vertex AI starts setting GOOGLE_CLOUD_PROJECT env variable. - project_number = os.environ.get("CLOUD_ML_PROJECT_ID") + project_number = os.environ.get("GOOGLE_CLOUD_PROJECT") or os.environ.get( + "CLOUD_ML_PROJECT_ID" + ) if project_number: if not self._credentials: credentials, _ = google.auth.default() @@ -312,7 +314,7 @@ def location(self) -> str: if self._location: return self._location - location = os.getenv("CLOUD_ML_REGION") + location = os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("CLOUD_ML_REGION") if location: utils.validate_region(location) return location diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 9fa1c13b56..d3993014f0 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -86,6 +86,28 @@ def mock_get_project_id(project_number: str, **_): ): assert initializer.global_config.project == _TEST_PROJECT + def test_infer_project_id_with_precedence(self): + lower_precedence_cloud_project_number = "456" + higher_precedence_cloud_project_number = "123" + + def mock_get_project_id(project_number: str, **_): + assert project_number == higher_precedence_cloud_project_number + return _TEST_PROJECT + + with mock.patch.object( + target=resource_manager_utils, + attribute="get_project_id", + new=mock_get_project_id, + ), mock.patch.dict( + os.environ, + { + "GOOGLE_CLOUD_PROJECT": higher_precedence_cloud_project_number, + "CLOUD_ML_PROJECT_ID": lower_precedence_cloud_project_number, + }, + clear=True, + ): + assert initializer.global_config.project == _TEST_PROJECT + def test_init_location_sets_location(self): initializer.global_config.init(location=_TEST_LOCATION) assert initializer.global_config.location == _TEST_LOCATION