From 59788192c455a3864655dd79ccf350319c44f76c Mon Sep 17 00:00:00 2001 From: "Lau, Luke" Date: Mon, 16 Dec 2024 15:21:54 -0500 Subject: [PATCH] refactor changes and add UT --- service/controller.go | 81 ++++++++---- service/controller_test.go | 263 +++++++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+), 28 deletions(-) create mode 100644 service/controller_test.go diff --git a/service/controller.go b/service/controller.go index ee178631..7f37d402 100644 --- a/service/controller.go +++ b/service/controller.go @@ -2375,16 +2375,9 @@ func (s *service) GetCapacity( // If using availability zones, get capacity for the system in the zone // using accessible topology parameter from k8s. if s.opts.zoneLabelKey != "" { - zoneLabel, ok := req.AccessibleTopology.Segments[s.opts.zoneLabelKey] - if !ok { - Log.Infof("could not get availability zone from accessible topology. Getting capacity for all systems") - } else { - for _, array := range s.opts.arrays { - if zoneLabel == string(array.AvailabilityZone.Name) { - systemID = array.SystemID - break - } - } + systemID, err = s.getSystemIDFromZoneLabelKey(req) + if err != nil { + return nil, status.Errorf(codes.Internal, "%s", err.Error()) } } @@ -2430,6 +2423,28 @@ func (s *service) GetCapacity( }, nil } +// getSystemIDFromZoneLabelKey returns the system ID associated with the zoneLabelKey if zoneLabelKey is set and +// contains an associated zone name. Returns an empty string otherwise. +func (s *service) getSystemIDFromZoneLabelKey(req *csi.GetCapacityRequest) (systemID string, err error) { + zoneName, ok := req.AccessibleTopology.Segments[s.opts.zoneLabelKey] + if !ok { + Log.Infof("could not get availability zone from accessible topology. Getting capacity for all systems") + return "", nil + } + + // find the systemID with the matching zone name + for _, array := range s.opts.arrays { + if zoneName == string(array.AvailabilityZone.Name) { + systemID = array.SystemID + break + } + } + if systemID == "" { + return "", fmt.Errorf("could not find an array assigned to zone '%s'", zoneName) + } + return systemID, nil +} + func (s *service) getMaximumVolumeSize(systemID string) (int64, error) { valueInCache, found := getCachedMaximumVolumeSize(systemID) if !found || valueInCache < 0 { @@ -2563,36 +2578,46 @@ func (s *service) ControllerGetCapabilities( }, nil } +func (s *service) getZoneFromZoneLabelKey(ctx context.Context, zoneLabelKey string) (zone string, err error) { + if zoneLabelKey == "" { + return "", nil + } + + labels, err := GetNodeLabels(ctx, s) + if err != nil { + return "", err + } + + Log.Infof("Listing labels: %v", labels) + + if val, ok := labels[zoneLabelKey]; ok { + Log.Infof("probing zoneLabel %s, zone value: %s", zoneLabelKey, val) + return val, nil + } + + return "", fmt.Errorf("label %s not found", zoneLabelKey) +} + // systemProbeAll will iterate through all arrays in service.opts.arrays and probe them. If failed, it logs // the failed system name -func (s *service) systemProbeAll(ctx context.Context, zoneLabel string) error { +func (s *service) systemProbeAll(ctx context.Context, zoneLabelKey string) error { // probe all arrays // Log.Infof("Probing all arrays. Number of arrays: %d", len(s.opts.arrays)) Log.Infoln("Probing all associated arrays") allArrayFail := true errMap := make(map[string]error) - zone := "" - if zoneLabel != "" { - labels, err := GetNodeLabels(ctx, s) - if err != nil { - return err - } - - Log.Infof("Listing labels: %v", labels) - - if val, ok := labels[zoneLabel]; ok { - Log.Infof("probing zoneLabel %s, zone value: %s", zoneLabel, val) - zone = val - } else { - return fmt.Errorf("label %s not found", zoneLabel) - } + zoneName, err := s.getZoneFromZoneLabelKey(ctx, zoneLabelKey) + if err != nil { + return err } for _, array := range s.opts.arrays { // If zone information is available, use it to probe the array - if strings.EqualFold(s.mode, "node") && array.AvailabilityZone != nil && array.AvailabilityZone.Name != ZoneName(zone) { - Log.Warnf("array %s zone %s does not match %s, not pinging this array\n", array.SystemID, array.AvailabilityZone.Name, zone) + if strings.EqualFold(s.mode, "node") && array.AvailabilityZone != nil && array.AvailabilityZone.Name != ZoneName(zoneName) { + // Driver node containers should not probe arrays that exist outside their assigned zone + // Driver controller container should probe all arrays + Log.Warnf("array %s zone %s does not match %s, not pinging this array\n", array.SystemID, array.AvailabilityZone.Name, zoneName) continue } diff --git a/service/controller_test.go b/service/controller_test.go new file mode 100644 index 00000000..d4d532f4 --- /dev/null +++ b/service/controller_test.go @@ -0,0 +1,263 @@ +// Copyright © 2019-2024 Dell Inc. or its subsidiaries. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "errors" + "sync" + "testing" + + csi "github.com/container-storage-interface/spec/lib/go/csi" + sio "github.com/dell/goscaleio" + siotypes "github.com/dell/goscaleio/types/v1" + "golang.org/x/net/context" +) + +func Test_service_getZoneFromZoneLabelKey(t *testing.T) { + type fields struct { + opts Opts + adminClients map[string]*sio.Client + systems map[string]*sio.System + mode string + volCache []*siotypes.Volume + volCacheRWL sync.RWMutex + volCacheSystemID string + snapCache []*siotypes.Volume + snapCacheRWL sync.RWMutex + snapCacheSystemID string + privDir string + storagePoolIDToName map[string]string + statisticsCounter int + volumePrefixToSystems map[string][]string + connectedSystemNameToID map[string]string + } + type args struct { + ctx context.Context + zoneLabelKey string + } + tests := []struct { + name string + fields fields + args args + wantZone string + wantErr bool + getNodeLabelFunc func(ctx context.Context, s *service) (map[string]string, error) + }{ + { + name: "get good zone label", + args: args{ + ctx: context.Background(), + zoneLabelKey: "topology.kubernetes.io/zone", + }, + wantZone: "zoneA", + wantErr: false, + getNodeLabelFunc: func(ctx context.Context, s *service) (map[string]string, error) { + nodeLabels := map[string]string{"topology.kubernetes.io/zone": "zoneA"} + return nodeLabels, nil + }, + }, + { + name: "use bad zone label key", + args: args{ + ctx: context.Background(), + zoneLabelKey: "badkey", + }, + wantZone: "", + wantErr: true, + getNodeLabelFunc: func(ctx context.Context, s *service) (map[string]string, error) { + return nil, nil + }, + }, + { + name: "fail to get node labels", + args: args{ + ctx: context.Background(), + zoneLabelKey: "unimportant", + }, + wantZone: "", + wantErr: true, + getNodeLabelFunc: func(ctx context.Context, s *service) (map[string]string, error) { + return nil, errors.New("") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &service{ + opts: tt.fields.opts, + adminClients: tt.fields.adminClients, + systems: tt.fields.systems, + mode: tt.fields.mode, + volCache: tt.fields.volCache, + volCacheRWL: tt.fields.volCacheRWL, + volCacheSystemID: tt.fields.volCacheSystemID, + snapCache: tt.fields.snapCache, + snapCacheRWL: tt.fields.snapCacheRWL, + snapCacheSystemID: tt.fields.snapCacheSystemID, + privDir: tt.fields.privDir, + storagePoolIDToName: tt.fields.storagePoolIDToName, + statisticsCounter: tt.fields.statisticsCounter, + volumePrefixToSystems: tt.fields.volumePrefixToSystems, + connectedSystemNameToID: tt.fields.connectedSystemNameToID, + } + GetNodeLabels = tt.getNodeLabelFunc + gotZone, err := s.getZoneFromZoneLabelKey(tt.args.ctx, tt.args.zoneLabelKey) + if (err != nil) != tt.wantErr { + t.Errorf("service.getZoneFromZoneLabelKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotZone != tt.wantZone { + t.Errorf("service.getZoneFromZoneLabelKey() = %v, want %v", gotZone, tt.wantZone) + } + }) + } +} + +func Test_service_getSystemIDFromZoneLabelKey(t *testing.T) { + type fields struct { + opts Opts + adminClients map[string]*sio.Client + systems map[string]*sio.System + mode string + volCache []*siotypes.Volume + volCacheRWL sync.RWMutex + volCacheSystemID string + snapCache []*siotypes.Volume + snapCacheRWL sync.RWMutex + snapCacheSystemID string + privDir string + storagePoolIDToName map[string]string + statisticsCounter int + volumePrefixToSystems map[string][]string + connectedSystemNameToID map[string]string + } + type args struct { + req *csi.GetCapacityRequest + } + const validSystemID = "valid-id" + const validTopologyKey = "topology.kubernetes.io/zone" + const validZone = "zoneA" + + tests := []struct { + name string + fields fields + args args + wantSystemID string + wantErr bool + }{ + { + name: "get a valid system ID", + wantErr: false, + wantSystemID: validSystemID, + args: args{ + req: &csi.GetCapacityRequest{ + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + validTopologyKey: validZone, + }, + }, + }, + }, + fields: fields{ + opts: Opts{ + zoneLabelKey: "topology.kubernetes.io/zone", + arrays: map[string]*ArrayConnectionData{ + "array1": { + SystemID: validSystemID, + AvailabilityZone: &AvailabilityZone{ + Name: validZone, + }, + }, + }, + }, + }, + }, + { + name: "topology not passed with csi request", + wantErr: false, + wantSystemID: "", + args: args{ + req: &csi.GetCapacityRequest{ + AccessibleTopology: &csi.Topology{ + // don't pass any topology info with the request + Segments: map[string]string{}, + }, + }, + }, + fields: fields{ + opts: Opts{ + zoneLabelKey: "topology.kubernetes.io/zone", + }, + }, + }, + { + name: "zone name missing in secret", + wantErr: true, + wantSystemID: "", + args: args{ + req: &csi.GetCapacityRequest{ + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + validTopologyKey: validZone, + }, + }, + }, + }, + fields: fields{ + opts: Opts{ + zoneLabelKey: "topology.kubernetes.io/zone", + arrays: map[string]*ArrayConnectionData{ + "array1": { + SystemID: validSystemID, + AvailabilityZone: &AvailabilityZone{ + // ensure the zone name will not match the topology key value + // in the request + Name: validZone + "no-match", + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &service{ + opts: tt.fields.opts, + adminClients: tt.fields.adminClients, + systems: tt.fields.systems, + mode: tt.fields.mode, + volCache: tt.fields.volCache, + volCacheRWL: tt.fields.volCacheRWL, + volCacheSystemID: tt.fields.volCacheSystemID, + snapCache: tt.fields.snapCache, + snapCacheRWL: tt.fields.snapCacheRWL, + snapCacheSystemID: tt.fields.snapCacheSystemID, + privDir: tt.fields.privDir, + storagePoolIDToName: tt.fields.storagePoolIDToName, + statisticsCounter: tt.fields.statisticsCounter, + volumePrefixToSystems: tt.fields.volumePrefixToSystems, + connectedSystemNameToID: tt.fields.connectedSystemNameToID, + } + gotSystemID, err := s.getSystemIDFromZoneLabelKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("service.getSystemIDFromZoneLabelKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotSystemID != tt.wantSystemID { + t.Errorf("service.getSystemIDFromZoneLabelKey() = %v, want %v", gotSystemID, tt.wantSystemID) + } + }) + } +}