diff --git a/ecs_files_composer/aws_mgmt.py b/ecs_files_composer/aws_mgmt.py index d737790..5868fd1 100644 --- a/ecs_files_composer/aws_mgmt.py +++ b/ecs_files_composer/aws_mgmt.py @@ -38,7 +38,9 @@ class AwsResourceHandler(object): Class to handle all AWS related credentials init. """ - def __init__(self, role_arn=None, external_id=None, region=None, iam_config_object=None): + def __init__( + self, role_arn=None, external_id=None, region=None, iam_config_object=None, client_session_override=None + ): """ :param str role_arn: :param str external_id: @@ -47,7 +49,9 @@ def __init__(self, role_arn=None, external_id=None, region=None, iam_config_obje """ self.session = session.Session() self.client_session = session.Session() - if role_arn or iam_config_object: + if client_session_override: + self.client_session = client_session_override + elif not client_session_override and (role_arn or iam_config_object): if role_arn and not iam_config_object: params = {"RoleArn": role_arn, "RoleSessionName": "EcsConfigComposer@AwsResourceHandlerInit"} if external_id: @@ -70,8 +74,10 @@ class S3Fetcher(AwsResourceHandler): Class to handle S3 actions """ - def __init__(self, role_arn=None, external_id=None, region=None, iam_config_object=None): - super().__init__(role_arn, external_id, region, iam_config_object) + def __init__( + self, role_arn=None, external_id=None, region=None, iam_config_object=None, client_session_override=None + ): + super().__init__(role_arn, external_id, region, iam_config_object, client_session_override) self.client = self.client_session.client("s3") def get_content(self, s3_uri=None, s3_bucket=None, s3_key=None): @@ -103,8 +109,10 @@ class SsmFetcher(AwsResourceHandler): arn_re = re.compile(r"(?:^arn:aws(?:-[a-z]+)?:ssm:[\S]+:[0-9]+:parameter)(?P/[\S]+)$") - def __init__(self, role_arn=None, external_id=None, region=None, iam_config_object=None): - super().__init__(role_arn, external_id, region, iam_config_object) + def __init__( + self, role_arn=None, external_id=None, region=None, iam_config_object=None, client_session_override=None + ): + super().__init__(role_arn, external_id, region, iam_config_object, client_session_override) self.client = self.client_session.client("ssm") def get_content(self, parameter_name): @@ -131,8 +139,10 @@ class SecretFetcher(AwsResourceHandler): Class to handle Secret Manager actions """ - def __init__(self, role_arn=None, external_id=None, region=None, iam_config_object=None): - super().__init__(role_arn, external_id, region, iam_config_object) + def __init__( + self, role_arn=None, external_id=None, region=None, iam_config_object=None, client_session_override=None + ): + super().__init__(role_arn, external_id, region, iam_config_object, client_session_override) self.client = self.client_session.client("secretsmanager") def get_content(self, secret):