diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 57b9f2ff18..01d682076e 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -268,7 +268,7 @@ func TestCreateVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) resp, err := awsDriver.CreateVolume(context.TODO(), tc.req) if err != nil { @@ -353,7 +353,7 @@ func TestDeleteVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) _, err := awsDriver.DeleteVolume(context.TODO(), tc.req) if err != nil { srvErr, ok := status.FromError(err) diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 1683a59d4b..bb8e4b1e6f 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -22,6 +22,10 @@ import ( "k8s.io/kubernetes/pkg/util/mount" ) +func NewFakeCloudProvider() *cloud.FakeCloudProvider { + return cloud.NewFakeCloudProvider() +} + func NewFakeMounter() *mount.FakeMounter { return &mount.FakeMounter{ MountPoints: []mount.MountPoint{}, @@ -38,12 +42,11 @@ func NewFakeSafeFormatAndMounter(fakeMounter *mount.FakeMounter) *mount.SafeForm } // NewFakeDriver creates a new mock driver used for testing -func NewFakeDriver(endpoint string, fakeMounter *mount.FakeMounter) *Driver { - cloud := cloud.NewFakeCloudProvider() +func NewFakeDriver(endpoint string, fakeCloud *cloud.FakeCloudProvider, fakeMounter *mount.FakeMounter) *Driver { return &Driver{ endpoint: endpoint, - nodeID: cloud.GetMetadata().GetInstanceID(), - cloud: cloud, + nodeID: fakeCloud.GetMetadata().GetInstanceID(), + cloud: fakeCloud, mounter: NewFakeSafeFormatAndMounter(fakeMounter), inFlight: internal.NewInFlight(), } diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 9e5b102785..4556d75417 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -187,7 +187,7 @@ func TestNodeStageVolume(t *testing.T) { if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeStageVolume(context.TODO(), tc.req) if err != nil { @@ -287,7 +287,7 @@ func TestNodeUnstageVolume(t *testing.T) { if len(tc.fakeMountPoints) > 0 { fakeMounter.MountPoints = tc.fakeMountPoints } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeUnstageVolume(context.TODO(), tc.req) if err != nil { @@ -484,7 +484,7 @@ func TestNodePublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { fakeMounter := NewFakeMounter() - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodePublishVolume(context.TODO(), tc.req) if err != nil { @@ -561,7 +561,7 @@ func TestNodeUnpublishVolume(t *testing.T) { if tc.fakeMountPoint != nil { fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) } - awsDriver := NewFakeDriver("", fakeMounter) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), fakeMounter) _, err := awsDriver.NodeUnpublishVolume(context.TODO(), tc.req) if err != nil { @@ -587,7 +587,7 @@ func TestNodeUnpublishVolume(t *testing.T) { func TestNodeGetVolumeStats(t *testing.T) { req := &csi.NodeGetVolumeStatsRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) expErrCode := codes.Unimplemented _, err := awsDriver.NodeGetVolumeStats(context.TODO(), req) @@ -605,7 +605,7 @@ func TestNodeGetVolumeStats(t *testing.T) { func TestNodeGetCapabilities(t *testing.T) { req := &csi.NodeGetCapabilitiesRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) caps := []*csi.NodeServiceCapability{ { Type: &csi.NodeServiceCapability_Rpc{ @@ -632,7 +632,7 @@ func TestNodeGetCapabilities(t *testing.T) { func TestNodeGetInfo(t *testing.T) { req := &csi.NodeGetInfoRequest{} - awsDriver := NewFakeDriver("", NewFakeMounter()) + awsDriver := NewFakeDriver("", NewFakeCloudProvider(), NewFakeMounter()) m := awsDriver.cloud.GetMetadata() expResp := &csi.NodeGetInfoResponse{ NodeId: "instanceID", diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index 554f367de9..87b597e055 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -43,7 +43,7 @@ func TestSanity(t *testing.T) { } var _ = BeforeSuite(func() { - ebsDriver = driver.NewFakeDriver(endpoint, driver.NewFakeMounter()) + ebsDriver = driver.NewFakeDriver(endpoint, driver.NewFakeCloudProvider(), driver.NewFakeMounter()) go func() { Expect(ebsDriver.Run()).NotTo(HaveOccurred()) }()