Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Component] Add Managed Spot Training Support for SageMaker #2219

Merged
merged 4 commits into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions components/aws/sagemaker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q ca-certif

RUN easy_install pip

RUN pip install boto3==1.9.169 sagemaker pathlib2 pyyaml==3.12
RUN pip install boto3 sagemaker pathlib2 pyyaml==3.12

COPY hyperparameter_tuning/src/hyperparameter_tuning.py .
COPY train/src/train.py .
COPY deploy/src/deploy.py .
COPY model/src/create_model.py .
COPY batch_transform/src/batch_transform.py .
COPY workteam/workteam.py .
COPY ground_truth/ground_truth.py .
COPY workteam/src/workteam.py .
COPY ground_truth/src/ground_truth.py .

COPY common /app/common/

Expand Down
45 changes: 36 additions & 9 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from time import gmtime, strftime
import time
Expand Down Expand Up @@ -48,6 +49,9 @@
'xgboost': 'xgboost'
}

# Get current directory to open templates
__cwd__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))

def get_client(region=None):
"""Builds a client to the AWS SageMaker API."""
client = boto3.client('sagemaker', region_name=region)
Expand All @@ -56,7 +60,7 @@ def get_client(region=None):

def create_training_job_request(args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
with open('/app/common/train.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'train.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

job_name = args['job_name'] if args['job_name'] else 'TrainingJob-' + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
Expand Down Expand Up @@ -91,7 +95,7 @@ def create_training_job_request(args):
else:
request['AlgorithmSpecification']['AlgorithmName'] = args['algorithm_name']
request['AlgorithmSpecification'].pop('TrainingImage')

### Update metric definitions
if args['metric_definitions']:
for key, val in args['metric_definitions'].items():
Expand All @@ -116,7 +120,7 @@ def create_training_job_request(args):
request['InputDataConfig'][i-1]['DataSource']['S3DataSource']['S3Uri'] = args['data_location_' + str(i)]
else:
logging.error("Must specify at least one input channel.")
raise Exception('Could not make job request')
raise Exception('Could not create job request')

request['OutputDataConfig']['S3OutputPath'] = args['model_artifact_path']
request['OutputDataConfig']['KmsKeyId'] = args['output_encryption_key']
Expand All @@ -135,6 +139,8 @@ def create_training_job_request(args):
if args['max_run_time']:
request['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time']

enable_spot_instance_support(request, args)

### Update tags
for key, val in args['tags'].items():
request['Tags'].append({'Key': key, 'Value': val})
Expand Down Expand Up @@ -192,7 +198,7 @@ def get_image_from_job(client, job_name):

def create_model(client, args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model
with open('/app/common/model.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'model.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

request['ModelName'] = args['model_name']
Expand Down Expand Up @@ -253,7 +259,7 @@ def deploy_model(client, args):

def create_endpoint_config(client, args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config
with open('/app/common/endpoint_config.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'endpoint_config.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

endpoint_config_name = args['endpoint_config_name'] if args['endpoint_config_name'] else 'EndpointConfig' + args['model_name_1'][args['model_name_1'].index('-'):]
Expand Down Expand Up @@ -344,7 +350,7 @@ def wait_for_endpoint_creation(client, endpoint_name):

def create_transform_job_request(args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_transform_job
with open('/app/common/transform.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'transform.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

job_name = args['job_name'] if args['job_name'] else 'BatchTransform' + args['model_name'][args['model_name'].index('-'):]
Expand Down Expand Up @@ -436,7 +442,7 @@ def wait_for_transform_job(client, batch_job_name):

def create_hyperparameter_tuning_job_request(args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
with open('/app/common/hpo.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'hpo.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

### Create a hyperparameter tuning job
Expand Down Expand Up @@ -540,6 +546,8 @@ def create_hyperparameter_tuning_job_request(args):
raise Exception('Could not make job request')
request.pop('WarmStartConfig')

enable_spot_instance_support(request['TrainingJobDefinition'], args)

### Update tags
for key, val in args['tags'].items():
request['Tags'].append({'Key': key, 'Value': val})
Expand Down Expand Up @@ -592,7 +600,7 @@ def get_best_training_job_and_hyperparameters(client, hpo_job_name):
def create_workteam(client, args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_workteam
"""Create a workteam"""
with open('/app/common/workteam.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'workteam.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

request['WorkteamName'] = args['team_name']
Expand Down Expand Up @@ -620,7 +628,7 @@ def create_workteam(client, args):

def create_labeling_job_request(args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job
with open('/app/common/gt.template.yaml', 'r') as f:
with open(os.path.join(__cwd__, 'gt.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

# Mapping are extracted from ARNs listed in https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job
Expand Down Expand Up @@ -793,6 +801,25 @@ def get_labeling_job_outputs(client, labeling_job_name, auto_labeling):
active_learning_model_arn = ''
return output_manifest, active_learning_model_arn

def enable_spot_instance_support(training_job_config, args):
if args['spot_instance']:
training_job_config['EnableManagedSpotTraining'] = args['spot_instance']
if args['max_wait_time'] >= training_job_config['StoppingCondition']['MaxRuntimeInSeconds']:
training_job_config['StoppingCondition']['MaxWaitTimeInSeconds'] = args['max_wait_time']
else:
logging.error("Max wait time must be greater than or equal to max run time.")
raise Exception('Could not create job request.')

if args['checkpoint_config'] and 'S3Uri' in args['checkpoint_config']:
training_job_config['CheckpointConfig'] = args['checkpoint_config']
else:
logging.error("EnableManagedSpotTraining requires checkpoint config with an S3 uri.")
raise Exception('Could not create job request.')
else:
# Remove any artifacts that require spot instance support
del training_job_config['StoppingCondition']['MaxWaitTimeInSeconds']
del training_job_config['CheckpointConfig']


def id_generator(size=4, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
Expand Down
5 changes: 5 additions & 0 deletions components/aws/sagemaker/common/hpo.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ TrainingJobDefinition:
VolumeKmsKeyId: ''
StoppingCondition:
MaxRuntimeInSeconds: 86400
MaxWaitTimeInSeconds: 86400
CheckpointConfig:
S3Uri: ''
LocalPath: ''
EnableNetworkIsolation: True
EnableInterContainerTrafficEncryption: False
EnableManagedSpotTraining: False
WarmStartConfig:
ParentHyperParameterTuningJobs: []
WarmStartType: ''
Expand Down
5 changes: 5 additions & 0 deletions components/aws/sagemaker/common/train.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ VpcConfig:
Subnets: []
StoppingCondition:
MaxRuntimeInSeconds: 86400
MaxWaitTimeInSeconds: 86400
CheckpointConfig:
S3Uri: ''
LocalPath: ''
Tags: []
EnableNetworkIsolation: True
EnableInterContainerTrafficEncryption: False
EnableManagedSpotTraining: False
4 changes: 4 additions & 0 deletions components/aws/sagemaker/hyperparameter_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Y
vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
network_isolation | Isolates the training container if true | Yes | No | Boolean | False, True | True |
traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | Yes | No | Boolean | False, True | False |
spot_instance | Use managed spot training if true | Yes | No | Boolean | False, True | False |
max_wait_time | The maximum time in seconds you are willing to wait for a managed spot training job to complete | Yes | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
checkpoint_config | Dictionary of information about the output location for managed spot training checkpoint data | Yes | Yes | Dict | | {} |
warm_start_type | Specifies the type of warm start used | Yes | No | String | IdenticalDataAndAlgorithm, TransferLearning | |
parent_hpo_jobs | List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point | Yes | Yes | String | Yes | | |
tags | Key-value pairs to categorize AWS resources | Yes | Yes | Dict | | {} |
Expand All @@ -53,6 +56,7 @@ Notes:
## Outputs
Name | Description
:--- | :----------
hpo_job_name | The name of the hyper parameter tuning job
model_artifact_url | URL where model artifacts were stored
best_job_name | Best hyperparameter tuning training job name
best_hyperparameters | Tuned hyperparameters
Expand Down
17 changes: 16 additions & 1 deletion components/aws/sagemaker/hyperparameter_tuning/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ inputs:
- name: traffic_encryption
description: 'Encrypts all communications between ML compute instances in distributed training.'
default: 'False'
- name: spot_instance
description: 'Use managed spot training.'
default: 'False'
- name: max_wait_time
description: 'The maximum time in seconds you are willing to wait for a managed spot training job to complete.'
default: '86400'
- name: checkpoint_config
description: 'Dictionary of information about the output location for managed spot training checkpoint data.'
default: '{}'
- name: warm_start_type
description: 'Specifies either "IdenticalDataAndAlgorithm" or "TransferLearning"'
default: ''
Expand All @@ -115,10 +124,12 @@ inputs:
description: 'Key-value pairs, to categorize AWS resources.'
default: '{}'
outputs:
- name: hpo_job_name
description: 'The name of the hyper parameter tuning job'
- name: model_artifact_url
description: 'Model artifacts url'
- name: best_job_name
description: 'Best training job in the hyperparameter tuning job'
description: 'Best training job in the hyper parameter tuning job'
- name: best_hyperparameters
description: 'Tuned hyperparameters'
- name: training_image
Expand Down Expand Up @@ -166,11 +177,15 @@ implementation:
--vpc_subnets, {inputValue: vpc_subnets},
--network_isolation, {inputValue: network_isolation},
--traffic_encryption, {inputValue: traffic_encryption},
--spot_instance, {inputValue: spot_instance},
--max_wait_time, {inputValue: max_wait_time},
--checkpoint_config, {inputValue: checkpoint_config},
--warm_start_type, {inputValue: warm_start_type},
--parent_hpo_jobs, {inputValue: parent_hpo_jobs},
--tags, {inputValue: tags}
]
fileOutputs:
hpo_job_name: /tmp/hpo_job_name.txt
model_artifact_url: /tmp/model_artifact_url.txt
best_job_name: /tmp/best_job_name.txt
best_hyperparameters: /tmp/best_hyperparameters.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from common import _utils

def main(argv=None):
def create_parser():
parser = argparse.ArgumentParser(description='SageMaker Hyperparameter Tuning Job')
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
parser.add_argument('--job_name', type=str.strip, required=False, help='The name of the tuning job. Must be unique within the same AWS account and AWS region.')
parser.add_argument('--role', type=str.strip, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
parser.add_argument('--image', type=str.strip, required=False, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--image', type=str.strip, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--algorithm_name', type=str.strip, required=False, help='The name of the resource algorithm to use for the hyperparameter tuning job.', default='')
parser.add_argument('--training_input_mode', choices=['File', 'Pipe'], type=str.strip, required=False, help='The input mode that the algorithm supports. File or Pipe.', default='File')
parser.add_argument('--metric_definitions', type=_utils.str_to_json_dict, required=False, help='The dictionary of name-regex pairs specify the metrics that the algorithm emits.', default='{}')
Expand Down Expand Up @@ -59,8 +59,19 @@ def main(argv=None):
parser.add_argument('--traffic_encryption', type=_utils.str_to_bool, required=False, help='Encrypts all communications between ML compute instances in distributed training.', default=False)
parser.add_argument('--warm_start_type', choices=['IdenticalDataAndAlgorithm', 'TransferLearning', ''], type=str.strip, required=False, help='Specifies either "IdenticalDataAndAlgorithm" or "TransferLearning"')
parser.add_argument('--parent_hpo_jobs', type=str.strip, required=False, help='List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point.', default='')

### Start spot instance support
parser.add_argument('--spot_instance', type=_utils.str_to_bool, required=False, help='Use managed spot training.', default=False)
parser.add_argument('--max_wait_time', type=_utils.str_to_int, required=False, help='The maximum time in seconds you are willing to wait for a managed spot training job to complete.', default=86400)
parser.add_argument('--checkpoint_config', type=_utils.str_to_json_dict, required=False, help='Dictionary of information about the output location for managed spot training checkpoint data.', default='{}')
### End spot instance support

parser.add_argument('--tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')

return parser

def main(argv=None):
parser = create_parser()
args = parser.parse_args()

logging.getLogger().setLevel(logging.INFO)
Expand All @@ -75,6 +86,8 @@ def main(argv=None):

logging.info('HyperParameter Tuning Job completed.')

with open('/tmp/hpo_job_name.txt', 'w') as f:
f.write(hpo_job_name)
with open('/tmp/best_job_name.txt', 'w') as f:
f.write(best_job)
with open('/tmp/best_hyperparameters.txt', 'w') as f:
Expand Down
20 changes: 20 additions & 0 deletions components/aws/sagemaker/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Configures and runs the unit tests for all the components

import os
import sys
import unittest


# Taken from http://stackoverflow.com/a/17004263/2931197
def load_and_run_tests():
setup_file = sys.modules['__main__'].__file__
setup_dir = os.path.abspath(os.path.dirname(setup_file))

test_loader = unittest.defaultTestLoader
test_runner = unittest.TextTestRunner()
test_suite = test_loader.discover(setup_dir, pattern="test_*.py")

test_runner.run(test_suite)

if __name__ == '__main__':
load_and_run_tests()
Empty file.
54 changes: 54 additions & 0 deletions components/aws/sagemaker/test/test_hpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
import unittest

from unittest.mock import patch, Mock, MagicMock
from botocore.exceptions import ClientError
from datetime import datetime

from hyperparameter_tuning.src import hyperparameter_tuning as hpo
from common import _utils

required_args = [
'--region', 'us-west-2',
'--role', 'arn:aws:iam::123456789012:user/Development/product_1234/*',
'--image', 'test-image',
'--metric_name', 'test-metric',
'--metric_type', 'Maximize',
'--channels', '[{"ChannelName": "train", "DataSource": {"S3DataSource":{"S3Uri": "s3://fake-bucket/data","S3DataType":"S3Prefix","S3DataDistributionType": "FullyReplicated"}},"ContentType":"","CompressionType": "None","RecordWrapperType":"None","InputMode": "File"}]',
'--output_location', 'test-output-location',
'--max_num_jobs', '5',
'--max_parallel_jobs', '2'
]

class HyperparameterTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
parser = hpo.create_parser()
cls.parser = parser

def test_spot_bad_args(self):
no_max_wait_args = self.parser.parse_args(required_args + ['--spot_instance', 'True'])
no_checkpoint_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3600'])
no_s3_uri_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3600', '--checkpoint_config', '{}'])

for arg in [no_max_wait_args, no_checkpoint_args, no_s3_uri_args]:
with self.assertRaises(Exception):
_utils.create_hyperparameter_tuning_job_request(vars(arg))

def test_spot_lesser_wait_time(self):
args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '86399', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}'])
with self.assertRaises(Exception):
_utils.create_hyperparameter_tuning_job_request(vars(args))

def test_spot_good_args(self):
good_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '86400', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}'])
response = _utils.create_hyperparameter_tuning_job_request(vars(good_args))
self.assertTrue(response['TrainingJobDefinition']['EnableManagedSpotTraining'])
self.assertEqual(response['TrainingJobDefinition']['StoppingCondition']['MaxWaitTimeInSeconds'], 86400)
self.assertEqual(response['TrainingJobDefinition']['CheckpointConfig']['S3Uri'], 's3://fake-uri/')

def test_spot_local_path(self):
args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '86400', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}'])
response = _utils.create_hyperparameter_tuning_job_request(vars(args))
self.assertEqual(response['TrainingJobDefinition']['CheckpointConfig']['S3Uri'], 's3://fake-uri/')
self.assertEqual(response['TrainingJobDefinition']['CheckpointConfig']['LocalPath'], 'local-path')
Loading