From ee19d850454e95f6a19398b72ab9f6cf78f4020f Mon Sep 17 00:00:00 2001 From: Drew Sirenko <68304519+AndrewSirenko@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:26:56 +0000 Subject: [PATCH] Batch EC2 DescribeInstances calls --- pkg/cloud/cloud.go | 153 ++++++++++++++++++++++++++++------------ pkg/cloud/cloud_test.go | 108 +++++++++++++++++++++++++++- 2 files changed, 215 insertions(+), 46 deletions(-) diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 751cb78a18..c19df71b13 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -143,6 +143,8 @@ const ( const ( volumeIDBatcher batcherType = iota volumeTagBatcher + + batchDescribeTimeout = 30 * time.Second ) var ( @@ -245,7 +247,9 @@ type batcherType int // batcherManager maintains a collection of batchers for different types of tasks. type batcherManager struct { - batchers map[batcherType]*batcher.Batcher[string, *ec2.Volume] + volumeIDBatcher *batcher.Batcher[string, *ec2.Volume] + volumeTagBatcher *batcher.Batcher[string, *ec2.Volume] + instanceIDBatcher *batcher.Batcher[string, *ec2.Instance] } type cloud struct { @@ -260,25 +264,15 @@ var _ Cloud = &cloud{} // NewCloud returns a new instance of AWS cloud // It panics if session is invalid func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (Cloud, error) { - c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra) - - if batching { - klog.V(4).InfoS("NewCloud: batching enabled") - cloudInstance, ok := c.(*cloud) - if !ok { - return nil, fmt.Errorf("expected *cloud type but got %T", c) - } - cloudInstance.bm = newBatcherManager(cloudInstance.ec2) - } - + c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra, batching) return c, nil } -func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Cloud { +func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, isBatchingEnabled bool) Cloud { awsConfig := &aws.Config{ Region: aws.String(region), CredentialsChainVerboseErrors: aws.Bool(true), - // Set MaxRetries to a high value. It will be "ovewritten" if context deadline comes sooner. + // Set MaxRetries to a high value. It will be "overwritten" if context deadline comes sooner. MaxRetries: aws.Int(8), } @@ -317,33 +311,36 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Clou Fn: RecordRequestsHandler, }) + var bm *batcherManager + if isBatchingEnabled { + klog.V(4).InfoS("newEC2Cloud: batching enabled") + bm = newBatcherManager(svc) + } + return &cloud{ region: region, dm: dm.NewDeviceManager(), ec2: svc, + bm: bm, } } // newBatcherManager initializes a new instance of batcherManager. func newBatcherManager(svc ec2iface.EC2API) *batcherManager { return &batcherManager{ - batchers: map[batcherType]*batcher.Batcher[string, *ec2.Volume]{ - volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) { - return execBatchDescribeVolumes(svc, ids, volumeIDBatcher) - }), - volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) { - return execBatchDescribeVolumes(svc, names, volumeTagBatcher) - }), - }, + volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, ids, volumeIDBatcher) + }), + volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, names, volumeTagBatcher) + }), + instanceIDBatcher: batcher.New(50, 500*time.Millisecond, func(names []string) (map[string]*ec2.Instance, error) { + return execBatchDescribeInstances(svc, names) + }), } } -// getBatcher fetches a specific type of batcher from the batcherManager. -func (bm *batcherManager) getBatcher(b batcherType) *batcher.Batcher[string, *ec2.Volume] { - return bm.batchers[b] -} - -// executes a batched DescribeVolumes API call depending on the type of batcher. +// execBatchDescribeVolumes executes a batched DescribeVolumes API call depending on the type of batcher. func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) { var request *ec2.DescribeVolumesInput @@ -370,7 +367,7 @@ func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batch return nil, fmt.Errorf("execBatchDescribeVolumes: unsupported request type") } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout) defer cancel() resp, err := describeVolumes(ctx, svc, request) @@ -396,16 +393,16 @@ func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batch // batchDescribeVolumes processes a DescribeVolumes request. Depending on the request, // it determines the appropriate batcher to use, queues the task, and waits for the result. func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { - var bType batcherType + var b *batcher.Batcher[string, *ec2.Volume] var task string switch { case len(request.VolumeIds) == 1 && request.VolumeIds[0] != nil: - bType = volumeIDBatcher + b = c.bm.volumeIDBatcher task = *request.VolumeIds[0] case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+VolumeNameTagKey && len(request.Filters[0].Values) == 1: - bType = volumeTagBatcher + b = c.bm.volumeTagBatcher task = *request.Filters[0].Values[0] default: @@ -414,7 +411,6 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo ch := make(chan batcher.BatchResult[*ec2.Volume]) - b := c.bm.getBatcher(bType) b.AddTask(task, ch) r := <-ch @@ -423,7 +419,7 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo return nil, r.Err } if r.Result == nil { - return nil, fmt.Errorf("batchDescribeVolumes: no volume found %s", task) + return nil, ErrNotFound } return r.Result, nil } @@ -684,6 +680,61 @@ func (c *cloud) DeleteDisk(ctx context.Context, volumeID string) (bool, error) { return true, nil } +// executes a batched DescribeInstances API call +func execBatchDescribeInstances(svc ec2iface.EC2API, input []string) (map[string]*ec2.Instance, error) { + klog.V(7).InfoS("execBatchDescribeInstances", "instanceIds", input) + request := &ec2.DescribeInstancesInput{ + InstanceIds: aws.StringSlice(input), + } + + ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout) + defer cancel() + + resp, err := describeInstances(ctx, svc, request) + if err != nil { + return nil, err + } + + result := make(map[string]*ec2.Instance) + + for _, instance := range resp { + if instance.InstanceId == nil { + klog.Warningf("execBatchDescribeInstances: skipping instance: %v, reason: missing instance ID", instance) + continue + } + result[*instance.InstanceId] = instance + } + + klog.V(7).InfoS("execBatchDescribeInstances: success", "result", result) + return result, nil +} + +// batchDescribeInstances processes a DescribeInstances request by queuing the task and waiting for the result. +func (c *cloud) batchDescribeInstances(request *ec2.DescribeInstancesInput) (*ec2.Instance, error) { + var task string + + if len(request.InstanceIds) == 1 && request.InstanceIds[0] != nil { + task = *request.InstanceIds[0] + } else { + return nil, fmt.Errorf("batchDescribeInstances: invalid request, request: %v", request) + } + + ch := make(chan batcher.BatchResult[*ec2.Instance]) + + b := c.bm.instanceIDBatcher + b.AddTask(task, ch) + + r := <-ch + + if r.Err != nil { + return nil, r.Err + } + if r.Result == nil { + return nil, ErrNotFound + } + return r.Result, nil +} + // Node likely bad device names cache // Remember device names that are already in use on an instance and use them last when attaching volumes // This works around device names that are used but do not appear in the mapping from DescribeInstanceStatus @@ -1166,15 +1217,11 @@ func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput } } -func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) { +func describeInstances(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) { instances := []*ec2.Instance{} - request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{&nodeID}, - } - var nextToken *string for { - response, err := c.ec2.DescribeInstancesWithContext(ctx, request) + response, err := svc.DescribeInstancesWithContext(ctx, request) if err != nil { if isAWSErrorInstanceNotFound(err) { return nil, ErrNotFound @@ -1192,14 +1239,30 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, } request.NextToken = nextToken } + return instances, nil +} - if l := len(instances); l > 1 { - return nil, fmt.Errorf("found %d instances with ID %q", l, nodeID) - } else if l < 1 { - return nil, ErrNotFound +func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) { + request := &ec2.DescribeInstancesInput{ + InstanceIds: []*string{&nodeID}, } - return instances[0], nil + if c.bm == nil { + instances, err := describeInstances(ctx, c.ec2, request) + if err != nil { + return nil, err + } + + if l := len(instances); l > 1 { + return nil, fmt.Errorf("found %d instances with ID %q", l, nodeID) + } else if l < 1 { + return nil, ErrNotFound + } + + return instances[0], nil + } else { + return c.batchDescribeInstances(request) + } } func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 9c26bb3278..12382147ab 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -223,7 +223,113 @@ func executeDescribeVolumesTest(t *testing.T, c *cloud, volumeIDs, volumeNames [ } resultCh <- volume // passing `request` as a parameter to create a copy - // TODO remove after https://github.com/golang/go/discussions/56010 is implemented + // TODO remove after upgrading to go v1.22 (see https://github.com/golang/go/discussions/56010) + }(request, r[i], e[i]) + } + + wg.Wait() + + for i := range requests { + select { + case result := <-r[i]: + if result == nil { + t.Errorf("Received nil result for a request") + } + case err := <-e[i]: + if expErr == nil { + t.Errorf("Error while processing request: %v", err) + } + if !errors.Is(err, expErr) { + t.Errorf("Expected error %v, but got %v", expErr, err) + } + default: + t.Errorf("Did not receive a result or an error for a request") + } + } +} + +func TestBatchDescribeInstances(t *testing.T) { + testCases := []struct { + name string + instanceIds []string + mockFunc func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation) + expErr error + }{ + { + name: "TestBatchDescribeInstances: instance by ID", + instanceIds: []string{"i-001", "i-002", "i-003"}, + mockFunc: func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation) { + reservationOutput := &ec2.DescribeInstancesOutput{Reservations: reservations} + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), gomock.Any()).Return(reservationOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeInstances: EC2 API generic error", + instanceIds: []string{"i-001", "i-002", "i-003"}, + mockFunc: func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation) { + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1) + }, + expErr: fmt.Errorf("generic EC2 API error"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + cloudInstance := c.(*cloud) + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + + // Setup mocks + var instances []*ec2.Instance + for _, instanceId := range tc.instanceIds { + instances = append(instances, &ec2.Instance{InstanceId: aws.String(instanceId)}) + } + reservation := &ec2.Reservation{Instances: instances} + reservations := []*ec2.Reservation{reservation} + tc.mockFunc(mockEC2, tc.expErr, reservations) + + executeDescribeInstancesTest(t, cloudInstance, tc.instanceIds, tc.expErr) + }) + } +} + +func executeDescribeInstancesTest(t *testing.T, c *cloud, instanceIds []string, expErr error) { + var wg sync.WaitGroup + + getRequestForID := func(id string) *ec2.DescribeInstancesInput { + return &ec2.DescribeInstancesInput{InstanceIds: []*string{&id}} + } + + requests := make([]*ec2.DescribeInstancesInput, 0, len(instanceIds)) + for _, instanceID := range instanceIds { + requests = append(requests, getRequestForID(instanceID)) + } + + r := make([]chan *ec2.Instance, len(requests)) + e := make([]chan error, len(requests)) + + for i, request := range requests { + wg.Add(1) + r[i] = make(chan *ec2.Instance, 1) + e[i] = make(chan error, 1) + + go func(req *ec2.DescribeInstancesInput, resultCh chan *ec2.Instance, errCh chan error) { + defer wg.Done() + instance, err := c.batchDescribeInstances(req) + if err != nil { + errCh <- err + return + } + resultCh <- instance + // passing `request` as a parameter to create a copy + // TODO remove after upgrading to go v1.22 (see https://github.com/golang/go/discussions/56010) }(request, r[i], e[i]) }