diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 0c1603bf3b..a0529fb8eb 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -268,7 +268,7 @@ func TestCreateVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) resp, err := awsDriver.CreateVolume(context.TODO(), tc.req) if err != nil { @@ -353,7 +353,7 @@ func TestDeleteVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) _, err := awsDriver.DeleteVolume(context.TODO(), tc.req) if err != nil { srvErr, ok := status.FromError(err) @@ -499,7 +499,7 @@ func TestCreateSnapshot(t *testing.T) { } for _, tc := range testCases { t.Logf("Test case: %s", tc.name) - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) resp, err := awsDriver.CreateSnapshot(context.TODO(), tc.req) if err != nil { srvErr, ok := status.FromError(err) @@ -565,7 +565,7 @@ func TestDeleteSnapshot(t *testing.T) { } for _, tc := range testCases { t.Logf("Test case: %s", tc.name) - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) snapResp, err := awsDriver.CreateSnapshot(context.TODO(), snapReq) if err != nil { t.Fatalf("Error creating testing snapshot: %v", err) @@ -589,3 +589,204 @@ func TestDeleteSnapshot(t *testing.T) { } } } + +func TestControllerPublishVolume(t *testing.T) { + fakeCloud := NewFakeCloudProvider() + stdVolCap := &csi.VolumeCapability{ + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + } + testCases := []struct { + name string + req *csi.ControllerPublishVolumeRequest + expResp *csi.ControllerPublishVolumeResponse + expErrCode codes.Code + setup func(req *csi.ControllerPublishVolumeRequest) + }{ + { + name: "success normal", + expResp: &csi.ControllerPublishVolumeResponse{}, + req: &csi.ControllerPublishVolumeRequest{ + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + VolumeCapability: stdVolCap, + }, + // create a fake disk and setup the request + // parameters appropriately + setup: func(req *csi.ControllerPublishVolumeRequest) { + fakeDiskOpts := &cloud.DiskOptions{ + CapacityBytes: 1, + AvailabilityZone: "az", + } + fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) + req.VolumeId = fakeDisk.VolumeID + }, + }, + { + name: "fail no VolumeId", + req: &csi.ControllerPublishVolumeRequest{}, + expErrCode: codes.InvalidArgument, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail no NodeId", + expErrCode: codes.InvalidArgument, + req: &csi.ControllerPublishVolumeRequest{ + VolumeId: "vol-test", + }, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail no VolumeCapability", + expErrCode: codes.InvalidArgument, + req: &csi.ControllerPublishVolumeRequest{ + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + VolumeId: "vol-test", + }, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail invalid VolumeCapability", + expErrCode: codes.InvalidArgument, + req: &csi.ControllerPublishVolumeRequest{ + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_UNKNOWN, + }, + }, + VolumeId: "vol-test", + }, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail instance not found", + expErrCode: codes.NotFound, + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "does-not-exist", + VolumeId: "vol-test", + VolumeCapability: stdVolCap, + }, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail volume not found", + expErrCode: codes.NotFound, + req: &csi.ControllerPublishVolumeRequest{ + VolumeId: "does-not-exist", + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + VolumeCapability: stdVolCap, + }, + setup: func(req *csi.ControllerPublishVolumeRequest) {}, + }, + { + name: "fail attach disk with already exists error", + expErrCode: codes.AlreadyExists, + req: &csi.ControllerPublishVolumeRequest{ + VolumeId: "does-not-exist", + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + VolumeCapability: stdVolCap, + }, + // create a fake disk, attach it and setup the + // request appropriately + setup: func(req *csi.ControllerPublishVolumeRequest) { + fakeDiskOpts := &cloud.DiskOptions{ + CapacityBytes: 1, + AvailabilityZone: "az", + } + fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) + req.VolumeId = fakeDisk.VolumeID + _, _ = fakeCloud.AttachDisk(context.TODO(), fakeDisk.VolumeID, fakeCloud.GetMetadata().GetInstanceID()) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setup(tc.req) + awsDriver := NewFakeDriver("", fakeCloud, NewFakeMounter()) + _, err := awsDriver.ControllerPublishVolume(context.TODO(), tc.req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != tc.expErrCode { + t.Fatalf("Expected error code '%d', got '%d': %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", tc.expErrCode) + } + }) + } +} + +func TestControllerUnpublishVolume(t *testing.T) { + fakeCloud := NewFakeCloudProvider() + testCases := []struct { + name string + req *csi.ControllerUnpublishVolumeRequest + expResp *csi.ControllerUnpublishVolumeResponse + expErrCode codes.Code + setup func(req *csi.ControllerUnpublishVolumeRequest) + }{ + { + name: "success normal", + expResp: &csi.ControllerUnpublishVolumeResponse{}, + req: &csi.ControllerUnpublishVolumeRequest{ + NodeId: fakeCloud.GetMetadata().GetInstanceID(), + }, + // create a fake disk, attach it and setup the request + // parameters appropriately + setup: func(req *csi.ControllerUnpublishVolumeRequest) { + fakeDiskOpts := &cloud.DiskOptions{ + CapacityBytes: 1, + AvailabilityZone: "az", + } + fakeDisk, _ := fakeCloud.CreateDisk(context.TODO(), "vol-test", fakeDiskOpts) + req.VolumeId = fakeDisk.VolumeID + _, _ = fakeCloud.AttachDisk(context.TODO(), fakeDisk.VolumeID, fakeCloud.GetMetadata().GetInstanceID()) + }, + }, + { + name: "fail no VolumeId", + req: &csi.ControllerUnpublishVolumeRequest{}, + expErrCode: codes.InvalidArgument, + setup: func(req *csi.ControllerUnpublishVolumeRequest) {}, + }, + { + name: "fail no NodeId", + expErrCode: codes.InvalidArgument, + req: &csi.ControllerUnpublishVolumeRequest{ + VolumeId: "vol-test", + }, + setup: func(req *csi.ControllerUnpublishVolumeRequest) {}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setup(tc.req) + awsDriver := NewFakeDriver("", fakeCloud, NewFakeMounter()) + _, err := awsDriver.ControllerUnpublishVolume(context.TODO(), tc.req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != tc.expErrCode { + t.Fatalf("Expected error code '%d', got '%d': %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) + } + return + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error %v, got no error", tc.expErrCode) + } + }) + } +} diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 1683a59d4b..bb8e4b1e6f 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -22,6 +22,10 @@ import ( "k8s.io/kubernetes/pkg/util/mount" ) +func NewFakeCloudProvider() *cloud.FakeCloudProvider { + return cloud.NewFakeCloudProvider() +} + func NewFakeMounter() *mount.FakeMounter { return &mount.FakeMounter{ MountPoints: []mount.MountPoint{}, @@ -38,12 +42,11 @@ func NewFakeSafeFormatAndMounter(fakeMounter *mount.FakeMounter) *mount.SafeForm } // NewFakeDriver creates a new mock driver used for testing -func NewFakeDriver(endpoint string, fakeMounter *mount.FakeMounter) *Driver { - cloud := cloud.NewFakeCloudProvider() +func NewFakeDriver(endpoint string, fakeCloud *cloud.FakeCloudProvider, fakeMounter *mount.FakeMounter) *Driver { return &Driver{ endpoint: endpoint, - nodeID: cloud.GetMetadata().GetInstanceID(), - cloud: cloud, + nodeID: fakeCloud.GetMetadata().GetInstanceID(), + cloud: fakeCloud, mounter: NewFakeSafeFormatAndMounter(fakeMounter), inFlight: internal.NewInFlight(), } diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 0918d55d86..cea882ecce 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -205,7 +205,7 @@ func TestNodeStageVolume(t *testing.T) { if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeStageVolume(context.TODO(), tc.req) if err != nil { @@ -305,7 +305,7 @@ func TestNodeUnstageVolume(t *testing.T) { if len(tc.fakeMountPoints) > 0 { fakeMounter.MountPoints = tc.fakeMountPoints } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeUnstageVolume(context.TODO(), tc.req) if err != nil { @@ -535,7 +535,7 @@ func TestNodePublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { fakeMounter := NewFakeMounter() - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodePublishVolume(context.TODO(), tc.req) if err != nil { @@ -612,7 +612,7 @@ func TestNodeUnpublishVolume(t *testing.T) { if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeUnpublishVolume(context.TODO(), tc.req) if err != nil { @@ -638,7 +638,7 @@ func TestNodeUnpublishVolume(t *testing.T) { func TestNodeGetVolumeStats(t *testing.T) { req := &csi.NodeGetVolumeStatsRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) expErrCode := codes.Unimplemented _, err := awsDriver.NodeGetVolumeStats(context.TODO(), req) @@ -656,7 +656,7 @@ func TestNodeGetVolumeStats(t *testing.T) { func TestNodeGetCapabilities(t *testing.T) { req := &csi.NodeGetCapabilitiesRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) caps := []*csi.NodeServiceCapability{ { Type: &csi.NodeServiceCapability_Rpc{ @@ -683,7 +683,7 @@ func TestNodeGetCapabilities(t *testing.T) { func TestNodeGetInfo(t *testing.T) { req := &csi.NodeGetInfoRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) m := awsDriver.cloud.GetMetadata() expResp := &csi.NodeGetInfoResponse{ NodeId: "instanceID", diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index 554f367de9..87b597e055 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -43,7 +43,7 @@ func TestSanity(t *testing.T) { } var _ = BeforeSuite(func() { - ebsDriver = driver.NewFakeDriver(endpoint, driver.NewFakeMounter()) + ebsDriver = driver.NewFakeDriver(endpoint, driver.NewFakeCloudProvider(), driver.NewFakeMounter()) go func() { Expect(ebsDriver.Run()).NotTo(HaveOccurred()) }()