From 44bf7e1df3fb6c0bcd7eb854733e0f2cb08810bf Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Thu, 25 Apr 2019 10:27:56 -0400 Subject: [PATCH 1/3] Implementing ListSnapshots --- pkg/cloud/cloud.go | 79 +++++++++++- pkg/cloud/cloud_test.go | 160 ++++++++++++++++++++++++ pkg/driver/controller.go | 78 +++++++++++- pkg/driver/controller_test.go | 183 ++++++++++++++++++++++++++++ pkg/driver/mocks/mock_cloud.go | 15 +++ tests/sanity/fake_cloud_provider.go | 10 ++ 6 files changed, 523 insertions(+), 2 deletions(-) diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index fda0dcdff0..6832960865 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -91,6 +91,10 @@ var ( // ErrAlreadyExists is returned when a resource is already existent. ErrAlreadyExists = errors.New("Resource already exists") + + // ErrMultiSnapshots is returned when multiple snapshots are found + // with the same ID + ErrMultiSnapshots = errors.New("Multiple snapshots with the same name found") ) // Disk represents a EBS volume @@ -124,11 +128,23 @@ type Snapshot struct { ReadyToUse bool } +// ListSnapshotsResponse is the container for our snapshots along with a pagination token to pass back to the caller +type ListSnapshotsResponse struct { + Snapshots []*Snapshot + NextToken string +} + // SnapshotOptions represents parameters to create an EBS volume type SnapshotOptions struct { Tags map[string]string } +// ec2ListSnapshotsResponse is a helper struct returned from the AWS API calling function to the main ListSnapshots function +type ec2ListSnapshotsResponse struct { + Snapshots []*ec2.Snapshot + NextToken string +} + // EC2 abstracts aws.EC2 to facilitate its mocking. // See https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/ for details type EC2 interface { @@ -156,6 +172,7 @@ type Cloud interface { CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) + ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) } type cloud struct { @@ -542,6 +559,44 @@ func (c *cloud) GetSnapshotByName(ctx context.Context, name string) (snapshot *S return c.ec2SnapshotResponseToStruct(ec2snapshot), nil } +func (c *cloud) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) { + describeSnapshotsInput := &ec2.DescribeSnapshotsInput{} + + if maxResults >= 5 { + describeSnapshotsInput.MaxResults = aws.Int64(maxResults) + } + + if len(nextToken) != 0 { + describeSnapshotsInput.NextToken = aws.String(nextToken) + } + if len(volumeID) != 0 { + describeSnapshotsInput.Filters = []*ec2.Filter{ + { + Name: aws.String("volume-id"), + Values: []*string{aws.String(volumeID)}, + }, + } + } + + ec2SnapshotsResponse, err := c.listSnapshots(ctx, describeSnapshotsInput) + if err != nil { + return nil, err + } + var snapshots []*Snapshot + for _, ec2Snapshot := range ec2SnapshotsResponse.Snapshots { + snapshots = append(snapshots, c.ec2SnapshotResponseToStruct(ec2Snapshot)) + } + + if len(snapshots) == 0 { + return nil, ErrNotFound + } + + return &ListSnapshotsResponse{ + Snapshots: snapshots, + NextToken: ec2SnapshotsResponse.NextToken, + }, nil +} + // Helper method converting EC2 snapshot type to the internal struct func (c *cloud) ec2SnapshotResponseToStruct(ec2Snapshot *ec2.Snapshot) *Snapshot { if ec2Snapshot == nil { @@ -640,7 +695,7 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI } if l := len(snapshots); l > 1 { - return nil, errors.New("Multiple snapshots with the same name found") + return nil, ErrMultiSnapshots } else if l < 1 { return nil, ErrNotFound } @@ -648,6 +703,28 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI return snapshots[0], nil } +// listSnapshots returns all snapshots based from a request +func (c *cloud) listSnapshots(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2ListSnapshotsResponse, error) { + var snapshots []*ec2.Snapshot + var nextToken string + + response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) + if err != nil { + return nil, err + } + + snapshots = append(snapshots, response.Snapshots...) + + if response.NextToken != nil { + nextToken = *response.NextToken + } + + return &ec2ListSnapshotsResponse{ + Snapshots: snapshots, + NextToken: nextToken, + }, nil +} + // waitForVolume waits for volume to be in the "available" state. // On a random AWS account (shared among several developers) it took 4s on average. func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error { diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 916209c3d7..8e1dc51d81 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -18,6 +18,7 @@ package cloud import ( "context" + "errors" "fmt" "strings" "testing" @@ -647,6 +648,165 @@ func TestGetSnapshotByName(t *testing.T) { } } +func TestListSnapshots(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "success: normal", + testFunc: func(t *testing.T) { + expSnapshots := []*Snapshot{ + { + SourceVolumeID: "snap-test-volume1", + SnapshotID: "snap-test-name1", + }, + { + SourceVolumeID: "snap-test-volume2", + SnapshotID: "snap-test-name2", + }, + } + ec2Snapshots := []*ec2.Snapshot{ + { + SnapshotId: aws.String(expSnapshots[0].SnapshotID), + VolumeId: aws.String("snap-test-volume1"), + State: aws.String("completed"), + }, + { + SnapshotId: aws.String(expSnapshots[1].SnapshotID), + VolumeId: aws.String("snap-test-volume2"), + State: aws.String("completed"), + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: ec2Snapshots}, nil) + + _, err := c.ListSnapshots(ctx, "", 0, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + }, + }, + { + name: "success: max results, next token", + testFunc: func(t *testing.T) { + maxResults := 5 + nextTokenValue := "nextTokenValue" + var expSnapshots []*Snapshot + for i := 0; i < maxResults*2; i++ { + expSnapshots = append(expSnapshots, &Snapshot{ + SourceVolumeID: "snap-test-volume1", + SnapshotID: fmt.Sprintf("snap-test-name%d", i), + }) + } + + var ec2Snapshots []*ec2.Snapshot + for i := 0; i < maxResults*2; i++ { + ec2Snapshots = append(ec2Snapshots, &ec2.Snapshot{ + SnapshotId: aws.String(expSnapshots[i].SnapshotID), + VolumeId: aws.String(fmt.Sprintf("snap-test-volume%d", i)), + State: aws.String("completed"), + }) + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + firstCall := mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{ + Snapshots: ec2Snapshots[:maxResults], + NextToken: aws.String(nextTokenValue), + }, nil) + secondCall := mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{ + Snapshots: ec2Snapshots[maxResults:], + }, nil) + gomock.InOrder( + firstCall, + secondCall, + ) + + firstSnapshotsResponse, err := c.ListSnapshots(ctx, "", 5, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + + if len(firstSnapshotsResponse.Snapshots) != maxResults { + t.Fatalf("Expected %d snapshots, got %d", maxResults, len(firstSnapshotsResponse.Snapshots)) + } + + if firstSnapshotsResponse.NextToken != nextTokenValue { + t.Fatalf("Expected next token value '%s' got '%s'", nextTokenValue, firstSnapshotsResponse.NextToken) + } + + secondSnapshotsResponse, err := c.ListSnapshots(ctx, "", 0, firstSnapshotsResponse.NextToken) + if err != nil { + t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err) + } + + if len(secondSnapshotsResponse.Snapshots) != maxResults { + t.Fatalf("Expected %d snapshots, got %d", maxResults, len(secondSnapshotsResponse.Snapshots)) + } + + if secondSnapshotsResponse.NextToken != "" { + t.Fatalf("Expected next token value to be empty got %s", secondSnapshotsResponse.NextToken) + } + }, + }, + { + name: "fail: AWS DescribeSnapshotsWithContext error", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(nil, errors.New("test error")) + + if _, err := c.ListSnapshots(ctx, "", 0, ""); err == nil { + t.Fatalf("ListSnapshots() failed: expected an error, got none") + } + }, + }, + { + name: "fail: no snapshots ErrNotFound", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{}, nil) + + if _, err := c.ListSnapshots(ctx, "", 0, ""); err != nil { + if err != ErrNotFound { + t.Fatalf("Expected error %v, got %v", ErrNotFound, err) + } + } else { + t.Fatalf("Expected error, got none") + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + func newCloud(mockEC2 EC2) Cloud { return &cloud{ metadata: &Metadata{ diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 3e9c5efbba..73bcd81e50 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -369,7 +369,51 @@ func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteS } func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { - return nil, status.Error(codes.Unimplemented, "") + klog.V(4).Infof("ListSnapshots: called with args %+v", req) + var snapshots []*cloud.Snapshot + + snapshotID := req.GetSnapshotId() + if len(snapshotID) != 0 { + snapshot, err := d.cloud.GetSnapshotByName(ctx, snapshotID) + if err != nil { + if err == cloud.ErrNotFound { + klog.V(4).Info("ListSnapshots: snapshot not found, returning with success") + return &csi.ListSnapshotsResponse{}, nil + } + return nil, status.Errorf(codes.Internal, "Could not get snapshot ID %q: %v", snapshotID, err) + } + snapshots = append(snapshots, snapshot) + if response, err := newListSnapshotsResponse(&cloud.ListSnapshotsResponse{ + Snapshots: snapshots, + }); err != nil { + return nil, status.Errorf(codes.Internal, "Could not build ListSnapshotsResponse: %v", err) + } else { + return response, nil + } + } + + volumeID := req.GetSourceVolumeId() + nextToken := req.GetStartingToken() + maxEntries := int64(req.GetMaxEntries()) + + if maxEntries > 0 && maxEntries < 5 { + return nil, status.Errorf(codes.InvalidArgument, "MaxEntries must be greater than or equal to 5") + } + + cloudSnapshots, err := d.cloud.ListSnapshots(ctx, volumeID, maxEntries, nextToken) + if err != nil { + if err == cloud.ErrNotFound { + klog.V(4).Info("ListSnapshots: snapshot not found, returning with success") + return &csi.ListSnapshotsResponse{}, nil + } + return nil, status.Errorf(codes.Internal, "Could not list snapshots: %v", err) + } + + response, err := newListSnapshotsResponse(cloudSnapshots) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not build ListSnapshotsResponse: %v", err) + } + return response, nil } func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { @@ -430,6 +474,38 @@ func newCreateSnapshotResponse(snapshot *cloud.Snapshot) (*csi.CreateSnapshotRes }, nil } +func newListSnapshotsResponse(cloudResponse *cloud.ListSnapshotsResponse) (*csi.ListSnapshotsResponse, error) { + + var entries []*csi.ListSnapshotsResponse_Entry + for _, snapshot := range cloudResponse.Snapshots { + snapshotResponseEntry, err := newListSnapshotsResponseEntry(snapshot) + if err != nil { + return nil, err + } + entries = append(entries, snapshotResponseEntry) + } + return &csi.ListSnapshotsResponse{ + Entries: entries, + NextToken: cloudResponse.NextToken, + }, nil +} + +func newListSnapshotsResponseEntry(snapshot *cloud.Snapshot) (*csi.ListSnapshotsResponse_Entry, error) { + ts, err := ptypes.TimestampProto(snapshot.CreationTime) + if err != nil { + return nil, err + } + return &csi.ListSnapshotsResponse_Entry{ + Snapshot: &csi.Snapshot{ + SnapshotId: snapshot.SnapshotID, + SourceVolumeId: snapshot.SourceVolumeID, + SizeBytes: snapshot.Size, + CreationTime: ts, + ReadyToUse: snapshot.ReadyToUse, + }, + }, nil +} + func getVolSizeBytes(req *csi.CreateVolumeRequest) (int64, error) { var volSizeBytes int64 capRange := req.GetCapacityRange() diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 6e0ab037fc..d1fdd89322 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1136,6 +1136,189 @@ func TestDeleteSnapshot(t *testing.T) { } } +func TestListSnapshots(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "success normal", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{} + mockCloudSnapshotsResponse := &cloud.ListSnapshotsResponse{ + Snapshots: []*cloud.Snapshot{ + { + SnapshotID: "snapshot-1", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + }, + { + SnapshotID: "snapshot-2", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + }, + }, + NextToken: "", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(mockCloudSnapshotsResponse, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(resp.GetEntries()) != len(mockCloudSnapshotsResponse.Snapshots) { + t.Fatalf("Expected %d entries, got %d", len(mockCloudSnapshotsResponse.Snapshots), len(resp.GetEntries())) + } + }, + }, + { + name: "success no snapshots", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{} + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(resp, &csi.ListSnapshotsResponse{}) { + t.Fatalf("Expected empty response, got %+v", resp) + } + }, + }, + { + name: "success snapshot ID", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + mockCloudSnapshotsResponse := &cloud.Snapshot{ + SnapshotID: "snapshot-1", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(mockCloudSnapshotsResponse, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(resp.GetEntries()) != 1 { + t.Fatalf("Expected %d entry, got %d", 1, len(resp.GetEntries())) + } + }, + }, + { + name: "success snapshot ID not found", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(resp, &csi.ListSnapshotsResponse{}) { + t.Fatalf("Expected empty response, got %+v", resp) + } + }, + }, + { + name: "fail snapshot ID multiple found", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrMultiSnapshots) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.Internal { + t.Fatalf("Expected error code %d, got %d message %s", codes.Internal, srvErr.Code(), srvErr.Message()) + } + } else { + t.Fatalf("Expected error code %d, got no error", codes.Internal) + } + }, + }, + { + name: "fail 0 < MaxEntries < 5", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + MaxEntries: 4, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.Internal, srvErr.Code(), srvErr.Message()) + } + } else { + t.Fatalf("Expected error code %d, got no error", codes.Internal) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + func TestControllerPublishVolume(t *testing.T) { stdVolCap := &csi.VolumeCapability{ AccessType: &csi.VolumeCapability_Mount{ diff --git a/pkg/driver/mocks/mock_cloud.go b/pkg/driver/mocks/mock_cloud.go index 53b4aa199f..cb5f47016b 100644 --- a/pkg/driver/mocks/mock_cloud.go +++ b/pkg/driver/mocks/mock_cloud.go @@ -196,6 +196,21 @@ func (mr *MockCloudMockRecorder) IsExistInstance(arg0, arg1 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsExistInstance", reflect.TypeOf((*MockCloud)(nil).IsExistInstance), arg0, arg1) } +// ListSnapshots mocks base method +func (m *MockCloud) ListSnapshots(arg0 context.Context, arg1 string, arg2 int64, arg3 string) (*cloud.ListSnapshotsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSnapshots", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*cloud.ListSnapshotsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSnapshots indicates an expected call of ListSnapshots +func (mr *MockCloudMockRecorder) ListSnapshots(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockCloud)(nil).ListSnapshots), arg0, arg1, arg2, arg3) +} + // WaitForAttachmentState mocks base method func (m *MockCloud) WaitForAttachmentState(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() diff --git a/tests/sanity/fake_cloud_provider.go b/tests/sanity/fake_cloud_provider.go index d1e7fac2a5..02fb3786d4 100644 --- a/tests/sanity/fake_cloud_provider.go +++ b/tests/sanity/fake_cloud_provider.go @@ -173,3 +173,13 @@ func (c *fakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) } return snapshots[0].Snapshot, nil } + +func (c *fakeCloudProvider) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *cloud.ListSnapshotsResponse, err error) { + var snapshots []*cloud.Snapshot + for _, fakeSnapshot := range c.snapshots { + snapshots = append(snapshots, fakeSnapshot.Snapshot) + } + return &cloud.ListSnapshotsResponse{ + Snapshots: snapshots, + }, nil +} From fdca56406799ed4b05f36c5bd257b53f2389eff7 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Wed, 8 May 2019 11:05:06 -0400 Subject: [PATCH 2/3] Sanity tests passing --- pkg/cloud/cloud.go | 23 ++++++++----- pkg/cloud/cloud_test.go | 52 +++++++++++++++++++++++++++++ pkg/driver/controller.go | 8 ++--- tests/sanity/fake_cloud_provider.go | 39 +++++++++++++++++----- 4 files changed, 102 insertions(+), 20 deletions(-) diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 6832960865..f2d22b591e 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -95,6 +95,9 @@ var ( // ErrMultiSnapshots is returned when multiple snapshots are found // with the same ID ErrMultiSnapshots = errors.New("Multiple snapshots with the same name found") + + // ErrInvalidMaxResults is returned when a MaxResults pagination parameter is between 1 and 4 + ErrInvalidMaxResults = errors.New("MaxResults parameter must be 0 or greater than or equal to 5") ) // Disk represents a EBS volume @@ -142,7 +145,7 @@ type SnapshotOptions struct { // ec2ListSnapshotsResponse is a helper struct returned from the AWS API calling function to the main ListSnapshots function type ec2ListSnapshotsResponse struct { Snapshots []*ec2.Snapshot - NextToken string + NextToken *string } // EC2 abstracts aws.EC2 to facilitate its mocking. @@ -559,11 +562,16 @@ func (c *cloud) GetSnapshotByName(ctx context.Context, name string) (snapshot *S return c.ec2SnapshotResponseToStruct(ec2snapshot), nil } +// ListSnapshots retrieves AWS EBS snapshots for an optionally specified volume ID. If maxResults is set, it will return up to maxResults snapshots. If there are more snapshots than maxResults, +// a next token value will be returned to the client as well. They can use this token with subsequent calls to retrieve the next page of results. If maxResults is not set (0), +// there will be no restriction up to 1000 results (https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/#DescribeSnapshotsInput). func (c *cloud) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) { - describeSnapshotsInput := &ec2.DescribeSnapshotsInput{} + if maxResults > 0 && maxResults < 5 { + return nil, ErrInvalidMaxResults + } - if maxResults >= 5 { - describeSnapshotsInput.MaxResults = aws.Int64(maxResults) + describeSnapshotsInput := &ec2.DescribeSnapshotsInput{ + MaxResults: aws.Int64(maxResults), } if len(nextToken) != 0 { @@ -593,7 +601,7 @@ func (c *cloud) ListSnapshots(ctx context.Context, volumeID string, maxResults i return &ListSnapshotsResponse{ Snapshots: snapshots, - NextToken: ec2SnapshotsResponse.NextToken, + NextToken: aws.StringValue(ec2SnapshotsResponse.NextToken), }, nil } @@ -680,7 +688,6 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { var snapshots []*ec2.Snapshot var nextToken *string - for { response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) if err != nil { @@ -706,7 +713,7 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI // listSnapshots returns all snapshots based from a request func (c *cloud) listSnapshots(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2ListSnapshotsResponse, error) { var snapshots []*ec2.Snapshot - var nextToken string + var nextToken *string response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) if err != nil { @@ -716,7 +723,7 @@ func (c *cloud) listSnapshots(ctx context.Context, request *ec2.DescribeSnapshot snapshots = append(snapshots, response.Snapshots...) if response.NextToken != nil { - nextToken = *response.NextToken + nextToken = response.NextToken } return &ec2ListSnapshotsResponse{ diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 8e1dc51d81..f877443001 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -694,6 +694,58 @@ func TestListSnapshots(t *testing.T) { } }, }, + { + name: "success: with volume ID", + testFunc: func(t *testing.T) { + sourceVolumeID := "snap-test-volume" + expSnapshots := []*Snapshot{ + { + SourceVolumeID: sourceVolumeID, + SnapshotID: "snap-test-name1", + }, + { + SourceVolumeID: sourceVolumeID, + SnapshotID: "snap-test-name2", + }, + } + ec2Snapshots := []*ec2.Snapshot{ + { + SnapshotId: aws.String(expSnapshots[0].SnapshotID), + VolumeId: aws.String(sourceVolumeID), + State: aws.String("completed"), + }, + { + SnapshotId: aws.String(expSnapshots[1].SnapshotID), + VolumeId: aws.String(sourceVolumeID), + State: aws.String("completed"), + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: ec2Snapshots}, nil) + + resp, err := c.ListSnapshots(ctx, sourceVolumeID, 0, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + + if len(resp.Snapshots) != len(expSnapshots) { + t.Fatalf("Expected %d snapshots, got %d", len(expSnapshots), len(resp.Snapshots)) + } + + for _, snap := range resp.Snapshots { + if snap.SourceVolumeID != sourceVolumeID { + t.Fatalf("Unexpected source volume. Expected %s, got %s", sourceVolumeID, snap.SourceVolumeID) + } + } + }, + }, { name: "success: max results, next token", testFunc: func(t *testing.T) { diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 73bcd81e50..03c490ce19 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -44,6 +44,7 @@ var ( csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, } ) @@ -396,16 +397,15 @@ func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnap nextToken := req.GetStartingToken() maxEntries := int64(req.GetMaxEntries()) - if maxEntries > 0 && maxEntries < 5 { - return nil, status.Errorf(codes.InvalidArgument, "MaxEntries must be greater than or equal to 5") - } - cloudSnapshots, err := d.cloud.ListSnapshots(ctx, volumeID, maxEntries, nextToken) if err != nil { if err == cloud.ErrNotFound { klog.V(4).Info("ListSnapshots: snapshot not found, returning with success") return &csi.ListSnapshotsResponse{}, nil } + if err == cloud.ErrInvalidMaxResults { + return nil, status.Errorf(codes.InvalidArgument, "Error mapping MaxEntries to AWS MaxResults: %v", err) + } return nil, status.Errorf(codes.Internal, "Could not list snapshots: %v", err) } diff --git a/tests/sanity/fake_cloud_provider.go b/tests/sanity/fake_cloud_provider.go index 02fb3786d4..4270b97a8d 100644 --- a/tests/sanity/fake_cloud_provider.go +++ b/tests/sanity/fake_cloud_provider.go @@ -31,6 +31,7 @@ type fakeCloudProvider struct { snapshots map[string]*fakeSnapshot m *cloud.Metadata pub map[string]string + tokens map[string]int64 } type fakeDisk struct { @@ -53,6 +54,7 @@ func newFakeCloudProvider() *fakeCloudProvider { Region: "region", AvailabilityZone: "az", }, + tokens: make(map[string]int64), } } @@ -133,11 +135,19 @@ func (c *fakeCloudProvider) IsExistInstance(ctx context.Context, nodeID string) } func (c *fakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *cloud.SnapshotOptions) (snapshot *cloud.Snapshot, err error) { - r1 := rand.New(rand.NewSource(time.Now().UnixNano())) - snapshotID := fmt.Sprintf("snapshot-%d", r1.Uint64()) + var snapshotID string if len(snapshotOptions.Tags[cloud.SnapshotNameTagKey]) == 0 { // for simplicity: let's have the Name and ID identical + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + snapshotID = fmt.Sprintf("snapshot-%d", r1.Uint64()) snapshotOptions.Tags[cloud.SnapshotNameTagKey] = snapshotID + } else { + snapshotID = snapshotOptions.Tags[cloud.SnapshotNameTagKey] + } + for _, existingSnapshot := range c.snapshots { + if existingSnapshot.Snapshot.SnapshotID == snapshotID && existingSnapshot.Snapshot.SourceVolumeID == volumeID { + return nil, cloud.ErrAlreadyExists + } } s := &fakeSnapshot{ Snapshot: &cloud.Snapshot{ @@ -162,24 +172,37 @@ func (c *fakeCloudProvider) DeleteSnapshot(ctx context.Context, snapshotID strin func (c *fakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) (snapshot *cloud.Snapshot, err error) { var snapshots []*fakeSnapshot for _, s := range c.snapshots { - for key, value := range s.tags { - if key == cloud.SnapshotNameTagKey && value == name { - snapshots = append(snapshots, s) - } + if s.SnapshotID == name { + snapshots = append(snapshots, s) } } if len(snapshots) == 0 { - return nil, nil + return nil, cloud.ErrNotFound } return snapshots[0].Snapshot, nil } func (c *fakeCloudProvider) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *cloud.ListSnapshotsResponse, err error) { var snapshots []*cloud.Snapshot + var retToken string for _, fakeSnapshot := range c.snapshots { - snapshots = append(snapshots, fakeSnapshot.Snapshot) + if fakeSnapshot.Snapshot.SourceVolumeID == volumeID || len(volumeID) == 0 { + snapshots = append(snapshots, fakeSnapshot.Snapshot) + } + } + if maxResults > 0 { + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + retToken = fmt.Sprintf("token-%d", r1.Uint64()) + c.tokens[retToken] = maxResults + snapshots = snapshots[0:maxResults] + fmt.Printf("%v\n", snapshots) + } + if len(nextToken) != 0 { + snapshots = snapshots[c.tokens[nextToken]:] } return &cloud.ListSnapshotsResponse{ Snapshots: snapshots, + NextToken: retToken, }, nil + } From 6c0eb0dc0557737a6e6e4181d3faf21bd194b300 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Wed, 8 May 2019 11:19:20 -0400 Subject: [PATCH 3/3] Adding mock to TestListSnapshots --- pkg/driver/controller_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index d1fdd89322..96f2656c19 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1294,9 +1294,11 @@ func TestListSnapshots(t *testing.T) { MaxEntries: 4, } + ctx := context.Background() mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(4)), gomock.Eq("")).Return(nil, cloud.ErrInvalidMaxResults) awsDriver := controllerService{cloud: mockCloud} if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { @@ -1305,10 +1307,10 @@ func TestListSnapshots(t *testing.T) { t.Fatalf("Could not get error status code from error: %v", srvErr) } if srvErr.Code() != codes.InvalidArgument { - t.Fatalf("Expected error code %d, got %d message %s", codes.Internal, srvErr.Code(), srvErr.Message()) + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } } else { - t.Fatalf("Expected error code %d, got no error", codes.Internal) + t.Fatalf("Expected error code %d, got no error", codes.InvalidArgument) } }, },