Skip to content

Commit

Permalink
[AIRFLOW-2524] Add SageMaker Batch Inference (apache#3767)
Browse files Browse the repository at this point in the history
* Fix for comments
* Fix sensor test
* Update non_terminal_states and failed_states to static variables of SageMakerHook

Add SageMaker Transform Operator & Sensor
Co-authored-by: srrajeev-aws <[email protected]>
  • Loading branch information
troychen728 authored and galak75 committed Nov 23, 2018
1 parent 3752828 commit c0dd882
Show file tree
Hide file tree
Showing 10 changed files with 627 additions and 16 deletions.
79 changes: 69 additions & 10 deletions airflow/contrib/hooks/sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class SageMakerHook(AwsHook):
sagemaker_conn_id is required for using
the config stored in db for training/tuning
"""
non_terminal_states = {'InProgress', 'Stopping', 'Stopped'}
failed_states = {'Failed'}

def __init__(self,
sagemaker_conn_id=None,
Expand Down Expand Up @@ -96,9 +98,9 @@ def check_status(self, non_terminal_states,
describe_function, *args):
"""
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: dict
:type non_terminal_states: set
:param failed_state: the set of failed states
:type failed_state: dict
:type failed_state: set
:param key: the key of the response dict
that points to the state
:type key: string
Expand Down Expand Up @@ -177,7 +179,7 @@ def create_training_job(self, training_job_config, wait_for_completion=True):
:param training_job_config: the config for training
:type training_job_config: dict
:param wait_for_completion: if the program should keep running until job finishes
:param wait_for_completion: bool
:type wait_for_completion: bool
:return: A dict that contains ARN of the training job.
"""
if self.use_db_config:
Expand All @@ -194,8 +196,8 @@ def create_training_job(self, training_job_config, wait_for_completion=True):
response = self.conn.create_training_job(
**training_job_config)
if wait_for_completion:
self.check_status(['InProgress', 'Stopping', 'Stopped'],
['Failed'],
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'TrainingJobStatus',
self.describe_training_job,
training_job_config['TrainingJobName'])
Expand All @@ -213,8 +215,8 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True):
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"sagemaker connection id must be present to \
read sagemaker tunning job configuration.")
"SageMaker connection id must be present to \
read SageMaker tunning job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

Expand All @@ -226,13 +228,59 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True):
response = self.conn.create_hyper_parameter_tuning_job(
**tuning_job_config)
if wait_for_completion:
self.check_status(['InProgress', 'Stopping', 'Stopped'],
['Failed'],
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'HyperParameterTuningJobStatus',
self.describe_tuning_job,
tuning_job_config['HyperParameterTuningJobName'])
return response

def create_transform_job(self, transform_job_config, wait_for_completion=True):
"""
Create a transform job
:param transform_job_config: the config for transform job
:type transform_job_config: dict
:param wait_for_completion:
if the program should keep running until job finishes
:type wait_for_completion: bool
:return: A dict that contains ARN of the transform job.
"""
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"SageMaker connection id must be present to \
read SageMaker transform job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

config = sagemaker_conn.extra_dejson.copy()
transform_job_config.update(config)

self.check_for_url(transform_job_config
['TransformInput']['DataSource']
['S3DataSource']['S3Uri'])

response = self.conn.create_transform_job(
**transform_job_config)
if wait_for_completion:
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'TransformJobStatus',
self.describe_transform_job,
transform_job_config['TransformJobName'])
return response

def create_model(self, model_config):
"""
Create a model job
:param model_config: the config for model
:type model_config: dict
:return: A dict that contains ARN of the model.
"""

return self.conn.create_model(
**model_config)

def describe_training_job(self, training_job_name):
"""
:param training_job_name: the name of the training job
Expand All @@ -245,11 +293,22 @@ def describe_training_job(self, training_job_name):

def describe_tuning_job(self, tuning_job_name):
"""
:param tuning_job_name: the name of the training job
:param tuning_job_name: the name of the tuning job
:type tuning_job_name: string
Return the tuning job info associated with the current job_name
:return: A dict contains all the tuning job info
"""
return self.conn\
.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuning_job_name)

def describe_transform_job(self, transform_job_name):
"""
:param transform_job_name: the name of the transform job
:type transform_job_name: string
Return the transform job info associated with the current job_name
:return: A dict contains all the transform job info
"""
return self.conn\
.describe_transform_job(
TransformJobName=transform_job_name)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SageMakerCreateTrainingJobOperator(BaseOperator):
until training job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
which the operator will check the status of the training job
in seconds which the operator will check the status of the training job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the training job hasn't finish within the max_ingestion_time
Expand Down
132 changes: 132 additions & 0 deletions airflow/contrib/operators/sagemaker_create_transform_job_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException


class SageMakerCreateTransformJobOperator(BaseOperator):
"""
Initiate a SageMaker transform
This operator returns The ARN of the model created in Amazon SageMaker
:param sagemaker_conn_id: The SageMaker connection ID to use.
:type sagemaker_conn_id: string
:param transform_job_config:
The configuration necessary to start a transform job (templated)
:type transform_job_config: dict
:param model_config:
The configuration necessary to create a model, the default is none
which means that user should provide a created model in transform_job_config
If given, will be used to create a model before creating transform job
:type model_config: dict
:param use_db_config: Whether or not to use db config
associated with sagemaker_conn_id.
If set to true, will automatically update the transform config
with what's in db, so the db config doesn't need to
included everything, but what's there does replace the ones
in the transform_job_config, so be careful
:type use_db_config: bool
:param region_name: The AWS region_name
:type region_name: string
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the transform job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the transform job hasn't finish within the max_ingestion_time
(Caution: be careful to set this parameters because transform can take very long)
:type max_ingestion_time: int
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: string
**Example**:
The following operator would start a transform job when executed
sagemaker_transform =
SageMakerCreateTransformJobOperator(
task_id='sagemaker_transform',
transform_job_config=config_transform,
model_config=config_model,
region_name='us-west-2'
sagemaker_conn_id='sagemaker_customers_conn',
use_db_config=True,
aws_conn_id='aws_customers_conn'
)
"""

template_fields = ['transform_job_config']
template_ext = ()
ui_color = '#ededed'

@apply_defaults
def __init__(self,
sagemaker_conn_id=None,
transform_job_config=None,
model_config=None,
use_db_config=False,
region_name=None,
wait_for_completion=True,
check_interval=2,
max_ingestion_time=None,
*args, **kwargs):
super(SageMakerCreateTransformJobOperator, self).__init__(*args, **kwargs)

self.sagemaker_conn_id = sagemaker_conn_id
self.transform_job_config = transform_job_config
self.model_config = model_config
self.use_db_config = use_db_config
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time

def execute(self, context):
sagemaker = SageMakerHook(
sagemaker_conn_id=self.sagemaker_conn_id,
use_db_config=self.use_db_config,
region_name=self.region_name,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)

if self.model_config:
self.log.info(
"Creating SageMaker Model %s for transform job"
% self.model_config['ModelName']
)
sagemaker.create_model(self.model_config)

self.log.info(
"Creating SageMaker transform Job %s."
% self.transform_job_config['TransformJobName']
)
response = sagemaker.create_transform_job(
self.transform_job_config,
wait_for_completion=self.wait_for_completion)
if not response['ResponseMetadata']['HTTPStatusCode'] \
== 200:
raise AirflowException(
'Sagemaker transform Job creation failed: %s' % response)
else:
return response
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SageMakerCreateTuningJobOperator(BaseOperator):
until tuning job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
which the operator will check the status of the tuning job
in seconds which the operator will check the status of the tuning job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the tuning job hasn't finish within the max_ingestion_time
Expand Down
4 changes: 2 additions & 2 deletions airflow/contrib/sensors/sagemaker_training_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def __init__(self,
self.region_name = region_name

def non_terminal_states(self):
return ['InProgress', 'Stopping', 'Stopped']
return SageMakerHook.non_terminal_states

def failed_states(self):
return ['Failed']
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
Expand Down
69 changes: 69 additions & 0 deletions airflow/contrib/sensors/sagemaker_transform_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults


class SageMakerTransformSensor(SageMakerBaseSensor):
"""
Asks for the state of the transform state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.
:param job_name: job_name of the transform job instance to check the state of
:type job_name: string
:param region_name: The AWS region_name
:type region_name: string
"""

template_fields = ['job_name']
template_ext = ()

@apply_defaults
def __init__(self,
job_name,
region_name=None,
*args,
**kwargs):
super(SageMakerTransformSensor, self).__init__(*args, **kwargs)
self.job_name = job_name
self.region_name = region_name

def non_terminal_states(self):
return SageMakerHook.non_terminal_states

def failed_states(self):
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name
)

self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
return sagemaker.describe_transform_job(self.job_name)

def get_failed_reason_from_response(self, response):
return response['FailureReason']

def state_from_response(self, response):
return response['TransformJobStatus']
4 changes: 2 additions & 2 deletions airflow/contrib/sensors/sagemaker_tuning_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(self,
self.region_name = region_name

def non_terminal_states(self):
return ['InProgress', 'Stopping', 'Stopped']
return SageMakerHook.non_terminal_states

def failed_states(self):
return ['Failed']
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
Expand Down
Loading

0 comments on commit c0dd882

Please sign in to comment.