From ba1d5f5fd029841795d4011da9103102f2d23ea4 Mon Sep 17 00:00:00 2001 From: Keliang Chen Date: Sun, 12 Aug 2018 11:03:43 -0700 Subject: [PATCH] Add Amazon SageMaker Training (#3658) Add SageMaker Hook, Training Operator & Sensor Co-authored-by: srrajeev-aws --- airflow/contrib/hooks/sagemaker_hook.py | 241 ++++++++++ .../sagemaker_create_training_job_operator.py | 119 +++++ .../contrib/sensors/sagemaker_base_sensor.py | 76 ++++ .../sensors/sagemaker_training_sensor.py | 66 +++ tests/contrib/hooks/test_sagemaker_hook.py | 415 ++++++++++++++++++ ..._sagemaker_create_training_job_operator.py | 141 ++++++ .../sensors/test_sagemaker_base_sensor.py | 149 +++++++ .../sensors/test_sagemaker_training_sensor.py | 118 +++++ 8 files changed, 1325 insertions(+) create mode 100644 airflow/contrib/hooks/sagemaker_hook.py create mode 100644 airflow/contrib/operators/sagemaker_create_training_job_operator.py create mode 100644 airflow/contrib/sensors/sagemaker_base_sensor.py create mode 100644 airflow/contrib/sensors/sagemaker_training_sensor.py create mode 100644 tests/contrib/hooks/test_sagemaker_hook.py create mode 100644 tests/contrib/operators/test_sagemaker_create_training_job_operator.py create mode 100644 tests/contrib/sensors/test_sagemaker_base_sensor.py create mode 100644 tests/contrib/sensors/test_sagemaker_training_sensor.py diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py new file mode 100644 index 0000000000000..8b8e2e41e7678 --- /dev/null +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import copy +import time +from botocore.exceptions import ClientError + +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.hooks.S3_hook import S3Hook + + +class SageMakerHook(AwsHook): + """ + Interact with Amazon SageMaker. + sagemaker_conn_id is required for using + the config stored in db for training/tuning + """ + + def __init__(self, + sagemaker_conn_id=None, + use_db_config=False, + region_name=None, + check_interval=5, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerHook, self).__init__(*args, **kwargs) + self.sagemaker_conn_id = sagemaker_conn_id + self.use_db_config = use_db_config + self.region_name = region_name + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + self.conn = self.get_conn() + + def check_for_url(self, s3url): + """ + check if the s3url exists + :param s3url: S3 url + :type s3url:str + :return: bool + """ + bucket, key = S3Hook.parse_s3_url(s3url) + s3hook = S3Hook(aws_conn_id=self.aws_conn_id) + if not s3hook.check_for_bucket(bucket_name=bucket): + raise AirflowException( + "The input S3 Bucket {} does not exist ".format(bucket)) + if not s3hook.check_for_key(key=key, bucket_name=bucket): + raise AirflowException("The input S3 Key {} does not exist in the Bucket" + .format(s3url, bucket)) + return True + + def check_valid_training_input(self, training_config): + """ + Run checks before a training starts + :param training_config: training_config + :type training_config: dict + :return: None + """ + for channel in training_config['InputDataConfig']: + self.check_for_url(channel['DataSource'] + ['S3DataSource']['S3Uri']) + + def check_valid_tuning_input(self, tuning_config): + """ + Run checks before a tuning job starts + :param tuning_config: tuning_config + :type tuning_config: dict + :return: None + """ + for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']: + self.check_for_url(channel['DataSource'] + ['S3DataSource']['S3Uri']) + + def check_status(self, non_terminal_states, + failed_state, key, + describe_function, *args): + """ + :param non_terminal_states: the set of non_terminal states + :type non_terminal_states: dict + :param failed_state: the set of failed states + :type failed_state: dict + :param key: the key of the response dict + that points to the state + :type key: string + :param describe_function: the function used to retrieve the status + :type describe_function: python callable + :param args: the arguments for the function + :return: None + """ + sec = 0 + running = True + + while running: + + sec = sec + self.check_interval + + if self.max_ingestion_time and sec > self.max_ingestion_time: + # ensure that the job gets killed if the max ingestion time is exceeded + raise AirflowException("SageMaker job took more than " + "%s seconds", self.max_ingestion_time) + + time.sleep(self.check_interval) + try: + response = describe_function(*args) + status = response[key] + self.log.info("Job still running for %s seconds... " + "current status is %s" % (sec, status)) + except KeyError: + raise AirflowException("Could not get status of the SageMaker job") + except ClientError: + raise AirflowException("AWS request failed, check log for more info") + + if status in non_terminal_states: + running = True + elif status in failed_state: + raise AirflowException("SageMaker job failed because %s" + % response['FailureReason']) + else: + running = False + + self.log.info('SageMaker Job Compeleted') + + def get_conn(self): + """ + Establish an AWS connection + :return: a boto3 SageMaker client + """ + return self.get_client_type('sagemaker', region_name=self.region_name) + + def list_training_job(self, name_contains=None, status_equals=None): + """ + List the training jobs associated with the given input + :param name_contains: A string in the training job name + :type name_contains: str + :param status_equals: 'InProgress'|'Completed' + |'Failed'|'Stopping'|'Stopped' + :return:dict + """ + return self.conn.list_training_jobs( + NameContains=name_contains, StatusEquals=status_equals) + + def list_tuning_job(self, name_contains=None, status_equals=None): + """ + List the tuning jobs associated with the given input + :param name_contains: A string in the training job name + :type name_contains: str + :param status_equals: 'InProgress'|'Completed' + |'Failed'|'Stopping'|'Stopped' + :return:dict + """ + return self.conn.list_hyper_parameter_tuning_job( + NameContains=name_contains, StatusEquals=status_equals) + + def create_training_job(self, training_job_config, wait_for_completion=True): + """ + Create a training job + :param training_job_config: the config for training + :type training_job_config: dict + :param wait_for_completion: if the program should keep running until job finishes + :param wait_for_completion: bool + :return: A dict that contains ARN of the training job. + """ + if self.use_db_config: + if not self.sagemaker_conn_id: + raise AirflowException("SageMaker connection id must be present to read \ + SageMaker training jobs configuration.") + sagemaker_conn = self.get_connection(self.sagemaker_conn_id) + + config = copy.deepcopy(sagemaker_conn.extra_dejson) + training_job_config.update(config) + + self.check_valid_training_input(training_job_config) + + response = self.conn.create_training_job( + **training_job_config) + if wait_for_completion: + self.check_status(['InProgress', 'Stopping', 'Stopped'], + ['Failed'], + 'TrainingJobStatus', + self.describe_training_job, + training_job_config['TrainingJobName']) + return response + + def create_tuning_job(self, tuning_job_config): + """ + Create a tuning job + :param tuning_job_config: the config for tuning + :type tuning_job_config: dict + :return: A dict that contains ARN of the tuning job. + """ + if self.use_db_config: + if not self.sagemaker_conn_id: + raise AirflowException( + "sagemaker connection id must be present to \ + read sagemaker tunning job configuration.") + + sagemaker_conn = self.get_connection(self.sagemaker_conn_id) + + config = sagemaker_conn.extra_dejson.copy() + tuning_job_config.update(config) + + self.check_valid_tuning_input(tuning_job_config) + + return self.conn.create_hyper_parameter_tuning_job( + **tuning_job_config) + + def describe_training_job(self, training_job_name): + """ + :param training_job_name: the name of the training job + :type train_job_name: string + Return the training job info associated with the current job_name + :return: A dict contains all the training job info + """ + return self.conn\ + .describe_training_job(TrainingJobName=training_job_name) + + def describe_tuning_job(self, tuning_job_name): + """ + :param tuning_job_name: the name of the training job + :type tuning_job_name: string + Return the tuning job info associated with the current job_name + :return: A dict contains all the tuning job info + """ + return self.conn\ + .describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name) diff --git a/airflow/contrib/operators/sagemaker_create_training_job_operator.py b/airflow/contrib/operators/sagemaker_create_training_job_operator.py new file mode 100644 index 0000000000000..409c5f6aa936a --- /dev/null +++ b/airflow/contrib/operators/sagemaker_create_training_job_operator.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerCreateTrainingJobOperator(BaseOperator): + + """ + Initiate a SageMaker training + + This operator returns The ARN of the model created in Amazon SageMaker + + :param training_job_config: + The configuration necessary to start a training job (templated) + :type training_job_config: dict + :param region_name: The AWS region_name + :type region_name: string + :param sagemaker_conn_id: The SageMaker connection ID to use. + :type sagemaker_conn_id: string + :param use_db_config: Whether or not to use db config + associated with sagemaker_conn_id. + If set to true, will automatically update the training config + with what's in db, so the db config doesn't need to + included everything, but what's there does replace the ones + in the training_job_config, so be careful + :type use_db_config: bool + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: string + :param wait_for_completion: if the operator should block + until training job finishes + :type wait_for_completion: bool + :param check_interval: if wait is set to be true, this is the time interval + which the operator will check the status of the training job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the training job hasn't finish within the max_ingestion_time + (Caution: be careful to set this parameters because training can take very long) + :type max_ingestion_time: int + + **Example**: + The following operator would start a training job when executed + + sagemaker_training = + SageMakerCreateTrainingJobOperator( + task_id='sagemaker_training', + training_job_config=config, + region_name='us-west-2' + sagemaker_conn_id='sagemaker_customers_conn', + use_db_config=True, + aws_conn_id='aws_customers_conn' + ) + """ + + template_fields = ['training_job_config'] + template_ext = () + ui_color = '#ededed' + + @apply_defaults + def __init__(self, + training_job_config=None, + region_name=None, + sagemaker_conn_id=None, + use_db_config=False, + wait_for_completion=True, + check_interval=5, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerCreateTrainingJobOperator, self).__init__(*args, **kwargs) + + self.sagemaker_conn_id = sagemaker_conn_id + self.training_job_config = training_job_config + self.use_db_config = use_db_config + self.region_name = region_name + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def execute(self, context): + sagemaker = SageMakerHook( + sagemaker_conn_id=self.sagemaker_conn_id, + use_db_config=self.use_db_config, + region_name=self.region_name, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time + ) + + self.log.info( + "Creating SageMaker Training Job %s." + % self.training_job_config['TrainingJobName'] + ) + response = sagemaker.create_training_job( + self.training_job_config, + wait_for_completion=self.wait_for_completion) + if not response['ResponseMetadata']['HTTPStatusCode'] \ + == 200: + raise AirflowException( + 'Sagemaker Training Job creation failed: %s' % response) + else: + return response diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py new file mode 100644 index 0000000000000..149c2a1aab124 --- /dev/null +++ b/airflow/contrib/sensors/sagemaker_base_sensor.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerBaseSensor(BaseSensorOperator): + """ + Contains general sensor behavior for SageMaker. + Subclasses should implement get_sagemaker_response() + and state_from_response() methods. + Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. + """ + ui_color = '#66c3ff' + + @apply_defaults + def __init__( + self, + aws_conn_id='aws_default', + *args, **kwargs): + super(SageMakerBaseSensor, self).__init__(*args, **kwargs) + self.aws_conn_id = aws_conn_id + + def poke(self, context): + response = self.get_sagemaker_response() + + if not response['ResponseMetadata']['HTTPStatusCode'] == 200: + self.log.info('Bad HTTP response: %s', response) + return False + + state = self.state_from_response(response) + + self.log.info('Job currently %s', state) + + if state in self.non_terminal_states(): + return False + + if state in self.failed_states(): + failed_reason = self.get_failed_reason_from_response(response) + raise AirflowException("Sagemaker job failed for the following reason: %s" + % failed_reason) + return True + + def non_terminal_states(self): + raise AirflowException("Non Terminal States need to be specified in subclass") + + def failed_states(self): + raise AirflowException("Failed States need to be specified in subclass") + + def get_sagemaker_response(self): + raise AirflowException( + "Method get_sagemaker_response()not implemented.") + + def get_failed_reason_from_response(self, response): + return 'Unknown' + + def state_from_response(self, response): + raise AirflowException( + "Method state_from_response()not implemented.") diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py new file mode 100644 index 0000000000000..90c62ce988fbf --- /dev/null +++ b/airflow/contrib/sensors/sagemaker_training_sensor.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTrainingSensor(SageMakerBaseSensor): + """ + Asks for the state of the training state until it reaches a terminal state. + If it fails the sensor errors, failing the task. + + :param job_name: job_name of the training instance to check the state of + :type job_name: string + """ + + template_fields = ['job_name'] + template_ext = () + + @apply_defaults + def __init__(self, + job_name, + region_name=None, + *args, + **kwargs): + super(SageMakerTrainingSensor, self).__init__(*args, **kwargs) + self.job_name = job_name + self.region_name = region_name + + def non_terminal_states(self): + return ['InProgress', 'Stopping', 'Stopped'] + + def failed_states(self): + return ['Failed'] + + def get_sagemaker_response(self): + sagemaker = SageMakerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name + ) + + self.log.info('Poking Sagemaker Training Job %s', self.job_name) + return sagemaker.describe_training_job(self.job_name) + + def get_failed_reason_from_response(self, response): + return response['FailureReason'] + + def state_from_response(self, response): + return response['TrainingJobStatus'] diff --git a/tests/contrib/hooks/test_sagemaker_hook.py b/tests/contrib/hooks/test_sagemaker_hook.py new file mode 100644 index 0000000000000..6887a5b484bed --- /dev/null +++ b/tests/contrib/hooks/test_sagemaker_hook.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import json +import unittest +import copy +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow import models +from airflow.utils import db +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.hooks.S3_hook import S3Hook +from airflow.exceptions import AirflowException + + +role = 'test-role' + +bucket = 'test-bucket' + +key = 'test/data' +data_url = 's3://{}/{}'.format(bucket, key) + +job_name = 'test-job-name' + +image = 'test-image' + +test_arn_return = {'TrainingJobArn': 'testarn'} + +test_list_training_job_return = { + 'TrainingJobSummaries': [ + { + 'TrainingJobName': job_name, + 'TrainingJobStatus': 'InProgress' + }, + ], + 'NextToken': 'test-token' +} + +test_list_tuning_job_return = { + 'TrainingJobSummaries': [ + { + 'TrainingJobName': job_name, + 'TrainingJobArn': 'testarn', + 'TunedHyperParameters': { + 'k': '3' + }, + 'TrainingJobStatus': 'InProgress' + }, + ], + 'NextToken': 'test-token' +} + +output_url = 's3://{}/test/output'.format(bucket) +create_training_params = \ + { + 'AlgorithmSpecification': { + 'TrainingImage': image, + 'TrainingInputMode': 'File' + }, + 'RoleArn': role, + 'OutputDataConfig': { + 'S3OutputPath': output_url + }, + 'ResourceConfig': { + 'InstanceCount': 2, + 'InstanceType': 'ml.c4.8xlarge', + 'VolumeSizeInGB': 50 + }, + 'TrainingJobName': job_name, + 'HyperParameters': { + 'k': '10', + 'feature_dim': '784', + 'mini_batch_size': '500', + 'force_dense': 'True' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 60 * 60 + }, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': 'FullyReplicated' + } + }, + 'CompressionType': 'None', + 'RecordWrapperType': 'None' + } + ] + } + +create_tuning_params = \ + { + 'HyperParameterTuningJobName': job_name, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': { + 'Type': 'Maximize', + 'MetricName': 'test_metric' + }, + 'ResourceLimits': { + 'MaxNumberOfTrainingJobs': 123, + 'MaxParallelTrainingJobs': 123 + }, + 'ParameterRanges': { + 'IntegerParameterRanges': [ + { + 'Name': 'k', + 'MinValue': '2', + 'MaxValue': '10' + }, + + ] + } + }, + 'TrainingJobDefinition': { + 'StaticHyperParameters': create_training_params['HyperParameters'], + 'AlgorithmSpecification': create_training_params['AlgorithmSpecification'], + 'RoleArn': 'string', + 'InputDataConfig': create_training_params['InputDataConfig'], + 'OutputDataConfig': create_training_params['OutputDataConfig'], + 'ResourceConfig': create_training_params['ResourceConfig'], + 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) + } + } + +db_config = { + 'Tags': [ + { + 'Key': 'test-db-key', + 'Value': 'test-db-value', + + }, + ] +} + +DESCRIBE_TRAINING_INPROGRESS_RETURN = { + 'TrainingJobStatus': 'InProgress', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_COMPELETED_RETURN = { + 'TrainingJobStatus': 'Compeleted', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_FAILED_RETURN = { + 'TrainingJobStatus': 'Failed', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + }, + 'FailureReason': 'Unknown' +} +DESCRIBE_TRAINING_STOPPING_RETURN = { + 'TrainingJobStatus': 'Stopping', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_STOPPED_RETURN = { + 'TrainingJobStatus': 'Stopped', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} + + +class TestSageMakerHook(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='sagemaker_test_conn_id', + conn_type='sagemaker', + login='access_id', + password='access_key', + extra=json.dumps(db_config) + ) + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(S3Hook, 'check_for_key') + @mock.patch.object(S3Hook, 'check_for_bucket') + def test_check_for_url(self, + mock_check_bucket, mock_check_key, mock_client): + mock_client.return_value = None + hook = SageMakerHook() + mock_check_bucket.side_effect = [False, True, True] + mock_check_key.side_effect = [False, True] + self.assertRaises(AirflowException, + hook.check_for_url, data_url) + self.assertRaises(AirflowException, + hook.check_for_url, data_url) + self.assertEqual(hook.check_for_url(data_url), True) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'check_for_url') + def test_check_valid_training(self, mock_check_url, mock_client): + mock_client.return_value = None + hook = SageMakerHook() + hook.check_valid_training_input(create_training_params) + mock_check_url.assert_called_once_with(data_url) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'check_for_url') + def test_check_valid_tuning(self, mock_check_url, mock_client): + mock_client.return_value = None + hook = SageMakerHook() + hook.check_valid_tuning_input(create_tuning_params) + mock_check_url.assert_called_once_with(data_url) + + @mock.patch.object(SageMakerHook, 'get_client_type') + def test_conn(self, mock_get_client): + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', + region_name='us-east-1' + ) + self.assertEqual(hook.sagemaker_conn_id, 'sagemaker_test_conn_id') + mock_get_client.assert_called_once_with('sagemaker', + region_name='us-east-1' + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_list_training_job(self, mock_client): + mock_session = mock.Mock() + attrs = {'list_training_jobs.return_value': + test_list_training_job_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.list_training_job(name_contains=job_name, + status_equals='InProgress') + mock_session.list_training_jobs. \ + assert_called_once_with(NameContains=job_name, + StatusEquals='InProgress') + self.assertEqual(response, test_list_training_job_return) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_list_tuning_job(self, mock_client): + mock_session = mock.Mock() + attrs = {'list_hyper_parameter_tuning_job.return_value': + test_list_tuning_job_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.list_tuning_job(name_contains=job_name, + status_equals='InProgress') + mock_session.list_hyper_parameter_tuning_job. \ + assert_called_once_with(NameContains=job_name, + StatusEquals='InProgress') + self.assertEqual(response, test_list_tuning_job_return) + + @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_training_job(self, mock_client, mock_check_training): + mock_check_training.return_value = True + mock_session = mock.Mock() + attrs = {'create_training_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.create_training_job(create_training_params, + wait_for_completion=False) + mock_session.create_training_job.assert_called_once_with(**create_training_params) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_training_job_db_config(self, mock_client, mock_check_training): + mock_check_training.return_value = True + mock_session = mock.Mock() + attrs = {'create_training_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', + use_db_config=True) + response = hook_use_db_config.create_training_job(create_training_params, + wait_for_completion=False) + updated_config = copy.deepcopy(create_training_params) + updated_config.update(db_config) + mock_session.create_training_job.assert_called_once_with(**updated_config) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_training_ends_with_wait_on(self, mock_client, mock_check_training): + mock_check_training.return_value = True + mock_session = mock.Mock() + attrs = {'create_training_job.return_value': + test_arn_return, + 'describe_training_job.side_effect': + [DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_STOPPED_RETURN, + DESCRIBE_TRAINING_COMPELETED_RETURN] + } + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') + hook.create_training_job(create_training_params, wait_for_completion=True) + self.assertEqual(mock_session.describe_training_job.call_count, 4) + + @mock.patch.object(SageMakerHook, 'check_valid_training_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_training_throws_error_when_failed_with_wait_on( + self, mock_client, mock_check_training): + mock_check_training.return_value = True + mock_session = mock.Mock() + attrs = {'create_training_job.return_value': + test_arn_return, + 'describe_training_job.side_effect': + [DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_STOPPED_RETURN, + DESCRIBE_TRAINING_FAILED_RETURN] + } + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id_1') + self.assertRaises(AirflowException, hook.create_training_job, + create_training_params, wait_for_completion=True) + self.assertEqual(mock_session.describe_training_job.call_count, 4) + + @mock.patch.object(SageMakerHook, 'check_valid_tuning_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_tuning_job(self, mock_client, mock_check_tuning): + mock_session = mock.Mock() + attrs = {'create_hyper_parameter_tuning_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.create_tuning_job(create_tuning_params) + mock_session.create_hyper_parameter_tuning_job.\ + assert_called_once_with(**create_tuning_params) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'check_valid_tuning_input') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning): + mock_check_tuning.return_value = True + mock_session = mock.Mock() + attrs = {'create_hyper_parameter_tuning_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', + use_db_config=True) + response = hook.create_tuning_job(create_tuning_params) + updated_config = copy.deepcopy(create_tuning_params) + updated_config.update(db_config) + mock_session.create_hyper_parameter_tuning_job. \ + assert_called_once_with(**updated_config) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_training_job(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_training_job.return_value': 'InProgress'} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.describe_training_job(job_name) + mock_session.describe_training_job.\ + assert_called_once_with(TrainingJobName=job_name) + self.assertEqual(response, 'InProgress') + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_tuning_job(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_hyper_parameter_tuning_job.return_value': + 'InProgress'} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.describe_tuning_job(job_name) + mock_session.describe_hyper_parameter_tuning_job.\ + assert_called_once_with(HyperParameterTuningJobName=job_name) + self.assertEqual(response, 'InProgress') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_sagemaker_create_training_job_operator.py b/tests/contrib/operators/test_sagemaker_create_training_job_operator.py new file mode 100644 index 0000000000000..156c9d74c79ec --- /dev/null +++ b/tests/contrib/operators/test_sagemaker_create_training_job_operator.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.operators.sagemaker_create_training_job_operator \ + import SageMakerCreateTrainingJobOperator +from airflow.exceptions import AirflowException + +role = "test-role" + +bucket = "test-bucket" + +key = "test/data" +data_url = "s3://{}/{}".format(bucket, key) + +job_name = "test-job-name" + +image = "test-image" + +output_url = "s3://{}/test/output".format(bucket) +create_training_params = \ + { + "AlgorithmSpecification": { + "TrainingImage": image, + "TrainingInputMode": "File" + }, + "RoleArn": role, + "OutputDataConfig": { + "S3OutputPath": output_url + }, + "ResourceConfig": { + "InstanceCount": 2, + "InstanceType": "ml.c4.8xlarge", + "VolumeSizeInGB": 50 + }, + "TrainingJobName": job_name, + "HyperParameters": { + "k": "10", + "feature_dim": "784", + "mini_batch_size": "500", + "force_dense": "True" + }, + "StoppingCondition": { + "MaxRuntimeInSeconds": 60 * 60 + }, + "InputDataConfig": [ + { + "ChannelName": "train", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": data_url, + "S3DataDistributionType": "FullyReplicated" + } + }, + "CompressionType": "None", + "RecordWrapperType": "None" + } + ] + } + + +class TestSageMakerTrainingOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerCreateTrainingJobOperator( + task_id='test_sagemaker_operator', + sagemaker_conn_id='sagemaker_test_id', + training_job_config=create_training_params, + region_name='us-west-2', + use_db_config=True, + wait_for_completion=False, + check_interval=5 + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + @mock.patch.object(SageMakerHook, '__init__') + def test_hook_init(self, hook_init, mock_training, mock_client): + mock_training.return_value = {"TrainingJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 200}} + hook_init.return_value = None + self.sagemaker.execute(None) + hook_init.assert_called_once_with( + sagemaker_conn_id='sagemaker_test_id', + region_name='us-west-2', + use_db_config=True, + check_interval=5, + max_ingestion_time=None + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + def test_execute_without_failure(self, mock_training, mock_client): + mock_training.return_value = {"TrainingJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 200}} + self.sagemaker.execute(None) + mock_training.assert_called_once_with(create_training_params, + wait_for_completion=False + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + def test_execute_with_failure(self, mock_training, mock_client): + mock_training.return_value = {"TrainingJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 404}} + self.assertRaises(AirflowException, self.sagemaker.execute, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/sensors/test_sagemaker_base_sensor.py b/tests/contrib/sensors/test_sagemaker_base_sensor.py new file mode 100644 index 0000000000000..bc8cbe349858f --- /dev/null +++ b/tests/contrib/sensors/test_sagemaker_base_sensor.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import configuration +from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor +from airflow.exceptions import AirflowException + + +class TestSagemakerBaseSensor(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + def test_subclasses_succeed_when_response_is_good(self): + class SageMakerBaseSensorSubclass(SageMakerBaseSensor): + def non_terminal_states(self): + return ['PENDING', 'RUNNING', 'CONTINUE'] + + def failed_states(self): + return ['FAILED'] + + def get_sagemaker_response(self): + return { + 'SomeKey': {'State': 'COMPLETED'}, + 'ResponseMetadata': {'HTTPStatusCode': 200} + } + + def state_from_response(self, response): + return response['SomeKey']['State'] + + sensor = SageMakerBaseSensorSubclass( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test' + ) + + sensor.execute(None) + + def test_poke_returns_false_when_state_is_a_non_terminal_state(self): + class SageMakerBaseSensorSubclass(SageMakerBaseSensor): + def non_terminal_states(self): + return ['PENDING', 'RUNNING', 'CONTINUE'] + + def failed_states(self): + return ['FAILED'] + + def get_sagemaker_response(self): + return { + 'SomeKey': {'State': 'PENDING'}, + 'ResponseMetadata': {'HTTPStatusCode': 200} + } + + def state_from_response(self, response): + return response['SomeKey']['State'] + + sensor = SageMakerBaseSensorSubclass( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test' + ) + + self.assertEqual(sensor.poke(None), False) + + def test_poke_raise_exception_when_method_not_implemented(self): + class SageMakerBaseSensorSubclass(SageMakerBaseSensor): + def non_terminal_states(self): + return ['PENDING', 'RUNNING', 'CONTINUE'] + + def failed_states(self): + return ['FAILED'] + + sensor = SageMakerBaseSensorSubclass( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test' + ) + + self.assertRaises(AirflowException, sensor.poke, None) + + def test_poke_returns_false_when_http_response_is_bad(self): + class SageMakerBaseSensorSubclass(SageMakerBaseSensor): + def non_terminal_states(self): + return ['PENDING', 'RUNNING', 'CONTINUE'] + + def failed_states(self): + return ['FAILED'] + + def get_sagemaker_response(self): + return { + 'SomeKey': {'State': 'COMPLETED'}, + 'ResponseMetadata': {'HTTPStatusCode': 400} + } + + def state_from_response(self, response): + return response['SomeKey']['State'] + + sensor = SageMakerBaseSensorSubclass( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test' + ) + + self.assertEqual(sensor.poke(None), False) + + def test_poke_raises_error_when_job_has_failed(self): + class SageMakerBaseSensorSubclass(SageMakerBaseSensor): + def non_terminal_states(self): + return ['PENDING', 'RUNNING', 'CONTINUE'] + + def failed_states(self): + return ['FAILED'] + + def get_sagemaker_response(self): + return { + 'SomeKey': {'State': 'FAILED'}, + 'ResponseMetadata': {'HTTPStatusCode': 200} + } + + def state_from_response(self, response): + return response['SomeKey']['State'] + + sensor = SageMakerBaseSensorSubclass( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test' + ) + + self.assertRaises(AirflowException, sensor.poke, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/sensors/test_sagemaker_training_sensor.py b/tests/contrib/sensors/test_sagemaker_training_sensor.py new file mode 100644 index 0000000000000..fb966f60afbf0 --- /dev/null +++ b/tests/contrib/sensors/test_sagemaker_training_sensor.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.sensors.sagemaker_training_sensor \ + import SageMakerTrainingSensor +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.exceptions import AirflowException + +DESCRIBE_TRAINING_INPROGRESS_RETURN = { + 'TrainingJobStatus': 'InProgress', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_COMPELETED_RETURN = { + 'TrainingJobStatus': 'Compeleted', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_FAILED_RETURN = { + 'TrainingJobStatus': 'Failed', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + }, + 'FailureReason': 'Unknown' +} +DESCRIBE_TRAINING_STOPPING_RETURN = { + 'TrainingJobStatus': 'Stopping', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRAINING_STOPPED_RETURN = { + 'TrainingJobStatus': 'Stopped', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} + + +class TestSageMakerTrainingSensor(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'describe_training_job') + def test_raises_errors_failed_state(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RETURN] + sensor = SageMakerTrainingSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name' + ) + self.assertRaises(AirflowException, sensor.execute, None) + mock_describe_job.assert_called_once_with('test_job_name') + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, '__init__') + @mock.patch.object(SageMakerHook, 'describe_training_job') + def test_calls_until_a_terminal_state(self, + mock_describe_job, hook_init, mock_client): + hook_init.return_value = None + + mock_describe_job.side_effect = [ + DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_STOPPED_RETURN, + DESCRIBE_TRAINING_COMPELETED_RETURN + ] + sensor = SageMakerTrainingSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name', + region_name='us-east-1' + ) + + sensor.execute(None) + + # make sure we called 4 times(terminated when its compeleted) + self.assertEqual(mock_describe_job.call_count, 4) + + # make sure the hook was initialized with the specific params + hook_init.assert_called_with(aws_conn_id='aws_test', + region_name='us-east-1') + + +if __name__ == '__main__': + unittest.main()