-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathinference_pipeline_create.py
124 lines (104 loc) · 3.47 KB
/
inference_pipeline_create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from inference_pipeline_define import define_inference_pipeline
def format_template_str():
with open("/tmp/my_inference_pipeline.yaml", "r") as file:
data = file.read()
# add the parameters
data = data.replace(
"""AWSTemplateFormatVersion: '2010-09-09'
Description: CloudFormation template for AWS Step Functions - State Machine
""",
"""AWSTemplateFormatVersion: '2010-09-09'
Description: CloudFormation template for AWS Step Functions - State Machine
# Added by script
Parameters:
PipelineName:
Type: String
SagerMakerRoleArn:
Type: String
WorkflowExecutionRoleArn:
Type: String
TargetEnv:
Type: String
""",
)
# replace StateMachineName
data = data.replace(
"StateMachineName: ${InferencePipelineName}",
'StateMachineName: !Sub "${PipelineName}-Inference-${TargetEnv}"',
)
# replace DefinitionString
data = data.replace("DefinitionString:", "DefinitionString: !Sub")
# replace Role Arn
data = data.replace(
"RoleArn: ${WorkflowExecutionRoleArn}",
'RoleArn: !Sub "${WorkflowExecutionRoleArn}"',
)
# add output
data = (
data
+ "\n"
+ """Outputs:
StateMachineComponentArn:
Description: The step function ARN
Value: !GetAtt StateMachineComponent.Arn
"""
)
with open("./templates/my_inference_pipeline.yaml", "w") as file:
file.write(data)
def create_inference_pipeline(
sm_role,
workflow_execution_role,
inference_pipeline_name,
return_yaml=True,
dump_yaml_file="templates/sagemaker_inference_pipeline.yaml",
kms_key_id=None,
):
"""
Return YAML definition of the inference pipeline, which consists of
multiple Amazon StepFunction steps
sm_role: ARN of the SageMaker execution role
workflow_execution_role: ARN of the StepFunction execution role
return_yaml: Return YAML representation or not, if False,
it returns an instance of
`stepfunctions.workflow.WorkflowObject`
dump_yaml_file: If not None, a YAML file will be generated at this
file location
"""
inference_pipeline = define_inference_pipeline(
sm_role,
workflow_execution_role,
inference_pipeline_name,
return_yaml,
dump_yaml_file,
kms_key_id=kms_key_id,
)
# dump YAML cloud formation template
yml = inference_pipeline.get_cloudformation_template()
if dump_yaml_file is not None:
with open(dump_yaml_file, "w") as fout:
fout.write(yml)
if return_yaml:
return yml
else:
return inference_pipeline
def example_create_inference_pipeline():
"""
An example on obtaining YAML CF template from the inference pipeline definition
"""
sm_role = "${SagerMakerRoleArn}"
workflow_execution_role = "${WorkflowExecutionRoleArn}"
inference_pipeline_name = "${InferencePipelineName}"
kms_key_id = os.getenv("KMSKEY_ARN", None)
yaml_rep = create_inference_pipeline(
sm_role=sm_role,
workflow_execution_role=workflow_execution_role,
inference_pipeline_name=inference_pipeline_name,
dump_yaml_file=None,
kms_key_id=kms_key_id,
)
with open("/tmp/my_inference_pipeline.yaml", "w") as fout:
fout.write(yaml_rep)
if __name__ == "__main__":
example_create_inference_pipeline()
format_template_str()