-
Notifications
You must be signed in to change notification settings - Fork 421
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(task): run task in provided subnets and security groups (#1145)
This PR implements a task runner that runs tasks in user provided subnets and security groups. Related #702 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
- Loading branch information
1 parent
e3fa88d
commit 8948838
Showing
2 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package task | ||
|
||
import ( | ||
"fmt" | ||
"github.com/aws/copilot-cli/internal/pkg/aws/ecs" | ||
"github.com/pkg/errors" | ||
) | ||
|
||
// NetworkConfigRunner runs an Amazon ECS task in the specified network configuration and the default cluster. | ||
type NetworkConfigRunner struct { | ||
// Count of the tasks to be launched. | ||
Count int | ||
// Group Name of the tasks that use the same task definition. | ||
GroupName string | ||
|
||
// Network configuration | ||
Subnets []string | ||
SecurityGroups []string | ||
|
||
// Interfaces to interact with dependencies. Must not be nil. | ||
ClusterGetter DefaultClusterGetter | ||
Starter TaskRunner | ||
} | ||
|
||
// Run runs tasks in the subnets and the security groups, and returns the task ARNs. | ||
func (r *NetworkConfigRunner) Run() ([]string, error) { | ||
if err:= r.validateDependencies(); err != nil { | ||
return nil, err | ||
} | ||
|
||
cluster, err := r.ClusterGetter.DefaultCluster() | ||
if err != nil { | ||
return nil, fmt.Errorf("get default cluster: %w", err) | ||
} | ||
|
||
arns, err := r.Starter.RunTask(ecs.RunTaskInput{ | ||
Cluster: cluster, | ||
Count: r.Count, | ||
Subnets: r.Subnets, | ||
SecurityGroups: r.SecurityGroups, | ||
TaskFamilyName: taskFamilyName(r.GroupName), | ||
StartedBy: startedBy, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("run task %s: %w", r.GroupName, err) | ||
} | ||
|
||
return arns, nil | ||
} | ||
|
||
func (r *NetworkConfigRunner) validateDependencies() error { | ||
if r.ClusterGetter == nil { | ||
return errors.New("cluster getter is not set") | ||
} | ||
|
||
if r.Starter == nil { | ||
return errors.New("starter is not set") | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package task | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"github.com/aws/copilot-cli/internal/pkg/aws/ecs" | ||
"github.com/aws/copilot-cli/internal/pkg/task/mocks" | ||
"github.com/golang/mock/gomock" | ||
"github.com/stretchr/testify/require" | ||
"testing" | ||
) | ||
|
||
func TestNetworkConfigRunner_Run(t *testing.T) { | ||
testCases := map[string]struct { | ||
count int | ||
groupName string | ||
|
||
subnets []string | ||
securityGroups []string | ||
|
||
mockClusterGetter func(m *mocks.MockdefaultClusterGetter) | ||
mockStarter func(m *mocks.MocktaskRunner) | ||
|
||
wantedError error | ||
wantedARNs []string | ||
}{ | ||
"failed to get clusters": { | ||
mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { | ||
m.EXPECT().DefaultCluster().Return("", errors.New("error getting default cluster")) | ||
}, | ||
mockStarter: func(m *mocks.MocktaskRunner) { | ||
m.EXPECT().RunTask(gomock.Any()).Times(0) | ||
}, | ||
wantedError: fmt.Errorf("get default cluster: error getting default cluster"), | ||
}, | ||
"failed to kick off task": { | ||
count: 1, | ||
groupName: "my-task", | ||
|
||
subnets: []string{"subnet-1", "subnet-2"}, | ||
securityGroups: []string{"sg-1", "sg-2"}, | ||
|
||
mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { | ||
m.EXPECT().DefaultCluster().Return("cluster-1", nil) | ||
}, | ||
mockStarter: func(m *mocks.MocktaskRunner) { | ||
m.EXPECT().RunTask(ecs.RunTaskInput{ | ||
Cluster: "cluster-1", | ||
Count: 1, | ||
Subnets: []string{"subnet-1", "subnet-2"}, | ||
SecurityGroups: []string{"sg-1", "sg-2"}, | ||
TaskFamilyName: taskFamilyName("my-task"), | ||
StartedBy: startedBy, | ||
}).Return(nil, errors.New("error running task")) | ||
}, | ||
|
||
wantedError: fmt.Errorf("run task my-task: error running task"), | ||
}, | ||
"successfully kick off task with both input subnets and security groups": { | ||
count: 1, | ||
groupName: "my-task", | ||
|
||
subnets: []string{"subnet-1", "subnet-2"}, | ||
securityGroups: []string{"sg-1", "sg-2"}, | ||
|
||
mockClusterGetter: func(m *mocks.MockdefaultClusterGetter) { | ||
m.EXPECT().DefaultCluster().Return("cluster-1", nil) | ||
}, | ||
mockStarter: func(m *mocks.MocktaskRunner) { | ||
m.EXPECT().RunTask(ecs.RunTaskInput{ | ||
Cluster: "cluster-1", | ||
Count: 1, | ||
Subnets: []string{"subnet-1", "subnet-2"}, | ||
SecurityGroups: []string{"sg-1", "sg-2"}, | ||
TaskFamilyName: taskFamilyName("my-task"), | ||
StartedBy: startedBy, | ||
}).Return([]string{"task-1"}, nil) | ||
}, | ||
|
||
wantedARNs: []string{"task-1"}, | ||
}, | ||
} | ||
|
||
for name, tc := range testCases { | ||
t.Run(name, func(t *testing.T) { | ||
ctrl := gomock.NewController(t) | ||
defer ctrl.Finish() | ||
|
||
mockClusterGetter := mocks.NewMockdefaultClusterGetter(ctrl) | ||
mockStarter := mocks.NewMocktaskRunner(ctrl) | ||
|
||
tc.mockClusterGetter(mockClusterGetter) | ||
tc.mockStarter(mockStarter) | ||
|
||
task := &NetworkConfigRunner{ | ||
Count: tc.count, | ||
GroupName: tc.groupName, | ||
|
||
Subnets: tc.subnets, | ||
SecurityGroups: tc.securityGroups, | ||
|
||
ClusterGetter: mockClusterGetter, | ||
Starter: mockStarter, | ||
} | ||
|
||
arns, err := task.Run() | ||
if tc.wantedError != nil { | ||
require.EqualError(t, tc.wantedError, err.Error()) | ||
} else { | ||
require.Nil(t, err) | ||
require.Equal(t, tc.wantedARNs, arns) | ||
} | ||
}) | ||
} | ||
} |