Skip to content

Commit

Permalink
fix(aws-stepfunctions): refactor sagemaker tasks and fix default role…
Browse files Browse the repository at this point in the history
… issue (#3014)

* fix(aws-stepfunctions) refactor and fix default role issue

* fix(aws-stepfunctions) removed console log statements and fixed s3 prefix error

* fix(aws-stepfunctions) removed construct from contructor for sagemaker tasks. Changed ISubnet[] to SubnetSelection in props

* fix(aws-stepfunctions) renamed cdk core package reference

* Update tests
  • Loading branch information
mmcclean-aws authored and mergify[bot] committed Aug 21, 2019
1 parent c020efa commit d8fcb50
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,15 @@ export interface ResourceConfig {
* @experimental
*/
export interface VpcConfig {
/**
* VPC security groups.
*/
readonly securityGroups: ec2.ISecurityGroup[];

/**
* VPC id
*/
readonly vpc: ec2.Vpc;
readonly vpc: ec2.IVpc;

/**
* VPC subnets.
*/
readonly subnets: ec2.ISubnet[];
readonly subnets?: ec2.SubnetSelection;
}

/**
Expand Down
126 changes: 84 additions & 42 deletions packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ec2 = require('@aws-cdk/aws-ec2');
import iam = require('@aws-cdk/aws-iam');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration, Stack } from '@aws-cdk/core';
import { Duration, Lazy, Stack } from '@aws-cdk/core';
import { resourceArnSuffix } from './resource-arn-suffix';
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig,
S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types';
Expand Down Expand Up @@ -53,7 +53,7 @@ export interface SagemakerTrainTaskProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training.
Expand Down Expand Up @@ -88,15 +88,6 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
*/
public readonly connections: ec2.Connections = new ec2.Connections();

/**
* The execution role for the Sagemaker training job.
*
* @default new role for Amazon SageMaker to assume is automatically created.
*/
public readonly role: iam.IRole;

public readonly grantPrincipal: iam.IPrincipal;

/**
* The Algorithm Specification
*/
Expand All @@ -117,9 +108,15 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
*/
private readonly stoppingCondition: StoppingCondition;

private readonly vpc: ec2.IVpc;
private securityGroup: ec2.ISecurityGroup;
private readonly securityGroups: ec2.ISecurityGroup[] = [];
private readonly subnets: string[];
private readonly integrationPattern: sfn.ServiceIntegrationPattern;
private _role?: iam.IRole;
private _grantPrincipal?: iam.IPrincipal;

constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) {
constructor(private readonly props: SagemakerTrainTaskProps) {
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;

const supportedPatterns = [
Expand All @@ -143,8 +140,66 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
maxRuntime: Duration.hours(1)
};

// check that either algorithm name or image is defined
if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) {
throw new Error("Must define either an algorithm name or training image URI in the algorithm specification");
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.inputDataConfig = props.inputDataConfig.map(config => {
if (!config.dataSource.s3DataSource.s3DataType) {
return Object.assign({}, config, { dataSource: { s3DataSource:
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
} else {
return config;
}
});

// add the security groups to the connections object
if (props.vpcConfig) {
this.vpc = props.vpcConfig.vpc;
this.subnets = (props.vpcConfig.subnets) ?
(this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds;
}
}

/**
* The execution role for the Sagemaker training job.
*
* Only available after task has been added to a state machine.
*/
public get role(): iam.IRole {
if (this._role === undefined) {
throw new Error(`role not available yet--use the object in a Task first`);
}
return this._role;
}

public get grantPrincipal(): iam.IPrincipal {
if (this._grantPrincipal === undefined) {
throw new Error(`Principal not available yet--use the object in a Task first`);
}
return this._grantPrincipal;
}

/**
* Add the security group to all instances via the launch configuration
* security groups array.
*
* @param securityGroup: The security group to add
*/
public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void {
this.securityGroups.push(securityGroup);
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
// set the sagemaker role or create new one
this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
inlinePolicies: {
CreateTrainingJob: new iam.PolicyDocument({
Expand All @@ -157,7 +212,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
'logs:CreateLogGroup',
'logs:DescribeLogStreams',
'ecr:GetAuthorizationToken',
...props.vpcConfig
...this.props.vpcConfig
? [
'ec2:CreateNetworkInterface',
'ec2:CreateNetworkInterfacePermission',
Expand All @@ -178,36 +233,23 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
}
});

if (props.outputDataConfig.encryptionKey) {
props.outputDataConfig.encryptionKey.grantEncrypt(this.role);
if (this.props.outputDataConfig.encryptionKey) {
this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role);
}

if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) {
props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant');
if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) {
this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant');
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.inputDataConfig = props.inputDataConfig.map(config => {
if (!config.dataSource.s3DataSource.s3DataType) {
return Object.assign({}, config, { dataSource: { s3DataSource:
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
} else {
return config;
}
});

// add the security groups to the connections object
if (this.props.vpcConfig) {
this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg));
// create a security group if not defined
if (this.vpc && this.securityGroup === undefined) {
this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', {
vpc: this.vpc
});
this.connections.addSecurityGroup(this.securityGroup);
this.securityGroups.push(this.securityGroup);
}
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
return {
resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + resourceArnSuffix.get(this.integrationPattern),
parameters: this.renderParameters(),
Expand All @@ -218,7 +260,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
private renderParameters(): {[key: string]: any} {
return {
TrainingJobName: this.props.trainingJobName,
RoleArn: this.role.roleArn,
RoleArn: this._role!.roleArn,
...(this.renderAlgorithmSpecification(this.algorithmSpecification)),
...(this.renderInputDataConfig(this.inputDataConfig)),
...(this.renderOutputDataConfig(this.props.outputDataConfig)),
Expand Down Expand Up @@ -303,8 +345,8 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn

private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} {
return (config) ? { VpcConfig: {
SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )),
Subnets: config.subnets.map(subnet => ( subnet.subnetId )),
SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }),
Subnets: this.subnets,
}} : {};
}

Expand All @@ -330,7 +372,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
}),
new iam.PolicyStatement({
actions: ['iam:PassRole'],
resources: [this.role.roleArn],
resources: [this._role!.roleArn],
conditions: {
StringEquals: { "iam:PassedToService": "sagemaker.amazonaws.com" }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ec2 = require('@aws-cdk/aws-ec2');
import iam = require('@aws-cdk/aws-iam');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Stack } from '@aws-cdk/core';
import { Stack } from '@aws-cdk/core';
import { resourceArnSuffix } from './resource-arn-suffix';
import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types';

Expand Down Expand Up @@ -37,7 +37,7 @@ export interface SagemakerTransformProps {
/**
* Environment variables to set in the Docker container.
*/
readonly environment?: {[key: string]: any};
readonly environment?: {[key: string]: string};

/**
* Maximum number of parallel requests that can be sent to each instance in a transform job.
Expand All @@ -57,7 +57,7 @@ export interface SagemakerTransformProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
Expand All @@ -82,13 +82,6 @@ export interface SagemakerTransformProps {
*/
export class SagemakerTransformTask implements sfn.IStepFunctionsTask {

/**
* The execution role for the Sagemaker training job.
*
* @default new role for Amazon SageMaker to assume is automatically created.
*/
public readonly role: iam.IRole;

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
*/
Expand All @@ -98,10 +91,10 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
* ML compute instances for the transform job.
*/
private readonly transformResources: TransformResources;

private readonly integrationPattern: sfn.ServiceIntegrationPattern;
private _role?: iam.IRole;

constructor(scope: Construct, private readonly props: SagemakerTransformProps) {
constructor(private readonly props: SagemakerTransformProps) {
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;

const supportedPatterns = [
Expand All @@ -114,12 +107,9 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
}

// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
});
if (props.role) {
this._role = props.role;
}

// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) :
Expand All @@ -140,13 +130,35 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
// create new role if doesn't exist
if (this._role === undefined) {
this._role = new iam.Role(task, 'SagemakerTransformRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
});
}

return {
resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + resourceArnSuffix.get(this.integrationPattern),
parameters: this.renderParameters(),
policyStatements: this.makePolicyStatements(task),
};
}

/**
* The execution role for the Sagemaker training job.
*
* Only available after task has been added to a state machine.
*/
public get role(): iam.IRole {
if (this._role === undefined) {
throw new Error(`role not available yet--use the object in a Task first`);
}
return this._role;
}

private renderParameters(): {[key: string]: any} {
return {
...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {},
Expand Down
Loading

0 comments on commit d8fcb50

Please sign in to comment.