From c0dd882707b5795ee933ba51d5ca882fb947f0f4 Mon Sep 17 00:00:00 2001 From: Keliang Chen Date: Fri, 14 Sep 2018 02:19:25 -0500 Subject: [PATCH] [AIRFLOW-2524] Add SageMaker Batch Inference (#3767) * Fix for comments * Fix sensor test * Update non_terminal_states and failed_states to static variables of SageMakerHook Add SageMaker Transform Operator & Sensor Co-authored-by: srrajeev-aws --- airflow/contrib/hooks/sagemaker_hook.py | 79 ++++++++-- .../sagemaker_create_training_job_operator.py | 2 +- ...sagemaker_create_transform_job_operator.py | 132 +++++++++++++++++ .../sagemaker_create_tuning_job_operator.py | 2 +- .../sensors/sagemaker_training_sensor.py | 4 +- .../sensors/sagemaker_transform_sensor.py | 69 +++++++++ .../sensors/sagemaker_tuning_sensor.py | 4 +- tests/contrib/hooks/test_sagemaker_hook.py | 93 ++++++++++++ .../test_sagemaker_transform_sensor.py | 118 +++++++++++++++ ...sagemaker_create_transform_job_operator.py | 140 ++++++++++++++++++ 10 files changed, 627 insertions(+), 16 deletions(-) create mode 100644 airflow/contrib/operators/sagemaker_create_transform_job_operator.py create mode 100644 airflow/contrib/sensors/sagemaker_transform_sensor.py create mode 100644 tests/contrib/sensors/test_sagemaker_transform_sensor.py create mode 100644 tests/operators/test_sagemaker_create_transform_job_operator.py diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py index 09993f96d8738..ebab5d83e4099 100644 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -31,6 +31,8 @@ class SageMakerHook(AwsHook): sagemaker_conn_id is required for using the config stored in db for training/tuning """ + non_terminal_states = {'InProgress', 'Stopping', 'Stopped'} + failed_states = {'Failed'} def __init__(self, sagemaker_conn_id=None, @@ -96,9 +98,9 @@ def check_status(self, non_terminal_states, describe_function, *args): """ :param non_terminal_states: the set of non_terminal states - :type non_terminal_states: dict + :type non_terminal_states: set :param failed_state: the set of failed states - :type failed_state: dict + :type failed_state: set :param key: the key of the response dict that points to the state :type key: string @@ -177,7 +179,7 @@ def create_training_job(self, training_job_config, wait_for_completion=True): :param training_job_config: the config for training :type training_job_config: dict :param wait_for_completion: if the program should keep running until job finishes - :param wait_for_completion: bool + :type wait_for_completion: bool :return: A dict that contains ARN of the training job. """ if self.use_db_config: @@ -194,8 +196,8 @@ def create_training_job(self, training_job_config, wait_for_completion=True): response = self.conn.create_training_job( **training_job_config) if wait_for_completion: - self.check_status(['InProgress', 'Stopping', 'Stopped'], - ['Failed'], + self.check_status(SageMakerHook.non_terminal_states, + SageMakerHook.failed_states, 'TrainingJobStatus', self.describe_training_job, training_job_config['TrainingJobName']) @@ -213,8 +215,8 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True): if self.use_db_config: if not self.sagemaker_conn_id: raise AirflowException( - "sagemaker connection id must be present to \ - read sagemaker tunning job configuration.") + "SageMaker connection id must be present to \ + read SageMaker tunning job configuration.") sagemaker_conn = self.get_connection(self.sagemaker_conn_id) @@ -226,13 +228,59 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True): response = self.conn.create_hyper_parameter_tuning_job( **tuning_job_config) if wait_for_completion: - self.check_status(['InProgress', 'Stopping', 'Stopped'], - ['Failed'], + self.check_status(SageMakerHook.non_terminal_states, + SageMakerHook.failed_states, 'HyperParameterTuningJobStatus', self.describe_tuning_job, tuning_job_config['HyperParameterTuningJobName']) return response + def create_transform_job(self, transform_job_config, wait_for_completion=True): + """ + Create a transform job + :param transform_job_config: the config for transform job + :type transform_job_config: dict + :param wait_for_completion: + if the program should keep running until job finishes + :type wait_for_completion: bool + :return: A dict that contains ARN of the transform job. + """ + if self.use_db_config: + if not self.sagemaker_conn_id: + raise AirflowException( + "SageMaker connection id must be present to \ + read SageMaker transform job configuration.") + + sagemaker_conn = self.get_connection(self.sagemaker_conn_id) + + config = sagemaker_conn.extra_dejson.copy() + transform_job_config.update(config) + + self.check_for_url(transform_job_config + ['TransformInput']['DataSource'] + ['S3DataSource']['S3Uri']) + + response = self.conn.create_transform_job( + **transform_job_config) + if wait_for_completion: + self.check_status(SageMakerHook.non_terminal_states, + SageMakerHook.failed_states, + 'TransformJobStatus', + self.describe_transform_job, + transform_job_config['TransformJobName']) + return response + + def create_model(self, model_config): + """ + Create a model job + :param model_config: the config for model + :type model_config: dict + :return: A dict that contains ARN of the model. + """ + + return self.conn.create_model( + **model_config) + def describe_training_job(self, training_job_name): """ :param training_job_name: the name of the training job @@ -245,7 +293,7 @@ def describe_training_job(self, training_job_name): def describe_tuning_job(self, tuning_job_name): """ - :param tuning_job_name: the name of the training job + :param tuning_job_name: the name of the tuning job :type tuning_job_name: string Return the tuning job info associated with the current job_name :return: A dict contains all the tuning job info @@ -253,3 +301,14 @@ def describe_tuning_job(self, tuning_job_name): return self.conn\ .describe_hyper_parameter_tuning_job( HyperParameterTuningJobName=tuning_job_name) + + def describe_transform_job(self, transform_job_name): + """ + :param transform_job_name: the name of the transform job + :type transform_job_name: string + Return the transform job info associated with the current job_name + :return: A dict contains all the transform job info + """ + return self.conn\ + .describe_transform_job( + TransformJobName=transform_job_name) diff --git a/airflow/contrib/operators/sagemaker_create_training_job_operator.py b/airflow/contrib/operators/sagemaker_create_training_job_operator.py index 409c5f6aa936a..fdd935fc2931b 100644 --- a/airflow/contrib/operators/sagemaker_create_training_job_operator.py +++ b/airflow/contrib/operators/sagemaker_create_training_job_operator.py @@ -50,7 +50,7 @@ class SageMakerCreateTrainingJobOperator(BaseOperator): until training job finishes :type wait_for_completion: bool :param check_interval: if wait is set to be true, this is the time interval - which the operator will check the status of the training job + in seconds which the operator will check the status of the training job :type check_interval: int :param max_ingestion_time: if wait is set to be true, the operator will fail if the training job hasn't finish within the max_ingestion_time diff --git a/airflow/contrib/operators/sagemaker_create_transform_job_operator.py b/airflow/contrib/operators/sagemaker_create_transform_job_operator.py new file mode 100644 index 0000000000000..22c8c2b4ba297 --- /dev/null +++ b/airflow/contrib/operators/sagemaker_create_transform_job_operator.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerCreateTransformJobOperator(BaseOperator): + """ + Initiate a SageMaker transform + + This operator returns The ARN of the model created in Amazon SageMaker + + :param sagemaker_conn_id: The SageMaker connection ID to use. + :type sagemaker_conn_id: string + :param transform_job_config: + The configuration necessary to start a transform job (templated) + :type transform_job_config: dict + :param model_config: + The configuration necessary to create a model, the default is none + which means that user should provide a created model in transform_job_config + If given, will be used to create a model before creating transform job + :type model_config: dict + :param use_db_config: Whether or not to use db config + associated with sagemaker_conn_id. + If set to true, will automatically update the transform config + with what's in db, so the db config doesn't need to + included everything, but what's there does replace the ones + in the transform_job_config, so be careful + :type use_db_config: bool + :param region_name: The AWS region_name + :type region_name: string + :param wait_for_completion: if the program should keep running until job finishes + :type wait_for_completion: bool + :param check_interval: if wait is set to be true, this is the time interval + in seconds which the operator will check the status of the transform job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the transform job hasn't finish within the max_ingestion_time + (Caution: be careful to set this parameters because transform can take very long) + :type max_ingestion_time: int + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: string + + **Example**: + The following operator would start a transform job when executed + + sagemaker_transform = + SageMakerCreateTransformJobOperator( + task_id='sagemaker_transform', + transform_job_config=config_transform, + model_config=config_model, + region_name='us-west-2' + sagemaker_conn_id='sagemaker_customers_conn', + use_db_config=True, + aws_conn_id='aws_customers_conn' + ) + """ + + template_fields = ['transform_job_config'] + template_ext = () + ui_color = '#ededed' + + @apply_defaults + def __init__(self, + sagemaker_conn_id=None, + transform_job_config=None, + model_config=None, + use_db_config=False, + region_name=None, + wait_for_completion=True, + check_interval=2, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerCreateTransformJobOperator, self).__init__(*args, **kwargs) + + self.sagemaker_conn_id = sagemaker_conn_id + self.transform_job_config = transform_job_config + self.model_config = model_config + self.use_db_config = use_db_config + self.region_name = region_name + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def execute(self, context): + sagemaker = SageMakerHook( + sagemaker_conn_id=self.sagemaker_conn_id, + use_db_config=self.use_db_config, + region_name=self.region_name, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time + ) + + if self.model_config: + self.log.info( + "Creating SageMaker Model %s for transform job" + % self.model_config['ModelName'] + ) + sagemaker.create_model(self.model_config) + + self.log.info( + "Creating SageMaker transform Job %s." + % self.transform_job_config['TransformJobName'] + ) + response = sagemaker.create_transform_job( + self.transform_job_config, + wait_for_completion=self.wait_for_completion) + if not response['ResponseMetadata']['HTTPStatusCode'] \ + == 200: + raise AirflowException( + 'Sagemaker transform Job creation failed: %s' % response) + else: + return response diff --git a/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py b/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py index 0c40a9adc93f4..46ccb2a201144 100644 --- a/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py +++ b/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py @@ -48,7 +48,7 @@ class SageMakerCreateTuningJobOperator(BaseOperator): until tuning job finishes :type wait_for_completion: bool :param check_interval: if wait is set to be true, this is the time interval - which the operator will check the status of the tuning job + in seconds which the operator will check the status of the tuning job :type check_interval: int :param max_ingestion_time: if wait is set to be true, the operator will fail if the tuning job hasn't finish within the max_ingestion_time diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py index 90c62ce988fbf..449de44c0819c 100644 --- a/airflow/contrib/sensors/sagemaker_training_sensor.py +++ b/airflow/contrib/sensors/sagemaker_training_sensor.py @@ -45,10 +45,10 @@ def __init__(self, self.region_name = region_name def non_terminal_states(self): - return ['InProgress', 'Stopping', 'Stopped'] + return SageMakerHook.non_terminal_states def failed_states(self): - return ['Failed'] + return SageMakerHook.failed_states def get_sagemaker_response(self): sagemaker = SageMakerHook( diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py new file mode 100644 index 0000000000000..68ef1d8dd7b05 --- /dev/null +++ b/airflow/contrib/sensors/sagemaker_transform_sensor.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTransformSensor(SageMakerBaseSensor): + """ + Asks for the state of the transform state until it reaches a terminal state. + The sensor will error if the job errors, throwing a AirflowException + containing the failure reason. + + :param job_name: job_name of the transform job instance to check the state of + :type job_name: string + :param region_name: The AWS region_name + :type region_name: string + """ + + template_fields = ['job_name'] + template_ext = () + + @apply_defaults + def __init__(self, + job_name, + region_name=None, + *args, + **kwargs): + super(SageMakerTransformSensor, self).__init__(*args, **kwargs) + self.job_name = job_name + self.region_name = region_name + + def non_terminal_states(self): + return SageMakerHook.non_terminal_states + + def failed_states(self): + return SageMakerHook.failed_states + + def get_sagemaker_response(self): + sagemaker = SageMakerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name + ) + + self.log.info('Poking Sagemaker Transform Job %s', self.job_name) + return sagemaker.describe_transform_job(self.job_name) + + def get_failed_reason_from_response(self, response): + return response['FailureReason'] + + def state_from_response(self, response): + return response['TransformJobStatus'] diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py index bc74e3a5c5461..1f081100e2c69 100644 --- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py +++ b/airflow/contrib/sensors/sagemaker_tuning_sensor.py @@ -48,10 +48,10 @@ def __init__(self, self.region_name = region_name def non_terminal_states(self): - return ['InProgress', 'Stopping', 'Stopped'] + return SageMakerHook.non_terminal_states def failed_states(self): - return ['Failed'] + return SageMakerHook.failed_states def get_sagemaker_response(self): sagemaker = SageMakerHook( diff --git a/tests/contrib/hooks/test_sagemaker_hook.py b/tests/contrib/hooks/test_sagemaker_hook.py index 8bb56cc8e7d12..3a863b3cb0dc7 100644 --- a/tests/contrib/hooks/test_sagemaker_hook.py +++ b/tests/contrib/hooks/test_sagemaker_hook.py @@ -47,6 +47,8 @@ job_name = 'test-job-name' +model_name = 'test-model-name' + image = 'test-image' test_arn_return = {'TrainingJobArn': 'testarn'} @@ -152,6 +154,38 @@ } } +create_transform_params = \ + { + 'TransformJobName': job_name, + 'ModelName': model_name, + 'BatchStrategy': 'MultiRecord', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url + } + } + }, + 'TransformOutput': { + 'S3OutputPath': output_url, + }, + 'TransformResources': { + 'InstanceType': 'ml.m4.xlarge', + 'InstanceCount': 123 + } + } + +create_model_params = \ + { + 'ModelName': model_name, + 'PrimaryContainer': { + 'Image': image, + 'ModelDataUrl': output_url, + }, + 'ExecutionRoleArn': role + } + db_config = { 'Tags': [ { @@ -393,6 +427,52 @@ def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning): assert_called_once_with(**updated_config) self.assertEqual(response, test_arn_return) + @mock.patch.object(SageMakerHook, 'check_for_url') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_transform_job(self, mock_client, mock_check_url): + mock_check_url.return_value = True + mock_session = mock.Mock() + attrs = {'create_transform_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.create_transform_job(create_transform_params, + wait_for_completion=False) + mock_session.create_transform_job.assert_called_once_with( + **create_transform_params) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'check_for_url') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_transform_job_db_config(self, mock_client, mock_check_url): + mock_check_url.return_value = True + mock_session = mock.Mock() + attrs = {'create_transform_job.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook_use_db_config = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', + use_db_config=True) + response = hook_use_db_config.create_transform_job( + create_transform_params, wait_for_completion=False) + updated_config = copy.deepcopy(create_transform_params) + updated_config.update(db_config) + mock_session.create_transform_job.assert_called_once_with(**updated_config) + self.assertEqual(response, test_arn_return) + + @mock.patch.object(SageMakerHook, 'get_conn') + def test_create_model(self, mock_client): + mock_session = mock.Mock() + attrs = {'create_model.return_value': + test_arn_return} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.create_model(create_model_params) + mock_session.create_model.assert_called_once_with(**create_model_params) + self.assertEqual(response, test_arn_return) + @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_training_job(self, mock_client): mock_session = mock.Mock() @@ -418,6 +498,19 @@ def test_describe_tuning_job(self, mock_client): assert_called_once_with(HyperParameterTuningJobName=job_name) self.assertEqual(response, 'InProgress') + @mock.patch.object(SageMakerHook, 'get_conn') + def test_describe_transform_job(self, mock_client): + mock_session = mock.Mock() + attrs = {'describe_transform_job.return_value': + 'InProgress'} + mock_session.configure_mock(**attrs) + mock_client.return_value = mock_session + hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') + response = hook.describe_transform_job(job_name) + mock_session.describe_transform_job.\ + assert_called_once_with(TransformJobName=job_name) + self.assertEqual(response, 'InProgress') + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/sensors/test_sagemaker_transform_sensor.py b/tests/contrib/sensors/test_sagemaker_transform_sensor.py new file mode 100644 index 0000000000000..bb4a184bb2797 --- /dev/null +++ b/tests/contrib/sensors/test_sagemaker_transform_sensor.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.sensors.sagemaker_transform_sensor \ + import SageMakerTransformSensor +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.exceptions import AirflowException + +DESCRIBE_TRANSFORM_INPROGRESS_RETURN = { + 'TransformJobStatus': 'InProgress', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRANSFORM_COMPELETED_RETURN = { + 'TransformJobStatus': 'Compeleted', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRANSFORM_FAILED_RETURN = { + 'TransformJobStatus': 'Failed', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + }, + 'FailureReason': 'Unknown' +} +DESCRIBE_TRANSFORM_STOPPING_RETURN = { + 'TransformJobStatus': 'Stopping', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TRANSFORM_STOPPED_RETURN = { + 'TransformJobStatus': 'Stopped', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} + + +class TestSageMakerTransformSensor(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'describe_transform_job') + def test_raises_errors_failed_state(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_TRANSFORM_FAILED_RETURN] + sensor = SageMakerTransformSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name' + ) + self.assertRaises(AirflowException, sensor.execute, None) + mock_describe_job.assert_called_once_with('test_job_name') + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, '__init__') + @mock.patch.object(SageMakerHook, 'describe_transform_job') + def test_calls_until_a_terminal_state(self, + mock_describe_job, hook_init, mock_client): + hook_init.return_value = None + + mock_describe_job.side_effect = [ + DESCRIBE_TRANSFORM_INPROGRESS_RETURN, + DESCRIBE_TRANSFORM_STOPPING_RETURN, + DESCRIBE_TRANSFORM_STOPPED_RETURN, + DESCRIBE_TRANSFORM_COMPELETED_RETURN + ] + sensor = SageMakerTransformSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name', + region_name='us-east-1' + ) + + sensor.execute(None) + + # make sure we called 4 times(terminated when its compeleted) + self.assertEqual(mock_describe_job.call_count, 4) + + # make sure the hook was initialized with the specific params + hook_init.assert_called_with(aws_conn_id='aws_test', + region_name='us-east-1') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/operators/test_sagemaker_create_transform_job_operator.py b/tests/operators/test_sagemaker_create_transform_job_operator.py new file mode 100644 index 0000000000000..a8701530d9daa --- /dev/null +++ b/tests/operators/test_sagemaker_create_transform_job_operator.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.operators.sagemaker_create_transform_job_operator \ + import SageMakerCreateTransformJobOperator +from airflow.exceptions import AirflowException + +role = 'test-role' + +bucket = 'test-bucket' + +key = 'test/data' +data_url = 's3://{}/{}'.format(bucket, key) + +job_name = 'test-job-name' + +model_name = 'test-model-name' + +image = 'test-image' + +output_url = 's3://{}/test/output'.format(bucket) + +create_transform_params = \ + { + 'TransformJobName': job_name, + 'ModelName': model_name, + 'BatchStrategy': 'MultiRecord', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url + } + } + }, + 'TransformOutput': { + 'S3OutputPath': output_url, + }, + 'TransformResources': { + 'InstanceType': 'ml.m4.xlarge', + 'InstanceCount': 123 + } + } + +create_model_params = \ + { + 'ModelName': model_name, + 'PrimaryContainer': { + 'Image': image, + 'ModelDataUrl': output_url, + }, + 'ExecutionRoleArn': role + } + + +class TestSageMakertransformOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerCreateTransformJobOperator( + task_id='test_sagemaker_operator', + sagemaker_conn_id='sagemaker_test_id', + transform_job_config=create_transform_params, + model_config=create_model_params, + region_name='us-west-2', + use_db_config=True, + wait_for_completion=False, + check_interval=5 + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_model') + @mock.patch.object(SageMakerHook, 'create_transform_job') + @mock.patch.object(SageMakerHook, '__init__') + def test_hook_init(self, hook_init, mock_transform, mock_model, mock_client): + mock_transform.return_value = {"TransformJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 200}} + hook_init.return_value = None + self.sagemaker.execute(None) + hook_init.assert_called_once_with( + sagemaker_conn_id='sagemaker_test_id', + region_name='us-west-2', + use_db_config=True, + check_interval=5, + max_ingestion_time=None + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_model') + @mock.patch.object(SageMakerHook, 'create_transform_job') + def test_execute_without_failure(self, mock_transform, mock_model, mock_client): + mock_transform.return_value = {"TransformJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 200}} + self.sagemaker.execute(None) + mock_model.assert_called_once_with(create_model_params) + mock_transform.assert_called_once_with(create_transform_params, + wait_for_completion=False + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_model') + @mock.patch.object(SageMakerHook, 'create_transform_job') + def test_execute_with_failure(self, mock_transform, mock_model, mock_client): + mock_transform.return_value = {"TransformJobArn": "testarn", + "ResponseMetadata": + {"HTTPStatusCode": 404}} + self.assertRaises(AirflowException, self.sagemaker.execute, None) + + +if __name__ == '__main__': + unittest.main()