diff --git a/python/ee/__init__.py b/python/ee/__init__.py index 05736ecb0..8482358ad 100644 --- a/python/ee/__init__.py +++ b/python/ee/__init__.py @@ -69,6 +69,16 @@ NO_PROJECT_EXCEPTION = ('ee.Initialize: no project found. Call with project=' ' or see http://goo.gle/ee-auth.') +# Environment variables used to set the project ID. GOOGLE_CLOUD_PROJECT so that +# we interoperate with other Cloud libraries in the common case. EE_PROJECT_ID +# is a more specific value so it should take precedence if both values are +# present. See the following for more details: +# https://google-auth.readthedocs.io/en/master/reference/google.auth.environment_vars.html#google.auth.environment_vars.PROJECT. +_PROJECT_ENV_VARS = [ + 'EE_PROJECT_ID', + 'GOOGLE_CLOUD_PROJECT', +] + class _AlgorithmsContainer(dict): """A lightweight class that is used as a dictionary with dot notation.""" @@ -179,12 +189,17 @@ def Initialize( url: The base url for the EarthEngine REST API to connect to. cloud_api_key: An optional API key to use the Cloud API. http_transport: The http transport method to use when making requests. - project: The client project ID or number to use when making API calls. + project: The client project ID or number to use when making API calls. If + None, project is inferred from credentials or environment variables. """ if credentials == 'persistent': credentials = data.get_persistent_credentials() if not project and credentials and hasattr(credentials, 'quota_project_id'): project = credentials.quota_project_id + if not project: + for env_var in _PROJECT_ENV_VARS: + if project := _utils.get_environment_variable(env_var): + break # SDK credentials are not authorized for EE so a project must be given. if not project and oauth.is_sdk_credentials(credentials): raise EEException(NO_PROJECT_EXCEPTION) diff --git a/python/ee/_utils.py b/python/ee/_utils.py index 260771284..99f0d36bf 100644 --- a/python/ee/_utils.py +++ b/python/ee/_utils.py @@ -1,9 +1,18 @@ """General decorators and helper methods which should not import ee.""" import functools +import os from typing import Any, Callable +# Optional imports used for specific shells. +# pylint: disable=g-import-not-at-top +try: + import IPython +except ImportError: + pass + + def accept_opt_prefix(*opt_args) -> Callable[..., Any]: """Decorator to maintain support for "opt_" prefixed kwargs. @@ -40,3 +49,54 @@ def wrapper(*args, **kwargs): return wrapper return opt_fixed + + +def in_colab_shell() -> bool: + """Tests if the code is being executed within Google Colab.""" + try: + import google.colab # pylint: disable=unused-import,redefined-outer-name + + return True + except ImportError: + return False + + +def in_jupyter_shell() -> bool: + """Tests if the code is being executed within Jupyter.""" + try: + import ipykernel.zmqshell + + return isinstance( + IPython.get_ipython(), ipykernel.zmqshell.ZMQInteractiveShell + ) + except ImportError: + return False + except NameError: + return False + + +def get_environment_variable(key: str) -> Any: + """Retrieves a Colab secret or environment variable for the given key. + + Colab secrets have precedence over environment variables. + + Args: + key (str): The key that's used to fetch the environment variable. + + Returns: + Optional[str]: The retrieved key, or None if no environment variable was + found. + """ + if in_colab_shell(): + from google.colab import userdata # pylint: disable=g-import-not-at-top + + try: + return userdata.get(key) + except ( + userdata.SecretNotFoundError, + userdata.NotebookAccessError, + AttributeError, + ): + pass + + return os.environ.get(key) diff --git a/python/ee/cli/commands.py b/python/ee/cli/commands.py index 30c66702c..ee6afffd8 100644 --- a/python/ee/cli/commands.py +++ b/python/ee/cli/commands.py @@ -409,7 +409,7 @@ def run( if args.scopes: args_auth['scopes'] = args.scopes.split(',') - if ee.oauth.in_colab_shell(): + if ee._utils.in_colab_shell(): # pylint: disable=protected-access print( 'Authenticate: Limited support in Colab. Use ee.Authenticate()' ' or --auth_mode=notebook instead.' diff --git a/python/ee/oauth.py b/python/ee/oauth.py index 908cd935e..9d071c745 100644 --- a/python/ee/oauth.py +++ b/python/ee/oauth.py @@ -27,6 +27,7 @@ from google.auth import _cloud_sdk import google.auth.transport.requests +from ee import _utils from ee import data as ee_data from ee import ee_exception @@ -205,27 +206,6 @@ def write_private_json(json_path: str, info_dict: Dict[str, Any]) -> None: f.write(file_content) -def in_colab_shell() -> bool: - """Tests if the code is being executed within Google Colab.""" - try: - import google.colab # pylint: disable=unused-import,redefined-outer-name - return True - except ImportError: - return False - - -def _in_jupyter_shell() -> bool: - """Tests if the code is being executed within Jupyter.""" - try: - import ipykernel.zmqshell - return isinstance(IPython.get_ipython(), - ipykernel.zmqshell.ZMQInteractiveShell) - except ImportError: - return False - except NameError: - return False - - def _project_number_from_client_id(client_id: Optional[str]) -> Optional[str]: """Returns the project number associated with the given OAuth client ID.""" # Client IDs are of the form: @@ -507,9 +487,9 @@ def authenticate( return True if not auth_mode: - if in_colab_shell(): + if _utils.in_colab_shell(): auth_mode = 'colab' - elif _in_jupyter_shell(): + elif _utils.in_jupyter_shell(): auth_mode = 'notebook' elif _localhost_is_viable() and _no_gcloud(): auth_mode = 'localhost' @@ -596,9 +576,9 @@ def display_instructions(self, quiet: Optional[bool] = None) -> bool: return True coda = WAITING_CODA if self.server else None - if in_colab_shell(): + if _utils.in_colab_shell(): _display_auth_instructions_with_print(self.auth_url, coda) - elif _in_jupyter_shell(): + elif _utils.in_jupyter_shell(): _display_auth_instructions_with_html(self.auth_url, coda) else: _display_auth_instructions_with_print(self.auth_url, coda) diff --git a/python/ee/tests/_utils_test.py b/python/ee/tests/_utils_test.py index 1f6eace22..0bebc5f01 100644 --- a/python/ee/tests/_utils_test.py +++ b/python/ee/tests/_utils_test.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 """Tests for _utils decorators.""" +import sys +from unittest import mock + import unittest from ee import _utils @@ -84,6 +87,13 @@ def test_function(arg1=None, arg2_=None): # pylint: enable=unexpected-keyword-arg # pytype: enable=wrong-keyword-args + def test_in_colab_shell(self): + with mock.patch.dict(sys.modules, {'google.colab': None}): + self.assertFalse(_utils.in_colab_shell()) + + with mock.patch.dict(sys.modules, {'google.colab': mock.MagicMock()}): + self.assertTrue(_utils.in_colab_shell()) + if __name__ == '__main__': unittest.main() diff --git a/python/ee/tests/ee_test.py b/python/ee/tests/ee_test.py index 2c3379a82..8e63df4a7 100644 --- a/python/ee/tests/ee_test.py +++ b/python/ee/tests/ee_test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 """Test for the ee.__init__ file.""" +import os from unittest import mock import google.auth @@ -87,8 +88,22 @@ def CheckDataInit(**kwargs): google_creds = google_creds.with_quota_project(None) expected_project = None ee.Initialize() - self.assertEqual(5, inits.call_count) + expected_project = 'qp3' + with mock.patch.dict( + os.environ, + {'EE_PROJECT_ID': expected_project, 'GOOGLE_CLOUD_PROJECT': 'qp4'}, + ): + ee.Initialize() + + expected_project = 'qp4' + with mock.patch.dict( + os.environ, {'GOOGLE_CLOUD_PROJECT': expected_project} + ): + ee.Initialize() + self.assertEqual(7, inits.call_count) + + expected_project = None msg = 'Earth Engine API has not been used in project 764086051850 before' with moc(ee.ApiFunction, 'initialize', side_effect=ee.EEException(msg)): with self.assertRaisesRegex(ee.EEException, '.*no project found..*'): @@ -98,7 +113,7 @@ def CheckDataInit(**kwargs): cred_args['refresh_token'] = 'rt' with self.assertRaisesRegex(ee.EEException, '.*no project found..*'): ee.Initialize() - self.assertEqual(6, inits.call_count) + self.assertEqual(8, inits.call_count) def testCallAndApply(self): """Verifies library initialization.""" diff --git a/python/ee/tests/oauth_test.py b/python/ee/tests/oauth_test.py index c7577eb28..2ac16c15c 100644 --- a/python/ee/tests/oauth_test.py +++ b/python/ee/tests/oauth_test.py @@ -54,13 +54,6 @@ def mock_credentials_path(): token = json.load(f) self.assertEqual({'refresh_token': '123'}, token) - def test_in_colab_shell(self): - with mock.patch.dict(sys.modules, {'google.colab': None}): - self.assertFalse(oauth.in_colab_shell()) - - with mock.patch.dict(sys.modules, {'google.colab': mock.MagicMock()}): - self.assertTrue(oauth.in_colab_shell()) - def test_is_sdk_credentials(self): sdk_project = oauth.SDK_PROJECTS[0] self.assertFalse(oauth.is_sdk_credentials(None))