diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index e4022447e7..584c2c1889 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "math" "time" "github.com/aws/aws-sdk-go/aws" @@ -73,6 +74,8 @@ const ( const ( // VolumeNameTagKey is the key value that refers to the volume's name. VolumeNameTagKey = "CSIVolumeName" + // SnapshotNameTagKey is the key value that refers to the snapshot's name. + SnapshotNameTagKey = "CSIVolumeSnapshotName" ) var ( @@ -109,7 +112,21 @@ type DiskOptions struct { Encrypted bool // KmsKeyID represents a fully qualified resource name to the key to use for encryption. // example: arn:aws:kms:us-east-1:012345678910:key/abcd1234-a123-456a-a12b-a123b4cd56ef - KmsKeyID string + KmsKeyID string + SnapshotID string +} + +// Snapshot represents an EBS volume snapshot +type Snapshot struct { + SnapshotID string + SourceVolumeID string + Size int64 + CreationTime time.Time +} + +// SnapshotOptions represents parameters to create an EBS volume +type SnapshotOptions struct { + Tags map[string]string } // EC2 abstracts aws.EC2 to facilitate its mocking. @@ -121,6 +138,9 @@ type EC2 interface { DetachVolumeWithContext(ctx aws.Context, input *ec2.DetachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error) AttachVolumeWithContext(ctx aws.Context, input *ec2.AttachVolumeInput, opts ...request.Option) (*ec2.VolumeAttachment, error) DescribeInstancesWithContext(ctx aws.Context, input *ec2.DescribeInstancesInput, opts ...request.Option) (*ec2.DescribeInstancesOutput, error) + CreateSnapshotWithContext(ctx aws.Context, input *ec2.CreateSnapshotInput, opts ...request.Option) (*ec2.Snapshot, error) + DeleteSnapshotWithContext(ctx aws.Context, input *ec2.DeleteSnapshotInput, opts ...request.Option) (*ec2.DeleteSnapshotOutput, error) + DescribeSnapshotsWithContext(ctx aws.Context, input *ec2.DescribeSnapshotsInput, opts ...request.Option) (*ec2.DescribeSnapshotsOutput, error) } type Cloud interface { @@ -133,6 +153,9 @@ type Cloud interface { GetDiskByName(ctx context.Context, name string, capacityBytes int64) (disk *Disk, err error) GetDiskByID(ctx context.Context, volumeID string) (disk *Disk, err error) IsExistInstance(ctx context.Context, nodeID string) (success bool) + CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) + DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) + GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) } type cloud struct { @@ -245,6 +268,10 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * if iops > 0 { request.Iops = aws.Int64(iops) } + snapshotID := diskOptions.SnapshotID + if len(snapshotID) > 0 { + request.SnapshotId = aws.String(snapshotID) + } response, err := c.ec2.CreateVolumeWithContext(ctx, request) if err != nil { @@ -457,6 +484,84 @@ func (c *cloud) IsExistInstance(ctx context.Context, nodeID string) bool { return true } +func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) { + descriptions := "Created by AWS EBS CSI driver for volume " + volumeID + + var tags []*ec2.Tag + for key, value := range snapshotOptions.Tags { + tags = append(tags, &ec2.Tag{Key: &key, Value: &value}) + } + tagSpec := ec2.TagSpecification{ + ResourceType: aws.String("snapshot"), + Tags: tags, + } + request := &ec2.CreateSnapshotInput{ + VolumeId: aws.String(volumeID), + DryRun: aws.Bool(false), + TagSpecifications: []*ec2.TagSpecification{&tagSpec}, + Description: aws.String(descriptions), + } + + res, err := c.ec2.CreateSnapshotWithContext(ctx, request) + if err != nil { + return nil, fmt.Errorf("error creating snapshot of volume %s: %v", volumeID, err) + } + if res == nil { + return nil, fmt.Errorf("nil CreateSnapshotResponse") + } + err = c.waitForSnapshotCreate(ctx, res.SnapshotId) + if err != nil { + return nil, err + } + + return c.ec2SnapshotResponseToStruct(res), nil +} + +func (c *cloud) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) { + request := &ec2.DeleteSnapshotInput{} + request.SnapshotId = aws.String(snapshotID) + request.DryRun = aws.Bool(false) + if _, err := c.ec2.DeleteSnapshotWithContext(ctx, request); err != nil { + if isAWSErrorSnapshotNotFound(err) { + return false, ErrNotFound + } + return false, fmt.Errorf("DeleteSnapshot could not delete volume: %v", err) + } + return true, nil +} + +func (c *cloud) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) { + request := &ec2.DescribeSnapshotsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:" + SnapshotNameTagKey), + Values: []*string{aws.String(name)}, + }, + }, + } + + ec2snapshot, err := c.getSnapshot(ctx, request) + if err != nil { + return nil, err + } + + return c.ec2SnapshotResponseToStruct(ec2snapshot), nil +} + +// Helper method converting EC2 snapshot type to the internal struct +func (c *cloud) ec2SnapshotResponseToStruct(ec2Snapshot *ec2.Snapshot) *Snapshot { + if ec2Snapshot == nil { + return nil + } + snapshotSize := util.GiBToBytes(aws.Int64Value(ec2Snapshot.VolumeSize)) + return &Snapshot{ + SnapshotID: aws.StringValue(ec2Snapshot.SnapshotId), + SourceVolumeID: aws.StringValue(ec2Snapshot.VolumeId), + Size: snapshotSize, + CreationTime: aws.TimeValue(ec2Snapshot.StartTime), + } +} + func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { var volumes []*ec2.Volume var nextToken *string @@ -516,6 +621,32 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, return instances[0], nil } +func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { + var snapshots []*ec2.Snapshot + var nextToken *string + + for { + response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) + if err != nil { + return nil, err + } + snapshots = append(snapshots, response.Snapshots...) + nextToken = response.NextToken + if aws.StringValue(nextToken) == "" { + break + } + request.NextToken = nextToken + } + + if l := len(snapshots); l > 1 { + return nil, errors.New("Multiple snapshots with the same name found") + } else if l < 1 { + return nil, ErrNotFound + } + + return snapshots[0], nil +} + // waitForVolume waits for volume to be in the "available" state. // On a random AWS account (shared among several developers) it took 4s on average. func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error { @@ -564,3 +695,56 @@ func isAWSErrorVolumeNotFound(err error) bool { } return false } + +// Helper function for describeSnapshot callers. Tries to retype given error to AWS error +// and returns true in case the AWS error is "InvalidSnapshot.NotFound", false otherwise +func isAWSErrorSnapshotNotFound(err error) bool { + if awsError, ok := err.(awserr.Error); ok { + // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html + if awsError.Code() == "InvalidSnapshot.NotFound" { + return true + } + } + + return false +} + +func (c *cloud) waitForSnapshotCreate(ctx context.Context, snapshotID *string) error { + // This should give about 1 minute maximal interval + backoff := wait.Backoff{ + Duration: 1 * time.Second, + Factor: 1.5, + Steps: 10, + } + request := &ec2.DescribeSnapshotsInput{ + SnapshotIds: []*string{ + snapshotID, + }, + } + + conditionFunc := func() (done bool, err error) { + snapshot, err := c.getSnapshot(ctx, request) + if err != nil { + return true, err + } + if snapshot.State != nil { + switch *snapshot.State { + case "completed": + return true, nil + case "pending": + return false, nil + default: + return true, fmt.Errorf("unexpected State of newly created AWS EBS snapshot %v: %q", snapshotID, *snapshot.State) + } + } + return false, nil + } + + // Truncated exponential backoff: if the exponential backoff times-out, just keep polling using the longest interval + err := wait.ExponentialBackoff(backoff, conditionFunc) + if err == wait.ErrWaitTimeout { + timeout := time.Duration(backoff.Duration.Seconds() * math.Pow(backoff.Factor, float64(backoff.Steps))) + err = wait.PollInfinite(timeout*time.Second, conditionFunc) + } + return err +} diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index b2c56dc550..c993749418 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -127,6 +127,22 @@ func TestCreateDisk(t *testing.T) { }, expErr: fmt.Errorf("failed to get an available volume in EC2: timed out waiting for the condition"), }, + { + name: "success: normal from snapshot", + volumeName: "vol-test-name", + diskOptions: &DiskOptions{ + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: expZone, + SnapshotID: "snapshot-test", + }, + expDisk: &Disk{ + VolumeID: "vol-test", + CapacityGiB: 1, + AvailabilityZone: expZone, + }, + expErr: nil, + }, } for _, tc := range testCases { @@ -146,10 +162,18 @@ func TestCreateDisk(t *testing.T) { State: aws.String(volState), AvailabilityZone: aws.String(tc.diskOptions.AvailabilityZone), } + snapshot := &ec2.Snapshot{ + SnapshotId: aws.String(tc.diskOptions.SnapshotID), + VolumeId: aws.String("snap-test-volume"), + State: aws.String("completed"), + } ctx := context.Background() mockEC2.EXPECT().CreateVolumeWithContext(gomock.Eq(ctx), gomock.Any()).Return(vol, tc.expCreateVolumeErr) mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeVolumesOutput{Volumes: []*ec2.Volume{vol}}, tc.expDescVolumeErr).AnyTimes() + if len(tc.diskOptions.SnapshotID) > 0 { + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{snapshot}}, nil).AnyTimes() + } disk, err := c.CreateDisk(ctx, tc.volumeName, tc.diskOptions) if err != nil { @@ -457,6 +481,164 @@ func TestGetDiskByID(t *testing.T) { } } +func TestCreateSnapshot(t *testing.T) { + testCases := []struct { + name string + snapshotName string + snapshotOptions *SnapshotOptions + expSnapshot *Snapshot + expErr error + }{ + { + name: "success: normal", + snapshotName: "snap-test-name", + snapshotOptions: &SnapshotOptions{ + Tags: map[string]string{ + SnapshotNameTagKey: "snap-test-name", + }, + }, + expSnapshot: &Snapshot{ + SourceVolumeID: "snap-test-volume", + }, + expErr: nil, + }, + } + + for _, tc := range testCases { + t.Logf("Test case: %s", tc.name) + mockCtrl := gomock.NewController(t) + mockEC2 := mocks.NewMockEC2(mockCtrl) + c := newCloud(mockEC2) + + ec2snapshot := &ec2.Snapshot{ + SnapshotId: aws.String(tc.snapshotOptions.Tags[SnapshotNameTagKey]), + VolumeId: aws.String("snap-test-volume"), + State: aws.String("completed"), + } + + ctx := context.Background() + mockEC2.EXPECT().CreateSnapshotWithContext(gomock.Eq(ctx), gomock.Any()).Return(ec2snapshot, tc.expErr) + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{ec2snapshot}}, nil).AnyTimes() + + snapshot, err := c.CreateSnapshot(ctx, tc.expSnapshot.SourceVolumeID, tc.snapshotOptions) + if err != nil { + if tc.expErr == nil { + t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err) + } + } else { + if tc.expErr != nil { + t.Fatal("CreateSnapshot() failed: expected error, got nothing") + } else { + if snapshot.SourceVolumeID != tc.expSnapshot.SourceVolumeID { + t.Fatalf("CreateSnapshot() failed: expected source volume ID %s, got %v", tc.expSnapshot.SourceVolumeID, snapshot.SourceVolumeID) + } + } + } + + mockCtrl.Finish() + } +} + +func TestDeleteSnapshot(t *testing.T) { + testCases := []struct { + name string + snapshotName string + snapshotOptions *SnapshotOptions + expSnapshot *Snapshot + expErr error + }{ + { + name: "success: normal", + snapshotName: "snap-test-name", + snapshotOptions: &SnapshotOptions{ + Tags: map[string]string{ + SnapshotNameTagKey: "snap-test-name", + }, + }, + expSnapshot: &Snapshot{ + SourceVolumeID: "snap-test-volume", + }, + expErr: nil, + }, + } + + for _, tc := range testCases { + t.Logf("Test case: %s", tc.name) + mockCtrl := gomock.NewController(t) + mockEC2 := mocks.NewMockEC2(mockCtrl) + c := newCloud(mockEC2) + + ctx := context.Background() + mockEC2.EXPECT().DeleteSnapshotWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DeleteSnapshotOutput{}, tc.expErr) + + _, err := c.DeleteSnapshot(ctx, tc.snapshotOptions.Tags[SnapshotNameTagKey]) + if err != nil { + if tc.expErr == nil { + t.Fatalf("DeleteSnapshot() failed: expected no error, got: %v", err) + } + } else { + if tc.expErr != nil { + t.Fatal("DeleteSnapshot() failed: expected error, got nothing") + } + } + + mockCtrl.Finish() + } +} + +func TestGetSnapshotByName(t *testing.T) { + testCases := []struct { + name string + snapshotName string + snapshotOptions *SnapshotOptions + expSnapshot *Snapshot + expErr error + }{ + { + name: "success: normal", + snapshotName: "snap-test-name", + snapshotOptions: &SnapshotOptions{ + Tags: map[string]string{ + SnapshotNameTagKey: "snap-test-name", + }, + }, + expSnapshot: &Snapshot{ + SourceVolumeID: "snap-test-volume", + }, + expErr: nil, + }, + } + + for _, tc := range testCases { + t.Logf("Test case: %s", tc.name) + mockCtrl := gomock.NewController(t) + mockEC2 := mocks.NewMockEC2(mockCtrl) + c := newCloud(mockEC2) + + ec2snapshot := &ec2.Snapshot{ + SnapshotId: aws.String(tc.snapshotOptions.Tags[SnapshotNameTagKey]), + VolumeId: aws.String("snap-test-volume"), + State: aws.String("completed"), + } + + ctx := context.Background() + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{ec2snapshot}}, nil) + + _, err := c.GetSnapshotByName(ctx, tc.snapshotOptions.Tags[SnapshotNameTagKey]) + if err != nil { + if tc.expErr == nil { + t.Fatalf("GetSnapshotByName() failed: expected no error, got: %v", err) + } + } else { + if tc.expErr != nil { + t.Fatal("GetSnapshotByName() failed: expected error, got nothing") + } + } + + mockCtrl.Finish() + } +} + func newCloud(mockEC2 EC2) Cloud { return &cloud{ metadata: &metadata{ diff --git a/pkg/cloud/fakes.go b/pkg/cloud/fakes.go index 8bc085fe9b..e6941fac43 100644 --- a/pkg/cloud/fakes.go +++ b/pkg/cloud/fakes.go @@ -26,9 +26,10 @@ import ( ) type FakeCloudProvider struct { - disks map[string]*fakeDisk - m *metadata - pub map[string]string + disks map[string]*fakeDisk + snapshots map[string]*fakeSnapshot + m *metadata + pub map[string]string } type fakeDisk struct { @@ -36,11 +37,17 @@ type fakeDisk struct { tags map[string]string } +type fakeSnapshot struct { + *Snapshot + tags map[string]string +} + func NewFakeCloudProvider() *FakeCloudProvider { return &FakeCloudProvider{ - disks: make(map[string]*fakeDisk), - pub: make(map[string]string), - m: &metadata{"instanceID", "region", "az"}, + disks: make(map[string]*fakeDisk), + snapshots: make(map[string]*fakeSnapshot), + pub: make(map[string]string), + m: &metadata{"instanceID", "region", "az"}, } } @@ -119,3 +126,45 @@ func (c *FakeCloudProvider) GetDiskByID(ctx context.Context, volumeID string) (* func (c *FakeCloudProvider) IsExistInstance(ctx context.Context, nodeID string) bool { return nodeID == c.m.GetInstanceID() } + +func (c *FakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) { + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + snapshotID := fmt.Sprintf("snapshot-%d", r1.Uint64()) + if len(snapshotOptions.Tags[SnapshotNameTagKey]) == 0 { + // for simplicity: let's have the Name and ID identical + snapshotOptions.Tags[SnapshotNameTagKey] = snapshotID + } + s := &fakeSnapshot{ + Snapshot: &Snapshot{ + SnapshotID: snapshotID, + SourceVolumeID: volumeID, + Size: 1, + CreationTime: time.Now(), + }, + tags: snapshotOptions.Tags, + } + c.snapshots[snapshotID] = s + return s.Snapshot, nil + +} + +func (c *FakeCloudProvider) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) { + delete(c.snapshots, snapshotID) + return true, nil + +} + +func (c *FakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) { + var snapshots []*fakeSnapshot + for _, s := range c.snapshots { + for key, value := range s.tags { + if key == SnapshotNameTagKey && value == name { + snapshots = append(snapshots, s) + } + } + } + if len(snapshots) == 0 { + return nil, nil + } + return snapshots[0].Snapshot, nil +} diff --git a/pkg/cloud/mocks/mock_ec2.go b/pkg/cloud/mocks/mock_ec2.go index 53d4ac1ba2..d893d813e5 100644 --- a/pkg/cloud/mocks/mock_ec2.go +++ b/pkg/cloud/mocks/mock_ec2.go @@ -126,3 +126,51 @@ func (_mr *_MockEC2Recorder) DetachVolumeWithContext(arg0, arg1 interface{}, arg _s := append([]interface{}{arg0, arg1}, arg2...) return _mr.mock.ctrl.RecordCall(_mr.mock, "DetachVolumeWithContext", _s...) } + +func (_m *MockEC2) CreateSnapshotWithContext(arg0 aws.Context, arg1 *ec2.CreateSnapshotInput, arg2 ...request.Option) (*ec2.Snapshot, error) { + _s := []interface{}{arg0, arg1} + for _, a := range arg2 { + _s = append(_s, a) + } + ret := _m.ctrl.Call(_m, "CreateSnapshotWithContext", _s...) + ret0, _ := ret[0].(*ec2.Snapshot) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockEC2Recorder) CreateSnapshotWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + _s := append([]interface{}{arg0, arg1}, arg2...) + return _mr.mock.ctrl.RecordCall(_mr.mock, "CreateSnapshotWithContext", _s...) +} + +func (_m *MockEC2) DeleteSnapshotWithContext(arg0 aws.Context, arg1 *ec2.DeleteSnapshotInput, arg2 ...request.Option) (*ec2.DeleteSnapshotOutput, error) { + _s := []interface{}{arg0, arg1} + for _, a := range arg2 { + _s = append(_s, a) + } + ret := _m.ctrl.Call(_m, "DeleteSnapshotWithContext", _s...) + ret0, _ := ret[0].(*ec2.DeleteSnapshotOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockEC2Recorder) DeleteSnapshotWithContext(arg0, arg1 interface{}) *gomock.Call { + _s := []interface{}{arg0, arg1} + return _mr.mock.ctrl.RecordCall(_mr.mock, "DeleteSnapshotWithContext", _s...) +} + +func (_m *MockEC2) DescribeSnapshotsWithContext(arg0 aws.Context, arg1 *ec2.DescribeSnapshotsInput, arg2 ...request.Option) (*ec2.DescribeSnapshotsOutput, error) { + _s := []interface{}{arg0, arg1} + for _, a := range arg2 { + _s = append(_s, a) + } + ret := _m.ctrl.Call(_m, "DescribeSnapshotsWithContext", _s...) + ret0, _ := ret[0].(*ec2.DescribeSnapshotsOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockEC2Recorder) DescribeSnapshotsWithContext(arg0, arg1 interface{}) *gomock.Call { + _s := []interface{}{arg0, arg1} + return _mr.mock.ctrl.RecordCall(_mr.mock, "DescribeSnapshotsWithContext", _s...) +} diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 634879ffc5..10bb9477e8 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -21,6 +21,7 @@ import ( "strconv" csi "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/protobuf/ptypes" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "google.golang.org/grpc/codes" @@ -42,6 +43,7 @@ var ( controllerCaps = []csi.ControllerServiceCapability_RPC_Type{ csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, } ) @@ -124,6 +126,19 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) Encrypted: isEncrypted, KmsKeyID: kmsKeyId, } + + volumeSource := req.GetVolumeContentSource() + if volumeSource != nil { + if _, ok := volumeSource.GetType().(*csi.VolumeContentSource_Snapshot); !ok { + return nil, status.Error(codes.InvalidArgument, "Unsupported volumeContentSource type") + } + sourceSnapshot := volumeSource.GetSnapshot() + if sourceSnapshot == nil { + return nil, status.Error(codes.InvalidArgument, "Error retrieving snapshot from the volumeContentSource") + } + opts.SnapshotID = sourceSnapshot.GetSnapshotId() + } + disk, err = d.cloud.CreateDisk(ctx, volName, opts) if err != nil { return nil, status.Errorf(codes.Internal, "Could not create volume %q: %v", volName, err) @@ -290,11 +305,56 @@ func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool } func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { - return nil, status.Error(codes.Unimplemented, "") + klog.V(4).Infof("CreateSnapshot: called with args %+v", req) + snapshotName := req.GetName() + if len(snapshotName) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot name not provided") + } + + volumeID := req.GetSourceVolumeId() + if len(volumeID) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot volume source ID not provided") + } + snapshot, err := d.cloud.GetSnapshotByName(ctx, snapshotName) + if err != nil && err != cloud.ErrNotFound { + klog.Errorf("Error looking for the snapshot %s: %v", snapshotName, err) + return nil, err + } + if snapshot != nil { + if snapshot.SourceVolumeID != volumeID { + return nil, status.Errorf(codes.AlreadyExists, "Snapshot %s already exists for different volume (%s)", snapshotName, snapshot.SourceVolumeID) + } else { + klog.Infof("Snapshot %s of volume %s already exists; nothing to do", snapshotName, volumeID) + return newCreateSnapshotResponse(snapshot) + } + } + opts := &cloud.SnapshotOptions{ + Tags: map[string]string{cloud.SnapshotNameTagKey: snapshotName}, + } + snapshot, err = d.cloud.CreateSnapshot(ctx, volumeID, opts) + + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not create snapshot %q: %v", snapshotName, err) + } + return newCreateSnapshotResponse(snapshot) } func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { - return nil, status.Error(codes.Unimplemented, "") + klog.V(4).Infof("DeleteSnapshot: called with args %+v", req) + snapshotID := req.GetSnapshotId() + if len(snapshotID) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot ID not provided") + } + + if _, err := d.cloud.DeleteSnapshot(ctx, snapshotID); err != nil { + if err == cloud.ErrNotFound { + klog.V(4).Info("DeleteSnapshot: snapshot not found, returning with success") + return &csi.DeleteSnapshotResponse{}, nil + } + return nil, status.Errorf(codes.Internal, "Could not delete snapshot ID %q: %v", snapshotID, err) + } + + return &csi.DeleteSnapshotResponse{}, nil } func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { @@ -338,3 +398,19 @@ func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { }, } } + +func newCreateSnapshotResponse(snapshot *cloud.Snapshot) (*csi.CreateSnapshotResponse, error) { + ts, err := ptypes.TimestampProto(snapshot.CreationTime) + if err != nil { + return nil, err + } + return &csi.CreateSnapshotResponse{ + Snapshot: &csi.Snapshot{ + SnapshotId: snapshot.SnapshotID, + SourceVolumeId: snapshot.SourceVolumeID, + SizeBytes: snapshot.Size, + CreationTime: ts, + ReadyToUse: true, // In AWS it's eiter this or error + }, + }, nil +} diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 7317aacc2a..675386f101 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -429,3 +429,163 @@ func TestPickAvailabilityZone(t *testing.T) { } } + +func TestCreateSnapshot(t *testing.T) { + testCases := []struct { + name string + req *csi.CreateSnapshotRequest + extraReq *csi.CreateSnapshotRequest + expSnapshot *csi.Snapshot + expErrCode codes.Code + extraExpErrCode codes.Code + }{ + { + name: "success normal", + req: &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + }, + expSnapshot: &csi.Snapshot{ + ReadyToUse: true, + }, + expErrCode: codes.OK, + }, + { + name: "fail no name", + req: &csi.CreateSnapshotRequest{ + Parameters: nil, + SourceVolumeId: "vol-test", + }, + expSnapshot: nil, + expErrCode: codes.InvalidArgument, + }, + { + name: "fail same name different volume ID", + req: &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + }, + extraReq: &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-xxx", + }, + expSnapshot: &csi.Snapshot{ + ReadyToUse: true, + }, + expErrCode: codes.OK, + extraExpErrCode: codes.AlreadyExists, + }, + { + name: "success same name same volume ID", + req: &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + }, + extraReq: &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + }, + expSnapshot: &csi.Snapshot{ + ReadyToUse: true, + }, + expErrCode: codes.OK, + extraExpErrCode: codes.OK, + }, + } + for _, tc := range testCases { + t.Logf("Test case: %s", tc.name) + awsDriver := NewFakeDriver("", NewFakeMounter()) + resp, err := awsDriver.CreateSnapshot(context.TODO(), tc.req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != tc.expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + } + continue + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", tc.expErrCode) + } + snap := resp.GetSnapshot() + if snap == nil && tc.expSnapshot != nil { + t.Fatalf("Expected snapshot %v, got nil", tc.expSnapshot) + } + if tc.extraReq != nil { + // extraReq is never used in a situation when a new snapshot + // should be really created: checking the return code is enough + _, err = awsDriver.CreateSnapshot(context.TODO(), tc.extraReq) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != tc.extraExpErrCode { + t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + } + continue + } + if tc.extraExpErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", tc.extraExpErrCode) + } + } + } +} + +func TestDeleteSnapshot(t *testing.T) { + snapReq := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + testCases := []struct { + name string + req *csi.DeleteSnapshotRequest + expErrCode codes.Code + }{ + { + name: "success normal", + req: &csi.DeleteSnapshotRequest{}, + expErrCode: codes.OK, + }, + { + name: "success not found", + req: &csi.DeleteSnapshotRequest{ + SnapshotId: "xxx", + }, + expErrCode: codes.OK, + }, + } + for _, tc := range testCases { + t.Logf("Test case: %s", tc.name) + awsDriver := NewFakeDriver("", NewFakeMounter()) + snapResp, err := awsDriver.CreateSnapshot(context.TODO(), snapReq) + if err != nil { + t.Fatalf("Error creating testing snapshot: %v", err) + } + if len(tc.req.SnapshotId) == 0 { + tc.req.SnapshotId = snapResp.Snapshot.SnapshotId + } + _, err = awsDriver.DeleteSnapshot(context.TODO(), tc.req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != tc.expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + } + continue + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", tc.expErrCode) + } + } +}