Skip to content

Commit

Permalink
Batch EC2 DescribeInstances calls
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewSirenko committed Feb 28, 2024
1 parent 6f67b3d commit f320c1a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 46 deletions.
153 changes: 108 additions & 45 deletions pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ const (
const (
volumeIDBatcher batcherType = iota
volumeTagBatcher

batchDescribeTimeout = 30 * time.Second
)

var (
Expand Down Expand Up @@ -245,7 +247,9 @@ type batcherType int

// batcherManager maintains a collection of batchers for different types of tasks.
type batcherManager struct {
batchers map[batcherType]*batcher.Batcher[string, *ec2.Volume]
volumeIDBatcher *batcher.Batcher[string, *ec2.Volume]
volumeTagBatcher *batcher.Batcher[string, *ec2.Volume]
instanceIDBatcher *batcher.Batcher[string, *ec2.Instance]
}

type cloud struct {
Expand All @@ -260,25 +264,15 @@ var _ Cloud = &cloud{}
// NewCloud returns a new instance of AWS cloud
// It panics if session is invalid
func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (Cloud, error) {
c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra)

if batching {
klog.V(4).InfoS("NewCloud: batching enabled")
cloudInstance, ok := c.(*cloud)
if !ok {
return nil, fmt.Errorf("expected *cloud type but got %T", c)
}
cloudInstance.bm = newBatcherManager(cloudInstance.ec2)
}

c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra, batching)
return c, nil
}

func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Cloud {
func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string, batchingEnabled bool) Cloud {
awsConfig := &aws.Config{
Region: aws.String(region),
CredentialsChainVerboseErrors: aws.Bool(true),
// Set MaxRetries to a high value. It will be "ovewritten" if context deadline comes sooner.
// Set MaxRetries to a high value. It will be "overwritten" if context deadline comes sooner.
MaxRetries: aws.Int(8),
}

Expand Down Expand Up @@ -317,33 +311,36 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Clou
Fn: RecordRequestsHandler,
})

var bm *batcherManager
if batchingEnabled {
klog.V(4).InfoS("newEC2Cloud: batching enabled")
bm = newBatcherManager(svc)
}

return &cloud{
region: region,
dm: dm.NewDeviceManager(),
ec2: svc,
bm: bm,
}
}

// newBatcherManager initializes a new instance of batcherManager.
func newBatcherManager(svc ec2iface.EC2API) *batcherManager {
return &batcherManager{
batchers: map[batcherType]*batcher.Batcher[string, *ec2.Volume]{
volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) {
return execBatchDescribeVolumes(svc, ids, volumeIDBatcher)
}),
volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) {
return execBatchDescribeVolumes(svc, names, volumeTagBatcher)
}),
},
volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) {
return execBatchDescribeVolumes(svc, ids, volumeIDBatcher)
}),
volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) {
return execBatchDescribeVolumes(svc, names, volumeTagBatcher)
}),
instanceIDBatcher: batcher.New(50, 500*time.Millisecond, func(names []string) (map[string]*ec2.Instance, error) {
return execBatchDescribeInstances(svc, names)
}),
}
}

// getBatcher fetches a specific type of batcher from the batcherManager.
func (bm *batcherManager) getBatcher(b batcherType) *batcher.Batcher[string, *ec2.Volume] {
return bm.batchers[b]
}

// executes a batched DescribeVolumes API call depending on the type of batcher.
// execBatchDescribeVolumes executes a batched DescribeVolumes API call depending on the type of batcher.
func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) {
var request *ec2.DescribeVolumesInput

Expand All @@ -370,7 +367,7 @@ func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batch
return nil, fmt.Errorf("execBatchDescribeVolumes: unsupported request type")
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout)
defer cancel()

resp, err := describeVolumes(ctx, svc, request)
Expand All @@ -396,16 +393,16 @@ func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batch
// batchDescribeVolumes processes a DescribeVolumes request. Depending on the request,
// it determines the appropriate batcher to use, queues the task, and waits for the result.
func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Volume, error) {
var bType batcherType
var b *batcher.Batcher[string, *ec2.Volume]
var task string

switch {
case len(request.VolumeIds) == 1 && request.VolumeIds[0] != nil:
bType = volumeIDBatcher
b = c.bm.volumeIDBatcher
task = *request.VolumeIds[0]

case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+VolumeNameTagKey && len(request.Filters[0].Values) == 1:
bType = volumeTagBatcher
b = c.bm.volumeTagBatcher
task = *request.Filters[0].Values[0]

default:
Expand All @@ -414,7 +411,6 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo

ch := make(chan batcher.BatchResult[*ec2.Volume])

b := c.bm.getBatcher(bType)
b.AddTask(task, ch)

r := <-ch
Expand All @@ -423,7 +419,7 @@ func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Vo
return nil, r.Err
}
if r.Result == nil {
return nil, fmt.Errorf("batchDescribeVolumes: no volume found %s", task)
return nil, ErrNotFound
}
return r.Result, nil
}
Expand Down Expand Up @@ -684,6 +680,61 @@ func (c *cloud) DeleteDisk(ctx context.Context, volumeID string) (bool, error) {
return true, nil
}

// executes a batched DescribeInstances API call
func execBatchDescribeInstances(svc ec2iface.EC2API, input []string) (map[string]*ec2.Instance, error) {
klog.V(7).InfoS("execBatchDescribeInstances", "instanceIds", input)
request := &ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice(input),
}

ctx, cancel := context.WithTimeout(context.Background(), batchDescribeTimeout)
defer cancel()

resp, err := describeInstances(ctx, svc, request)
if err != nil {
return nil, err
}

result := make(map[string]*ec2.Instance)

for _, instance := range resp {
if instance.InstanceId == nil {
klog.Warningf("execBatchDescribeInstances: skipping instance: %v, reason: missing instance ID", instance)
continue
}
result[*instance.InstanceId] = instance
}

klog.V(7).InfoS("execBatchDescribeInstances: success", "result", result)
return result, nil
}

// batchDescribeInstances processes a DescribeInstances request by queuing the task and waiting for the result.
func (c *cloud) batchDescribeInstances(request *ec2.DescribeInstancesInput) (*ec2.Instance, error) {
var task string

if len(request.InstanceIds) == 1 && request.InstanceIds[0] != nil {
task = *request.InstanceIds[0]
} else {
return nil, fmt.Errorf("batchDescribeInstances: invalid request, request: %v", request)
}

ch := make(chan batcher.BatchResult[*ec2.Instance])

b := c.bm.instanceIDBatcher
b.AddTask(task, ch)

r := <-ch

if r.Err != nil {
return nil, r.Err
}
if r.Result == nil {
return nil, ErrNotFound
}
return r.Result, nil
}

// Node likely bad device names cache
// Remember device names that are already in use on an instance and use them last when attaching volumes
// This works around device names that are used but do not appear in the mapping from DescribeInstanceStatus
Expand Down Expand Up @@ -1166,15 +1217,11 @@ func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput
}
}

func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) {
func describeInstances(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) {
instances := []*ec2.Instance{}
request := &ec2.DescribeInstancesInput{
InstanceIds: []*string{&nodeID},
}

var nextToken *string
for {
response, err := c.ec2.DescribeInstancesWithContext(ctx, request)
response, err := svc.DescribeInstancesWithContext(ctx, request)
if err != nil {
if isAWSErrorInstanceNotFound(err) {
return nil, ErrNotFound
Expand All @@ -1192,14 +1239,30 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance,
}
request.NextToken = nextToken
}
return instances, nil
}

if l := len(instances); l > 1 {
return nil, fmt.Errorf("found %d instances with ID %q", l, nodeID)
} else if l < 1 {
return nil, ErrNotFound
func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) {
request := &ec2.DescribeInstancesInput{
InstanceIds: []*string{&nodeID},
}

return instances[0], nil
if c.bm == nil {
instances, err := describeInstances(ctx, c.ec2, request)
if err != nil {
return nil, err
}

if l := len(instances); l > 1 {
return nil, fmt.Errorf("found %d instances with ID %q", l, nodeID)
} else if l < 1 {
return nil, ErrNotFound
}

return instances[0], nil
} else {
return c.batchDescribeInstances(request)
}
}

func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) {
Expand Down
108 changes: 107 additions & 1 deletion pkg/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,113 @@ func executeDescribeVolumesTest(t *testing.T, c *cloud, volumeIDs, volumeNames [
}
resultCh <- volume
// passing `request` as a parameter to create a copy
// TODO remove after https://github.com/golang/go/discussions/56010 is implemented
// TODO remove after upgrading to go v1.22 (see https://github.com/golang/go/discussions/56010)
}(request, r[i], e[i])
}

wg.Wait()

for i := range requests {
select {
case result := <-r[i]:
if result == nil {
t.Errorf("Received nil result for a request")
}
case err := <-e[i]:
if expErr == nil {
t.Errorf("Error while processing request: %v", err)
}
if !errors.Is(err, expErr) {
t.Errorf("Expected error %v, but got %v", expErr, err)
}
default:
t.Errorf("Did not receive a result or an error for a request")
}
}
}

func TestBatchDescribeInstances(t *testing.T) {
testCases := []struct {
name string
instanceIds []string
mockFunc func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation)
expErr error
}{
{
name: "TestBatchDescribeInstances: instance by ID",
instanceIds: []string{"i-001", "i-002", "i-003"},
mockFunc: func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation) {
reservationOutput := &ec2.DescribeInstancesOutput{Reservations: reservations}
mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), gomock.Any()).Return(reservationOutput, expErr).Times(1)
},
expErr: nil,
},
{
name: "TestBatchDescribeInstances: EC2 API generic error",
instanceIds: []string{"i-001", "i-002", "i-003"},
mockFunc: func(mockEC2 *MockEC2API, expErr error, reservations []*ec2.Reservation) {
mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1)
},
expErr: fmt.Errorf("generic EC2 API error"),
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockEC2 := NewMockEC2API(mockCtrl)
c := newCloud(mockEC2)
cloudInstance := c.(*cloud)
cloudInstance.bm = newBatcherManager(cloudInstance.ec2)

// Setup mocks
var instances []*ec2.Instance
for _, instanceId := range tc.instanceIds {
instances = append(instances, &ec2.Instance{InstanceId: aws.String(instanceId)})
}
reservation := &ec2.Reservation{Instances: instances}
reservations := []*ec2.Reservation{reservation}
tc.mockFunc(mockEC2, tc.expErr, reservations)

executeDescribeInstancesTest(t, cloudInstance, tc.instanceIds, tc.expErr)
})
}
}

func executeDescribeInstancesTest(t *testing.T, c *cloud, instanceIds []string, expErr error) {
var wg sync.WaitGroup

getRequestForID := func(id string) *ec2.DescribeInstancesInput {
return &ec2.DescribeInstancesInput{InstanceIds: []*string{&id}}
}

requests := make([]*ec2.DescribeInstancesInput, 0, len(instanceIds))
for _, instanceID := range instanceIds {
requests = append(requests, getRequestForID(instanceID))
}

r := make([]chan *ec2.Instance, len(requests))
e := make([]chan error, len(requests))

for i, request := range requests {
wg.Add(1)
r[i] = make(chan *ec2.Instance, 1)
e[i] = make(chan error, 1)

go func(req *ec2.DescribeInstancesInput, resultCh chan *ec2.Instance, errCh chan error) {
defer wg.Done()
instance, err := c.batchDescribeInstances(req)
if err != nil {
errCh <- err
return
}
resultCh <- instance
// passing `request` as a parameter to create a copy
// TODO remove after upgrading to go v1.22 (see https://github.com/golang/go/discussions/56010)
}(request, r[i], e[i])
}

Expand Down

0 comments on commit f320c1a

Please sign in to comment.