From 734aa0dac56b5c910673c58b209a7b74a3f85ab5 Mon Sep 17 00:00:00 2001 From: Drew Sirenko <68304519+AndrewSirenko@users.noreply.github.com> Date: Wed, 4 Oct 2023 09:43:51 -0400 Subject: [PATCH] Refactor format option validation and tests --- pkg/driver/controller.go | 12 +-- pkg/driver/controller_test.go | 161 +++++++++++++++++++++++----------- pkg/driver/node.go | 17 ++-- pkg/util/util.go | 10 +++ pkg/util/util_test.go | 26 ++++++ 5 files changed, 161 insertions(+), 65 deletions(-) diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 1a911dc564..f2f5a37247 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -188,26 +188,22 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol blockExpress = true } case BlockSizeKey: - _, err = strconv.Atoi(value) - if err != nil { + if isAlphanumeric := util.StringIsAlphanumeric(value); !isAlphanumeric { return nil, status.Errorf(codes.InvalidArgument, "Could not parse blockSize (%s): %v", value, err) } blockSize = value case INodeSizeKey: - _, err = strconv.Atoi(value) - if err != nil { + if isAlphanumeric := util.StringIsAlphanumeric(value); !isAlphanumeric { return nil, status.Errorf(codes.InvalidArgument, "Could not parse inodeSize (%s): %v", value, err) } inodeSize = value case BytesPerINodeKey: - _, err = strconv.Atoi(value) - if err != nil { + if isAlphanumeric := util.StringIsAlphanumeric(value); !isAlphanumeric { return nil, status.Errorf(codes.InvalidArgument, "Could not parse bytesPerINode (%s): %v", value, err) } bytesPerINode = value case NumberOfINodesKey: - _, err = strconv.Atoi(value) - if err != nil { + if isAlphanumeric := util.StringIsAlphanumeric(value); !isAlphanumeric { return nil, status.Errorf(codes.InvalidArgument, "Could not parse numberOfINodes (%s): %v", value, err) } numberOfINodes = value diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 7a24de0a04..1f692042a8 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1634,81 +1634,144 @@ func TestCreateVolume(t *testing.T) { checkExpectedErrorCode(t, err, codes.AlreadyExists) }, }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + +func TestCreateVolumeWithFormattingParameters(t *testing.T) { + stdVolCap := []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + } + stdVolSize := int64(5 * 1024 * 1024 * 1024) + stdCapRange := &csi.CapacityRange{RequiredBytes: stdVolSize} + + testCases := []struct { + name string + formattingOptionParameters map[string]string + errExpected bool + }{ { name: "success with block size", - testFunc: func(t *testing.T) { - testSuccessWithParameter(t, BlockSizeKey, "4096", stdCapRange, stdVolCap, stdVolSize) + formattingOptionParameters: map[string]string{ + BlockSizeKey: "4096", }, + errExpected: false, }, { name: "success with inode size", - testFunc: func(t *testing.T) { - testSuccessWithParameter(t, INodeSizeKey, "256", stdCapRange, stdVolCap, stdVolSize) + formattingOptionParameters: map[string]string{ + INodeSizeKey: "256", }, + errExpected: false, }, { name: "success with bytes-per-inode", - testFunc: func(t *testing.T) { - testSuccessWithParameter(t, BytesPerINodeKey, "8192", stdCapRange, stdVolCap, stdVolSize) + formattingOptionParameters: map[string]string{ + BytesPerINodeKey: "8192", }, + errExpected: false, }, { name: "success with number-of-inodes", - testFunc: func(t *testing.T) { - testSuccessWithParameter(t, NumberOfINodesKey, "13107200", stdCapRange, stdVolCap, stdVolSize) + formattingOptionParameters: map[string]string{ + NumberOfINodesKey: "13107200", + }, + errExpected: false, + }, + { + name: "failure with block size", + formattingOptionParameters: map[string]string{ + BlockSizeKey: "wrong_value", + }, + errExpected: true, + }, + { + name: "failure with inode size", + formattingOptionParameters: map[string]string{ + INodeSizeKey: "wrong_value", + }, + errExpected: true, + }, + { + name: "failure with bytes-per-inode", + formattingOptionParameters: map[string]string{ + BytesPerINodeKey: "wrong_value", + }, + errExpected: true, + }, + { + name: "failure with number-of-inodes", + formattingOptionParameters: map[string]string{ + NumberOfINodesKey: "wrong_value", }, + errExpected: true, }, } - for _, tc := range testCases { - t.Run(tc.name, tc.testFunc) - } -} + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) -func testSuccessWithParameter(t *testing.T, key, value string, capRange *csi.CapacityRange, volCap []*csi.VolumeCapability, volSize int64) { - req := &csi.CreateVolumeRequest{ - Name: "random-vol-name", - CapacityRange: capRange, - VolumeCapabilities: volCap, - Parameters: map[string]string{key: value}, - } + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: tc.formattingOptionParameters, + } - ctx := context.Background() + ctx := context.Background() - mockDisk := &cloud.Disk{ - VolumeID: req.Name, - AvailabilityZone: expZone, - CapacityGiB: util.BytesToGiB(volSize), - } + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + CapacityGiB: util.BytesToGiB(stdVolSize), + } - mockCtl := gomock.NewController(t) - defer mockCtl.Finish() + mockCtl := gomock.NewController(t) - mockCloud := cloud.NewMockCloud(mockCtl) - mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + mockCloud := cloud.NewMockCloud(mockCtl) - awsDriver := controllerService{ - cloud: mockCloud, - inFlight: internal.NewInFlight(), - driverOptions: &DriverOptions{}, - } + // CreateDisk not called on Unhappy Case + if !tc.errExpected { + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + defer mockCtl.Finish() + } - response, err := awsDriver.CreateVolume(ctx, req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - t.Fatalf("Unexpected error: %v", srvErr.Code()) - } + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } - volCtx := response.Volume.VolumeContext - if sizeValue, ok := volCtx[key]; ok { - if sizeValue != value { - t.Fatalf("Invalid %s in VolumeContext (got %s expected %s)", key, sizeValue, value) - } - } else { - t.Fatalf("Missing key %s in VolumeContext", key) + response, err := awsDriver.CreateVolume(ctx, req) + + // Splits happy case tests from unhappy case tests + if !tc.errExpected { + assert.Nilf(err, "Unexpected error: %w", err) + + volCtx := response.Volume.VolumeContext + + for formattingParamKey, formattingParamValue := range tc.formattingOptionParameters { + createdFormattingParamValue, ok := volCtx[formattingParamKey] + assert.Truef(ok, "Missing key %s in VolumeContext", formattingParamKey) + + assert.Equalf(createdFormattingParamValue, formattingParamValue, "Invalid %s in VolumeContext", formattingParamKey) + } + } else { + assert.NotNilf(err, "CreateVolume did not return an error") + + checkExpectedErrorCode(t, err, codes.InvalidArgument) + } + }) } } diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 894161b4b6..39278ead86 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -22,7 +22,6 @@ import ( "fmt" "os" "path/filepath" - "strconv" "strings" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -158,19 +157,22 @@ func (d *nodeService) NodeStageVolume(ctx context.Context, req *csi.NodeStageVol context := req.GetVolumeContext() - blockSize, err := recheckParameter(context, BlockSizeKey, FileSystemConfigs, fsType) + blockSize, err := recheckFormattingOptionParameter(context, BlockSizeKey, FileSystemConfigs, fsType) if err != nil { return nil, err } - inodeSize, err := recheckParameter(context, INodeSizeKey, FileSystemConfigs, fsType) + inodeSize, err := recheckFormattingOptionParameter(context, INodeSizeKey, FileSystemConfigs, fsType) if err != nil { return nil, err } - bytesPerINode, err := recheckParameter(context, BytesPerINodeKey, FileSystemConfigs, fsType) + bytesPerINode, err := recheckFormattingOptionParameter(context, BytesPerINodeKey, FileSystemConfigs, fsType) + if err != nil { + return nil, err + } + numINodes, err := recheckFormattingOptionParameter(context, NumberOfINodesKey, FileSystemConfigs, fsType) if err != nil { return nil, err } - numINodes, err := recheckParameter(context, NumberOfINodesKey, FileSystemConfigs, fsType) if err != nil { return nil, err } @@ -897,13 +899,12 @@ func removeNotReadyTaint(k8sClient cloud.KubernetesAPIClient) error { return nil } -func recheckParameter(context map[string]string, key string, fsConfigs map[string]fileSystemConfig, fsType string) (value string, err error) { +func recheckFormattingOptionParameter(context map[string]string, key string, fsConfigs map[string]fileSystemConfig, fsType string) (value string, err error) { v, ok := context[key] if ok { // This check is already performed on the controller side // However, because it is potentially security-sensitive, we redo it here to be safe - _, err := strconv.Atoi(v) - if err != nil { + if isAlphanumeric := util.StringIsAlphanumeric(value); !isAlphanumeric { return "", status.Errorf(codes.InvalidArgument, "Invalid %s (aborting!): %v", key, err) } diff --git a/pkg/util/util.go b/pkg/util/util.go index 2dc6b25a3c..24e4433280 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -21,6 +21,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -30,6 +31,10 @@ const ( GiB = 1024 * 1024 * 1024 ) +var ( + isAlphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]*$`).MatchString +) + // RoundUpBytes rounds up the volume size in bytes upto multiplications of GiB // in the unit of Bytes func RoundUpBytes(volumeSizeBytes int64) int64 { @@ -93,3 +98,8 @@ func GetAccessModes(caps []*csi.VolumeCapability) *[]string { func IsSBE(region string) bool { return region == "snow" } + +// StringIsAlphanumeric returns true if a given string contains only English letters or numbers +func StringIsAlphanumeric(s string) bool { + return isAlphanumericRegex(s) +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 7715491ba9..1dfe707a56 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -21,6 +21,7 @@ package util import ( "fmt" + "github.com/stretchr/testify/assert" "reflect" "testing" @@ -154,3 +155,28 @@ func TestGetAccessModes(t *testing.T) { t.Fatalf("Wrong values returned for volume capabilities. Expected %v, got %v", expectedModes, actualModes) } } + +func TestIsAlphanumeric(t *testing.T) { + testCases := []struct { + name string + testString string + expResult bool + }{ + { + name: "success with alphanumeric", + testString: "4Kib", + expResult: true, + }, + { + name: "failure with non-alphanumeric", + testString: "space 4Kib", + expResult: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := StringIsAlphanumeric(tc.testString) + assert.Equalf(t, res, tc.expResult, "Wrong value returned for StringIsAlphanumeric. Expected %s for string %s, got %s", tc.expResult, tc.testString, res) + }) + } +}