Skip to content

Commit

Permalink
Pass additional information to session manager plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Mansi Dabhole authored and nateprewitt committed Mar 15, 2021
1 parent 24e1a24 commit 7d2f4fa
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 45 deletions.
45 changes: 44 additions & 1 deletion awscli/customizations/ecs/executecommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
'session-manager-plugin-not-found'
)

TASK_NOT_FOUND = (
'The task provided in the request was '
'not found.'
)


class ECSExecuteCommand(ServiceOperation):

Expand All @@ -37,6 +42,37 @@ def create_help_command(self):
return help_command


def get_container_runtime_id(client, container_name, task_id, cluster_name):
describe_tasks_params = {
"cluster": cluster_name,
"tasks": [task_id]
}
describe_tasks_response = client.describe_tasks(**describe_tasks_params)
# need to fail here if task has failed in the intermediate time
tasks = describe_tasks_response['tasks']
if not tasks:
raise ValueError(TASK_NOT_FOUND)
response = describe_tasks_response['tasks'][0]['containers']
for container in response:
if container_name == container['name']:
return container['runtimeId']


def build_ssm_request_paramaters(response, client):
cluster_name = response['clusterArn'].split('/')[-1]
task_id = response['taskArn'].split('/')[-1]
container_name = response['containerName']
# in order to get container run-time id
# we need to make a call to describe-tasks
container_runtime_id = \
get_container_runtime_id(client, container_name,
task_id, cluster_name)
target = "ecs:{}_{}_{}".format(cluster_name, task_id,
container_runtime_id)
ssm_request_params = {"Target": target}
return ssm_request_params


class ExecuteCommandCaller(CLIOperationCaller):
def invoke(self, service_name, operation_name, parameters, parsed_globals):
try:
Expand All @@ -54,6 +90,10 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
verify=parsed_globals.verify_ssl)
response = client.execute_command(**parameters)
region_name = client.meta.region_name
profile_name = self._session.profile \
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_request_params = build_ssm_request_paramaters(response, client)
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -65,7 +105,10 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
check_call(["session-manager-plugin",
json.dumps(response['session']),
region_name,
"StartSession"])
"StartSession",
profile_name,
json.dumps(ssm_request_params),
endpoint_url])
return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
169 changes: 125 additions & 44 deletions tests/unit/customizations/ecs/test_executecommand_startsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,83 @@ def setUp(self):
self.session = mock.Mock(botocore.session.Session)
self.client = mock.Mock()
self.region = 'us-west-2'
self.profile = 'testProfile'
self.endpoint_url = 'testUrl'
self.client.meta.region_name = self.region
self.client.meta.endpoint_url = self.endpoint_url
self.caller = executecommand.ExecuteCommandCaller(self.session)
self.session.profile = self.profile
self.session.create_client.return_value = self.client
self.execute_command_params = {
"cluster": "default",
"task": "someTaskId",
"command": "ls",
"interactive": "true"}
self.execute_command_response = {
"containerName": "someContainerName",
"containerArn": "ecs/someContainerArn",
"taskArn": "ecs/someTaskArn",
"session": {"sessionId": "session-id",
"tokenValue": "token-value",
"streamUrl": "stream-url"},
"clusterArn": "ecs/someClusterArn",
"interactive": "true"
}
self.describe_tasks_response = {
"failures": [],
"tasks": [
{
"clusterArn": "ecs/someCLusterArn",
"desiredStatus": "RUNNING",
"createdAt": "1611619514.46",
"taskArn": "someTaskArn",
"containers": [
{
"containerArn": "ecs/someContainerArn",
"taskArn": "ecs/someTaskArn",
"name": "someContainerName",
"managedAgents": [
{
"reason": "Execute Command Agent started",
"lastStatus": "RUNNING",
"lastStartedAt": "1611619528.272",
"name": "ExecuteCommandAgent"
}
],
"runtimeId": "someRuntimeId"
},
{
"containerArn": "ecs/dummyContainerArn",
"taskArn": "ecs/someTaskArn",
"name": "dummyContainerName",
"managedAgents": [
{
"reason": "Execute Command Agent started",
"lastStatus": "RUNNING",
"lastStartedAt": "1611619528.272",
"name": "ExecuteCommandAgent"
}
],
"runtimeId": "dummyRuntimeId"
}
],
"lastStatus": "RUNNING",
"enableExecuteCommand": "true"
}
]
}
self.describe_tasks_response_fail = {
"failures": [
{
"reason": "MISSING",
"arn": "someTaskArn"
}
],
"tasks": []
}
self.ssm_request_parameters = {
"Target": "ecs:someClusterArn_someTaskArn_someRuntimeId"
}

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_calls_fails_from_ecs(self, mock_check_call):
Expand All @@ -48,76 +120,85 @@ def test_when_session_manager_plugin_not_installed(self, mock_check_call):
@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_execute_command_success(self, mock_check_call):
mock_check_call.return_value = 0
execute_command_params = {
"cluster": "default",
"task": "someTaskId",
"command": "ls",
"interactive": "true"
}

execute_command_response = {
"containerName": "someContainerName",
"containerArn": "someContainerArn",
"taskArn": "someTaskArn",
"session": {"sessionId": "session-id",
"tokenValue": "token-value",
"streamUrl": "stream-url"},
"clusterArn": "someClusterArn",
"interactive": "true"
}

self.client.execute_command.return_value = execute_command_response
self.client.execute_command.return_value = \
self.execute_command_response
self.client.describe_tasks.return_value = self.describe_tasks_response

rc = self.caller.invoke('ecs', 'ExecuteCommand',
execute_command_params, mock.Mock())
self.execute_command_params, mock.Mock())

self.assertEquals(rc, 0)
self.client.execute_command.\
assert_called_with(**execute_command_params)
assert_called_with(**self.execute_command_params)

mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
execute_command_response["session"],
self.execute_command_response["session"],
self.region,
'StartSession']
'StartSession',
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url
]
)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_check_call_fails(self, mock_check_call):
mock_check_call.side_effect = [0, Exception('some Exception')]
def test_when_describe_task_fails(self, mock_check_call):
mock_check_call.return_value = 0

execute_command_params = {
"cluster": "default",
"task": "someTaskId",
"command": "ls",
"interactive": "true"
}
self.client.execute_command.return_value = \
self.execute_command_response
self.client.describe_tasks.side_effect = \
Exception("Some Server Exception")

execute_command_response = {
"containerName": "someContainerName",
"containerArn": "someContainerArn",
"taskArn": "someTaskArn",
"session": {"sessionId": "session-id",
"tokenValue": "token-value",
"streamUrl": "stream-url"},
"clusterArn": "someClusterArn",
"interactive": "true"
}
with self.assertRaisesRegexp(Exception, 'Some Server Exception'):
rc = self.caller.invoke('ecs', 'ExecuteCommand',
self.execute_command_params, mock.Mock())
self.assertEquals(rc, 0)
self.client.execute_command. \
assert_called_with(**self.execute_command_params)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_describe_task_returns_no_tasks(self, mock_check_call):
mock_check_call.return_value = 0

self.client.execute_command.return_value = \
self.execute_command_response
self.client.describe_tasks.return_value = \
self.describe_tasks_response_fail

with self.assertRaises(Exception):
rc = self.caller.invoke('ecs', 'ExecuteCommand',
self.execute_command_params, mock.Mock())
self.assertEquals(rc, 0)
self.client.execute_command. \
assert_called_with(**self.execute_command_params)

@mock.patch('awscli.customizations.ecs.executecommand.check_call')
def test_when_check_call_fails(self, mock_check_call):
mock_check_call.side_effect = [0, Exception('some Exception')]

self.client.execute_command.return_value = execute_command_response
self.client.execute_command.return_value = \
self.execute_command_response
self.client.describe_tasks.return_value = self.describe_tasks_response

with self.assertRaises(Exception):
self.caller.invoke('ecs', 'ExecuteCommand',
execute_command_params, mock.Mock())
self.execute_command_params, mock.Mock())

mock_check_call_list = mock_check_call.call_args[0][0]
mock_check_call_list[1] = json.loads(mock_check_call_list[1])
self.assertEqual(
mock_check_call_list,
['session-manager-plugin',
execute_command_response["session"],
self.execute_command_response["session"],
self.region,
'StartSession'])
'StartSession',
self.profile,
json.dumps(self.ssm_request_parameters),
self.endpoint_url],
)

0 comments on commit 7d2f4fa

Please sign in to comment.