From a6633872c7c207c4dd53f5edde1caad060512f32 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Thu, 4 Apr 2019 23:38:29 -0400 Subject: [PATCH 1/9] Updating gomock hack --- hack/update-gomock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hack/update-gomock b/hack/update-gomock index c16467105d..b4c7752392 100755 --- a/hack/update-gomock +++ b/hack/update-gomock @@ -20,4 +20,6 @@ IMPORT_PATH=github.com/kubernetes-sigs/aws-ebs-csi-driver mockgen -package=mocks -destination=./pkg/cloud/mocks/mock_ec2.go ${IMPORT_PATH}/pkg/cloud EC2 mockgen -package=mocks -destination=./pkg/cloud/mocks/mock_ec2metadata.go ${IMPORT_PATH}/pkg/cloud EC2Metadata +mockgen -package=mocks -destination=./pkg/driver/mocks/mock_cloud.go ${IMPORT_PATH}/pkg/cloud Cloud +mockgen -package=mocks -destination=./pkg/driver/mocks/mock_metadata_service.go ${IMPORT_PATH}/pkg/cloud MetadataService From d6a4017a38074dabfa7589ef52747f9e6271c570 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Thu, 4 Apr 2019 23:49:52 -0400 Subject: [PATCH 2/9] fin --- pkg/cloud/cloud_test.go | 8 +- pkg/cloud/metadata.go | 30 +- pkg/cloud/mocks/mock_ec2.go | 210 +- pkg/cloud/mocks/mock_ec2metadata.go | 44 +- pkg/driver/controller.go | 28 +- pkg/driver/controller_test.go | 1778 ++++++++++++----- pkg/driver/fakes.go | 2 +- pkg/driver/mocks/mock_cloud.go | 211 ++ pkg/driver/mocks/mock_metadata_service.go | 75 + pkg/driver/node_test.go | 98 +- pkg/util/util.go | 2 +- .../sanity/fake_cloud_provider.go | 65 +- 12 files changed, 1864 insertions(+), 687 deletions(-) create mode 100644 pkg/driver/mocks/mock_cloud.go create mode 100644 pkg/driver/mocks/mock_metadata_service.go rename pkg/cloud/fakes.go => tests/sanity/fake_cloud_provider.go (60%) diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 66bb979008..916209c3d7 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -649,10 +649,10 @@ func TestGetSnapshotByName(t *testing.T) { func newCloud(mockEC2 EC2) Cloud { return &cloud{ - metadata: &metadata{ - instanceID: "test-instance", - region: "test-region", - availabilityZone: defaultZone, + metadata: &Metadata{ + InstanceID: "test-instance", + Region: "test-region", + AvailabilityZone: defaultZone, }, dm: dm.NewDeviceManager(), ec2: mockEC2, diff --git a/pkg/cloud/metadata.go b/pkg/cloud/metadata.go index 15a294988c..e74d32cded 100644 --- a/pkg/cloud/metadata.go +++ b/pkg/cloud/metadata.go @@ -34,27 +34,27 @@ type MetadataService interface { GetAvailabilityZone() string } -type metadata struct { - instanceID string - region string - availabilityZone string +type Metadata struct { + InstanceID string + Region string + AvailabilityZone string } -var _ MetadataService = &metadata{} +var _ MetadataService = &Metadata{} // GetInstanceID returns the instance identification. -func (m *metadata) GetInstanceID() string { - return m.instanceID +func (m *Metadata) GetInstanceID() string { + return m.InstanceID } // GetRegion returns the region which the instance is in. -func (m *metadata) GetRegion() string { - return m.region +func (m *Metadata) GetRegion() string { + return m.Region } // GetAvailabilityZone returns the Availability Zone which the instance is in. -func (m *metadata) GetAvailabilityZone() string { - return m.availabilityZone +func (m *Metadata) GetAvailabilityZone() string { + return m.AvailabilityZone } // NewMetadataService returns a new MetadataServiceImplementation. @@ -80,9 +80,9 @@ func NewMetadataService(svc EC2Metadata) (MetadataService, error) { return nil, fmt.Errorf("could not get valid EC2 availavility zone") } - return &metadata{ - instanceID: doc.InstanceID, - region: doc.Region, - availabilityZone: doc.AvailabilityZone, + return &Metadata{ + InstanceID: doc.InstanceID, + Region: doc.Region, + AvailabilityZone: doc.AvailabilityZone, }, nil } diff --git a/pkg/cloud/mocks/mock_ec2.go b/pkg/cloud/mocks/mock_ec2.go index d893d813e5..1ee105dd1e 100644 --- a/pkg/cloud/mocks/mock_ec2.go +++ b/pkg/cloud/mocks/mock_ec2.go @@ -1,6 +1,7 @@ -// Automatically generated by MockGen. DO NOT EDIT! +// Code generated by MockGen. DO NOT EDIT. // Source: github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud (interfaces: EC2) +// Package mocks is a generated GoMock package. package mocks import ( @@ -8,169 +9,208 @@ import ( request "github.com/aws/aws-sdk-go/aws/request" ec2 "github.com/aws/aws-sdk-go/service/ec2" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) -// Mock of EC2 interface +// MockEC2 is a mock of EC2 interface type MockEC2 struct { ctrl *gomock.Controller - recorder *_MockEC2Recorder + recorder *MockEC2MockRecorder } -// Recorder for MockEC2 (not exported) -type _MockEC2Recorder struct { +// MockEC2MockRecorder is the mock recorder for MockEC2 +type MockEC2MockRecorder struct { mock *MockEC2 } +// NewMockEC2 creates a new mock instance func NewMockEC2(ctrl *gomock.Controller) *MockEC2 { mock := &MockEC2{ctrl: ctrl} - mock.recorder = &_MockEC2Recorder{mock} + mock.recorder = &MockEC2MockRecorder{mock} return mock } -func (_m *MockEC2) EXPECT() *_MockEC2Recorder { - return _m.recorder +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockEC2) EXPECT() *MockEC2MockRecorder { + return m.recorder } -func (_m *MockEC2) AttachVolumeWithContext(_param0 aws.Context, _param1 *ec2.AttachVolumeInput, _param2 ...request.Option) (*ec2.VolumeAttachment, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// AttachVolumeWithContext mocks base method +func (m *MockEC2) AttachVolumeWithContext(arg0 aws.Context, arg1 *ec2.AttachVolumeInput, arg2 ...request.Option) (*ec2.VolumeAttachment, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "AttachVolumeWithContext", _s...) + ret := m.ctrl.Call(m, "AttachVolumeWithContext", varargs...) ret0, _ := ret[0].(*ec2.VolumeAttachment) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) AttachVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "AttachVolumeWithContext", _s...) +// AttachVolumeWithContext indicates an expected call of AttachVolumeWithContext +func (mr *MockEC2MockRecorder) AttachVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).AttachVolumeWithContext), varargs...) } -func (_m *MockEC2) CreateVolumeWithContext(_param0 aws.Context, _param1 *ec2.CreateVolumeInput, _param2 ...request.Option) (*ec2.Volume, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// CreateSnapshotWithContext mocks base method +func (m *MockEC2) CreateSnapshotWithContext(arg0 aws.Context, arg1 *ec2.CreateSnapshotInput, arg2 ...request.Option) (*ec2.Snapshot, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "CreateVolumeWithContext", _s...) - ret0, _ := ret[0].(*ec2.Volume) + ret := m.ctrl.Call(m, "CreateSnapshotWithContext", varargs...) + ret0, _ := ret[0].(*ec2.Snapshot) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) CreateVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "CreateVolumeWithContext", _s...) +// CreateSnapshotWithContext indicates an expected call of CreateSnapshotWithContext +func (mr *MockEC2MockRecorder) CreateSnapshotWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSnapshotWithContext", reflect.TypeOf((*MockEC2)(nil).CreateSnapshotWithContext), varargs...) } -func (_m *MockEC2) DeleteVolumeWithContext(_param0 aws.Context, _param1 *ec2.DeleteVolumeInput, _param2 ...request.Option) (*ec2.DeleteVolumeOutput, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// CreateVolumeWithContext mocks base method +func (m *MockEC2) CreateVolumeWithContext(arg0 aws.Context, arg1 *ec2.CreateVolumeInput, arg2 ...request.Option) (*ec2.Volume, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DeleteVolumeWithContext", _s...) - ret0, _ := ret[0].(*ec2.DeleteVolumeOutput) + ret := m.ctrl.Call(m, "CreateVolumeWithContext", varargs...) + ret0, _ := ret[0].(*ec2.Volume) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DeleteVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "DeleteVolumeWithContext", _s...) +// CreateVolumeWithContext indicates an expected call of CreateVolumeWithContext +func (mr *MockEC2MockRecorder) CreateVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).CreateVolumeWithContext), varargs...) } -func (_m *MockEC2) DescribeInstancesWithContext(_param0 aws.Context, _param1 *ec2.DescribeInstancesInput, _param2 ...request.Option) (*ec2.DescribeInstancesOutput, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// DeleteSnapshotWithContext mocks base method +func (m *MockEC2) DeleteSnapshotWithContext(arg0 aws.Context, arg1 *ec2.DeleteSnapshotInput, arg2 ...request.Option) (*ec2.DeleteSnapshotOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DescribeInstancesWithContext", _s...) - ret0, _ := ret[0].(*ec2.DescribeInstancesOutput) + ret := m.ctrl.Call(m, "DeleteSnapshotWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DeleteSnapshotOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DescribeInstancesWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "DescribeInstancesWithContext", _s...) +// DeleteSnapshotWithContext indicates an expected call of DeleteSnapshotWithContext +func (mr *MockEC2MockRecorder) DeleteSnapshotWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSnapshotWithContext", reflect.TypeOf((*MockEC2)(nil).DeleteSnapshotWithContext), varargs...) } -func (_m *MockEC2) DescribeVolumesWithContext(_param0 aws.Context, _param1 *ec2.DescribeVolumesInput, _param2 ...request.Option) (*ec2.DescribeVolumesOutput, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// DeleteVolumeWithContext mocks base method +func (m *MockEC2) DeleteVolumeWithContext(arg0 aws.Context, arg1 *ec2.DeleteVolumeInput, arg2 ...request.Option) (*ec2.DeleteVolumeOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DescribeVolumesWithContext", _s...) - ret0, _ := ret[0].(*ec2.DescribeVolumesOutput) + ret := m.ctrl.Call(m, "DeleteVolumeWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DeleteVolumeOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DescribeVolumesWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "DescribeVolumesWithContext", _s...) +// DeleteVolumeWithContext indicates an expected call of DeleteVolumeWithContext +func (mr *MockEC2MockRecorder) DeleteVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).DeleteVolumeWithContext), varargs...) } -func (_m *MockEC2) DetachVolumeWithContext(_param0 aws.Context, _param1 *ec2.DetachVolumeInput, _param2 ...request.Option) (*ec2.VolumeAttachment, error) { - _s := []interface{}{_param0, _param1} - for _, _x := range _param2 { - _s = append(_s, _x) +// DescribeInstancesWithContext mocks base method +func (m *MockEC2) DescribeInstancesWithContext(arg0 aws.Context, arg1 *ec2.DescribeInstancesInput, arg2 ...request.Option) (*ec2.DescribeInstancesOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DetachVolumeWithContext", _s...) - ret0, _ := ret[0].(*ec2.VolumeAttachment) + ret := m.ctrl.Call(m, "DescribeInstancesWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DescribeInstancesOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DetachVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "DetachVolumeWithContext", _s...) +// DescribeInstancesWithContext indicates an expected call of DescribeInstancesWithContext +func (mr *MockEC2MockRecorder) DescribeInstancesWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeInstancesWithContext", reflect.TypeOf((*MockEC2)(nil).DescribeInstancesWithContext), varargs...) } -func (_m *MockEC2) CreateSnapshotWithContext(arg0 aws.Context, arg1 *ec2.CreateSnapshotInput, arg2 ...request.Option) (*ec2.Snapshot, error) { - _s := []interface{}{arg0, arg1} +// DescribeSnapshotsWithContext mocks base method +func (m *MockEC2) DescribeSnapshotsWithContext(arg0 aws.Context, arg1 *ec2.DescribeSnapshotsInput, arg2 ...request.Option) (*ec2.DescribeSnapshotsOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} for _, a := range arg2 { - _s = append(_s, a) + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "CreateSnapshotWithContext", _s...) - ret0, _ := ret[0].(*ec2.Snapshot) + ret := m.ctrl.Call(m, "DescribeSnapshotsWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DescribeSnapshotsOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) CreateSnapshotWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - _s := append([]interface{}{arg0, arg1}, arg2...) - return _mr.mock.ctrl.RecordCall(_mr.mock, "CreateSnapshotWithContext", _s...) +// DescribeSnapshotsWithContext indicates an expected call of DescribeSnapshotsWithContext +func (mr *MockEC2MockRecorder) DescribeSnapshotsWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeSnapshotsWithContext", reflect.TypeOf((*MockEC2)(nil).DescribeSnapshotsWithContext), varargs...) } -func (_m *MockEC2) DeleteSnapshotWithContext(arg0 aws.Context, arg1 *ec2.DeleteSnapshotInput, arg2 ...request.Option) (*ec2.DeleteSnapshotOutput, error) { - _s := []interface{}{arg0, arg1} +// DescribeVolumesWithContext mocks base method +func (m *MockEC2) DescribeVolumesWithContext(arg0 aws.Context, arg1 *ec2.DescribeVolumesInput, arg2 ...request.Option) (*ec2.DescribeVolumesOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} for _, a := range arg2 { - _s = append(_s, a) + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DeleteSnapshotWithContext", _s...) - ret0, _ := ret[0].(*ec2.DeleteSnapshotOutput) + ret := m.ctrl.Call(m, "DescribeVolumesWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DescribeVolumesOutput) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DeleteSnapshotWithContext(arg0, arg1 interface{}) *gomock.Call { - _s := []interface{}{arg0, arg1} - return _mr.mock.ctrl.RecordCall(_mr.mock, "DeleteSnapshotWithContext", _s...) +// DescribeVolumesWithContext indicates an expected call of DescribeVolumesWithContext +func (mr *MockEC2MockRecorder) DescribeVolumesWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeVolumesWithContext", reflect.TypeOf((*MockEC2)(nil).DescribeVolumesWithContext), varargs...) } -func (_m *MockEC2) DescribeSnapshotsWithContext(arg0 aws.Context, arg1 *ec2.DescribeSnapshotsInput, arg2 ...request.Option) (*ec2.DescribeSnapshotsOutput, error) { - _s := []interface{}{arg0, arg1} +// DetachVolumeWithContext mocks base method +func (m *MockEC2) DetachVolumeWithContext(arg0 aws.Context, arg1 *ec2.DetachVolumeInput, arg2 ...request.Option) (*ec2.VolumeAttachment, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} for _, a := range arg2 { - _s = append(_s, a) + varargs = append(varargs, a) } - ret := _m.ctrl.Call(_m, "DescribeSnapshotsWithContext", _s...) - ret0, _ := ret[0].(*ec2.DescribeSnapshotsOutput) + ret := m.ctrl.Call(m, "DetachVolumeWithContext", varargs...) + ret0, _ := ret[0].(*ec2.VolumeAttachment) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2Recorder) DescribeSnapshotsWithContext(arg0, arg1 interface{}) *gomock.Call { - _s := []interface{}{arg0, arg1} - return _mr.mock.ctrl.RecordCall(_mr.mock, "DescribeSnapshotsWithContext", _s...) +// DetachVolumeWithContext indicates an expected call of DetachVolumeWithContext +func (mr *MockEC2MockRecorder) DetachVolumeWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).DetachVolumeWithContext), varargs...) } diff --git a/pkg/cloud/mocks/mock_ec2metadata.go b/pkg/cloud/mocks/mock_ec2metadata.go index 15d05364ee..37e2ea0d0b 100644 --- a/pkg/cloud/mocks/mock_ec2metadata.go +++ b/pkg/cloud/mocks/mock_ec2metadata.go @@ -1,51 +1,63 @@ -// Automatically generated by MockGen. DO NOT EDIT! +// Code generated by MockGen. DO NOT EDIT. // Source: github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud (interfaces: EC2Metadata) +// Package mocks is a generated GoMock package. package mocks import ( ec2metadata "github.com/aws/aws-sdk-go/aws/ec2metadata" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) -// Mock of EC2Metadata interface +// MockEC2Metadata is a mock of EC2Metadata interface type MockEC2Metadata struct { ctrl *gomock.Controller - recorder *_MockEC2MetadataRecorder + recorder *MockEC2MetadataMockRecorder } -// Recorder for MockEC2Metadata (not exported) -type _MockEC2MetadataRecorder struct { +// MockEC2MetadataMockRecorder is the mock recorder for MockEC2Metadata +type MockEC2MetadataMockRecorder struct { mock *MockEC2Metadata } +// NewMockEC2Metadata creates a new mock instance func NewMockEC2Metadata(ctrl *gomock.Controller) *MockEC2Metadata { mock := &MockEC2Metadata{ctrl: ctrl} - mock.recorder = &_MockEC2MetadataRecorder{mock} + mock.recorder = &MockEC2MetadataMockRecorder{mock} return mock } -func (_m *MockEC2Metadata) EXPECT() *_MockEC2MetadataRecorder { - return _m.recorder +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockEC2Metadata) EXPECT() *MockEC2MetadataMockRecorder { + return m.recorder } -func (_m *MockEC2Metadata) Available() bool { - ret := _m.ctrl.Call(_m, "Available") +// Available mocks base method +func (m *MockEC2Metadata) Available() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Available") ret0, _ := ret[0].(bool) return ret0 } -func (_mr *_MockEC2MetadataRecorder) Available() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "Available") +// Available indicates an expected call of Available +func (mr *MockEC2MetadataMockRecorder) Available() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockEC2Metadata)(nil).Available)) } -func (_m *MockEC2Metadata) GetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) { - ret := _m.ctrl.Call(_m, "GetInstanceIdentityDocument") +// GetInstanceIdentityDocument mocks base method +func (m *MockEC2Metadata) GetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInstanceIdentityDocument") ret0, _ := ret[0].(ec2metadata.EC2InstanceIdentityDocument) ret1, _ := ret[1].(error) return ret0, ret1 } -func (_mr *_MockEC2MetadataRecorder) GetInstanceIdentityDocument() *gomock.Call { - return _mr.mock.ctrl.RecordCall(_mr.mock, "GetInstanceIdentityDocument") +// GetInstanceIdentityDocument indicates an expected call of GetInstanceIdentityDocument +func (mr *MockEC2MetadataMockRecorder) GetInstanceIdentityDocument() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceIdentityDocument", reflect.TypeOf((*MockEC2Metadata)(nil).GetInstanceIdentityDocument)) } diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 13aca15416..7511d2b520 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -72,16 +72,9 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol return nil, status.Error(codes.InvalidArgument, "Volume name not provided") } - var volSizeBytes int64 - capRange := req.GetCapacityRange() - if capRange == nil { - volSizeBytes = cloud.DefaultVolumeSize - } else { - volSizeBytes = util.RoundUpBytes(capRange.GetRequiredBytes()) - maxVolSize := capRange.GetLimitBytes() - if maxVolSize > 0 && maxVolSize < volSizeBytes { - return nil, status.Error(codes.InvalidArgument, "After round-up, volume size exceeds the limit specified") - } + volSizeBytes, err := getVolSizeBytes(req) + if err != nil { + return nil, err } volCaps := req.GetVolumeCapabilities() @@ -436,3 +429,18 @@ func newCreateSnapshotResponse(snapshot *cloud.Snapshot) (*csi.CreateSnapshotRes }, }, nil } + +func getVolSizeBytes(req *csi.CreateVolumeRequest) (int64, error) { + var volSizeBytes int64 + capRange := req.GetCapacityRange() + if capRange == nil { + volSizeBytes = cloud.DefaultVolumeSize + } else { + volSizeBytes = util.RoundUpBytes(capRange.GetRequiredBytes()) + maxVolSize := capRange.GetLimitBytes() + if maxVolSize > 0 && maxVolSize < volSizeBytes { + return 0, status.Error(codes.InvalidArgument, "After round-up, volume size exceeds the limit specified") + } + } + return volSizeBytes, nil +} diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 4594bfab6a..5a8d71d038 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -18,18 +18,25 @@ package driver import ( "context" + "fmt" + "math/rand" "reflect" "testing" + "time" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/mocks" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) const ( - expZone = "us-west-2b" - expFsType = "ext2" + expZone = "us-west-2b" + expFsType = "ext2" + expInstanceId = "i-123456789abcdef01" ) func TestCreateVolume(t *testing.T) { @@ -48,327 +55,768 @@ func TestCreateVolume(t *testing.T) { stdParams := map[string]string{} testCases := []struct { - name string - req *csi.CreateVolumeRequest - extraReq *csi.CreateVolumeRequest - expVol *csi.Volume - expErrCode codes.Code + name string + testFunc func(t *testing.T) }{ { name: "success normal", - req: &csi.CreateVolumeRequest{ - Name: "random-vol-name", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: nil, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: nil, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } }, }, { name: "fail no name", - req: &csi.CreateVolumeRequest{ - Name: "", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: stdParams, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + expErr := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErr { + t.Fatalf("Expected error code %d, got %d message %s", expErr, srvErr.Code(), srvErr.Message()) + } + } else { + t.Fatalf("Expected error got nil") + } }, - expErrCode: codes.InvalidArgument, }, { name: "success same name and same capacity", - req: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: stdParams, - }, - extraReq: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: stdParams, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + extraReq := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + // Subsequent call returns the created disk + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil) + resp, err := awsDriver.CreateVolume(ctx, extraReq) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + vol := resp.GetVolume() + if vol == nil && expVol != nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } + + if vol.GetCapacityBytes() != expVol.GetCapacityBytes() { + t.Fatalf("Expected volume capacity bytes: %v, got: %v", expVol.GetCapacityBytes(), vol.GetCapacityBytes()) + } + + for expKey, expVal := range expVol.GetVolumeContext() { + ctx := vol.GetVolumeContext() + if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { + t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + } + } + + if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { + t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) + } }, }, { name: "fail same name and different capacity", - req: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: stdParams, - }, - extraReq: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: &csi.CapacityRange{RequiredBytes: 10000}, - VolumeCapabilities: stdVolCap, - Parameters: stdParams, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + extraReq := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: &csi.CapacityRange{RequiredBytes: 10000}, + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expFsType, + } + volSizeBytes, err := getVolSizeBytes(req) + if err != nil { + t.Fatalf("Unable to get volume size bytes for req: %s", err) + } + mockDisk.CapacityGiB = util.BytesToGiB(volSizeBytes) + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + _, err = awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + extraVolSizeBytes, err := getVolSizeBytes(extraReq) + if err != nil { + t.Fatalf("Unable to get volume size bytes for req: %s", err) + } + + // Subsequent failure + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(extraReq.Name), gomock.Eq(extraVolSizeBytes)).Return(nil, cloud.ErrDiskExistsDiffSize) + if _, err := awsDriver.CreateVolume(ctx, extraReq); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.AlreadyExists { + t.Fatalf("Expected error code %d, got %d", codes.AlreadyExists, srvErr.Code()) + } + } else { + t.Fatalf("Expected error code %d, got nil", codes.AlreadyExists) + } }, - expErrCode: codes.AlreadyExists, }, { name: "success no capacity range", - req: &csi.CreateVolumeRequest{ - Name: "test-vol", - VolumeCapabilities: stdVolCap, - Parameters: stdParams, - }, - expVol: &csi.Volume{ - CapacityBytes: cloud.DefaultVolumeSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "test-vol", + VolumeCapabilities: stdVolCap, + Parameters: stdParams, + } + expVol := &csi.Volume{ + CapacityBytes: cloud.DefaultVolumeSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + } + volSizeBytes, err := getVolSizeBytes(req) + if err != nil { + t.Fatalf("Unable to get volume size bytes for req: %s", err) + } + mockDisk.CapacityGiB = util.BytesToGiB(volSizeBytes) + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + resp, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + vol := resp.GetVolume() + if vol == nil && expVol != nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } + + if vol.GetCapacityBytes() != expVol.GetCapacityBytes() { + t.Fatalf("Expected volume capacity bytes: %v, got: %v", expVol.GetCapacityBytes(), vol.GetCapacityBytes()) + } + + for expKey, expVal := range expVol.GetVolumeContext() { + ctx := vol.GetVolumeContext() + if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { + t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + } + } + + if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { + t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) + } }, }, { name: "success with correct round up", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: &csi.CapacityRange{RequiredBytes: 1073741825}, - VolumeCapabilities: stdVolCap, - Parameters: nil, - }, - expVol: &csi.Volume{ - CapacityBytes: 2147483648, // 1 GiB + 1 byte = 2 GiB - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: &csi.CapacityRange{RequiredBytes: 1073741825}, + VolumeCapabilities: stdVolCap, + Parameters: nil, + } + expVol := &csi.Volume{ + CapacityBytes: 2147483648, // 1 GiB + 1 byte = 2 GiB + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + } + volSizeBytes, err := getVolSizeBytes(req) + if err != nil { + t.Fatalf("Unable to get volume size bytes for req: %s", err) + } + mockDisk.CapacityGiB = util.BytesToGiB(volSizeBytes) + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + resp, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + vol := resp.GetVolume() + if vol == nil && expVol != nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } + + if vol.GetCapacityBytes() != expVol.GetCapacityBytes() { + t.Fatalf("Expected volume capacity bytes: %v, got: %v", expVol.GetCapacityBytes(), vol.GetCapacityBytes()) + } }, }, { name: "success with fstype parameter", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{FsTypeKey: defaultFsType}, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: defaultFsType}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{FsTypeKey: defaultFsType}, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: defaultFsType}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + resp, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + vol := resp.GetVolume() + if vol == nil && expVol != nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } + + if vol.GetCapacityBytes() != expVol.GetCapacityBytes() { + t.Fatalf("Expected volume capacity bytes: %v, got: %v", expVol.GetCapacityBytes(), vol.GetCapacityBytes()) + } + + for expKey, expVal := range expVol.GetVolumeContext() { + ctx := vol.GetVolumeContext() + if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { + t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + } + } + + if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { + t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) + } }, }, { name: "success with volume type io1", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - VolumeTypeKey: cloud.VolumeTypeIO1, - IopsPerGBKey: "5", - }, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + VolumeTypeKey: cloud.VolumeTypeIO1, + IopsPerGBKey: "5", + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } }, }, { name: "success with volume type sc1", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - VolumeTypeKey: cloud.VolumeTypeSC1, - }, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + VolumeTypeKey: cloud.VolumeTypeSC1, + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } }, }, { name: "success with volume encryption", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - EncryptedKey: "true", - }, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + EncryptedKey: "true", + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } }, }, { name: "success with volume encryption with KMS key", - req: &csi.CreateVolumeRequest{ - Name: "vol-test", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - EncryptedKey: "true", - KmsKeyIdKey: "arn:aws:kms:us-east-1:012345678910:key/abcd1234-a123-456a-a12b-a123b4cd56ef", - }, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: ""}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-test", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + EncryptedKey: "true", + KmsKeyIdKey: "arn:aws:kms:us-east-1:012345678910:key/abcd1234-a123-456a-a12b-a123b4cd56ef", + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: ""}, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } }, }, { name: "success when volume exists and contains VolumeContext and AccessibleTopology", - req: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - FsTypeKey: expFsType, - }, - AccessibilityRequirements: &csi.TopologyRequirement{ - Requisite: []*csi.Topology{ - { - Segments: map[string]string{TopologyKey: expZone}, + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + FsTypeKey: expFsType, + }, + AccessibilityRequirements: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{TopologyKey: expZone}, + }, }, }, - }, - }, - extraReq: &csi.CreateVolumeRequest{ - Name: "test-vol", - CapacityRange: stdCapRange, - VolumeCapabilities: stdVolCap, - Parameters: map[string]string{ - FsTypeKey: expFsType, - }, - AccessibilityRequirements: &csi.TopologyRequirement{ - Requisite: []*csi.Topology{ + } + extraReq := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{ + FsTypeKey: expFsType, + }, + AccessibilityRequirements: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{TopologyKey: expZone}, + }, + }, + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{FsTypeKey: expFsType}, + AccessibleTopology: []*csi.Topology{ { Segments: map[string]string{TopologyKey: expZone}, }, }, - }, - }, - expVol: &csi.Volume{ - CapacityBytes: stdVolSize, - VolumeId: "vol-test", - VolumeContext: map[string]string{FsTypeKey: expFsType}, - AccessibleTopology: []*csi.Topology{ - { - Segments: map[string]string{TopologyKey: expZone}, - }, - }, - }, - }, - } + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()} + ctx := context.Background() - resp, err := awsDriver.CreateVolume(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(stdVolSize), } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{cloud: mockCloud} + + if _, err := awsDriver.CreateVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) } - return - } - // Repeat the same request and check they results of the second call - if tc.extraReq != nil { - resp, err = awsDriver.CreateVolume(context.TODO(), tc.extraReq) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil) + resp, err := awsDriver.CreateVolume(ctx, extraReq) if err != nil { srvErr, ok := status.FromError(err) if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d", tc.expErrCode, srvErr.Code()) - } - return + t.Fatalf("Unexpected error: %v", srvErr.Code()) } - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } - - vol := resp.GetVolume() - if vol == nil && tc.expVol != nil { - t.Fatalf("Expected volume %v, got nil", tc.expVol) - } + vol := resp.GetVolume() + if vol == nil && expVol != nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } - if vol.GetCapacityBytes() != tc.expVol.GetCapacityBytes() { - t.Fatalf("Expected volume capacity bytes: %v, got: %v", tc.expVol.GetCapacityBytes(), vol.GetCapacityBytes()) - } + for expKey, expVal := range expVol.GetVolumeContext() { + ctx := vol.GetVolumeContext() + if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { + t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + } + } - for expKey, expVal := range tc.expVol.GetVolumeContext() { - ctx := vol.GetVolumeContext() - if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { - t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { + t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) } - } - if tc.expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { - t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) - } - if tc.expVol.GetAccessibleTopology() != nil { - if !reflect.DeepEqual(tc.expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) { - t.Fatalf("Expected AccessibleTopology to be %+v, got: %+v", tc.expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) + + if expVol.GetAccessibleTopology() != nil { + if !reflect.DeepEqual(expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) { + t.Fatalf("Expected AccessibleTopology to be %+v, got: %+v", expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) + } } - } - }) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) } } func TestDeleteVolume(t *testing.T) { testCases := []struct { - name string - req *csi.DeleteVolumeRequest - expResp *csi.DeleteVolumeResponse - expErrCode codes.Code + name string + testFunc func(t *testing.T) }{ { name: "success normal", - req: &csi.DeleteVolumeRequest{ - VolumeId: "vol-test", + testFunc: func(t *testing.T) { + req := &csi.DeleteVolumeRequest{ + VolumeId: "vol-test", + } + expResp := &csi.DeleteVolumeResponse{} + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(true, nil) + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.DeleteVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + if !reflect.DeepEqual(resp, expResp) { + t.Fatalf("Expected resp to be %+v, got: %+v", expResp, resp) + } }, - expResp: &csi.DeleteVolumeResponse{}, }, { name: "success invalid volume id", - req: &csi.DeleteVolumeRequest{ - VolumeId: "invalid-volume-name", + testFunc: func(t *testing.T) { + req := &csi.DeleteVolumeRequest{ + VolumeId: "invalid-volume-name", + } + expResp := &csi.DeleteVolumeResponse{} + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(true, nil) + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.DeleteVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + if !reflect.DeepEqual(resp, expResp) { + t.Fatalf("Expected resp to be %+v, got: %+v", expResp, resp) + } }, - expResp: &csi.DeleteVolumeResponse{}, }, } for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()} - _, err := awsDriver.DeleteVolume(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d", tc.expErrCode, srvErr.Code()) - } - return - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } - }) + t.Run(tc.name, tc.testFunc) } } @@ -432,166 +880,300 @@ func TestPickAvailabilityZone(t *testing.T) { func TestCreateSnapshot(t *testing.T) { testCases := []struct { - name string - req *csi.CreateSnapshotRequest - extraReq *csi.CreateSnapshotRequest - expSnapshot *csi.Snapshot - expErrCode codes.Code - extraExpErrCode codes.Code + name string + testFunc func(t *testing.T) }{ { name: "success normal", - req: &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-test", - }, - expSnapshot: &csi.Snapshot{ - ReadyToUse: true, + testFunc: func(t *testing.T) { + req := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + expErrCode := codes.OK + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.CreateSnapshot(context.Background(), req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + snap := resp.GetSnapshot() + if snap == nil && expSnapshot != nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } }, - expErrCode: codes.OK, }, { name: "fail no name", - req: &csi.CreateSnapshotRequest{ - Parameters: nil, - SourceVolumeId: "vol-test", + testFunc: func(t *testing.T) { + req := &csi.CreateSnapshotRequest{ + Parameters: nil, + SourceVolumeId: "vol-test", + } + expErrCode := codes.InvalidArgument + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.CreateSnapshot(context.Background(), req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } }, - expSnapshot: nil, - expErrCode: codes.InvalidArgument, }, { name: "fail same name different volume ID", - req: &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-test", - }, - extraReq: &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-xxx", - }, - expSnapshot: &csi.Snapshot{ - ReadyToUse: true, + testFunc: func(t *testing.T) { + req := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + extraReq := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-xxx", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + expErrCode := codes.OK + extraExpErrCode := codes.AlreadyExists + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.CreateSnapshot(context.Background(), req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + snap := resp.GetSnapshot() + if snap == nil && expSnapshot != nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(extraReq.GetName())).Return(mockSnapshot, nil) + _, err = awsDriver.CreateSnapshot(ctx, extraReq) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != extraExpErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if extraExpErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", extraExpErrCode) + } }, - expErrCode: codes.OK, - extraExpErrCode: codes.AlreadyExists, }, { name: "success same name same volume ID", - req: &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-test", - }, - extraReq: &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-test", - }, - expSnapshot: &csi.Snapshot{ - ReadyToUse: true, + testFunc: func(t *testing.T) { + req := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + extraReq := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + expErrCode := codes.OK + extraExpErrCode := codes.OK + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Any()).Return(mockSnapshot, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.CreateSnapshot(context.Background(), req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + snap := resp.GetSnapshot() + if snap == nil && expSnapshot != nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(extraReq.GetName())).Return(mockSnapshot, nil) + _, err = awsDriver.CreateSnapshot(ctx, extraReq) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != extraExpErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if extraExpErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", extraExpErrCode) + } }, - expErrCode: codes.OK, - extraExpErrCode: codes.OK, }, } + for _, tc := range testCases { - t.Logf("Test case: %s", tc.name) - awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()} - resp, err := awsDriver.CreateSnapshot(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) - } - continue - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } - snap := resp.GetSnapshot() - if snap == nil && tc.expSnapshot != nil { - t.Fatalf("Expected snapshot %v, got nil", tc.expSnapshot) - } - if tc.extraReq != nil { - // extraReq is never used in a situation when a new snapshot - // should be really created: checking the return code is enough - _, err = awsDriver.CreateSnapshot(context.TODO(), tc.extraReq) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.extraExpErrCode { - t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) - } - continue - } - if tc.extraExpErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.extraExpErrCode) - } - } + t.Run(tc.name, tc.testFunc) } } func TestDeleteSnapshot(t *testing.T) { - snapReq := &csi.CreateSnapshotRequest{ - Name: "test-snapshot", - Parameters: nil, - SourceVolumeId: "vol-test", - } testCases := []struct { - name string - req *csi.DeleteSnapshotRequest - expErrCode codes.Code + name string + testFunc func(t *testing.T) }{ { - name: "success normal", - req: &csi.DeleteSnapshotRequest{}, - expErrCode: codes.OK, + name: "success normal", + testFunc: func(t *testing.T) { + expErrCode := codes.OK + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + + req := &csi.DeleteSnapshotRequest{ + SnapshotId: "xxx", + } + + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq("xxx")).Return(true, nil) + if _, err := awsDriver.DeleteSnapshot(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + }, }, { name: "success not found", - req: &csi.DeleteSnapshotRequest{ - SnapshotId: "xxx", + testFunc: func(t *testing.T) { + expErrCode := codes.OK + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + + req := &csi.DeleteSnapshotRequest{ + SnapshotId: "xxx", + } + + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq("xxx")).Return(false, cloud.ErrNotFound) + if _, err := awsDriver.DeleteSnapshot(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - expErrCode: codes.OK, }, } + for _, tc := range testCases { - t.Logf("Test case: %s", tc.name) - awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()} - snapResp, err := awsDriver.CreateSnapshot(context.TODO(), snapReq) - if err != nil { - t.Fatalf("Error creating testing snapshot: %v", err) - } - if len(tc.req.SnapshotId) == 0 { - tc.req.SnapshotId = snapResp.Snapshot.SnapshotId - } - _, err = awsDriver.DeleteSnapshot(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) - } - continue - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } + t.Run(tc.name, tc.testFunc) } } func TestControllerPublishVolume(t *testing.T) { - fakeCloud := cloud.NewFakeCloudProvider() stdVolCap := &csi.VolumeCapability{ AccessType: &csi.VolumeCapability_Mount{ Mount: &csi.VolumeCapability_MountVolume{}, @@ -600,193 +1182,391 @@ func TestControllerPublishVolume(t *testing.T) { Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, }, } + expDevicePath := "/dev/xvda" + testCases := []struct { - name string - req *csi.ControllerPublishVolumeRequest - expResp *csi.ControllerPublishVolumeResponse - expErrCode codes.Code - setup func(req *csi.ControllerPublishVolumeRequest) + name string + testFunc func(t *testing.T) }{ { - name: "success normal", - expResp: &csi.ControllerPublishVolumeResponse{}, - req: &csi.ControllerPublishVolumeRequest{ - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - VolumeCapability: stdVolCap, - }, - // create a fake disk and setup the request - // parameters appropriately - setup: func(req *csi.ControllerPublishVolumeRequest) { - fakeDiskOpts := &cloud.DiskOptions{ - CapacityBytes: 1, - AvailabilityZone: "az", + name: "success normal", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + NodeId: expInstanceId, + VolumeCapability: stdVolCap, + VolumeId: "vol-test", + } + expResp := &csi.ControllerPublishVolumeResponse{ + PublishContext: map[string]string{DevicePathKey: expDevicePath}, + } + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(true) + mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) + mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return(expDevicePath, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ControllerPublishVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + if !reflect.DeepEqual(resp, expResp) { + t.Fatalf("Expected resp to be %+v, got: %+v", expResp, resp) } - fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) - req.VolumeId = fakeDisk.VolumeID }, }, { - name: "fail no VolumeId", - req: &csi.ControllerPublishVolumeRequest{}, - expErrCode: codes.InvalidArgument, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, + name: "fail no VolumeId", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{} + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + }, }, { - name: "fail no NodeId", - expErrCode: codes.InvalidArgument, - req: &csi.ControllerPublishVolumeRequest{ - VolumeId: "vol-test", + name: "fail no NodeId", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + VolumeId: "vol-test", + } + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, }, { - name: "fail no VolumeCapability", - expErrCode: codes.InvalidArgument, - req: &csi.ControllerPublishVolumeRequest{ - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - VolumeId: "vol-test", + name: "fail no VolumeCapability", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + NodeId: expInstanceId, + VolumeId: "vol-test", + } + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, }, { - name: "fail invalid VolumeCapability", - expErrCode: codes.InvalidArgument, - req: &csi.ControllerPublishVolumeRequest{ - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - VolumeCapability: &csi.VolumeCapability{ - AccessMode: &csi.VolumeCapability_AccessMode{ - Mode: csi.VolumeCapability_AccessMode_UNKNOWN, + name: "fail invalid VolumeCapability", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + NodeId: expInstanceId, + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_UNKNOWN, + }, }, - }, - VolumeId: "vol-test", + VolumeId: "vol-test", + } + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, }, { - name: "fail instance not found", - expErrCode: codes.NotFound, - req: &csi.ControllerPublishVolumeRequest{ - NodeId: "does-not-exist", - VolumeId: "vol-test", - VolumeCapability: stdVolCap, + name: "fail instance not found", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + NodeId: "does-not-exist", + VolumeId: "vol-test", + VolumeCapability: stdVolCap, + } + expErrCode := codes.NotFound + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(false) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, }, { - name: "fail volume not found", - expErrCode: codes.NotFound, - req: &csi.ControllerPublishVolumeRequest{ - VolumeId: "does-not-exist", - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - VolumeCapability: stdVolCap, + name: "fail volume not found", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + VolumeId: "does-not-exist", + NodeId: expInstanceId, + VolumeCapability: stdVolCap, + } + expErrCode := codes.NotFound + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(true) + mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerPublishVolumeRequest) {}, }, { - name: "fail attach disk with already exists error", - expErrCode: codes.AlreadyExists, - req: &csi.ControllerPublishVolumeRequest{ - VolumeId: "does-not-exist", - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - VolumeCapability: stdVolCap, - }, - // create a fake disk, attach it and setup the - // request appropriately - setup: func(req *csi.ControllerPublishVolumeRequest) { - fakeDiskOpts := &cloud.DiskOptions{ - CapacityBytes: 1, - AvailabilityZone: "az", + name: "fail attach disk with already exists error", + testFunc: func(t *testing.T) { + req := &csi.ControllerPublishVolumeRequest{ + VolumeId: "does-not-exist", + NodeId: expInstanceId, + VolumeCapability: stdVolCap, + } + expErrCode := codes.AlreadyExists + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().IsExistInstance(gomock.Eq(ctx), gomock.Eq(req.NodeId)).Return(true) + mockCloud.EXPECT().GetDiskByID(gomock.Eq(ctx), gomock.Any()).Return(&cloud.Disk{}, nil) + mockCloud.EXPECT().AttachDisk(gomock.Eq(ctx), gomock.Any(), gomock.Eq(req.NodeId)).Return("", cloud.ErrAlreadyExists) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerPublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) } - fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) - req.VolumeId = fakeDisk.VolumeID - _, _ = fakeCloud.AttachDisk(context.TODO(), fakeDisk.VolumeID, fakeCloud.GetMetadata().GetInstanceID()) }, }, } for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.setup(tc.req) - awsDriver := controllerService{cloud: fakeCloud} - _, err := awsDriver.ControllerPublishVolume(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code '%d', got '%d': %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } - }) + t.Run(tc.name, tc.testFunc) } } func TestControllerUnpublishVolume(t *testing.T) { - fakeCloud := cloud.NewFakeCloudProvider() testCases := []struct { - name string - req *csi.ControllerUnpublishVolumeRequest - expResp *csi.ControllerUnpublishVolumeResponse - expErrCode codes.Code - setup func(req *csi.ControllerUnpublishVolumeRequest) + name string + setup func(req *csi.ControllerUnpublishVolumeRequest) + testFunc func(t *testing.T) }{ { - name: "success normal", - expResp: &csi.ControllerUnpublishVolumeResponse{}, - req: &csi.ControllerUnpublishVolumeRequest{ - NodeId: fakeCloud.GetMetadata().GetInstanceID(), - }, - // create a fake disk, attach it and setup the request - // parameters appropriately - setup: func(req *csi.ControllerUnpublishVolumeRequest) { - fakeDiskOpts := &cloud.DiskOptions{ - CapacityBytes: 1, - AvailabilityZone: "az", + name: "success normal", + testFunc: func(t *testing.T) { + req := &csi.ControllerUnpublishVolumeRequest{ + NodeId: expInstanceId, + VolumeId: "vol-test", + } + expResp := &csi.ControllerUnpublishVolumeResponse{} + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().DetachDisk(gomock.Eq(ctx), req.VolumeId, req.NodeId).Return(nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ControllerUnpublishVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + if !reflect.DeepEqual(resp, expResp) { + t.Fatalf("Expected resp to be %+v, got: %+v", expResp, resp) } - fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) - req.VolumeId = fakeDisk.VolumeID - _, _ = fakeCloud.AttachDisk(context.TODO(), fakeDisk.VolumeID, fakeCloud.GetMetadata().GetInstanceID()) }, }, { - name: "fail no VolumeId", - req: &csi.ControllerUnpublishVolumeRequest{}, - expErrCode: codes.InvalidArgument, - setup: func(req *csi.ControllerUnpublishVolumeRequest) {}, + name: "fail no VolumeId", + testFunc: func(t *testing.T) { + req := &csi.ControllerUnpublishVolumeRequest{} + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } + }, }, { - name: "fail no NodeId", - expErrCode: codes.InvalidArgument, - req: &csi.ControllerUnpublishVolumeRequest{ - VolumeId: "vol-test", + name: "fail no NodeId", + testFunc: func(t *testing.T) { + req := &csi.ControllerUnpublishVolumeRequest{ + VolumeId: "vol-test", + } + expErrCode := codes.InvalidArgument + + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ControllerUnpublishVolume(ctx, req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != expErrCode { + t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", expErrCode) + } }, - setup: func(req *csi.ControllerUnpublishVolumeRequest) {}, }, } for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tc.setup(tc.req) - awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()} - _, err := awsDriver.ControllerUnpublishVolume(context.TODO(), tc.req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code '%d', got '%d': %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if tc.expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", tc.expErrCode) - } - }) + t.Run(tc.name, tc.testFunc) } } diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 1d1c23802f..8372303806 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -37,7 +37,7 @@ func NewFakeSafeFormatAndMounter(fakeMounter mount.Interface) *mount.SafeFormatA } // NewFakeDriver creates a new mock driver used for testing -func NewFakeDriver(endpoint string, fakeCloud *cloud.FakeCloudProvider, fakeMounter *mount.FakeMounter) *Driver { +func NewFakeDriver(endpoint string, fakeCloud cloud.Cloud, fakeMounter *mount.FakeMounter) *Driver { return &Driver{ endpoint: endpoint, controllerService: controllerService{ diff --git a/pkg/driver/mocks/mock_cloud.go b/pkg/driver/mocks/mock_cloud.go new file mode 100644 index 0000000000..53b4aa199f --- /dev/null +++ b/pkg/driver/mocks/mock_cloud.go @@ -0,0 +1,211 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud (interfaces: Cloud) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + cloud "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" + reflect "reflect" +) + +// MockCloud is a mock of Cloud interface +type MockCloud struct { + ctrl *gomock.Controller + recorder *MockCloudMockRecorder +} + +// MockCloudMockRecorder is the mock recorder for MockCloud +type MockCloudMockRecorder struct { + mock *MockCloud +} + +// NewMockCloud creates a new mock instance +func NewMockCloud(ctrl *gomock.Controller) *MockCloud { + mock := &MockCloud{ctrl: ctrl} + mock.recorder = &MockCloudMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCloud) EXPECT() *MockCloudMockRecorder { + return m.recorder +} + +// AttachDisk mocks base method +func (m *MockCloud) AttachDisk(arg0 context.Context, arg1, arg2 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AttachDisk", arg0, arg1, arg2) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AttachDisk indicates an expected call of AttachDisk +func (mr *MockCloudMockRecorder) AttachDisk(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachDisk", reflect.TypeOf((*MockCloud)(nil).AttachDisk), arg0, arg1, arg2) +} + +// CreateDisk mocks base method +func (m *MockCloud) CreateDisk(arg0 context.Context, arg1 string, arg2 *cloud.DiskOptions) (*cloud.Disk, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateDisk", arg0, arg1, arg2) + ret0, _ := ret[0].(*cloud.Disk) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateDisk indicates an expected call of CreateDisk +func (mr *MockCloudMockRecorder) CreateDisk(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDisk", reflect.TypeOf((*MockCloud)(nil).CreateDisk), arg0, arg1, arg2) +} + +// CreateSnapshot mocks base method +func (m *MockCloud) CreateSnapshot(arg0 context.Context, arg1 string, arg2 *cloud.SnapshotOptions) (*cloud.Snapshot, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSnapshot", arg0, arg1, arg2) + ret0, _ := ret[0].(*cloud.Snapshot) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSnapshot indicates an expected call of CreateSnapshot +func (mr *MockCloudMockRecorder) CreateSnapshot(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSnapshot", reflect.TypeOf((*MockCloud)(nil).CreateSnapshot), arg0, arg1, arg2) +} + +// DeleteDisk mocks base method +func (m *MockCloud) DeleteDisk(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteDisk", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteDisk indicates an expected call of DeleteDisk +func (mr *MockCloudMockRecorder) DeleteDisk(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDisk", reflect.TypeOf((*MockCloud)(nil).DeleteDisk), arg0, arg1) +} + +// DeleteSnapshot mocks base method +func (m *MockCloud) DeleteSnapshot(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSnapshot", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteSnapshot indicates an expected call of DeleteSnapshot +func (mr *MockCloudMockRecorder) DeleteSnapshot(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSnapshot", reflect.TypeOf((*MockCloud)(nil).DeleteSnapshot), arg0, arg1) +} + +// DetachDisk mocks base method +func (m *MockCloud) DetachDisk(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DetachDisk", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DetachDisk indicates an expected call of DetachDisk +func (mr *MockCloudMockRecorder) DetachDisk(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachDisk", reflect.TypeOf((*MockCloud)(nil).DetachDisk), arg0, arg1, arg2) +} + +// GetDiskByID mocks base method +func (m *MockCloud) GetDiskByID(arg0 context.Context, arg1 string) (*cloud.Disk, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDiskByID", arg0, arg1) + ret0, _ := ret[0].(*cloud.Disk) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDiskByID indicates an expected call of GetDiskByID +func (mr *MockCloudMockRecorder) GetDiskByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDiskByID", reflect.TypeOf((*MockCloud)(nil).GetDiskByID), arg0, arg1) +} + +// GetDiskByName mocks base method +func (m *MockCloud) GetDiskByName(arg0 context.Context, arg1 string, arg2 int64) (*cloud.Disk, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDiskByName", arg0, arg1, arg2) + ret0, _ := ret[0].(*cloud.Disk) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDiskByName indicates an expected call of GetDiskByName +func (mr *MockCloudMockRecorder) GetDiskByName(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDiskByName", reflect.TypeOf((*MockCloud)(nil).GetDiskByName), arg0, arg1, arg2) +} + +// GetMetadata mocks base method +func (m *MockCloud) GetMetadata() cloud.MetadataService { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMetadata") + ret0, _ := ret[0].(cloud.MetadataService) + return ret0 +} + +// GetMetadata indicates an expected call of GetMetadata +func (mr *MockCloudMockRecorder) GetMetadata() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetadata", reflect.TypeOf((*MockCloud)(nil).GetMetadata)) +} + +// GetSnapshotByName mocks base method +func (m *MockCloud) GetSnapshotByName(arg0 context.Context, arg1 string) (*cloud.Snapshot, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSnapshotByName", arg0, arg1) + ret0, _ := ret[0].(*cloud.Snapshot) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSnapshotByName indicates an expected call of GetSnapshotByName +func (mr *MockCloudMockRecorder) GetSnapshotByName(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSnapshotByName", reflect.TypeOf((*MockCloud)(nil).GetSnapshotByName), arg0, arg1) +} + +// IsExistInstance mocks base method +func (m *MockCloud) IsExistInstance(arg0 context.Context, arg1 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsExistInstance", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsExistInstance indicates an expected call of IsExistInstance +func (mr *MockCloudMockRecorder) IsExistInstance(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsExistInstance", reflect.TypeOf((*MockCloud)(nil).IsExistInstance), arg0, arg1) +} + +// WaitForAttachmentState mocks base method +func (m *MockCloud) WaitForAttachmentState(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WaitForAttachmentState", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WaitForAttachmentState indicates an expected call of WaitForAttachmentState +func (mr *MockCloudMockRecorder) WaitForAttachmentState(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForAttachmentState", reflect.TypeOf((*MockCloud)(nil).WaitForAttachmentState), arg0, arg1, arg2) +} diff --git a/pkg/driver/mocks/mock_metadata_service.go b/pkg/driver/mocks/mock_metadata_service.go new file mode 100644 index 0000000000..dc9ef67c20 --- /dev/null +++ b/pkg/driver/mocks/mock_metadata_service.go @@ -0,0 +1,75 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud (interfaces: MetadataService) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockMetadataService is a mock of MetadataService interface +type MockMetadataService struct { + ctrl *gomock.Controller + recorder *MockMetadataServiceMockRecorder +} + +// MockMetadataServiceMockRecorder is the mock recorder for MockMetadataService +type MockMetadataServiceMockRecorder struct { + mock *MockMetadataService +} + +// NewMockMetadataService creates a new mock instance +func NewMockMetadataService(ctrl *gomock.Controller) *MockMetadataService { + mock := &MockMetadataService{ctrl: ctrl} + mock.recorder = &MockMetadataServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMetadataService) EXPECT() *MockMetadataServiceMockRecorder { + return m.recorder +} + +// GetAvailabilityZone mocks base method +func (m *MockMetadataService) GetAvailabilityZone() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAvailabilityZone") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetAvailabilityZone indicates an expected call of GetAvailabilityZone +func (mr *MockMetadataServiceMockRecorder) GetAvailabilityZone() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAvailabilityZone", reflect.TypeOf((*MockMetadataService)(nil).GetAvailabilityZone)) +} + +// GetInstanceID mocks base method +func (m *MockMetadataService) GetInstanceID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInstanceID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetInstanceID indicates an expected call of GetInstanceID +func (mr *MockMetadataServiceMockRecorder) GetInstanceID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceID", reflect.TypeOf((*MockMetadataService)(nil).GetInstanceID)) +} + +// GetRegion mocks base method +func (m *MockMetadataService) GetRegion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRegion") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetRegion indicates an expected call of GetRegion +func (mr *MockMetadataServiceMockRecorder) GetRegion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRegion", reflect.TypeOf((*MockMetadataService)(nil).GetRegion)) +} diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 6a74188999..92ced51ec8 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -22,8 +22,10 @@ import ( "testing" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/internal" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/mocks" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/kubernetes/pkg/util/mount" @@ -226,9 +228,15 @@ func TestNodeStageVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - tc.fakeMounter) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + + awsDriver := newTestNodeService(mockCloud, tc.fakeMounter) _, err := awsDriver.NodeStageVolume(context.TODO(), tc.req) if err != nil { @@ -324,13 +332,19 @@ func TestNodeUnstageVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + fakeMounter := NewFakeMounter() if len(tc.fakeMountPoints) > 0 { fakeMounter.MountPoints = tc.fakeMountPoints } - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - fakeMounter) + awsDriver := newTestNodeService(mockCloud, fakeMounter) _, err := awsDriver.NodeUnstageVolume(context.TODO(), tc.req) if err != nil { @@ -574,9 +588,15 @@ func TestNodePublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - tc.fakeMounter) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + + awsDriver := newTestNodeService(mockCloud, tc.fakeMounter) _, err := awsDriver.NodePublishVolume(context.TODO(), tc.req) if err != nil { @@ -650,14 +670,20 @@ func TestNodeUnpublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + fakeMounter := NewFakeMounter() if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - fakeMounter) + awsDriver := newTestNodeService(mockCloud, fakeMounter) _, err := awsDriver.NodeUnpublishVolume(context.TODO(), tc.req) if err != nil { @@ -682,11 +708,17 @@ func TestNodeUnpublishVolume(t *testing.T) { } func TestNodeGetVolumeStats(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + req := &csi.NodeGetVolumeStatsRequest{} - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - NewFakeMounter()) + awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) expErrCode := codes.Unimplemented @@ -704,10 +736,17 @@ func TestNodeGetVolumeStats(t *testing.T) { } func TestNodeGetCapabilities(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata) + req := &csi.NodeGetCapabilitiesRequest{} - awsDriver := newTestNodeService( - cloud.NewFakeCloudProvider().GetMetadata(), - NewFakeMounter()) + + awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) caps := []*csi.NodeServiceCapability{ { @@ -734,16 +773,23 @@ func TestNodeGetCapabilities(t *testing.T) { } func TestNodeGetInfo(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockMetadata := mocks.NewMockMetadataService(mockCtl) + mockMetadata.EXPECT().GetAvailabilityZone().Return(expZone).Times(2) + mockMetadata.EXPECT().GetInstanceID().Return(expInstanceId) + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetMetadata().Return(mockMetadata).Times(2) + req := &csi.NodeGetInfoRequest{} - cloud := cloud.NewFakeCloudProvider() - awsDriver := newTestNodeService( - cloud.GetMetadata(), - NewFakeMounter()) + awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) - m := cloud.GetMetadata() + m := mockCloud.GetMetadata() expResp := &csi.NodeGetInfoResponse{ - NodeId: "instanceID", + NodeId: expInstanceId, AccessibleTopology: &csi.Topology{ Segments: map[string]string{TopologyKey: m.GetAvailabilityZone()}, }, @@ -762,9 +808,9 @@ func TestNodeGetInfo(t *testing.T) { } } -func newTestNodeService(metadata cloud.MetadataService, mounter mount.Interface) nodeService { +func newTestNodeService(cloud cloud.Cloud, mounter mount.Interface) nodeService { return nodeService{ - metadata: cloud.NewFakeCloudProvider().GetMetadata(), + metadata: cloud.GetMetadata(), mounter: NewFakeSafeFormatAndMounter(mounter), inFlight: internal.NewInFlight(), } diff --git a/pkg/util/util.go b/pkg/util/util.go index 0be9125c69..387b107525 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -41,7 +41,7 @@ func RoundUpGiB(volumeSizeBytes int64) int64 { return roundUpSize(volumeSizeBytes, GiB) } -// BytesToGiB conversts Bytes to GiB +// BytesToGiB converts Bytes to GiB func BytesToGiB(volumeSizeBytes int64) int64 { return volumeSizeBytes / GiB } diff --git a/pkg/cloud/fakes.go b/tests/sanity/fake_cloud_provider.go similarity index 60% rename from pkg/cloud/fakes.go rename to tests/sanity/fake_cloud_provider.go index 83b1c3013b..d1e7fac2a5 100644 --- a/pkg/cloud/fakes.go +++ b/tests/sanity/fake_cloud_provider.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cloud +package sanity import ( "context" @@ -22,43 +22,48 @@ import ( "math/rand" "time" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" ) -type FakeCloudProvider struct { +type fakeCloudProvider struct { disks map[string]*fakeDisk snapshots map[string]*fakeSnapshot - m *metadata + m *cloud.Metadata pub map[string]string } type fakeDisk struct { - *Disk + *cloud.Disk tags map[string]string } type fakeSnapshot struct { - *Snapshot + *cloud.Snapshot tags map[string]string } -func NewFakeCloudProvider() *FakeCloudProvider { - return &FakeCloudProvider{ +func newFakeCloudProvider() *fakeCloudProvider { + return &fakeCloudProvider{ disks: make(map[string]*fakeDisk), snapshots: make(map[string]*fakeSnapshot), pub: make(map[string]string), - m: &metadata{"instanceID", "region", "az"}, + m: &cloud.Metadata{ + InstanceID: "instanceID", + Region: "region", + AvailabilityZone: "az", + }, } } -func (c *FakeCloudProvider) GetMetadata() MetadataService { +func (c *fakeCloudProvider) GetMetadata() cloud.MetadataService { return c.m } -func (c *FakeCloudProvider) CreateDisk(ctx context.Context, volumeName string, diskOptions *DiskOptions) (*Disk, error) { +func (c *fakeCloudProvider) CreateDisk(ctx context.Context, volumeName string, diskOptions *cloud.DiskOptions) (*cloud.Disk, error) { r1 := rand.New(rand.NewSource(time.Now().UnixNano())) d := &fakeDisk{ - Disk: &Disk{ + Disk: &cloud.Disk{ VolumeID: fmt.Sprintf("vol-%d", r1.Uint64()), CapacityGiB: util.BytesToGiB(diskOptions.CapacityBytes), AvailabilityZone: diskOptions.AvailabilityZone, @@ -69,7 +74,7 @@ func (c *FakeCloudProvider) CreateDisk(ctx context.Context, volumeName string, d return d.Disk, nil } -func (c *FakeCloudProvider) DeleteDisk(ctx context.Context, volumeID string) (bool, error) { +func (c *fakeCloudProvider) DeleteDisk(ctx context.Context, volumeID string) (bool, error) { for volName, f := range c.disks { if f.Disk.VolumeID == volumeID { delete(c.disks, volName) @@ -78,64 +83,64 @@ func (c *FakeCloudProvider) DeleteDisk(ctx context.Context, volumeID string) (bo return true, nil } -func (c *FakeCloudProvider) AttachDisk(ctx context.Context, volumeID, nodeID string) (string, error) { +func (c *fakeCloudProvider) AttachDisk(ctx context.Context, volumeID, nodeID string) (string, error) { if _, ok := c.pub[volumeID]; ok { - return "", ErrAlreadyExists + return "", cloud.ErrAlreadyExists } c.pub[volumeID] = nodeID return "/dev/xvdbc", nil } -func (c *FakeCloudProvider) DetachDisk(ctx context.Context, volumeID, nodeID string) error { +func (c *fakeCloudProvider) DetachDisk(ctx context.Context, volumeID, nodeID string) error { return nil } -func (c *FakeCloudProvider) WaitForAttachmentState(ctx context.Context, volumeID, state string) error { +func (c *fakeCloudProvider) WaitForAttachmentState(ctx context.Context, volumeID, state string) error { return nil } -func (c *FakeCloudProvider) GetDiskByName(ctx context.Context, name string, capacityBytes int64) (*Disk, error) { +func (c *fakeCloudProvider) GetDiskByName(ctx context.Context, name string, capacityBytes int64) (*cloud.Disk, error) { var disks []*fakeDisk for _, d := range c.disks { for key, value := range d.tags { - if key == VolumeNameTagKey && value == name { + if key == cloud.VolumeNameTagKey && value == name { disks = append(disks, d) } } } if len(disks) > 1 { - return nil, ErrMultiDisks + return nil, cloud.ErrMultiDisks } else if len(disks) == 1 { if capacityBytes != disks[0].Disk.CapacityGiB*util.GiB { - return nil, ErrDiskExistsDiffSize + return nil, cloud.ErrDiskExistsDiffSize } return disks[0].Disk, nil } return nil, nil } -func (c *FakeCloudProvider) GetDiskByID(ctx context.Context, volumeID string) (*Disk, error) { +func (c *fakeCloudProvider) GetDiskByID(ctx context.Context, volumeID string) (*cloud.Disk, error) { for _, f := range c.disks { if f.Disk.VolumeID == volumeID { return f.Disk, nil } } - return nil, ErrNotFound + return nil, cloud.ErrNotFound } -func (c *FakeCloudProvider) IsExistInstance(ctx context.Context, nodeID string) bool { +func (c *fakeCloudProvider) IsExistInstance(ctx context.Context, nodeID string) bool { return nodeID == c.m.GetInstanceID() } -func (c *FakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) { +func (c *fakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *cloud.SnapshotOptions) (snapshot *cloud.Snapshot, err error) { r1 := rand.New(rand.NewSource(time.Now().UnixNano())) snapshotID := fmt.Sprintf("snapshot-%d", r1.Uint64()) - if len(snapshotOptions.Tags[SnapshotNameTagKey]) == 0 { + if len(snapshotOptions.Tags[cloud.SnapshotNameTagKey]) == 0 { // for simplicity: let's have the Name and ID identical - snapshotOptions.Tags[SnapshotNameTagKey] = snapshotID + snapshotOptions.Tags[cloud.SnapshotNameTagKey] = snapshotID } s := &fakeSnapshot{ - Snapshot: &Snapshot{ + Snapshot: &cloud.Snapshot{ SnapshotID: snapshotID, SourceVolumeID: volumeID, Size: 1, @@ -148,17 +153,17 @@ func (c *FakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, } -func (c *FakeCloudProvider) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) { +func (c *fakeCloudProvider) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) { delete(c.snapshots, snapshotID) return true, nil } -func (c *FakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) { +func (c *fakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) (snapshot *cloud.Snapshot, err error) { var snapshots []*fakeSnapshot for _, s := range c.snapshots { for key, value := range s.tags { - if key == SnapshotNameTagKey && value == name { + if key == cloud.SnapshotNameTagKey && value == name { snapshots = append(snapshots, s) } } From 017dfdfe48a5553874c63ee15453b46f836d95ac Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Thu, 4 Apr 2019 23:54:30 -0400 Subject: [PATCH 3/9] Removing dependency on cloud from sanity --- tests/sanity/sanity_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index b50b0456fd..991d17f347 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -26,7 +26,6 @@ import ( sanity "github.com/kubernetes-csi/csi-test/pkg/sanity" - "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" ) @@ -50,7 +49,7 @@ var _ = BeforeSuite(func() { "/dev/xvdbc": mount.FileTypeFile, }, } - ebsDriver = driver.NewFakeDriver(endpoint, cloud.NewFakeCloudProvider(), fakeMounter) + ebsDriver = driver.NewFakeDriver(endpoint, newFakeCloudProvider(), fakeMounter) go func() { Expect(ebsDriver.Run()).NotTo(HaveOccurred()) }() From 636fe707fad2b367c78ed99cdab2cac38c63911f Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Fri, 5 Apr 2019 09:07:12 -0400 Subject: [PATCH 4/9] Fixing prow vet errors --- pkg/driver/controller_test.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 5a8d71d038..6b6fbd4010 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -206,10 +206,6 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) } } - - if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { - t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) - } }, }, { @@ -339,10 +335,6 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) } } - - if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { - t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) - } }, }, { @@ -458,10 +450,6 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) } } - - if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { - t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) - } }, }, { @@ -734,10 +722,6 @@ func TestCreateVolume(t *testing.T) { } } - if expVol.GetVolumeContext() == nil && vol.GetVolumeContext() != nil { - t.Fatalf("Expected volume context to be nil, got: %#v", vol.GetVolumeContext()) - } - if expVol.GetAccessibleTopology() != nil { if !reflect.DeepEqual(expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) { t.Fatalf("Expected AccessibleTopology to be %+v, got: %+v", expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) From 79fe15e76c7f26f997bb0246426de5918e791660 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Fri, 5 Apr 2019 09:08:46 -0400 Subject: [PATCH 5/9] Removing unneeded setup func --- pkg/driver/controller_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 6b6fbd4010..6d3659ee3a 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1453,7 +1453,6 @@ func TestControllerPublishVolume(t *testing.T) { func TestControllerUnpublishVolume(t *testing.T) { testCases := []struct { name string - setup func(req *csi.ControllerUnpublishVolumeRequest) testFunc func(t *testing.T) }{ { From 0038a0c86b87ab39828342f9c40ba26945a244b3 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Fri, 5 Apr 2019 09:15:50 -0400 Subject: [PATCH 6/9] Removing nilness vet issue for returned volume --- pkg/driver/controller_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 6d3659ee3a..8cecfc056f 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -192,7 +192,7 @@ func TestCreateVolume(t *testing.T) { } vol := resp.GetVolume() - if vol == nil && expVol != nil { + if vol == nil { t.Fatalf("Expected volume %v, got nil", expVol) } @@ -321,7 +321,7 @@ func TestCreateVolume(t *testing.T) { } vol := resp.GetVolume() - if vol == nil && expVol != nil { + if vol == nil { t.Fatalf("Expected volume %v, got nil", expVol) } @@ -384,7 +384,7 @@ func TestCreateVolume(t *testing.T) { } vol := resp.GetVolume() - if vol == nil && expVol != nil { + if vol == nil { t.Fatalf("Expected volume %v, got nil", expVol) } @@ -436,7 +436,7 @@ func TestCreateVolume(t *testing.T) { } vol := resp.GetVolume() - if vol == nil && expVol != nil { + if vol == nil { t.Fatalf("Expected volume %v, got nil", expVol) } @@ -711,7 +711,7 @@ func TestCreateVolume(t *testing.T) { } vol := resp.GetVolume() - if vol == nil && expVol != nil { + if vol == nil { t.Fatalf("Expected volume %v, got nil", expVol) } From e4b300fb4f8c6e704687e9f96c7b7592cabcdb53 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Fri, 5 Apr 2019 09:23:03 -0400 Subject: [PATCH 7/9] Fixing snap test case vet errors --- pkg/driver/controller_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 8cecfc056f..f36ebdbda4 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -909,8 +909,7 @@ func TestCreateSnapshot(t *testing.T) { if expErrCode != codes.OK { t.Fatalf("Expected error %v, got no error", expErrCode) } - snap := resp.GetSnapshot() - if snap == nil && expSnapshot != nil { + if snap := resp.GetSnapshot(); snap == nil { t.Fatalf("Expected snapshot %v, got nil", expSnapshot) } }, @@ -987,7 +986,7 @@ func TestCreateSnapshot(t *testing.T) { t.Fatalf("Expected error %v, got no error", expErrCode) } snap := resp.GetSnapshot() - if snap == nil && expSnapshot != nil { + if snap == nil { t.Fatalf("Expected snapshot %v, got nil", expSnapshot) } @@ -1053,7 +1052,7 @@ func TestCreateSnapshot(t *testing.T) { t.Fatalf("Expected error %v, got no error", expErrCode) } snap := resp.GetSnapshot() - if snap == nil && expSnapshot != nil { + if snap == nil { t.Fatalf("Expected snapshot %v, got nil", expSnapshot) } From cb6243a23bde6834b843f0b8ea225014adeccc96 Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Mon, 8 Apr 2019 11:34:09 -0400 Subject: [PATCH 8/9] Getting rid of mock cloud creation --- pkg/driver/node_test.go | 44 +++++++++++------------------------------ 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 92ced51ec8..73732d21dd 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -233,10 +233,7 @@ func TestNodeStageVolume(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - - awsDriver := newTestNodeService(mockCloud, tc.fakeMounter) + awsDriver := newTestNodeService(mockMetadata, tc.fakeMounter) _, err := awsDriver.NodeStageVolume(context.TODO(), tc.req) if err != nil { @@ -337,14 +334,11 @@ func TestNodeUnstageVolume(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - fakeMounter := NewFakeMounter() if len(tc.fakeMountPoints) > 0 { fakeMounter.MountPoints = tc.fakeMountPoints } - awsDriver := newTestNodeService(mockCloud, fakeMounter) + awsDriver := newTestNodeService(mockMetadata, fakeMounter) _, err := awsDriver.NodeUnstageVolume(context.TODO(), tc.req) if err != nil { @@ -593,10 +587,7 @@ func TestNodePublishVolume(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - - awsDriver := newTestNodeService(mockCloud, tc.fakeMounter) + awsDriver := newTestNodeService(mockMetadata, tc.fakeMounter) _, err := awsDriver.NodePublishVolume(context.TODO(), tc.req) if err != nil { @@ -675,15 +666,12 @@ func TestNodeUnpublishVolume(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - fakeMounter := NewFakeMounter() if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := newTestNodeService(mockCloud, fakeMounter) + awsDriver := newTestNodeService(mockMetadata, fakeMounter) _, err := awsDriver.NodeUnpublishVolume(context.TODO(), tc.req) if err != nil { @@ -713,12 +701,9 @@ func TestNodeGetVolumeStats(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - req := &csi.NodeGetVolumeStatsRequest{} - awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) + awsDriver := newTestNodeService(mockMetadata, NewFakeMounter()) expErrCode := codes.Unimplemented @@ -741,12 +726,9 @@ func TestNodeGetCapabilities(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata) - req := &csi.NodeGetCapabilitiesRequest{} - awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) + awsDriver := newTestNodeService(mockMetadata, NewFakeMounter()) caps := []*csi.NodeServiceCapability{ { @@ -777,21 +759,17 @@ func TestNodeGetInfo(t *testing.T) { defer mockCtl.Finish() mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockMetadata.EXPECT().GetAvailabilityZone().Return(expZone).Times(2) mockMetadata.EXPECT().GetInstanceID().Return(expInstanceId) - - mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetMetadata().Return(mockMetadata).Times(2) + mockMetadata.EXPECT().GetAvailabilityZone().Return(expZone).Times(2) req := &csi.NodeGetInfoRequest{} - awsDriver := newTestNodeService(mockCloud, NewFakeMounter()) + awsDriver := newTestNodeService(mockMetadata, NewFakeMounter()) - m := mockCloud.GetMetadata() expResp := &csi.NodeGetInfoResponse{ NodeId: expInstanceId, AccessibleTopology: &csi.Topology{ - Segments: map[string]string{TopologyKey: m.GetAvailabilityZone()}, + Segments: map[string]string{TopologyKey: mockMetadata.GetAvailabilityZone()}, }, } @@ -808,9 +786,9 @@ func TestNodeGetInfo(t *testing.T) { } } -func newTestNodeService(cloud cloud.Cloud, mounter mount.Interface) nodeService { +func newTestNodeService(metadataService cloud.MetadataService, mounter mount.Interface) nodeService { return nodeService{ - metadata: cloud.GetMetadata(), + metadata: metadataService, mounter: NewFakeSafeFormatAndMounter(mounter), inFlight: internal.NewInFlight(), } From ba5c1c729d8abf6de0eab57bae0e3e755811d81a Mon Sep 17 00:00:00 2001 From: Zach Abrahamson Date: Mon, 8 Apr 2019 11:34:28 -0400 Subject: [PATCH 9/9] Refactor individual test cases for readability --- pkg/driver/controller_test.go | 271 ++++++++++++++-------------------- 1 file changed, 108 insertions(+), 163 deletions(-) diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index f36ebdbda4..6e0ab037fc 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -109,7 +109,6 @@ func TestCreateVolume(t *testing.T) { VolumeCapabilities: stdVolCap, Parameters: stdParams, } - expErr := codes.InvalidArgument ctx := context.Background() @@ -125,11 +124,11 @@ func TestCreateVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErr { - t.Fatalf("Expected error code %d, got %d message %s", expErr, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } } else { - t.Fatalf("Expected error got nil") + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -150,7 +149,7 @@ func TestCreateVolume(t *testing.T) { } expVol := &csi.Volume{ CapacityBytes: stdVolSize, - VolumeId: "vol-test", + VolumeId: "test-vol", VolumeContext: map[string]string{FsTypeKey: ""}, } @@ -200,6 +199,16 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Expected volume capacity bytes: %v, got: %v", expVol.GetCapacityBytes(), vol.GetCapacityBytes()) } + if vol.GetVolumeId() != expVol.GetVolumeId() { + t.Fatalf("Expected volume id: %v, got: %v", expVol.GetVolumeId(), vol.GetVolumeId()) + } + + if expVol.GetAccessibleTopology() != nil { + if !reflect.DeepEqual(expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) { + t.Fatalf("Expected AccessibleTopology to be %+v, got: %+v", expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) + } + } + for expKey, expVal := range expVol.GetVolumeContext() { ctx := vol.GetVolumeContext() if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { @@ -271,7 +280,7 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Expected error code %d, got %d", codes.AlreadyExists, srvErr.Code()) } } else { - t.Fatalf("Expected error code %d, got nil", codes.AlreadyExists) + t.Fatalf("Expected error %v, got no error", codes.AlreadyExists) } }, }, @@ -295,18 +304,14 @@ func TestCreateVolume(t *testing.T) { VolumeID: req.Name, AvailabilityZone: expZone, FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(cloud.DefaultVolumeSize), } - volSizeBytes, err := getVolSizeBytes(req) - if err != nil { - t.Fatalf("Unable to get volume size bytes for req: %s", err) - } - mockDisk.CapacityGiB = util.BytesToGiB(volSizeBytes) mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(cloud.DefaultVolumeSize)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) awsDriver := controllerService{cloud: mockCloud} @@ -358,18 +363,14 @@ func TestCreateVolume(t *testing.T) { VolumeID: req.Name, AvailabilityZone: expZone, FsType: expVol.VolumeContext[FsTypeKey], + CapacityGiB: util.BytesToGiB(expVol.CapacityBytes), } - volSizeBytes, err := getVolSizeBytes(req) - if err != nil { - t.Fatalf("Unable to get volume size bytes for req: %s", err) - } - mockDisk.CapacityGiB = util.BytesToGiB(volSizeBytes) mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(volSizeBytes)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(expVol.CapacityBytes)).Return(nil, cloud.ErrNotFound) mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) awsDriver := controllerService{cloud: mockCloud} @@ -782,7 +783,7 @@ func TestDeleteVolume(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(true, nil) + mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, cloud.ErrNotFound) awsDriver := controllerService{cloud: mockCloud} resp, err := awsDriver.DeleteVolume(ctx, req) if err != nil { @@ -797,6 +798,38 @@ func TestDeleteVolume(t *testing.T) { } }, }, + { + name: "fail delete disk", + testFunc: func(t *testing.T) { + req := &csi.DeleteVolumeRequest{ + VolumeId: "test-vol", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().DeleteDisk(gomock.Eq(ctx), gomock.Eq(req.VolumeId)).Return(false, fmt.Errorf("DeleteDisk could not delete volume")) + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.DeleteVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.Internal { + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + } else { + t.Fatalf("Expected error, got nil") + } + + if resp != nil { + t.Fatalf("Expected resp to be nil, got: %+v", resp) + } + }, + }, } for _, tc := range testCases { @@ -878,7 +911,6 @@ func TestCreateSnapshot(t *testing.T) { expSnapshot := &csi.Snapshot{ ReadyToUse: true, } - expErrCode := codes.OK ctx := context.Background() mockSnapshot := &cloud.Snapshot{ @@ -897,18 +929,9 @@ func TestCreateSnapshot(t *testing.T) { awsDriver := controllerService{cloud: mockCloud} resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + t.Fatalf("Unexpected error: %v", err) } + if snap := resp.GetSnapshot(); snap == nil { t.Fatalf("Expected snapshot %v, got nil", expSnapshot) } @@ -921,7 +944,6 @@ func TestCreateSnapshot(t *testing.T) { Parameters: nil, SourceVolumeId: "vol-test", } - expErrCode := codes.InvalidArgument mockCtl := gomock.NewController(t) defer mockCtl.Finish() @@ -934,10 +956,11 @@ func TestCreateSnapshot(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -957,8 +980,6 @@ func TestCreateSnapshot(t *testing.T) { expSnapshot := &csi.Snapshot{ ReadyToUse: true, } - expErrCode := codes.OK - extraExpErrCode := codes.AlreadyExists ctx := context.Background() mockSnapshot := &cloud.Snapshot{ @@ -981,9 +1002,10 @@ func TestCreateSnapshot(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + if srvErr.Code() != codes.OK { + t.Fatalf("Expected error code %d, got %d message %s", codes.OK, srvErr.Code(), srvErr.Message()) + } + t.Fatalf("Unexpected error: %v", err) } snap := resp.GetSnapshot() if snap == nil { @@ -997,13 +1019,11 @@ func TestCreateSnapshot(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != extraExpErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.AlreadyExists { + t.Fatalf("Expected error code %d, got %d message %s", codes.AlreadyExists, srvErr.Code(), srvErr.Message()) } - return - } - if extraExpErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", extraExpErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.AlreadyExists) } }, }, @@ -1023,8 +1043,6 @@ func TestCreateSnapshot(t *testing.T) { expSnapshot := &csi.Snapshot{ ReadyToUse: true, } - expErrCode := codes.OK - extraExpErrCode := codes.OK ctx := context.Background() mockSnapshot := &cloud.Snapshot{ @@ -1043,13 +1061,7 @@ func TestCreateSnapshot(t *testing.T) { awsDriver := controllerService{cloud: mockCloud} resp, err := awsDriver.CreateSnapshot(context.Background(), req) if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + t.Fatalf("Unexpected error: %v", err) } snap := resp.GetSnapshot() if snap == nil { @@ -1059,17 +1071,7 @@ func TestCreateSnapshot(t *testing.T) { mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(extraReq.GetName())).Return(mockSnapshot, nil) _, err = awsDriver.CreateSnapshot(ctx, extraReq) if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != extraExpErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if extraExpErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", extraExpErrCode) + t.Fatalf("Unexpected error: %v", err) } }, }, @@ -1088,7 +1090,6 @@ func TestDeleteSnapshot(t *testing.T) { { name: "success normal", testFunc: func(t *testing.T) { - expErrCode := codes.OK ctx := context.Background() mockCtl := gomock.NewController(t) @@ -1103,24 +1104,13 @@ func TestDeleteSnapshot(t *testing.T) { mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq("xxx")).Return(true, nil) if _, err := awsDriver.DeleteSnapshot(ctx, req); err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + t.Fatalf("Unexpected error: %v", err) } }, }, { name: "success not found", testFunc: func(t *testing.T) { - expErrCode := codes.OK ctx := context.Background() mockCtl := gomock.NewController(t) @@ -1135,17 +1125,7 @@ func TestDeleteSnapshot(t *testing.T) { mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq("xxx")).Return(false, cloud.ErrNotFound) if _, err := awsDriver.DeleteSnapshot(ctx, req); err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) - } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + t.Fatalf("Unexpected error: %v", err) } }, }, @@ -1196,11 +1176,7 @@ func TestControllerPublishVolume(t *testing.T) { awsDriver := controllerService{cloud: mockCloud} resp, err := awsDriver.ControllerPublishVolume(ctx, req) if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - t.Fatalf("Unexpected error: %v", srvErr.Code()) + t.Fatalf("Unexpected error: %v", err) } if !reflect.DeepEqual(resp, expResp) { @@ -1212,7 +1188,6 @@ func TestControllerPublishVolume(t *testing.T) { name: "fail no VolumeId", testFunc: func(t *testing.T) { req := &csi.ControllerPublishVolumeRequest{} - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1227,13 +1202,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -1243,7 +1216,6 @@ func TestControllerPublishVolume(t *testing.T) { req := &csi.ControllerPublishVolumeRequest{ VolumeId: "vol-test", } - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1258,13 +1230,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -1275,7 +1245,6 @@ func TestControllerPublishVolume(t *testing.T) { NodeId: expInstanceId, VolumeId: "vol-test", } - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1290,13 +1259,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -1312,7 +1279,6 @@ func TestControllerPublishVolume(t *testing.T) { }, VolumeId: "vol-test", } - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1327,13 +1293,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -1345,7 +1309,6 @@ func TestControllerPublishVolume(t *testing.T) { VolumeId: "vol-test", VolumeCapability: stdVolCap, } - expErrCode := codes.NotFound ctx := context.Background() @@ -1361,13 +1324,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.NotFound { + t.Fatalf("Expected error code %d, got %d message %s", codes.NotFound, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.NotFound) } }, }, @@ -1379,7 +1340,6 @@ func TestControllerPublishVolume(t *testing.T) { NodeId: expInstanceId, VolumeCapability: stdVolCap, } - expErrCode := codes.NotFound ctx := context.Background() @@ -1396,13 +1356,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.NotFound { + t.Fatalf("Expected error code %d, got %d message %s", codes.NotFound, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.NotFound) } }, }, @@ -1414,7 +1372,6 @@ func TestControllerPublishVolume(t *testing.T) { NodeId: expInstanceId, VolumeCapability: stdVolCap, } - expErrCode := codes.AlreadyExists ctx := context.Background() @@ -1432,13 +1389,11 @@ func TestControllerPublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.AlreadyExists { + t.Fatalf("Expected error code %d, got %d message %s", codes.AlreadyExists, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.AlreadyExists) } }, }, @@ -1474,11 +1429,7 @@ func TestControllerUnpublishVolume(t *testing.T) { awsDriver := controllerService{cloud: mockCloud} resp, err := awsDriver.ControllerUnpublishVolume(ctx, req) if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - t.Fatalf("Unexpected error: %v", srvErr.Code()) + t.Fatalf("Unexpected error: %v", err) } if !reflect.DeepEqual(resp, expResp) { @@ -1490,7 +1441,6 @@ func TestControllerUnpublishVolume(t *testing.T) { name: "fail no VolumeId", testFunc: func(t *testing.T) { req := &csi.ControllerUnpublishVolumeRequest{} - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1505,13 +1455,11 @@ func TestControllerUnpublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, }, @@ -1521,7 +1469,6 @@ func TestControllerUnpublishVolume(t *testing.T) { req := &csi.ControllerUnpublishVolumeRequest{ VolumeId: "vol-test", } - expErrCode := codes.InvalidArgument ctx := context.Background() @@ -1536,13 +1483,11 @@ func TestControllerUnpublishVolume(t *testing.T) { if !ok { t.Fatalf("Could not get error status code from error: %v", srvErr) } - if srvErr.Code() != expErrCode { - t.Fatalf("Expected error code %d, got %d message %s", expErrCode, srvErr.Code(), srvErr.Message()) + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) } - return - } - if expErrCode != codes.OK { - t.Fatalf("Expected error %v, got no error", expErrCode) + } else { + t.Fatalf("Expected error %v, got no error", codes.InvalidArgument) } }, },