Skip to content

Commit

Permalink
Add an option to define your own AWS client provider (#620)
Browse files Browse the repository at this point in the history
You can now specify a function that returns an AWS client. This is
useful if you want to use something other than boto3 and can be used
with the metaflow_custom mechanism to provide your own authentication
  • Loading branch information
romain-intel authored Jul 20, 2021
1 parent 54b1099 commit a11f7ea
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 50 deletions.
1 change: 1 addition & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def from_conf(name, default=None):
DEFAULT_METADATA = from_conf('METAFLOW_DEFAULT_METADATA', 'local')
DEFAULT_MONITOR = from_conf('METAFLOW_DEFAULT_MONITOR', 'nullSidecarMonitor')
DEFAULT_PACKAGE_SUFFIXES = from_conf('METAFLOW_DEFAULT_PACKAGE_SUFFIXES', '.py,.R,.RDS')
DEFAULT_AWS_CLIENT_PROVIDER = from_conf('METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER', 'boto3')


###
Expand Down
56 changes: 38 additions & 18 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import sys
import types

_expected_extensions = {
'FLOW_DECORATORS': [],
'STEP_DECORATORS': [],
'ENVIRONMENTS': [],
'METADATA_PROVIDERS': [],
'SIDECARS': {},
'LOGGING_SIDECARS': {},
'MONITOR_SIDECARS': {},
'AWS_CLIENT_PROVIDERS': [],
'get_plugin_cli': lambda : []
}

try:
import metaflow_custom.plugins as _ext_plugins
except ImportError as e:
Expand All @@ -16,20 +28,12 @@
"if you want to ignore, uninstall metaflow_custom package")
raise
class _fake(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def get_plugin_cli(self):
return []

_ext_plugins = _fake(
FLOW_DECORATORS=[],
STEP_DECORATORS=[],
ENVIRONMENTS=[],
METADATA_PROVIDERS=[],
SIDECARS={},
LOGGING_SIDECARS={},
MONITOR_SIDECARS={})
def __getattr__(self, name):
if name in _expected_extensions:
return _expected_extensions[name]
raise AttributeError

_ext_plugins = _fake()
else:
# We load into globals whatever we have in extension_module
# We specifically exclude any modules that may be included (like sys, os, etc)
Expand Down Expand Up @@ -61,6 +65,18 @@ def get_plugin_cli(self):
# This keeps it cleaner.
from metaflow import _LazyLoader
sys.meta_path = [_LazyLoader(lazy_load_custom_modules)] + sys.meta_path

class _wrap(object):
def __init__(self, obj):
self.__dict__ = obj.__dict__

def __getattr__(self, name):
if name in _expected_extensions:
return _expected_extensions[name]
raise AttributeError

_ext_plugins = _wrap(_ext_plugins)



def get_plugin_cli():
Expand Down Expand Up @@ -160,11 +176,15 @@ def _merge_lists(base, overrides, attr):
SIDECARS.update(LOGGING_SIDECARS)
SIDECARS.update(MONITOR_SIDECARS)

from .aws.aws_client import Boto3ClientProvider
AWS_CLIENT_PROVIDERS = _merge_lists(
[Boto3ClientProvider], _ext_plugins.AWS_CLIENT_PROVIDERS, 'name')

# Erase all temporary names to avoid leaking things
# We leave '_ext_plugins' because it is used in a function (so it needs
# to stick around)
for _n in ['ver', 'n', 'o', 'e', 'lazy_load_custom_modules',
'_LazyLoader', '_merge_lists', '_fake', 'addl_modules']:
# We leave '_ext_plugins' and '_expected_extensions' because they are used in
# a function (so they need to stick around)
for _n in ['ver', 'n', 'o', 'e', 'lazy_load_custom_modules', '_LazyLoader',
'_merge_lists', '_fake', '_wrap', 'addl_modules']:
try:
del globals()[_n]
except KeyError:
Expand Down
84 changes: 52 additions & 32 deletions metaflow/plugins/aws/aws_client.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,57 @@
cached_aws_sandbox_creds = None
cached_provider_class = None


def get_aws_client(module, with_error=False, params={}):
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import AWS_SANDBOX_ENABLED, \
AWS_SANDBOX_STS_ENDPOINT_URL, AWS_SANDBOX_API_KEY
import requests
try:
import boto3
from botocore.exceptions import ClientError
except (NameError, ImportError):
raise MetaflowException(
"Could not import module 'boto3'. Install boto3 first.")
class Boto3ClientProvider(object):
name = "boto3"

if AWS_SANDBOX_ENABLED:
global cached_aws_sandbox_creds
if cached_aws_sandbox_creds is None:
# authenticate using STS
url = "%s/auth/token" % AWS_SANDBOX_STS_ENDPOINT_URL
headers = {
'x-api-key': AWS_SANDBOX_API_KEY
}
try:
r = requests.get(url, headers=headers)
r.raise_for_status()
cached_aws_sandbox_creds = r.json()
except requests.exceptions.HTTPError as e:
raise MetaflowException(repr(e))
if with_error:
@staticmethod
def get_client(module, with_error=False, params={}):
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import AWS_SANDBOX_ENABLED, \
AWS_SANDBOX_STS_ENDPOINT_URL, AWS_SANDBOX_API_KEY
import requests
try:
import boto3
from botocore.exceptions import ClientError
except (NameError, ImportError):
raise MetaflowException(
"Could not import module 'boto3'. Install boto3 first.")

if AWS_SANDBOX_ENABLED:
global cached_aws_sandbox_creds
if cached_aws_sandbox_creds is None:
# authenticate using STS
url = "%s/auth/token" % AWS_SANDBOX_STS_ENDPOINT_URL
headers = {
'x-api-key': AWS_SANDBOX_API_KEY
}
try:
r = requests.get(url, headers=headers)
r.raise_for_status()
cached_aws_sandbox_creds = r.json()
except requests.exceptions.HTTPError as e:
raise MetaflowException(repr(e))
if with_error:
return boto3.session.Session(
**cached_aws_sandbox_creds).client(module, **params), ClientError
return boto3.session.Session(
**cached_aws_sandbox_creds).client(module, **params), ClientError
return boto3.session.Session(
**cached_aws_sandbox_creds).client(module, **params)
if with_error:
return boto3.client(module, **params), ClientError
return boto3.client(module, **params)
**cached_aws_sandbox_creds).client(module, **params)
if with_error:
return boto3.client(module, **params), ClientError
return boto3.client(module, **params)


def get_aws_client(module, with_error=False, params={}):
global cached_provider_class
if cached_provider_class is None:
from metaflow.metaflow_config import DEFAULT_AWS_CLIENT_PROVIDER
from metaflow.plugins import AWS_CLIENT_PROVIDERS
for p in AWS_CLIENT_PROVIDERS:
if p.name == DEFAULT_AWS_CLIENT_PROVIDER:
cached_provider_class = p
break
else:
raise ValueError("Cannot find AWS Client provider %s"
% DEFAULT_AWS_CLIENT_PROVIDER)
return cached_provider_class.get_client(module, with_error, params)

0 comments on commit a11f7ea

Please sign in to comment.