Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read GCP project from an environment variable. #445

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/ee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@
NO_PROJECT_EXCEPTION = ('ee.Initialize: no project found. Call with project='
' or see http://goo.gle/ee-auth.')

# Environment variables used to set the project ID. GOOGLE_CLOUD_PROJECT so that
# we interoperate with other Cloud libraries in the common case. EE_PROJECT_ID
# is a more specific value so it should take precedence if both values are
# present. See the following for more details:
# https://google-auth.readthedocs.io/en/master/reference/google.auth.environment_vars.html#google.auth.environment_vars.PROJECT.
_PROJECT_ENV_VARS = [
'EE_PROJECT_ID',
'GOOGLE_CLOUD_PROJECT',
]


class _AlgorithmsContainer(dict):
"""A lightweight class that is used as a dictionary with dot notation."""
Expand Down Expand Up @@ -179,12 +189,17 @@ def Initialize(
url: The base url for the EarthEngine REST API to connect to.
cloud_api_key: An optional API key to use the Cloud API.
http_transport: The http transport method to use when making requests.
project: The client project ID or number to use when making API calls.
project: The client project ID or number to use when making API calls. If
None, project is inferred from credentials or environment variables.
"""
if credentials == 'persistent':
credentials = data.get_persistent_credentials()
if not project and credentials and hasattr(credentials, 'quota_project_id'):
project = credentials.quota_project_id
if not project:
for env_var in _PROJECT_ENV_VARS:
if project := _utils.get_environment_variable(env_var):
break
# SDK credentials are not authorized for EE so a project must be given.
if not project and oauth.is_sdk_credentials(credentials):
raise EEException(NO_PROJECT_EXCEPTION)
Expand Down
60 changes: 60 additions & 0 deletions python/ee/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""General decorators and helper methods which should not import ee."""

import functools
import os
from typing import Any, Callable


# Optional imports used for specific shells.
# pylint: disable=g-import-not-at-top
try:
import IPython
except ImportError:
pass


def accept_opt_prefix(*opt_args) -> Callable[..., Any]:
"""Decorator to maintain support for "opt_" prefixed kwargs.

Expand Down Expand Up @@ -40,3 +49,54 @@ def wrapper(*args, **kwargs):
return wrapper

return opt_fixed


def in_colab_shell() -> bool:
"""Tests if the code is being executed within Google Colab."""
try:
import google.colab # pylint: disable=unused-import,redefined-outer-name

return True
except ImportError:
return False


def in_jupyter_shell() -> bool:
"""Tests if the code is being executed within Jupyter."""
try:
import ipykernel.zmqshell

return isinstance(
IPython.get_ipython(), ipykernel.zmqshell.ZMQInteractiveShell
)
except ImportError:
return False
except NameError:
return False


def get_environment_variable(key: str) -> Any:
"""Retrieves a Colab secret or environment variable for the given key.

Colab secrets have precedence over environment variables.

Args:
key (str): The key that's used to fetch the environment variable.

Returns:
Optional[str]: The retrieved key, or None if no environment variable was
found.
"""
if in_colab_shell():
from google.colab import userdata # pylint: disable=g-import-not-at-top

try:
return userdata.get(key)
except (
userdata.SecretNotFoundError,
userdata.NotebookAccessError,
AttributeError,
):
pass

return os.environ.get(key)
2 changes: 1 addition & 1 deletion python/ee/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def run(
if args.scopes:
args_auth['scopes'] = args.scopes.split(',')

if ee.oauth.in_colab_shell():
if ee._utils.in_colab_shell(): # pylint: disable=protected-access
print(
'Authenticate: Limited support in Colab. Use ee.Authenticate()'
' or --auth_mode=notebook instead.'
Expand Down
30 changes: 5 additions & 25 deletions python/ee/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.auth import _cloud_sdk
import google.auth.transport.requests

from ee import _utils
from ee import data as ee_data
from ee import ee_exception

Expand Down Expand Up @@ -205,27 +206,6 @@ def write_private_json(json_path: str, info_dict: Dict[str, Any]) -> None:
f.write(file_content)


def in_colab_shell() -> bool:
"""Tests if the code is being executed within Google Colab."""
try:
import google.colab # pylint: disable=unused-import,redefined-outer-name
return True
except ImportError:
return False


def _in_jupyter_shell() -> bool:
"""Tests if the code is being executed within Jupyter."""
try:
import ipykernel.zmqshell
return isinstance(IPython.get_ipython(),
ipykernel.zmqshell.ZMQInteractiveShell)
except ImportError:
return False
except NameError:
return False


def _project_number_from_client_id(client_id: Optional[str]) -> Optional[str]:
"""Returns the project number associated with the given OAuth client ID."""
# Client IDs are of the form:
Expand Down Expand Up @@ -507,9 +487,9 @@ def authenticate(
return True

if not auth_mode:
if in_colab_shell():
if _utils.in_colab_shell():
auth_mode = 'colab'
elif _in_jupyter_shell():
elif _utils.in_jupyter_shell():
auth_mode = 'notebook'
elif _localhost_is_viable() and _no_gcloud():
auth_mode = 'localhost'
Expand Down Expand Up @@ -596,9 +576,9 @@ def display_instructions(self, quiet: Optional[bool] = None) -> bool:
return True

coda = WAITING_CODA if self.server else None
if in_colab_shell():
if _utils.in_colab_shell():
_display_auth_instructions_with_print(self.auth_url, coda)
elif _in_jupyter_shell():
elif _utils.in_jupyter_shell():
_display_auth_instructions_with_html(self.auth_url, coda)
else:
_display_auth_instructions_with_print(self.auth_url, coda)
Expand Down
10 changes: 10 additions & 0 deletions python/ee/tests/_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python3
"""Tests for _utils decorators."""

import sys
from unittest import mock

import unittest
from ee import _utils

Expand Down Expand Up @@ -84,6 +87,13 @@ def test_function(arg1=None, arg2_=None):
# pylint: enable=unexpected-keyword-arg
# pytype: enable=wrong-keyword-args

def test_in_colab_shell(self):
with mock.patch.dict(sys.modules, {'google.colab': None}):
self.assertFalse(_utils.in_colab_shell())

with mock.patch.dict(sys.modules, {'google.colab': mock.MagicMock()}):
self.assertTrue(_utils.in_colab_shell())


if __name__ == '__main__':
unittest.main()
19 changes: 17 additions & 2 deletions python/ee/tests/ee_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Test for the ee.__init__ file."""

import os
from unittest import mock

import google.auth
Expand Down Expand Up @@ -87,8 +88,22 @@ def CheckDataInit(**kwargs):
google_creds = google_creds.with_quota_project(None)
expected_project = None
ee.Initialize()
self.assertEqual(5, inits.call_count)

expected_project = 'qp3'
with mock.patch.dict(
os.environ,
{'EE_PROJECT_ID': expected_project, 'GOOGLE_CLOUD_PROJECT': 'qp4'},
):
ee.Initialize()

expected_project = 'qp4'
with mock.patch.dict(
os.environ, {'GOOGLE_CLOUD_PROJECT': expected_project}
):
ee.Initialize()
self.assertEqual(7, inits.call_count)

expected_project = None
msg = 'Earth Engine API has not been used in project 764086051850 before'
with moc(ee.ApiFunction, 'initialize', side_effect=ee.EEException(msg)):
with self.assertRaisesRegex(ee.EEException, '.*no project found..*'):
Expand All @@ -98,7 +113,7 @@ def CheckDataInit(**kwargs):
cred_args['refresh_token'] = 'rt'
with self.assertRaisesRegex(ee.EEException, '.*no project found..*'):
ee.Initialize()
self.assertEqual(6, inits.call_count)
self.assertEqual(8, inits.call_count)

def testCallAndApply(self):
"""Verifies library initialization."""
Expand Down
7 changes: 0 additions & 7 deletions python/ee/tests/oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ def mock_credentials_path():
token = json.load(f)
self.assertEqual({'refresh_token': '123'}, token)

def test_in_colab_shell(self):
with mock.patch.dict(sys.modules, {'google.colab': None}):
self.assertFalse(oauth.in_colab_shell())

with mock.patch.dict(sys.modules, {'google.colab': mock.MagicMock()}):
self.assertTrue(oauth.in_colab_shell())

def test_is_sdk_credentials(self):
sdk_project = oauth.SDK_PROJECTS[0]
self.assertFalse(oauth.is_sdk_credentials(None))
Expand Down
Loading