diff --git a/internal/pkg/task/config_runner.go b/internal/pkg/task/config_runner.go new file mode 100644 index 00000000000..2adfc0579c3 --- /dev/null +++ b/internal/pkg/task/config_runner.go @@ -0,0 +1,64 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package task + +import ( + "fmt" + "github.com/aws/copilot-cli/internal/pkg/aws/ecs" + "github.com/pkg/errors" +) + +// NetworkConfigRunner runs an Amazon ECS task in the specified network configuration and the default cluster. +type NetworkConfigRunner struct { + // Count of the tasks to be launched. + Count int + // Group Name of the tasks that use the same task definition. + GroupName string + + // Network configuration + Subnets []string + SecurityGroups []string + + // Interfaces to interact with dependencies. Must not be nil. + ClusterGetter DefaultClusterGetter + Starter TaskRunner +} + +// Run runs tasks in the subnets and the security groups, and returns the task ARNs. +func (r *NetworkConfigRunner) Run() ([]string, error) { + if err:= r.validateDependencies(); err != nil { + return nil, err + } + + cluster, err := r.ClusterGetter.DefaultCluster() + if err != nil { + return nil, fmt.Errorf("get default cluster: %w", err) + } + + arns, err := r.Starter.RunTask(ecs.RunTaskInput{ + Cluster: cluster, + Count: r.Count, + Subnets: r.Subnets, + SecurityGroups: r.SecurityGroups, + TaskFamilyName: taskFamilyName(r.GroupName), + StartedBy: startedBy, + }) + if err != nil { + return nil, fmt.Errorf("run task %s: %w", r.GroupName, err) + } + + return arns, nil +} + +func (r *NetworkConfigRunner) validateDependencies() error { + if r.ClusterGetter == nil { + return errors.New("cluster getter is not set") + } + + if r.Starter == nil { + return errors.New("starter is not set") + } + + return nil +} \ No newline at end of file diff --git a/internal/pkg/task/config_runner_test.go b/internal/pkg/task/config_runner_test.go new file mode 100644 index 00000000000..788b04cfed9 --- /dev/null +++ b/internal/pkg/task/config_runner_test.go @@ -0,0 +1,118 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package task + +import ( + "errors" + "fmt" + "github.com/aws/copilot-cli/internal/pkg/aws/ecs" + "github.com/aws/copilot-cli/internal/pkg/task/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNetworkConfigRunner_Run(t *testing.T) { + testCases := map[string]struct { + count int + groupName string + + subnets []string + securityGroups []string + + mockClusterGetter func(m *mocks.MockdefaultClusterGetter) + mockStarter func(m *mocks.MocktaskRunner) + + wantedError error + wantedARNs []string + }{ + "failed to get clusters": { + mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { + m.EXPECT().DefaultCluster().Return("", errors.New("error getting default cluster")) + }, + mockStarter: func(m *mocks.MocktaskRunner) { + m.EXPECT().RunTask(gomock.Any()).Times(0) + }, + wantedError: fmt.Errorf("get default cluster: error getting default cluster"), + }, + "failed to kick off task": { + count: 1, + groupName: "my-task", + + subnets: []string{"subnet-1", "subnet-2"}, + securityGroups: []string{"sg-1", "sg-2"}, + + mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { + m.EXPECT().DefaultCluster().Return("cluster-1", nil) + }, + mockStarter: func(m *mocks.MocktaskRunner) { + m.EXPECT().RunTask(ecs.RunTaskInput{ + Cluster: "cluster-1", + Count: 1, + Subnets: []string{"subnet-1", "subnet-2"}, + SecurityGroups: []string{"sg-1", "sg-2"}, + TaskFamilyName: taskFamilyName("my-task"), + StartedBy: startedBy, + }).Return(nil, errors.New("error running task")) + }, + + wantedError: fmt.Errorf("run task my-task: error running task"), + }, + "successfully kick off task with both input subnets and security groups": { + count: 1, + groupName: "my-task", + + subnets: []string{"subnet-1", "subnet-2"}, + securityGroups: []string{"sg-1", "sg-2"}, + + mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { + m.EXPECT().DefaultCluster().Return("cluster-1", nil) + }, + mockStarter: func(m *mocks.MocktaskRunner) { + m.EXPECT().RunTask(ecs.RunTaskInput{ + Cluster: "cluster-1", + Count: 1, + Subnets: []string{"subnet-1", "subnet-2"}, + SecurityGroups: []string{"sg-1", "sg-2"}, + TaskFamilyName: taskFamilyName("my-task"), + StartedBy: startedBy, + }).Return([]string{"task-1"}, nil) + }, + + wantedARNs: []string{"task-1"}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClusterGetter := mocks.NewMockdefaultClusterGetter(ctrl) + mockStarter := mocks.NewMocktaskRunner(ctrl) + + tc.mockClusterGetter(mockClusterGetter) + tc.mockStarter(mockStarter) + + task := &NetworkConfigRunner{ + Count: tc.count, + GroupName: tc.groupName, + + Subnets: tc.subnets, + SecurityGroups: tc.securityGroups, + + ClusterGetter: mockClusterGetter, + Starter: mockStarter, + } + + arns, err := task.Run() + if tc.wantedError != nil { + require.EqualError(t, tc.wantedError, err.Error()) + } else { + require.Nil(t, err) + require.Equal(t, tc.wantedARNs, arns) + } + }) + } +} \ No newline at end of file