diff --git a/README.md b/README.md index 5e96169..97e0812 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,11 @@ This features is experimental, but now you can run commands to check and analyze * Incoming SSH Disabled * Cloudtrail enabled + +## Regions outside of main partition + +If you wish to analyze accounts in regions outside the main AWS partition (e.g. GovCloud or China), you should provide credentials (e.g. a profile) that are applicable to a given partition. It's not possible to analyze regions from multiple partitions. + ## Using a Docker container To build docker container using Dockerfile diff --git a/cloudiscovery/__init__.py b/cloudiscovery/__init__.py index 312afc5..67ce1e6 100644 --- a/cloudiscovery/__init__.py +++ b/cloudiscovery/__init__.py @@ -38,6 +38,7 @@ exit_critical, Filterable, parse_filters, + message_handler, ) from shared.common_aws import aws_verbose, generate_session @@ -51,6 +52,7 @@ AVAILABLE_LANGUAGES = ["en_US", "pt_BR"] DEFAULT_REGION = "us-east-1" +DEFAULT_PARTITION_CODE = "aws" def str2bool(v): @@ -184,7 +186,7 @@ def add_default_arguments( ) -# pylint: disable=too-many-branches,too-many-statements +# pylint: disable=too-many-branches,too-many-statements,too-many-locals def main(): # Entry point for the CLI. # Load commands @@ -237,10 +239,11 @@ def main(): session.get_credentials() region_name = session.region_name + partition_code = get_partition(session, region_name) + if "region_name" not in args: region_names = [DEFAULT_REGION] else: - # checking region configuration check_region_profile( arg_region_name=args.region_name, profile_region_name=region_name @@ -252,7 +255,10 @@ def main(): # get regions region_names = check_region( - region_parameter=args.region_name, region_name=region_name, session=session, + region_parameter=args.region_name, + region_name=region_name, + session=session, + partition_code=partition_code, ) if "threshold" in args: @@ -264,22 +270,40 @@ def main(): exit_critical(_("Threshold must be between 0 and 100")) if args.command == "aws-vpc": - command = Vpc(vpc_id=args.vpc_id, region_names=region_names, session=session,) + command = Vpc( + vpc_id=args.vpc_id, + region_names=region_names, + session=session, + partition_code=partition_code, + ) elif args.command == "aws-policy": - command = Policy(region_names=region_names, session=session,) + command = Policy( + region_names=region_names, session=session, partition_code=partition_code + ) elif args.command == "aws-iot": command = Iot( - thing_name=args.thing_name, region_names=region_names, session=session, + thing_name=args.thing_name, + region_names=region_names, + session=session, + partition_code=partition_code, ) elif args.command == "aws-all": - command = All(region_names=region_names, session=session) + command = All( + region_names=region_names, session=session, partition_code=partition_code + ) elif args.command == "aws-limit": command = Limit( - region_names=region_names, session=session, threshold=args.threshold, + region_names=region_names, + session=session, + threshold=args.threshold, + partition_code=partition_code, ) elif args.command == "aws-security": command = Security( - region_names=region_names, session=session, commands=args.commands, + region_names=region_names, + session=session, + commands=args.commands, + partition_code=partition_code, ) else: raise NotImplementedError("Unknown command") @@ -292,6 +316,28 @@ def main(): command.run(diagram, args.verbose, services, filters) +def get_partition(session, region_name): + partition_code = DEFAULT_PARTITION_CODE # assume it's always default partition, even if we can't find a region + partition_name = "AWS Standard" + # pylint: disable=protected-access + loader = session._session.get_component("data_loader") + endpoints = loader.load_data("endpoints") + for partition in endpoints["partitions"]: + for region, _ in partition["regions"].items(): + if region == region_name: + partition_code = partition["partition"] + partition_name = partition["partitionName"] + + if partition_code != DEFAULT_PARTITION_CODE: + message_handler( + "Found non-default partition: {} ({})".format( + partition_code, partition_name + ), + "HEADER", + ) + return partition_code + + def check_diagram_version(diagram): if diagram: # Checking diagram version. Must be 0.13 or higher @@ -303,17 +349,19 @@ def check_diagram_version(diagram): def check_region_profile(arg_region_name, profile_region_name): - if arg_region_name is None and profile_region_name is None: exit_critical("Neither region parameter nor region config were passed") -def check_region(region_parameter, region_name, session): +def check_region(region_parameter, region_name, session, partition_code): """ - Region us-east-1 as a default region here + Region us-east-1 as a default region here, if not aws partition, just return asked region This is just to list aws regions, doesn't matter default region """ + if partition_code != "aws": + return [region_name] + client = session.client("ec2", region_name=DEFAULT_REGION) valid_region_names = [ diff --git a/cloudiscovery/provider/iot/command.py b/cloudiscovery/provider/iot/command.py index fc31442..5d6473d 100644 --- a/cloudiscovery/provider/iot/command.py +++ b/cloudiscovery/provider/iot/command.py @@ -21,15 +21,16 @@ def iot_digest(self): class Iot(BaseAwsCommand): # pylint: disable=too-many-arguments - def __init__(self, thing_name, region_names, session): + def __init__(self, thing_name, region_names, session, partition_code): """ Iot command :param thing_name: :param region_names: :param session: + :param partition_code: """ - super().__init__(region_names, session) + super().__init__(region_names, session, partition_code) self.thing_name = thing_name def run( diff --git a/cloudiscovery/provider/limit/command.py b/cloudiscovery/provider/limit/command.py index c7973ff..cfb0d71 100644 --- a/cloudiscovery/provider/limit/command.py +++ b/cloudiscovery/provider/limit/command.py @@ -129,15 +129,16 @@ def get_quota(self, quota_code, service_code, service_quota): class Limit(BaseAwsCommand): - def __init__(self, region_names, session, threshold): + def __init__(self, region_names, session, threshold, partition_code): """ All AWS resources :param region_names: :param session: :param threshold: + :param partition_code: """ - super().__init__(region_names, session) + super().__init__(region_names, session, partition_code) self.threshold = threshold def init_globalaws_limits_cache(self, region, services, options: LimitOptions): diff --git a/cloudiscovery/provider/security/command.py b/cloudiscovery/provider/security/command.py index 82226ea..22462b3 100644 --- a/cloudiscovery/provider/security/command.py +++ b/cloudiscovery/provider/security/command.py @@ -31,15 +31,16 @@ def __init__(self, session, region: str, commands, options: SecurityOptions): class Security(BaseAwsCommand): - def __init__(self, region_names, session, commands): + def __init__(self, region_names, session, commands, partition_code): """ All AWS resources :param region_names: :param session: :param commands: + :param partition_code: """ - super().__init__(region_names, session) + super().__init__(region_names, session, partition_code) self.commands = commands def run( diff --git a/cloudiscovery/provider/vpc/command.py b/cloudiscovery/provider/vpc/command.py index f8911d1..3894323 100644 --- a/cloudiscovery/provider/vpc/command.py +++ b/cloudiscovery/provider/vpc/command.py @@ -31,15 +31,16 @@ def vpc_digest(self): class Vpc(BaseAwsCommand): # pylint: disable=too-many-arguments - def __init__(self, vpc_id, region_names, session): + def __init__(self, vpc_id, region_names, session, partition_code): """ VPC command :param vpc_id: :param region_names: :param session: + :param partition_code: """ - super().__init__(region_names, session) + super().__init__(region_names, session, partition_code) self.vpc_id = vpc_id @staticmethod diff --git a/cloudiscovery/shared/common_aws.py b/cloudiscovery/shared/common_aws.py index 0385e31..74c9aff 100644 --- a/cloudiscovery/shared/common_aws.py +++ b/cloudiscovery/shared/common_aws.py @@ -81,19 +81,52 @@ def account_number(self): class GlobalParameters: - def __init__(self, session, region: str, path: str): + def __init__(self, session, region: str, path: str, partition_code: str): self.region = region - self.session = session.client("ssm", region_name="us-east-1") + self.session = session + self.client = None self.path = path + self.partition_code = partition_code self.cache = ResourceCache() - def get_parameters_by_path(self, next_token=None): + def paths(self): - params = {"Path": self.path, "Recursive": True, "MaxResults": 10} - if next_token is not None: - params["NextToken"] = next_token + cache_key = "aws_paths_" + self.region + cache = self.cache.get_key(cache_key) - return self.session.get_parameters_by_path(**params) + if cache is not None: + return cache + + paths_found = [] + if self.partition_code == "aws": + message_handler( + "Fetching available resources in region {} to cache...".format( + self.region + ), + "HEADER", + ) + self.client = self.session.client("ssm", region_name="us-east-1") + paths = self.parameters() + for path in paths: + paths_found.append(path["Value"]) + else: + message_handler( + "Loading available resources in region {} to cache...".format( + self.region + ), + "HEADER", + ) + # pylint: disable=protected-access + loader = self.session._session.get_component("data_loader") + endpoints = loader.load_data("endpoints") + for partition in endpoints["partitions"]: + for service, service_info in partition["services"].items(): + for endpoint_region, _ in service_info["endpoints"].items(): + if self.region == endpoint_region: + paths_found.append(service) + + self.cache.set_key(key=cache_key, value=paths_found, expire=86400) + return paths_found def parameters(self): next_token = None @@ -108,37 +141,27 @@ def parameters(self): break next_token = response["NextToken"] - def paths(self): - - cache_key = "aws_paths_" + self.region - cache = self.cache.get_key(cache_key) - - if cache is not None: - return cache + def get_parameters_by_path(self, next_token=None): - message_handler( - "Fetching available resources in region {} to cache...".format(self.region), - "HEADER", - ) - paths_found = [] - paths = self.parameters() - for path in paths: - paths_found.append(path["Value"]) + params = {"Path": self.path, "Recursive": True, "MaxResults": 10} + if next_token is not None: + params["NextToken"] = next_token - self.cache.set_key(key=cache_key, value=paths_found, expire=86400) - return paths_found + return self.client.get_parameters_by_path(**params) class BaseAwsCommand(BaseCommand): - def __init__(self, region_names, session): + def __init__(self, region_names, session, partition_code): """ Base class for discovery command :param region_names: :param session: + :param partition_code: """ self.region_names: List[str] = region_names self.session: Session = session + self.partition_code: str = partition_code def run( self, @@ -150,9 +173,14 @@ def run( raise NotImplementedError() def init_region_cache(self, region): - # Get and cache SSM services available in specific region + # Get and cache services available in specific region path = "/aws/service/global-infrastructure/regions/" + region + "/services/" - GlobalParameters(session=self.session, region=region, path=path).paths() + GlobalParameters( + session=self.session, + region=region, + path=path, + partition_code=self.partition_code, + ).paths() def resource_tags(resource_data: dict) -> List[ResourceTag]: