diff --git a/internal/capacity/autoscalinggroup_test.go b/internal/capacity/autoscalinggroup_test.go index 02c0241..10cddb8 100644 --- a/internal/capacity/autoscalinggroup_test.go +++ b/internal/capacity/autoscalinggroup_test.go @@ -186,7 +186,7 @@ func expectTerminateInstances( }), // For InstanceTerminatedWaiter - ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + ec2Mock.EXPECT().DescribeInstances(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { instanceIds := make([]string, len(instancesToTerminate)) instances := make([]ec2types.Instance, len(instancesToTerminate)) for i, instance := range instancesToTerminate { @@ -618,7 +618,7 @@ func TestAutoScalingGroup_ReplaceInstances(t *testing.T) { t.Fatal(err) } - if err := group.ReplaceInstances(context.Background(), drainerMock); err != nil { + if err := group.ReplaceInstances(ctx, drainerMock); err != nil { t.Errorf("err = %#v; want nil", err) } }) @@ -751,7 +751,7 @@ func TestAutoScalingGroup_ReduceCapacity(t *testing.T) { Reservations: append(reservationsToTerminate, reservationsToKeep...), }, nil), - drainerMock.EXPECT().Drain(context.Background(), gomock.Len(len(instancesToTerminate))), + drainerMock.EXPECT().Drain(ctx, gomock.Len(len(instancesToTerminate))), asMock.EXPECT().DetachInstances(ctx, gomock.Any()).Times(4).Do(func(_ context.Context, input *autoscaling.DetachInstancesInput, _ ...func(options *autoscaling.Options)) { detachedInstanceIds = append(detachedInstanceIds, input.InstanceIds...) @@ -762,7 +762,7 @@ func TestAutoScalingGroup_ReduceCapacity(t *testing.T) { }), // For InstanceTerminatedWaiter - ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + ec2Mock.EXPECT().DescribeInstances(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { if !testutil.MatchSlice(input.InstanceIds, terminatedInstanceIds) { t.Errorf("input.InstanceIds = %v; want %v", input.InstanceIds, terminatedInstanceIds) } diff --git a/internal/capacity/drainer_test.go b/internal/capacity/drainer_test.go index eef88f3..fb1d132 100644 --- a/internal/capacity/drainer_test.go +++ b/internal/capacity/drainer_test.go @@ -14,6 +14,7 @@ import ( "github.com/abicky/ecsmec/internal/capacity" "github.com/abicky/ecsmec/internal/testing/capacitymock" + "github.com/abicky/ecsmec/internal/testing/testutil" ) func TestDrainer_Drain(t *testing.T) { @@ -87,7 +88,7 @@ func TestDrainer_Drain(t *testing.T) { ecsMock.EXPECT().UpdateContainerInstancesState(ctx, gomock.Any()).Return(&ecs.UpdateContainerInstancesStateOutput{}, nil) // For ecs.TasksStoppedWaiter - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { + ecsMock.EXPECT().DescribeTasks(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { return &ecs.DescribeTasksOutput{ Tasks: []ecstypes.Task{ { @@ -97,7 +98,7 @@ func TestDrainer_Drain(t *testing.T) { }, nil }) // For ecs.ServicesStableWaiter - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeServicesInput, _ ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { + ecsMock.EXPECT().DescribeServices(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeServicesInput, _ ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { return &ecs.DescribeServicesOutput{ Services: []ecstypes.Service{ { @@ -115,7 +116,7 @@ func TestDrainer_Drain(t *testing.T) { t.Fatal(err) } - if err := drainer.Drain(context.Background(), instanceIDs); err != nil { + if err := drainer.Drain(ctx, instanceIDs); err != nil { t.Errorf("err = %#v; want nil", err) } }) @@ -147,7 +148,7 @@ func TestDrainer_Drain(t *testing.T) { t.Fatal(err) } - if err := drainer.Drain(context.Background(), instanceIDs); err == nil { + if err := drainer.Drain(ctx, instanceIDs); err == nil { t.Errorf("err = nil; want non-nil") } }) @@ -245,7 +246,7 @@ func TestDrainer_ProcessInterruptions(t *testing.T) { t.Fatal(err) } - entries, err := drainer.ProcessInterruptions(context.Background(), messages) + entries, err := drainer.ProcessInterruptions(ctx, messages) if err != nil { t.Errorf("err = %#v; want nil", err) } diff --git a/internal/capacity/spotfleetrequest_test.go b/internal/capacity/spotfleetrequest_test.go index dd8e41a..063f59b 100644 --- a/internal/capacity/spotfleetrequest_test.go +++ b/internal/capacity/spotfleetrequest_test.go @@ -73,7 +73,7 @@ func TestSpotFleetRequest_TerminateAllInstances(t *testing.T) { }), // For InstanceTerminatedWaiter - ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + ec2Mock.EXPECT().DescribeInstances(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { instanceIds := make([]string, len(instances)) for i, instance := range instances { instanceIds[i] = *instance.InstanceId diff --git a/internal/service/service_test.go b/internal/service/service_test.go index fafde2b..9a28aac 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -67,7 +67,7 @@ func expectCopy( }), // For ecs.ServicesStableWaiter - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ecs.DescribeServicesInput, _ ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { + ecsMock.EXPECT().DescribeServices(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ecs.DescribeServicesInput, _ ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { if input.Services[0] != dstServiceName { t.Errorf("*input.Service[0] = %s; want %s", input.Services[0], dstServiceName) } @@ -119,7 +119,7 @@ func expectStopAndDelete( }), // For ecs.TasksStoppedWaiter - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { + ecsMock.EXPECT().DescribeTasks(testutil.AnyContext(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { if !reflect.DeepEqual(input.Tasks, runningTaskArns) { t.Errorf("input.Tasks = %#v; want %#v", input.Tasks, runningTaskArns) } diff --git a/internal/testing/testutil/testutil.go b/internal/testing/testutil/testutil.go index a0bcf19..0a6550a 100644 --- a/internal/testing/testutil/testutil.go +++ b/internal/testing/testutil/testutil.go @@ -1,9 +1,18 @@ package testutil import ( + "context" + "reflect" + "go.uber.org/mock/gomock" ) +var ctxIface = reflect.TypeOf((*context.Context)(nil)).Elem() + +func AnyContext() gomock.Matcher { + return gomock.AssignableToTypeOf(ctxIface) +} + func InOrder(calls ...*gomock.Call) *gomock.Call { args := make([]any, 0, len(calls)) for _, call := range calls {