Skip to content

Commit

Permalink
Merge pull request #8336 from kdaily/ssm-session-manager-pluging-env-…
Browse files Browse the repository at this point in the history
…variable

Pass StartSession response as env variable
  • Loading branch information
kdaily authored Nov 17, 2023
2 parents a453709 + 0d5e0c1 commit 4ad0f43
Show file tree
Hide file tree
Showing 4 changed files with 441 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``ssm`` Session Manager",
"description": "Pass StartSession API response as environment variable to session-manager-plugin"
}
75 changes: 71 additions & 4 deletions awscli/customizations/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import logging
import json
import errno
import os
import re

from subprocess import check_call
from subprocess import check_call, check_output
from awscli.compat import ignore_user_entered_signals
from awscli.clidriver import ServiceOperation, CLIOperationCaller

Expand Down Expand Up @@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs):
)


class StartSessionCommand(ServiceOperation):
class VersionRequirement:
WHITESPACE_REGEX = re.compile(r"\s+")
SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$")

def __init__(self, min_version):
self.min_version = min_version

def meets_requirement(self, version):
ssm_plugin_version = self._sanitize_plugin_version(version)
if self._is_valid_version(ssm_plugin_version):
norm_version, norm_min_version = self._normalize(
ssm_plugin_version, self.min_version
)
return norm_version > norm_min_version
else:
return False

def _sanitize_plugin_version(self, plugin_version):
return re.sub(self.WHITESPACE_REGEX, "", plugin_version)

def _is_valid_version(self, plugin_version):
return bool(
self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version)
)

def _normalize(self, v1, v2):
v1_parts = [int(v) for v in v1.split(".")]
v2_parts = [int(v) for v in v2.split(".")]
while len(v1_parts) != len(v2_parts):
if len(v1_parts) - len(v2_parts) > 0:
v2_parts.append(0)
else:
v1_parts.append(0)
return v1_parts, v2_parts


class StartSessionCommand(ServiceOperation):
def create_help_command(self):
help_command = super(
StartSessionCommand, self).create_help_command()
Expand All @@ -55,6 +92,9 @@ def create_help_command(self):


class StartSessionCaller(CLIOperationCaller):
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

def invoke(self, service_name, operation_name, parameters,
parsed_globals):
client = self._session.create_client(
Expand All @@ -70,8 +110,34 @@ def invoke(self, service_name, operation_name, parameters,
profile_name = self._session.profile \
if self._session.profile is not None else ''
endpoint_url = client.meta.endpoint_url
ssm_env_name = self.DEFAULT_SSM_ENV_NAME

try:
session_parameters = {
"SessionId": response["SessionId"],
"TokenValue": response["TokenValue"],
"StreamUrl": response["StreamUrl"],
}
start_session_response = json.dumps(session_parameters)

plugin_version = check_output(
["session-manager-plugin", "--version"], text=True
)
env = os.environ.copy()

# Check if this plugin supports passing the start session response
# as an environment variable name. If it does, it will set the
# value to the response from the start_session operation to the env
# variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin
# version is invalid or older than the version defined in
# LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to
# passing the start_session response directly.
version_requirement = VersionRequirement(
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
)
if version_requirement.meets_requirement(plugin_version):
env[ssm_env_name] = start_session_response
start_session_response = ssm_env_name
# ignore_user_entered_signals ignores these signals
# because if signals which kills the process are not
# captured would kill the foreground process but not the
Expand All @@ -81,12 +147,13 @@ def invoke(self, service_name, operation_name, parameters,
with ignore_user_entered_signals():
# call executable with necessary input
check_call(["session-manager-plugin",
json.dumps(response),
start_session_response,
region_name,
"StartSession",
profile_name,
json.dumps(parameters),
endpoint_url])
endpoint_url], env=env)

return 0
except OSError as ex:
if ex.errno == errno.ENOENT:
Expand Down
115 changes: 98 additions & 17 deletions tests/functional/ssm/test_start_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,119 @@

from awscli.testutils import BaseAWSCommandParamsTest
from awscli.testutils import BaseAWSHelpOutputTest
from awscli.testutils import mock
from awscli.testutils import mock

class TestSessionManager(BaseAWSCommandParamsTest):

class TestSessionManager(BaseAWSCommandParamsTest):
@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_success(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_success(self, mock_check_output, mock_check_call):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.return_value = 0
self.parsed_responses = [{
mock_check_output.return_value = "1.2.0.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"
}]
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]
start_session_params = {"Target": "instance-id"}

self.run_cmd(cmdline, expected_rc=0)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
json.dumps(expected_response),
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=self.environ,
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_with_new_version_plugin_success(
self, mock_check_output, mock_check_call
):
cmdline = "ssm start-session --target instance-id"
mock_check_call.return_value = 0
mock_check_output.return_value = "1.2.500.0\n"
expected_response = {
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
self.parsed_responses = [expected_response]

ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"
start_session_params = {"Target": "instance-id"}
expected_env = self.environ.copy()
expected_env.update({ssm_env_name: json.dumps(expected_response)})

self.run_cmd(cmdline, expected_rc=0)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
self.assertEqual(self.operations_called[0][1],
{'Target': 'instance-id'})
actual_response = json.loads(mock_check_call.call_args[0][0][1])
self.assertEqual(
{"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url"},
actual_response)

mock_check_call.assert_called_once_with(
[
"session-manager-plugin",
ssm_env_name,
mock.ANY,
"StartSession",
mock.ANY,
json.dumps(start_session_params),
mock.ANY,
],
env=expected_env,
)

@mock.patch('awscli.customizations.sessionmanager.check_call')
def test_start_session_fails(self, mock_check_call):
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_fails(self, mock_check_output, mock_check_call):
cmdline = "ssm start-session --target instance-id"
mock_check_output.return_value = "1.2.500.0\n"
mock_check_call.side_effect = OSError(errno.ENOENT, "some error")
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(
self.operations_called[0][0].name, "StartSession"
)
self.assertEqual(
self.operations_called[0][1], {"Target": "instance-id"}
)
self.assertEqual(
self.operations_called[1][0].name, "TerminateSession"
)
self.assertEqual(
self.operations_called[1][1], {"SessionId": "session-id"}
)

@mock.patch("awscli.customizations.sessionmanager.check_call")
@mock.patch("awscli.customizations.sessionmanager.check_output")
def test_start_session_when_get_plugin_version_fails(
self, mock_check_output, mock_check_call
):
cmdline = 'ssm start-session --target instance-id'
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [{
"SessionId": "session-id"
}]
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')
self.parsed_responses = [
{
"SessionId": "session-id",
"TokenValue": "token-value",
"StreamUrl": "stream-url",
}
]
self.run_cmd(cmdline, expected_rc=255)
self.assertEqual(self.operations_called[0][0].name,
'StartSession')
Expand Down
Loading

0 comments on commit 4ad0f43

Please sign in to comment.