Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(stepfunctions): Downscope SageMaker permissions #2991

Merged
merged 12 commits into from
Jul 3, 2019
Merged
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import ec2 = require('@aws-cdk/aws-ec2');
import ecr = require('@aws-cdk/aws-ecr');
import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets';
import iam = require('@aws-cdk/aws-iam');
import kms = require('@aws-cdk/aws-kms');
import { Duration } from '@aws-cdk/core';
import s3 = require('@aws-cdk/aws-s3');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration } from '@aws-cdk/core';

export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {}

//
// Create Training Job types
Expand All @@ -24,7 +31,7 @@ export interface AlgorithmSpecification {
/**
* Registry path of the Docker image that contains the training algorithm.
*/
readonly trainingImage?: string;
readonly trainingImage?: DockerImage;

/**
* Input mode that the algorithm supports.
Expand Down Expand Up @@ -125,7 +132,7 @@ export interface S3DataSource {
/**
* S3 Uri
*/
readonly s3Uri: string;
readonly s3Location: S3Location;
}

/**
Expand All @@ -140,7 +147,7 @@ export interface OutputDataConfig {
/**
* Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.
*/
readonly s3OutputPath: string;
readonly s3OutputLocation: S3Location;
}

export interface StoppingCondition {
Expand Down Expand Up @@ -169,7 +176,7 @@ export interface ResourceConfig {
/**
* KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job.
*/
readonly volumeKmsKeyId?: kms.IKey;
readonly volumeEncryptionKey?: kms.IKey;

/**
* Size of the ML storage volume that you want to provision.
Expand Down Expand Up @@ -218,6 +225,128 @@ export interface MetricDefinition {
readonly regex: string;
}

export interface S3LocationConfig {
readonly uri: string;
}

/**
* Constructs `IS3Location` objects.
*/
export abstract class S3Location {
/**
* An `IS3Location` built with a determined bucket and key prefix.
*
* @param bucket is the bucket where the objects are to be stored.
* @param keyPrefix is the key prefix used by the location.
*/
public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location {
return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) });
}

/**
* An `IS3Location` determined fully by a JSON Path from the task input.
*
* Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite`
* apply to the `*` resource.
*
* @param expression the JSON expression resolving to an S3 location URI.
*/
public static fromJsonExpression(expression: string): S3Location {
return new StandardS3Location({ uri: sfn.Data.stringAt(expression) });
}

/**
* Called when the S3Location is bound to a StepFunctions task.
*/
public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig;
}

/**
* Options for binding an S3 Location.
*/
export interface S3LocationBindOptions {
/**
* Allow reading from the S3 Location.
*
* @default false
*/
readonly forReading?: boolean;

/**
* Allow writing to the S3 Location.
*
* @default false
*/
readonly forWriting?: boolean;
}

/**
* Configuration for a using Docker image.
*
* @experimental
*/
export interface DockerImageConfig {
/**
* The fully qualified URI of the Docker image.
*/
readonly imageUri: string;
}

/**
* Creates `IDockerImage` instances.
*
* @experimental
*/
export abstract class DockerImage {
/**
* Reference a Docker image stored in an ECR repository.
*
* @param repository the ECR repository where the image is hosted.
* @param tag an optional `tag`
*/
public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage {
return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) });
}

/**
* Reference a Docker image which URI is obtained from the task's input.
*
* @param expression the JSON path expression with the task input.
* @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR).
*/
public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage {
return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull });
}

/**
* Reference a Docker image by it's URI.
*
* When referencing ECR images, prefer using `inEcr`.
*
* @param imageUri the URI to the docker image.
*/
public static fromRegistry(imageUri: string): DockerImage {
return new StandardDockerImage({ imageUri });
}

/**
* Reference a Docker image that is provided as an Asset in the current app.
*
* @param scope the scope in which to create the Asset.
* @param id the ID for the asset in the construct tree.
* @param props the configuration props of the asset.
*/
public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage {
const asset = new DockerImageAsset(scope, id, props);
return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri });
}

/**
* Called when the image is used by a SageMaker task.
*/
public abstract bind(task: ISageMakerTask): DockerImageConfig;
}

/**
* S3 Data Type.
*/
Expand Down Expand Up @@ -472,3 +601,70 @@ export enum AssembleWith {
LINE = 'Line'

}

class StandardDockerImage extends DockerImage {
private readonly allowAnyEcrImagePull: boolean;
private readonly imageUri: string;
private readonly repository?: ecr.IRepository;

constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) {
super();

this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull;
this.imageUri = opts.imageUri;
this.repository = opts.repository;
}

public bind(task: ISageMakerTask): DockerImageConfig {
if (this.repository) {
this.repository.grantPull(task);
}
if (this.allowAnyEcrImagePull) {
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: [
'ecr:BatchCheckLayerAvailability',
'ecr:GetDownloadUrlForLayer',
'ecr:BatchGetImage',
],
resources: ['*']
}));
}
return {
imageUri: this.imageUri,
};
}
}

class StandardS3Location extends S3Location {
private readonly bucket?: s3.IBucket;
private readonly keyGlob: string;
private readonly uri: string;

constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) {
super();
this.bucket = opts.bucket;
this.keyGlob = `${opts.keyPrefix || ''}*`;
this.uri = opts.uri;
}

public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig {
if (this.bucket) {
if (opts.forReading) {
this.bucket.grantRead(task, this.keyGlob);
}
if (opts.forWriting) {
this.bucket.grantWrite(task, this.keyGlob);
}
} else {
const actions = new Array<string>();
if (opts.forReading) {
actions.push('s3:GetObject', 's3:ListBucket');
}
if (opts.forWriting) {
actions.push('s3:PutObject');
}
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'], }));
}
return { uri: this.uri };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceC
/**
* @experimental
*/
export interface SagemakerTrainProps {
export interface SagemakerTrainTaskProps {

/**
* Training Job Name.
*/
readonly trainingJobName: string;

/**
* Role for thte Training Job.
* Role for the Training Job. The role must be granted all necessary permissions for the SageMaker training job to
* be able to operate.
*
* See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms
*
* @default - a role with appropriate permissions will be created.
*/
readonly role?: iam.IRole;

Expand Down Expand Up @@ -71,7 +76,7 @@ export interface SagemakerTrainProps {
/**
* Class representing the SageMaker Create Training Job task.
*/
export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask {
export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn.IStepFunctionsTask {

/**
* Allows specify security group connections for instances of this fleet.
Expand All @@ -85,6 +90,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
*/
public readonly role: iam.IRole;

public readonly grantPrincipal: iam.IPrincipal;

/**
* The Algorithm Specification
*/
Expand All @@ -105,7 +112,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
*/
private readonly stoppingCondition: StoppingCondition;

constructor(scope: Construct, private readonly props: SagemakerTrainProps) {
constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) {

// set the default resource config if not defined.
this.resourceConfig = props.resourceConfig || {
Expand All @@ -120,13 +127,48 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
};

// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
inlinePolicies: {
CreateTrainingJob: new iam.PolicyDocument({
statements: [
new iam.PolicyStatement({
actions: [
'cloudwatch:PutMetricData',
'logs:CreateLogStream',
'logs:PutLogEvents',
'logs:CreateLogGroup',
'logs:DescribeLogStreams',
'ecr:GetAuthorizationToken',
...props.vpcConfig
? [
'ec2:CreateNetworkInterface',
'ec2:CreateNetworkInterfacePermission',
'ec2:DeleteNetworkInterface',
'ec2:DeleteNetworkInterfacePermission',
'ec2:DescribeNetworkInterfaces',
'ec2:DescribeVpcs',
'ec2:DescribeDhcpOptions',
'ec2:DescribeSubnets',
'ec2:DescribeSecurityGroups',
]
: [],
],
resources: ['*'], // Those permissions cannot be resource-scoped
})
]
}),
}
});

if (props.outputDataConfig.encryptionKey) {
props.outputDataConfig.encryptionKey.grantEncrypt(this.role);
}

if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) {
props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant');
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
Expand Down Expand Up @@ -175,7 +217,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
return {
AlgorithmSpecification: {
TrainingInputMode: spec.trainingInputMode,
...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {},
...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {},
...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {},
...(spec.metricDefinitions) ?
{ MetricDefinitions: spec.metricDefinitions
Expand All @@ -190,7 +232,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
ChannelName: channel.channelName,
DataSource: {
S3DataSource: {
S3Uri: channel.dataSource.s3DataSource.s3Uri,
S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri,
S3DataType: channel.dataSource.s3DataSource.s3DataType,
...(channel.dataSource.s3DataSource.s3DataDistributionType) ?
{ S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {},
Expand All @@ -209,7 +251,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} {
return {
OutputDataConfig: {
S3OutputPath: config.s3OutputPath,
S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri,
...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {},
}
};
Expand All @@ -221,7 +263,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
InstanceCount: config.instanceCount,
InstanceType: 'ml.' + config.instanceType,
VolumeSizeInGB: config.volumeSizeInGB,
...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {},
...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {},
}
};
}
Expand Down Expand Up @@ -260,7 +302,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
stack.formatArn({
service: 'sagemaker',
resource: 'training-job',
resourceName: '*'
// If the job name comes from input, we cannot target the policy to a particular ARN prefix reliably...
resourceName: sfn.Data.isJsonPathString(this.props.trainingJobName) ? '*' : `${this.props.trainingJobName}*`
})
],
}),
Expand Down
Loading