From 6b82d457ed59a703a22b1a04947fda6647ea0f06 Mon Sep 17 00:00:00 2001 From: Austin Kline Date: Tue, 29 Dec 2020 08:37:37 -0800 Subject: [PATCH] NeptuneML integration (#48) - Integration with NeptuneML feature set in AWS Neptune - Add helper library to perform Sigv4 signing for `%neptune_ml export ...`, we will move our other signing at a later date. - Swap how credentials are obtained for `ROLE` iam credentials provider such that it uses a botocore session now instead of calling the ec2 metadata service. This should make the module more usable outside of Sagemaker. New Line magics: - `%neptune_ml export status` - `%neptune_ml dataprocessing start` - `%neptune_ml dataprocessing status` - `%neptune_ml training start` - `%neptune_ml training status` - `%neptune_ml endpoint create` - `%neptune_ml endpoint status` New Cell magics: - `%%neptune_ml export start` - `%%neptune_ml dataprocessing start` - `%%neptune_ml training start` - `%%neptune_ml endpoint create` NOTE: If a cell magic is used, its line inputs for specifying parts of the command will be ignore such as `--job-id` as a line-param. Inject variable as cell input: Currently this will only work for our new cell magic commands details above. You can now specify a variable to use as the cell input received by our `neptune_ml` magics using the syntax ${var_name}. For example... ``` # in one notebook cell: foo = {'foo', 'bar'} # in another notebook cell: %%neptune_ml export start ${foo} ``` NOTE: The above will only work if it is the sole content of the cell body. You cannot inline multiple variables at this time. --- .github/workflows/integration.yml | 4 +- MANIFEST.in | 2 +- THIRD_PARTY_LICENSES.txt | 30 ++ setup.py | 4 +- setupbase.py | 8 + src/graph_notebook/__init__.py | 2 +- .../ec2_metadata_credentials_provider.py | 13 +- .../env_credentials_provider.py | 2 +- .../authentication/iam_headers.py | 25 +- src/graph_notebook/magics/graph_magic.py | 18 + src/graph_notebook/magics/ml.py | 381 ++++++++++++++++++ src/graph_notebook/magics/parsing/__init__.py | 2 + .../magics/parsing/replace_namespace_vars.py | 25 ++ src/graph_notebook/ml/__init__.py | 0 src/graph_notebook/ml/sagemaker.py | 86 ++++ .../call_and_get_response.py | 10 +- .../default_request_generator.py | 7 +- .../iam_request_generator.py | 3 +- src/graph_notebook/widgets/package.json | 2 +- .../NeptuneIntegrationWorkflowSteps.py | 15 +- test/unit/graph_magic/parsing/__init__.py | 0 .../parsing/test_str_to_namespace_var.py | 21 + 22 files changed, 629 insertions(+), 31 deletions(-) create mode 100644 src/graph_notebook/magics/ml.py create mode 100644 src/graph_notebook/magics/parsing/__init__.py create mode 100644 src/graph_notebook/magics/parsing/replace_namespace_vars.py create mode 100644 src/graph_notebook/ml/__init__.py create mode 100644 src/graph_notebook/ml/sagemaker.py create mode 100644 test/unit/graph_magic/parsing/__init__.py create mode 100644 test/unit/graph_magic/parsing/test_str_to_namespace_var.py diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index f1b8a336..de3fb15e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -38,8 +38,7 @@ jobs: - uses: actions/checkout@v2 - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install flake8 pytest + pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install run: | @@ -75,7 +74,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install run: | diff --git a/MANIFEST.in b/MANIFEST.in index 8545ebb0..24e6605d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,7 +4,7 @@ include webpack.config.js # Javascript files recursive-include src/graph_notebook/widgets * -prune **/node_modules +prune src/graph_notebook/widgets/node_modules prune coverage # Patterns to exclude from any directory diff --git a/THIRD_PARTY_LICENSES.txt b/THIRD_PARTY_LICENSES.txt index ca2118ae..828f7de4 100644 --- a/THIRD_PARTY_LICENSES.txt +++ b/THIRD_PARTY_LICENSES.txt @@ -2591,6 +2591,36 @@ SOFTWARE. ------ +** requests-aws4auth 1.0.1; version 1.0.1 -- +https://pypi.org/project/requests-aws4auth/ +The MIT License (MIT) + +Copyright (c) 2015 Sam Washington + +The MIT License (MIT) + +Copyright (c) 2015 Sam Washington + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +------ + ** @types/webpack-env; version 1.15.2 -- https://github.com/DefinitelyTyped/DefinitelyTyped Copyright (c) Microsoft Corporation. All rights reserved. diff --git a/setup.py b/setup.py index f92354e2..a7fe9432 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,9 @@ def get_version(): 'notebook', 'jupyter-contrib-nbextensions', 'widgetsnbextension', - 'jupyter>=1.0.0' + 'jupyter>=1.0.0', + 'requests-aws4auth==1.0.1', + 'botocore>=1.19.37' ], package_data={ 'graph_notebook': ['graph_notebook/widgets/nbextensions/static/*.js', diff --git a/setupbase.py b/setupbase.py index 8b199f16..00456fa9 100644 --- a/setupbase.py +++ b/setupbase.py @@ -382,6 +382,14 @@ def run(self): if should_build: run(npm_cmd + ['run', build_cmd], cwd=node_package) + # ensure that __init__.py files are added to generated directories, otherwise it will not be packaged with + # package distribution to pypi + dirs_from_node_path = ['nbextension', pjoin('nbextension', 'static'), 'lib', 'labextension'] + for init_path in dirs_from_node_path: + full_path = pjoin(node_package, init_path, '__init__.py') + with open(full_path, 'w+'): + pass + return NPM diff --git a/src/graph_notebook/__init__.py b/src/graph_notebook/__init__.py index 37a85412..d68d2d4d 100644 --- a/src/graph_notebook/__init__.py +++ b/src/graph_notebook/__init__.py @@ -3,4 +3,4 @@ SPDX-License-Identifier: Apache-2.0 """ -__version__ = '2.0.2' +__version__ = '2.0.3' diff --git a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py index 491371e8..8dd9bf58 100644 --- a/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py +++ b/src/graph_notebook/authentication/iam_credentials_provider/ec2_metadata_credentials_provider.py @@ -3,13 +3,14 @@ SPDX-License-Identifier: Apache-2.0 """ +import botocore.session import requests + from graph_notebook.authentication.iam_credentials_provider.credentials_provider import CredentialsProviderBase, \ Credentials region_url = 'http://169.254.169.254/latest/meta-data/placement/availability-zone' -iam_url = 'http://169.254.169.254/latest/meta-data/iam/security-credentials/neptune-db' class MetadataCredentialsProvider(CredentialsProviderBase): @@ -20,10 +21,6 @@ def __init__(self): self.region = region def get_iam_credentials(self) -> Credentials: - res = requests.get(iam_url) - if res.status_code != 200: - raise Exception(f'unable to get iam credentials {res.content}') - - js = res.json() - creds = Credentials(key=js['AccessKeyId'], secret=js['SecretAccessKey'], token=js['Token'], region=self.region) - return creds + session = botocore.session.get_session() + creds = session.get_credentials() + return Credentials(key=creds.access_key, secret=creds.secret_key, token=creds.token, region=self.region) diff --git a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py b/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py index e5b91413..0ba39e9d 100644 --- a/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py +++ b/src/graph_notebook/authentication/iam_credentials_provider/env_credentials_provider.py @@ -28,7 +28,7 @@ def load_iam_credentials(self): self.loaded = True return - def get_iam_credentials(self) -> Credentials: + def get_iam_credentials(self, service=None) -> Credentials: if not self.loaded: self.load_iam_credentials() diff --git a/src/graph_notebook/authentication/iam_headers.py b/src/graph_notebook/authentication/iam_headers.py index 5fb713d6..4d195132 100644 --- a/src/graph_notebook/authentication/iam_headers.py +++ b/src/graph_notebook/authentication/iam_headers.py @@ -6,10 +6,10 @@ import datetime import hashlib import hmac +import json import logging import urllib - logging.basicConfig() logger = logging.getLogger("graph_magic") @@ -69,6 +69,19 @@ def get_canonical_uri_and_payload(query_type, query): elif query_type == "system": canonical_uri = "/system/" payload = query + + elif query_type.startswith("ml"): + canonical_uri = f'/{query_type}' + payload = query + + elif query_type.startswith("ml/dataprocessing"): + canonical_uri = f'/{query_type}' + payload = query + + elif query_type.startswith("ml/endpoints"): + canonical_uri = f'/{query_type}' + payload = query + else: raise ValueError('query_type %s is not valid' % query_type) @@ -85,7 +98,8 @@ def normalize_query_string(query): return normalized -def make_signed_request(method, query_type, query, host, port, signing_access_key, signing_secret, signing_region, use_ssl=False, signing_token='', additional_headers=None): +def make_signed_request(method, query_type, query, host, port, signing_access_key, signing_secret, signing_region, + use_ssl=False, signing_token='', additional_headers=None): if additional_headers is None: additional_headers = [] @@ -103,8 +117,11 @@ def make_signed_request(method, query_type, query, host, port, signing_access_ke # get canonical_uri and payload canonical_uri, payload = get_canonical_uri_and_payload(query_type, query) - request_parameters = urllib.parse.urlencode(payload, quote_via=urllib.parse.quote) - request_parameters = request_parameters.replace('%27', '%22') + if 'content-type' in additional_headers and additional_headers['content-type'] == 'application/json': + request_parameters = payload if type(payload) is str else json.dumps(payload) + else: + request_parameters = urllib.parse.urlencode(payload, quote_via=urllib.parse.quote) + request_parameters = request_parameters.replace('%27', '%22') t = datetime.datetime.utcnow() amz_date = t.strftime('%Y%m%dT%H%M%SZ') date_stamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 86ca7358..ac15e64a 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -23,6 +23,7 @@ import graph_notebook from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION from graph_notebook.decorators.decorators import display_exceptions +from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser from graph_notebook.network import SPARQLNetwork from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork from graph_notebook.sparql.table import get_rows_and_columns @@ -994,3 +995,20 @@ def graph_notebook_vis_options(self, line='', cell=''): else: options_dict = json.loads(cell) self.graph_notebook_vis_options = vis_options_merge(self.graph_notebook_vis_options, options_dict) + + @line_cell_magic + @display_exceptions + @needs_local_scope + def neptune_ml(self, line, cell='', local_ns: dict = None): + parser = generate_neptune_ml_parser() + args = parser.parse_args(line.split()) + logger.info(f'received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}') + request_generator = create_request_generator(self.graph_notebook_config.auth_mode, + self.graph_notebook_config.iam_credentials_provider_type) + main_output = widgets.Output() + display(main_output) + res = neptune_ml_magic_handler(args, request_generator, self.graph_notebook_config, main_output, cell, local_ns) + message = json.dumps(res, indent=2) if type(res) is dict else res + store_to_ns(args.store_to, res, local_ns) + with main_output: + print(message) diff --git a/src/graph_notebook/magics/ml.py b/src/graph_notebook/magics/ml.py new file mode 100644 index 00000000..b40db040 --- /dev/null +++ b/src/graph_notebook/magics/ml.py @@ -0,0 +1,381 @@ +import argparse +import json +import datetime +import logging +import time +from IPython.core.display import display +from ipywidgets import widgets + +from graph_notebook.authentication.iam_credentials_provider.credentials_factory import credentials_provider_factory +from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials +from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum +from graph_notebook.magics.parsing import str_to_namespace_var +from graph_notebook.ml.sagemaker import start_export, get_export_status, start_processing_job, get_processing_status, \ + start_training, get_training_status, start_create_endpoint, get_endpoint_status, EXPORT_SERVICE_NAME + +logger = logging.getLogger("neptune_ml_magic_handler") + +DEFAULT_WAIT_INTERVAL = 60 +DEFAULT_WAIT_TIMEOUT = 3600 + + +def generate_neptune_ml_parser(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help='sub-command help', dest='which') + + # Begin Export subparsers + parser_export = subparsers.add_parser('export', help='') + export_sub_parsers = parser_export.add_subparsers(help='', dest='which_sub') + export_start_parser = export_sub_parsers.add_parser('start', help='start a new exporter job') + export_start_parser.add_argument('--export-url', type=str, + help='api gateway endpoint to call the exporter such as foo.execute-api.us-east-1.amazonaws.com/v1') + export_start_parser.add_argument('--export-iam', action='store_true', + help='flag for whether to sign requests to the export url with SigV4') + export_start_parser.add_argument('--export-no-ssl', action='store_true', + help='toggle ssl off when connecting to exporter') + export_start_parser.add_argument('--wait', action='store_true', help='wait for the exporter to finish running') + export_start_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help=f'time in seconds between export status check. default: {DEFAULT_WAIT_INTERVAL}') + export_start_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help=f'time in seconds to wait for a given export job to complete before returning most recent status. default: {DEFAULT_WAIT_TIMEOUT}') + export_start_parser.add_argument('--store-to', default='', dest='store_to', + help='store result to this variable. If --wait is specified, will store the final status.') + + export_status_parser = export_sub_parsers.add_parser('status', help='obtain status of exporter job') + export_status_parser.add_argument('--job-id', type=str, help='job id to check the status of') + export_status_parser.add_argument('--export-url', type=str, + help='api gateway endpoint to call the exporter such as foo.execute-api.us-east-1.amazonaws.com/v1') + export_status_parser.add_argument('--export-iam', action='store_true', + help='flag for whether to sign requests to the export url with SigV4') + export_status_parser.add_argument('--export-no-ssl', action='store_true', + help='toggle ssl off when connecting to exporter') + export_status_parser.add_argument('--store-to', default='', dest='store_to', + help='store result to this variable') + export_status_parser.add_argument('--wait', action='store_true', help='wait for the exporter to finish running') + export_status_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help=f'time in seconds between export status check. default: {DEFAULT_WAIT_INTERVAL}') + export_status_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help=f'time in seconds to wait for a given export job to complete before returning most recent status. default: {DEFAULT_WAIT_TIMEOUT}') + + # Begin dataprocessing subparsers + parser_dataprocessing = subparsers.add_parser('dataprocessing', help='') + dataprocessing_subparsers = parser_dataprocessing.add_subparsers(help='dataprocessing sub-command', + dest='which_sub') + dataprocessing_start_parser = dataprocessing_subparsers.add_parser('start', help='start a new dataprocessing job') + dataprocessing_start_parser.add_argument('--job-id', type=str, + default='the unique identifier for for this processing job') + dataprocessing_start_parser.add_argument('--s3-input-uri', type=str, default='input data location in s3') + dataprocessing_start_parser.add_argument('--s3-processed-uri', type=str, default='processed data location in s3') + dataprocessing_start_parser.add_argument('--config-file-name', type=str, default='') + dataprocessing_start_parser.add_argument('--store-to', type=str, default='', + help='store result to this variable') + dataprocessing_start_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + dataprocessing_start_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + dataprocessing_start_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + dataprocessing_status_parser = dataprocessing_subparsers.add_parser('status', + help='obtain the status of an existing dataprocessing job') + dataprocessing_status_parser.add_argument('--job-id', type=str) + dataprocessing_status_parser.add_argument('--store-to', type=str, default='', + help='store result to this variable') + dataprocessing_status_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + dataprocessing_status_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + dataprocessing_status_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + # Begin training subparsers + parser_training = subparsers.add_parser('training', help='training command help') + training_subparsers = parser_training.add_subparsers(help='training sub-command help', + dest='which_sub') + training_start_parser = training_subparsers.add_parser('start', help='start a new training job') + training_start_parser.add_argument('--job-id', type=str, default='') + training_start_parser.add_argument('--data-processing-id', type=str, default='') + training_start_parser.add_argument('--s3-output-uri', type=str, default='') + training_start_parser.add_argument('--instance-type', type=str, default='') + training_start_parser.add_argument('--store-to', type=str, default='', help='store result to this variable') + training_start_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + training_start_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + training_start_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + training_status_parser = training_subparsers.add_parser('status', + help='obtain the status of an existing training job') + training_status_parser.add_argument('--job-id', type=str) + training_status_parser.add_argument('--store-to', type=str, default='', help='store result to this variable') + training_status_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + training_status_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + training_status_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + # Begin endpoint subparsers + parser_endpoint = subparsers.add_parser('endpoint', help='endpoint command help') + endpoint_subparsers = parser_endpoint.add_subparsers(help='endpoint sub-command help', + dest='which_sub') + endpoint_start_parser = endpoint_subparsers.add_parser('create', help='create a new endpoint') + endpoint_start_parser.add_argument('--job-id', type=str, default='') + endpoint_start_parser.add_argument('--model-job-id', type=str, default='') + endpoint_start_parser.add_argument('--instance-type', type=str, default='ml.r5.xlarge') + endpoint_start_parser.add_argument('--store-to', type=str, default='', help='store result to this variable') + endpoint_start_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + endpoint_start_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + endpoint_start_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + endpoint_status_parser = endpoint_subparsers.add_parser('status', + help='obtain the status of an existing endpoint creation job') + endpoint_status_parser.add_argument('--job-id', type=str, default='') + endpoint_status_parser.add_argument('--store-to', type=str, default='', help='store result to this variable') + endpoint_status_parser.add_argument('--wait', action='store_true', + help='wait for the exporter to finish running') + endpoint_status_parser.add_argument('--wait-interval', default=DEFAULT_WAIT_INTERVAL, type=int, + help='wait interval between checks for export status') + endpoint_status_parser.add_argument('--wait-timeout', default=DEFAULT_WAIT_TIMEOUT, type=int, + help='timeout while waiting for export job to complete') + + return parser + + +def neptune_ml_export_start(params, export_url: str, export_ssl: bool = True, creds: Credentials = None): + if type(params) is str: + params = json.loads(params) + + job = start_export(export_url, params, export_ssl, creds) + return job + + +def wait_for_export(export_url: str, job_id: str, output: widgets.Output, + export_ssl: bool = True, wait_interval: int = DEFAULT_WAIT_INTERVAL, + wait_timeout: int = DEFAULT_WAIT_TIMEOUT, creds: Credentials = None): + job_id_output = widgets.Output() + update_widget_output = widgets.Output() + with output: + display(job_id_output, update_widget_output) + + with job_id_output: + print(f'Wait called on export job {job_id}') + + with update_widget_output: + beginning_time = datetime.datetime.utcnow() + while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): + update_widget_output.clear_output() + print('Checking for latest status...') + export_status = get_export_status(export_url, export_ssl, job_id, creds) + if export_status['status'] in ['succeeded', 'failed']: + print('Export is finished') + return export_status + else: + print(f'Status is {export_status["status"]}') + print(f'Waiting for {wait_interval} before checking again...') + time.sleep(wait_interval) + + +def neptune_ml_export(args: argparse.Namespace, config: Configuration, output: widgets.Output, cell: str): + auth_mode = AuthModeEnum.IAM if args.export_iam else AuthModeEnum.DEFAULT + creds = None + if auth_mode == AuthModeEnum.IAM: + creds = credentials_provider_factory(config.iam_credentials_provider_type).get_iam_credentials() + + export_ssl = not args.export_no_ssl + if args.which_sub == 'start': + if cell == '': + return 'Cell body must have json payload or reference notebook variable using syntax ${payload_var}' + export_job = neptune_ml_export_start(cell, args.export_url, export_ssl, creds) + if args.wait: + return wait_for_export(args.export_url, export_job['jobId'], + output, export_ssl, args.wait_interval, args.wait_timeout, creds) + else: + return export_job + elif args.which_sub == 'status': + if args.wait: + status = wait_for_export(args.export_url, args.job_id, output, export_ssl, args.wait_interval, + args.wait_timeout, creds) + else: + status = get_export_status(args.export_url, export_ssl, args.job_id, creds) + return status + + +def wait_for_dataprocessing(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, + wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): + job_id_output = widgets.Output() + update_status_output = widgets.Output() + with output: + display(job_id_output, update_status_output) + + with job_id_output: + print(f'Wait called on dataprocessing job {job_id}') + + with update_status_output: + beginning_time = datetime.datetime.utcnow() + while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): + update_status_output.clear_output() + status = get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + if status['status'] in ['Completed', 'Failed']: + print('Data processing is finished') + return status + else: + print(f'Status is {status["status"]}') + print(f'Waiting for {wait_interval} before checking again...') + time.sleep(wait_interval) + + +def neptune_ml_dataprocessing(args: argparse.Namespace, request_param_generator, output: widgets.Output, + config: Configuration, params: dict = None): + if args.which_sub == 'start': + if params is None or params == '' or params == {}: + params = { + 'inputDataS3Location': args.s3_input_uri, + 'processedDataS3Location': args.s3_processed_uri, + 'id': args.job_id, + 'configFileName': args.config_file_name + } + + processing_job = start_processing_job(config.host, str(config.port), config.ssl, + request_param_generator, params) + job_id = params['id'] + if args.wait: + return wait_for_dataprocessing(job_id, config, request_param_generator, + output, args.wait_interval, args.wait_timeout) + else: + return processing_job + elif args.which_sub == 'status': + if args.wait: + return wait_for_dataprocessing(args.job_id, config, request_param_generator, output, args.wait_interval, + args.wait_timeout) + else: + return get_processing_status(config.host, str(config.port), config.ssl, request_param_generator, + args.job_id) + else: + return f'Sub parser "{args.which} {args.which_sub}" was not recognized' + + +def wait_for_training(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, + wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): + job_id_output = widgets.Output() + update_status_output = widgets.Output() + with output: + display(job_id_output, update_status_output) + + with job_id_output: + print(f'Wait called on training job {job_id}') + + with update_status_output: + beginning_time = datetime.datetime.utcnow() + while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): + update_status_output.clear_output() + status = get_training_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + if status['status'] in ['Completed', 'Failed']: + print('Training is finished') + return status + else: + print(f'Status is {status["status"]}') + print(f'Waiting for {wait_interval} before checking again...') + time.sleep(wait_interval) + + +def neptune_ml_training(args: argparse.Namespace, request_param_generator, config: Configuration, + output: widgets.Output, params): + if args.which_sub == 'start': + if params is None or params == '' or params == {}: + params = { + "id": args.job_id, + "dataProcessingJobId": args.data_processing_id, + "trainingInstanceType": args.instance_type, + "trainModelS3Location": args.s3_output_uri + } + + training_job = start_training(config.host, str(config.port), config.ssl, request_param_generator, params) + if args.wait: + return wait_for_training(training_job['id'], config, request_param_generator, output, args.wait_interval, + args.wait_timeout) + else: + return training_job + elif args.which_sub == 'status': + if args.wait: + return wait_for_training(args.job_id, config, request_param_generator, output, args.wait_interval, + args.wait_timeout) + else: + return get_training_status(config.host, str(config.port), config.ssl, request_param_generator, + args.job_id) + else: + return f'Sub parser "{args.which} {args.which_sub}" was not recognized' + + +def wait_for_endpoint(job_id: str, config: Configuration, request_param_generator, output: widgets.Output, + wait_interval: int = DEFAULT_WAIT_INTERVAL, wait_timeout: int = DEFAULT_WAIT_TIMEOUT): + job_id_output = widgets.Output() + update_status_output = widgets.Output() + with output: + display(job_id_output, update_status_output) + + with job_id_output: + print(f'Wait called on endpoint creation job {job_id}') + + with update_status_output: + beginning_time = datetime.datetime.utcnow() + while datetime.datetime.utcnow() - beginning_time < (datetime.timedelta(seconds=wait_timeout)): + update_status_output.clear_output() + status = get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, job_id) + if status['status'] in ['InService', 'Failed']: + print('Endpoint creation is finished') + return status + else: + print(f'Status is {status["status"]}') + print(f'Waiting for {wait_interval} before checking again...') + time.sleep(wait_interval) + + +def neptune_ml_endpoint(args: argparse.Namespace, request_param_generator, + config: Configuration, output: widgets.Output, params): + if args.which_sub == 'create': + if params is None or params == '' or params == {}: + params = { + "id": args.job_id, + "mlModelTrainingJobId": args.model_job_id, + 'instanceType': args.instance_type + } + + create_endpoint_job = start_create_endpoint(config.host, str(config.port), config.ssl, + request_param_generator, params) + + if args.wait: + return wait_for_endpoint(create_endpoint_job['id'], config, request_param_generator, output, + args.wait_interval, args.wait_timeout) + else: + return create_endpoint_job + elif args.which_sub == 'status': + if args.wait: + return wait_for_endpoint(args.job_id, config, request_param_generator, output, + args.wait_interval, args.wait_timeout) + else: + return get_endpoint_status(config.host, str(config.port), config.ssl, request_param_generator, args.job_id) + else: + return f'Sub parser "{args.which} {args.which_sub}" was not recognized' + + +def neptune_ml_magic_handler(args, request_param_generator, config: Configuration, output: widgets.Output, + cell: str = '', local_ns: dict = None) -> any: + if local_ns is None: + local_ns = {} + cell = str_to_namespace_var(cell, local_ns) + + if args.which == 'export': + return neptune_ml_export(args, config, output, cell) + elif args.which == 'dataprocessing': + return neptune_ml_dataprocessing(args, request_param_generator, output, config, cell) + elif args.which == 'training': + return neptune_ml_training(args, request_param_generator, config, output, cell) + elif args.which == 'endpoint': + return neptune_ml_endpoint(args, request_param_generator, config, output, cell) + else: + return f'sub parser {args.which} was not recognized' diff --git a/src/graph_notebook/magics/parsing/__init__.py b/src/graph_notebook/magics/parsing/__init__.py new file mode 100644 index 00000000..2ea06d43 --- /dev/null +++ b/src/graph_notebook/magics/parsing/__init__.py @@ -0,0 +1,2 @@ +from .replace_namespace_vars import str_to_namespace_var +from .replace_namespace_vars import replace_namespace_vars diff --git a/src/graph_notebook/magics/parsing/replace_namespace_vars.py b/src/graph_notebook/magics/parsing/replace_namespace_vars.py new file mode 100644 index 00000000..dbf6b710 --- /dev/null +++ b/src/graph_notebook/magics/parsing/replace_namespace_vars.py @@ -0,0 +1,25 @@ +import argparse + + +def str_to_namespace_var(key: str, local_ns: dict) -> any: + if local_ns is None: + return key + + if type(key) is not str: + return key + + tmp_key = key.strip() + if not (tmp_key.startswith('${') and tmp_key.endswith('}')): + return key + else: + tmp_key = tmp_key[2:-1].strip() + return key if tmp_key not in local_ns else local_ns[tmp_key] + + +def replace_namespace_vars(args: argparse.Namespace, local_ns: dict): + if local_ns is None or local_ns == {}: + return + + for key in list(args.__dict__.keys()): + new_value = str_to_namespace_var(args.__dict__[key], local_ns) + args.__dict__[key] = new_value diff --git a/src/graph_notebook/ml/__init__.py b/src/graph_notebook/ml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/graph_notebook/ml/sagemaker.py b/src/graph_notebook/ml/sagemaker.py new file mode 100644 index 00000000..71c4a59b --- /dev/null +++ b/src/graph_notebook/ml/sagemaker.py @@ -0,0 +1,86 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json +import requests +from requests_aws4auth import AWS4Auth + +from graph_notebook.authentication.iam_credentials_provider.credentials_provider import Credentials +from graph_notebook.request_param_generator.call_and_get_response import call_and_get_response + +EXPORT_SERVICE_NAME = 'execute-api' +EXPORT_ACTION = 'neptune-export' +EXTRA_HEADERS = {'content-type': 'application/json'} +UPDATE_DELAY_SECONDS = 60 + + +def start_export(export_host: str, export_params: dict, use_ssl: bool, + creds: Credentials = None) -> dict: + auth = None + if creds is not None: + auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, + session_token=creds.token) + + protocol = 'https' if use_ssl else 'http' + url = f'{protocol}://{export_host}/{EXPORT_ACTION}' + res = requests.post(url, json=export_params, headers=EXTRA_HEADERS, auth=auth) + res.raise_for_status() + job = res.json() + return job + + +def get_export_status(export_host: str, use_ssl: bool, job_id: str, creds: Credentials = None): + auth = None + if creds is not None: + auth = AWS4Auth(creds.key, creds.secret, creds.region, EXPORT_SERVICE_NAME, + session_token=creds.token) + + protocol = 'https' if use_ssl else 'http' + url = f'{protocol}://{export_host}/{EXPORT_ACTION}/{job_id}' + res = requests.get(url, headers=EXTRA_HEADERS, auth=auth) + res.raise_for_status() + job = res.json() + return job + + +def get_processing_status(host: str, port: str, use_ssl: bool, request_param_generator, job_name: str): + res = call_and_get_response('get', f'ml/dataprocessing/{job_name}', host, port, request_param_generator, + use_ssl, extra_headers=EXTRA_HEADERS) + status = res.json() + return status + + +def start_processing_job(host: str, port: str, use_ssl: bool, request_param_generator, params: dict): + params_raw = json.dumps(params) if type(params) is dict else params + res = call_and_get_response('post', 'ml/dataprocessing', host, port, request_param_generator, use_ssl, params_raw, + EXTRA_HEADERS) + job = res.json() + return job + + +def start_training(host: str, port: str, use_ssl: bool, request_param_generator, params): + params_raw = json.dumps(params) if type(params) is dict else params + res = call_and_get_response('post', 'ml/modeltraining', host, port, request_param_generator, use_ssl, params_raw, + EXTRA_HEADERS) + return res.json() + + +def get_training_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): + res = call_and_get_response('get', f'ml/modeltraining/{training_job_name}', host, port, + request_param_generator, use_ssl, extra_headers=EXTRA_HEADERS) + return res.json() + + +def start_create_endpoint(host: str, port: str, use_ssl: bool, request_param_generator, params): + params_raw = json.dumps(params) if type(params) is dict else params + res = call_and_get_response('post', 'ml/endpoints', host, port, request_param_generator, use_ssl, params_raw, + EXTRA_HEADERS) + return res.json() + + +def get_endpoint_status(host: str, port: str, use_ssl: bool, request_param_generator, training_job_name: str): + res = call_and_get_response('get', f'ml/endpoints/{training_job_name}', host, port, request_param_generator, + use_ssl, extra_headers=EXTRA_HEADERS) + return res.json() diff --git a/src/graph_notebook/request_param_generator/call_and_get_response.py b/src/graph_notebook/request_param_generator/call_and_get_response.py index 59fdcb2e..f0bc3e84 100644 --- a/src/graph_notebook/request_param_generator/call_and_get_response.py +++ b/src/graph_notebook/request_param_generator/call_and_get_response.py @@ -6,11 +6,17 @@ import requests -def call_and_get_response(method: str, action: str, host: str, port: str, request_param_generator, use_ssl: bool, query='', extra_headers={}): +def call_and_get_response(method: str, action: str, host: str, port: str, request_param_generator, use_ssl: bool, + query='', extra_headers=None): + if extra_headers is None: + extra_headers = {} + method = method.upper() protocol = 'https' if use_ssl else 'http' - request_params = request_param_generator.generate_request_params(method=method, action=action, query=query, host=host, port=port, protocol=protocol, headers=extra_headers) + request_params = request_param_generator.generate_request_params(method=method, action=action, query=query, + host=host, port=port, protocol=protocol, + headers=extra_headers) headers = request_params['headers'] if request_params['headers'] is not None else {} if method == 'GET': diff --git a/src/graph_notebook/request_param_generator/default_request_generator.py b/src/graph_notebook/request_param_generator/default_request_generator.py index ca10bc84..0fb2b8cf 100644 --- a/src/graph_notebook/request_param_generator/default_request_generator.py +++ b/src/graph_notebook/request_param_generator/default_request_generator.py @@ -7,9 +7,12 @@ class DefaultRequestGenerator(object): @staticmethod def generate_request_params(method, action, query, host, port, protocol, headers=None): - return { + url = f'{protocol}://{host}:{port}/{action}' if port != '' else f'{protocol}://{host}/{action}' + params = { 'method': method, - 'url': f'{protocol}://{host}:{port}/{action}', + 'url': url, 'headers': headers, 'params': query, } + + return params diff --git a/src/graph_notebook/request_param_generator/iam_request_generator.py b/src/graph_notebook/request_param_generator/iam_request_generator.py index ea8d7873..fc88e809 100644 --- a/src/graph_notebook/request_param_generator/iam_request_generator.py +++ b/src/graph_notebook/request_param_generator/iam_request_generator.py @@ -17,4 +17,5 @@ def generate_request_params(self, method, action, query, host, port, protocol, h else: use_ssl = False - return make_signed_request(method, action, query, host, port, credentials.key, credentials.secret, credentials.region, use_ssl, credentials.token, additional_headers=headers) + return make_signed_request(method, action, query, host, port, credentials.key, credentials.secret, + credentials.region, use_ssl, credentials.token, additional_headers=headers) diff --git a/src/graph_notebook/widgets/package.json b/src/graph_notebook/widgets/package.json index e99e7c3d..887a7d54 100644 --- a/src/graph_notebook/widgets/package.json +++ b/src/graph_notebook/widgets/package.json @@ -1,6 +1,6 @@ { "name": "graph_notebook_widgets", - "version": "2.0.2", + "version": "2.0.3", "author": "amazon", "description": "A Custom Jupyter Library for rendering NetworkX MultiDiGraphs using vis-network", "dependencies": { diff --git a/test/integration/NeptuneIntegrationWorkflowSteps.py b/test/integration/NeptuneIntegrationWorkflowSteps.py index 56f0a78e..5149330a 100644 --- a/test/integration/NeptuneIntegrationWorkflowSteps.py +++ b/test/integration/NeptuneIntegrationWorkflowSteps.py @@ -330,6 +330,8 @@ def main(): parser_run_tests.add_argument('--iam', action='store_true') parser_run_tests.add_argument('--cfn-stack-name', type=str, default='') parser_run_tests.add_argument('--aws-region', type=str, default='us-east-1') + parser_run_tests.add_argument('--skip-config-generation', action='store_true', + help=f'skips config generation for testing, using the one found under {TEST_CONFIG_PATH}') args = parser.parse_args() @@ -341,12 +343,13 @@ def main(): elif args.which == SUBPARSER_DELETE_CFN: delete_stack(args.cfn_stack_name, cfn_client) elif args.which == SUBPARSER_RUN_TESTS: - loop_until_stack_is_complete(args.cfn_stack_name, cfn_client) - stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) - cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) - set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) - config = generate_config_from_stack(stack, args.aws_region, args.iam) - config.write_to_file(TEST_CONFIG_PATH) + if not args.skip_config_generation: + loop_until_stack_is_complete(args.cfn_stack_name, cfn_client) + stack = get_cfn_stack_details(args.cfn_stack_name, cfn_client) + cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) + set_iam_auth_on_neptune_cluster(cluster_identifier, args.iam, neptune_client) + config = generate_config_from_stack(stack, args.aws_region, args.iam) + config.write_to_file(TEST_CONFIG_PATH) run_integration_tests(args.pattern) elif args.which == SUBPARSER_ENABLE_IAM: cluster_identifier = get_neptune_identifier_from_cfn(args.cfn_stack_name, cfn_client) diff --git a/test/unit/graph_magic/parsing/__init__.py b/test/unit/graph_magic/parsing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/graph_magic/parsing/test_str_to_namespace_var.py b/test/unit/graph_magic/parsing/test_str_to_namespace_var.py new file mode 100644 index 00000000..f669380e --- /dev/null +++ b/test/unit/graph_magic/parsing/test_str_to_namespace_var.py @@ -0,0 +1,21 @@ +import unittest + +from graph_notebook.magics.parsing import str_to_namespace_var + + +class TestParsingStrToNamespaceVar(unittest.TestCase): + def test_none_dict(self): + key = 'foo' + local_ns = None + res = str_to_namespace_var(key, local_ns) + self.assertEqual(key, res) + + def test_encapsulated_key(self): + key = '${foo}' + expected_value = 'test' + local_ns = { + 'foo': expected_value + } + + res = str_to_namespace_var(key, local_ns) + self.assertEqual(expected_value, res)