Skip to content

Commit

Permalink
Add snapshots support
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmetana committed Feb 8, 2019
1 parent c05802a commit 8e5d99f
Show file tree
Hide file tree
Showing 6 changed files with 708 additions and 9 deletions.
186 changes: 185 additions & 1 deletion pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
"math"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -73,6 +74,8 @@ const (
const (
// VolumeNameTagKey is the key value that refers to the volume's name.
VolumeNameTagKey = "CSIVolumeName"
// SnapshotNameTagKey is the key value that refers to the snapshot's name.
SnapshotNameTagKey = "CSIVolumeSnapshotName"
)

var (
Expand Down Expand Up @@ -109,7 +112,21 @@ type DiskOptions struct {
Encrypted bool
// KmsKeyID represents a fully qualified resource name to the key to use for encryption.
// example: arn:aws:kms:us-east-1:012345678910:key/abcd1234-a123-456a-a12b-a123b4cd56ef
KmsKeyID string
KmsKeyID string
SnapshotID string
}

// Snapshot represents an EBS volume snapshot
type Snapshot struct {
SnapshotID string
SourceVolumeID string
Size int64
CreationTime time.Time
}

// SnapshotOptions represents parameters to create an EBS volume
type SnapshotOptions struct {
Tags map[string]string
}

// EC2 abstracts aws.EC2 to facilitate its mocking.
Expand All @@ -121,6 +138,9 @@ type EC2 interface {
DetachVolumeWithContext(ctx aws.Context, input *ec2.DetachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error)
AttachVolumeWithContext(ctx aws.Context, input *ec2.AttachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error)
DescribeInstancesWithContext(ctx aws.Context, input *ec2.DescribeInstancesInput, opts ...request.Option) (*ec2.DescribeInstancesOutput, error)
CreateSnapshotWithContext(ctx aws.Context, input *ec2.CreateSnapshotInput, opts ...request.Option) (*ec2.Snapshot, error)
DeleteSnapshotWithContext(ctx aws.Context, input *ec2.DeleteSnapshotInput, opts ...request.Option) (*ec2.DeleteSnapshotOutput, error)
DescribeSnapshotsWithContext(ctx aws.Context, input *ec2.DescribeSnapshotsInput, opts ...request.Option) (*ec2.DescribeSnapshotsOutput, error)
}

type Cloud interface {
Expand All @@ -133,6 +153,9 @@ type Cloud interface {
GetDiskByName(ctx context.Context, name string, capacityBytes int64) (disk *Disk, err error)
GetDiskByID(ctx context.Context, volumeID string) (disk *Disk, err error)
IsExistInstance(ctx context.Context, nodeID string) (success bool)
CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error)
DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error)
GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error)
}

type cloud struct {
Expand Down Expand Up @@ -245,6 +268,10 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions *
if iops > 0 {
request.Iops = aws.Int64(iops)
}
snapshotID := diskOptions.SnapshotID
if len(snapshotID) > 0 {
request.SnapshotId = aws.String(snapshotID)
}

response, err := c.ec2.CreateVolumeWithContext(ctx, request)
if err != nil {
Expand Down Expand Up @@ -457,6 +484,84 @@ func (c *cloud) IsExistInstance(ctx context.Context, nodeID string) bool {
return true
}

func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) {
descriptions := "Created by AWS EBS CSI driver for volume " + volumeID

var tags []*ec2.Tag
for key, value := range snapshotOptions.Tags {
tags = append(tags, &ec2.Tag{Key: &key, Value: &value})
}
tagSpec := ec2.TagSpecification{
ResourceType: aws.String("snapshot"),
Tags: tags,
}
request := &ec2.CreateSnapshotInput{
VolumeId: aws.String(volumeID),
DryRun: aws.Bool(false),
TagSpecifications: []*ec2.TagSpecification{&tagSpec},
Description: aws.String(descriptions),
}

res, err := c.ec2.CreateSnapshotWithContext(ctx, request)
if err != nil {
return nil, fmt.Errorf("error creating snapshot of volume %s: %v", volumeID, err)
}
if res == nil {
return nil, fmt.Errorf("nil CreateSnapshotResponse")
}
err = c.waitForSnapshotCreate(ctx, res.SnapshotId)
if err != nil {
return nil, err
}

return c.ec2SnapshotResponseToStruct(res), nil
}

func (c *cloud) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) {
request := &ec2.DeleteSnapshotInput{}
request.SnapshotId = aws.String(snapshotID)
request.DryRun = aws.Bool(false)
if _, err := c.ec2.DeleteSnapshotWithContext(ctx, request); err != nil {
if isAWSErrorSnapshotNotFound(err) {
return false, ErrNotFound
}
return false, fmt.Errorf("DeleteSnapshot could not delete volume: %v", err)
}
return true, nil
}

func (c *cloud) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) {
request := &ec2.DescribeSnapshotsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("tag:" + SnapshotNameTagKey),
Values: []*string{aws.String(name)},
},
},
}

ec2snapshot, err := c.getSnapshot(ctx, request)
if err != nil {
return nil, err
}

return c.ec2SnapshotResponseToStruct(ec2snapshot), nil
}

// Helper method converting EC2 snapshot type to the internal struct
func (c *cloud) ec2SnapshotResponseToStruct(ec2Snapshot *ec2.Snapshot) *Snapshot {
if ec2Snapshot == nil {
return nil
}
snapshotSize := util.GiBToBytes(aws.Int64Value(ec2Snapshot.VolumeSize))
return &Snapshot{
SnapshotID: aws.StringValue(ec2Snapshot.SnapshotId),
SourceVolumeID: aws.StringValue(ec2Snapshot.VolumeId),
Size: snapshotSize,
CreationTime: aws.TimeValue(ec2Snapshot.StartTime),
}
}

func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) {
var volumes []*ec2.Volume
var nextToken *string
Expand Down Expand Up @@ -516,6 +621,32 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance,
return instances[0], nil
}

func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) {
var snapshots []*ec2.Snapshot
var nextToken *string

for {
response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request)
if err != nil {
return nil, err
}
snapshots = append(snapshots, response.Snapshots...)
nextToken = response.NextToken
if aws.StringValue(nextToken) == "" {
break
}
request.NextToken = nextToken
}

if l := len(snapshots); l > 1 {
return nil, errors.New("Multiple snapshots with the same name found")
} else if l < 1 {
return nil, ErrNotFound
}

return snapshots[0], nil
}

// waitForVolume waits for volume to be in the "available" state.
// On a random AWS account (shared among several developers) it took 4s on average.
func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error {
Expand Down Expand Up @@ -564,3 +695,56 @@ func isAWSErrorVolumeNotFound(err error) bool {
}
return false
}

// Helper function for describeSnapshot callers. Tries to retype given error to AWS error
// and returns true in case the AWS error is "InvalidSnapshot.NotFound", false otherwise
func isAWSErrorSnapshotNotFound(err error) bool {
if awsError, ok := err.(awserr.Error); ok {
// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html
if awsError.Code() == "InvalidSnapshot.NotFound" {
return true
}
}

return false
}

func (c *cloud) waitForSnapshotCreate(ctx context.Context, snapshotID *string) error {
// This should give about 1 minute maximal interval
backoff := wait.Backoff{
Duration: 1 * time.Second,
Factor: 1.5,
Steps: 10,
}
request := &ec2.DescribeSnapshotsInput{
SnapshotIds: []*string{
snapshotID,
},
}

conditionFunc := func() (done bool, err error) {
snapshot, err := c.getSnapshot(ctx, request)
if err != nil {
return true, err
}
if snapshot.State != nil {
switch *snapshot.State {
case "completed":
return true, nil
case "pending":
return false, nil
default:
return true, fmt.Errorf("unexpected State of newly created AWS EBS snapshot %v: %q", snapshotID, *snapshot.State)
}
}
return false, nil
}

// Truncated exponential backoff: if the exponential backoff times-out, just keep polling using the longest interval
err := wait.ExponentialBackoff(backoff, conditionFunc)
if err == wait.ErrWaitTimeout {
timeout := time.Duration(backoff.Duration.Seconds() * math.Pow(backoff.Factor, float64(backoff.Steps)))
err = wait.PollInfinite(timeout*time.Second, conditionFunc)
}
return err
}
Loading

0 comments on commit 8e5d99f

Please sign in to comment.