-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_deploy.py
204 lines (168 loc) · 6.43 KB
/
train_deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import boto3
import sagemaker
from sagemaker.estimator import Estimator
from sagemaker.inputs import TrainingInput
from sagemaker.serializers import IdentitySerializer
from sagemaker.deserializers import JSONDeserializer
import json
import time
import sys
# Constants
BUCKET_NAME = "ml-classify-butterfly"
S3_DATASET_TRAIN = "s3://ml-classify-butterfly/dataset/train/train"
S3_DATASET_VAL = "s3://ml-classify-butterfly/dataset/train/val"
S3_LST_TRAIN = "s3://ml-classify-butterfly/dataset/train/train.lst"
S3_LST_VAL = "s3://ml-classify-butterfly/dataset/train/validation.lst"
MODEL_NAME = "classify-butterfly"
PREFIX = "dataset/train/train/"
REGION = "us-east-1"
# Create a boto3 session using the specified profile and region
session = boto3.Session(region_name=REGION)
# Use the session to create a SageMaker session
sagemaker_session = sagemaker.Session(boto_session=session)
# Create IAM role
import time
import json
def create_iam_role():
iam_client = session.client('iam')
# Define role name with timestamp
role_name = f"SageMakerRole-{int(time.time())}"
# Create the IAM role
assume_role_policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sagemaker.amazonaws.com"
},
"Action": "sts:AssumeRole"
}
]
}
response = iam_client.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(assume_role_policy_document),
Description="Role for SageMaker to access S3 bucket and other AWS services"
)
role_arn = response['Role']['Arn']
# Attach necessary policies to the role
iam_client.attach_role_policy(
RoleName=role_name,
PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"
)
iam_client.attach_role_policy(
RoleName=role_name,
PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess"
)
# Sleep to ensure role propagation
time.sleep(10)
# Verify the role policies
attached_policies = iam_client.list_attached_role_policies(RoleName=role_name)['AttachedPolicies']
policy_arns = [policy['PolicyArn'] for policy in attached_policies]
required_policies = [
"arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
"arn:aws:iam::aws:policy/AmazonS3FullAccess"
]
for policy in required_policies:
if policy not in policy_arns:
raise Exception(f"Required policy {policy} is not attached to the role {role_name}")
# Verify the trust relationship
trust_relationship = iam_client.get_role(RoleName=role_name)['Role']['AssumeRolePolicyDocument']
required_trust_relationship = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sagemaker.amazonaws.com"
},
"Action": "sts:AssumeRole"
}
]
}
if trust_relationship != required_trust_relationship:
raise Exception(f"The trust relationship for the role {role_name} is not as expected")
return role_arn
def get_num_classes(bucket_name, prefix):
s3_client = boto3.client('s3', region_name=REGION)
paginator = s3_client.get_paginator('list_objects_v2')
response_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter='/')
num_classes = 0
for response in response_iterator:
print(response)
if 'CommonPrefixes' in response:
num_classes += len(response['CommonPrefixes'])
return num_classes
def get_num_training_samples(bucket_name, prefix):
s3_client = boto3.client('s3', region_name=REGION)
paginator = s3_client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
num_training_samples = 0
for page in pages:
if 'Contents' in page:
num_training_samples += len(page['Contents'])
return num_training_samples
#----MAIN----
print(f"Using training data on S3: {S3_DATASET_TRAIN} and {S3_DATASET_VAL}")
role_arn = create_iam_role()
# hyperparameters
num_classes = get_num_classes(BUCKET_NAME, PREFIX)
print(f"Number of classes (folders) in the training data: {num_classes}")
num_training_samples = get_num_training_samples(BUCKET_NAME, PREFIX)
print(f"Number of training samples: {num_training_samples}")
# Define the image URI for the algorithm container
algorithm_image_uri = sagemaker.image_uris.retrieve(framework='image-classification', region=REGION)
# Create and configure the Estimator
print("Creating and configuring the Estimator...")
estimator = Estimator(
image_uri=algorithm_image_uri,
role=role_arn,
instance_count=1,
instance_type='ml.p2.xlarge',
volume_size=50,
max_run=3600,
input_mode='File',
output_path=f"s3://{BUCKET_NAME}/{MODEL_NAME}/output",
sagemaker_session=sagemaker_session
)
# Set hyperparameters
estimator.set_hyperparameters(
num_layers=18,
use_pretrained_model=1,
num_classes=num_classes,
mini_batch_size=32,
epochs=10,
learning_rate=0.001,
precision_dtype='float32',
num_training_samples=num_training_samples
)
# Prepare the dataset for training
train_data = TrainingInput(s3_data=S3_DATASET_TRAIN, content_type="application/x-image", input_mode='File')
validation_data = TrainingInput(s3_data=S3_DATASET_VAL, content_type="application/x-image", input_mode='File')
train_lst = TrainingInput(s3_data=S3_LST_TRAIN, content_type="application/x-image", input_mode='File')
validation_lst = TrainingInput(s3_data=S3_LST_VAL, content_type="application/x-image", input_mode='File')
data_channels = {
'train': train_data,
'validation': validation_data,
'train_lst': train_lst,
'validation_lst': validation_lst
}
# Start the training job
print("Starting the training job...")
estimator.fit(inputs=data_channels)
print("Training job completed.")
# Deploy the model
print("Deploying the model...")
predictor = estimator.deploy(
initial_instance_count=1,
instance_type='ml.m5.large',
endpoint_name=MODEL_NAME,
serializer=IdentitySerializer("image/jpeg"),
deserializer=JSONDeserializer()
)
# Get the endpoint URL
endpoint_name = predictor.endpoint_name
endpoint_url = f"https://runtime.sagemaker.{REGION}.amazonaws.com/endpoints/{endpoint_name}/invocations"
print(f"Model deployed. Endpoint name: {endpoint_name}")
print(f"Endpoint URL: {endpoint_url}")