Skip to content

Commit

Permalink
#148 added support for other partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
meshuga committed Nov 19, 2020
1 parent 6290ae5 commit 219d1eb
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 47 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ This features is experimental, but now you can run commands to check and analyze
* Incoming SSH Disabled
* Cloudtrail enabled
## Regions outside of main partition
If you wish to analyze accounts in regions outside the main AWS partition (e.g. GovCloud or China), you should provide credentials (e.g. a profile) that are applicable to a given partition. It's not possible to analyze regions from multiple partitions.
## Using a Docker container
To build docker container using Dockerfile
Expand Down
72 changes: 60 additions & 12 deletions cloudiscovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
exit_critical,
Filterable,
parse_filters,
message_handler,
)
from shared.common_aws import aws_verbose, generate_session

Expand All @@ -51,6 +52,7 @@

AVAILABLE_LANGUAGES = ["en_US", "pt_BR"]
DEFAULT_REGION = "us-east-1"
DEFAULT_PARTITION_CODE = "aws"


def str2bool(v):
Expand Down Expand Up @@ -184,7 +186,7 @@ def add_default_arguments(
)


# pylint: disable=too-many-branches,too-many-statements
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
def main():
# Entry point for the CLI.
# Load commands
Expand Down Expand Up @@ -237,10 +239,11 @@ def main():
session.get_credentials()
region_name = session.region_name

partition_code = get_partition(session, region_name)

if "region_name" not in args:
region_names = [DEFAULT_REGION]
else:

# checking region configuration
check_region_profile(
arg_region_name=args.region_name, profile_region_name=region_name
Expand All @@ -252,7 +255,10 @@ def main():

# get regions
region_names = check_region(
region_parameter=args.region_name, region_name=region_name, session=session,
region_parameter=args.region_name,
region_name=region_name,
session=session,
partition_code=partition_code,
)

if "threshold" in args:
Expand All @@ -264,22 +270,40 @@ def main():
exit_critical(_("Threshold must be between 0 and 100"))

if args.command == "aws-vpc":
command = Vpc(vpc_id=args.vpc_id, region_names=region_names, session=session,)
command = Vpc(
vpc_id=args.vpc_id,
region_names=region_names,
session=session,
partition_code=partition_code,
)
elif args.command == "aws-policy":
command = Policy(region_names=region_names, session=session,)
command = Policy(
region_names=region_names, session=session, partition_code=partition_code
)
elif args.command == "aws-iot":
command = Iot(
thing_name=args.thing_name, region_names=region_names, session=session,
thing_name=args.thing_name,
region_names=region_names,
session=session,
partition_code=partition_code,
)
elif args.command == "aws-all":
command = All(region_names=region_names, session=session)
command = All(
region_names=region_names, session=session, partition_code=partition_code
)
elif args.command == "aws-limit":
command = Limit(
region_names=region_names, session=session, threshold=args.threshold,
region_names=region_names,
session=session,
threshold=args.threshold,
partition_code=partition_code,
)
elif args.command == "aws-security":
command = Security(
region_names=region_names, session=session, commands=args.commands,
region_names=region_names,
session=session,
commands=args.commands,
partition_code=partition_code,
)
else:
raise NotImplementedError("Unknown command")
Expand All @@ -292,6 +316,28 @@ def main():
command.run(diagram, args.verbose, services, filters)


def get_partition(session, region_name):
partition_code = DEFAULT_PARTITION_CODE # assume it's always default partition, even if we can't find a region
partition_name = "AWS Standard"
# pylint: disable=protected-access
loader = session._session.get_component("data_loader")
endpoints = loader.load_data("endpoints")
for partition in endpoints["partitions"]:
for region, _ in partition["regions"].items():
if region == region_name:
partition_code = partition["partition"]
partition_name = partition["partitionName"]

if partition_code != DEFAULT_PARTITION_CODE:
message_handler(
"Found non-default partition: {} ({})".format(
partition_code, partition_name
),
"HEADER",
)
return partition_code


def check_diagram_version(diagram):
if diagram:
# Checking diagram version. Must be 0.13 or higher
Expand All @@ -303,17 +349,19 @@ def check_diagram_version(diagram):


def check_region_profile(arg_region_name, profile_region_name):

if arg_region_name is None and profile_region_name is None:
exit_critical("Neither region parameter nor region config were passed")


def check_region(region_parameter, region_name, session):
def check_region(region_parameter, region_name, session, partition_code):
"""
Region us-east-1 as a default region here
Region us-east-1 as a default region here, if not aws partition, just return asked region
This is just to list aws regions, doesn't matter default region
"""
if partition_code != "aws":
return [region_name]

client = session.client("ec2", region_name=DEFAULT_REGION)

valid_region_names = [
Expand Down
5 changes: 3 additions & 2 deletions cloudiscovery/provider/iot/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def iot_digest(self):

class Iot(BaseAwsCommand):
# pylint: disable=too-many-arguments
def __init__(self, thing_name, region_names, session):
def __init__(self, thing_name, region_names, session, partition_code):
"""
Iot command
:param thing_name:
:param region_names:
:param session:
:param partition_code:
"""
super().__init__(region_names, session)
super().__init__(region_names, session, partition_code)
self.thing_name = thing_name

def run(
Expand Down
5 changes: 3 additions & 2 deletions cloudiscovery/provider/limit/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ def get_quota(self, quota_code, service_code, service_quota):


class Limit(BaseAwsCommand):
def __init__(self, region_names, session, threshold):
def __init__(self, region_names, session, threshold, partition_code):
"""
All AWS resources
:param region_names:
:param session:
:param threshold:
:param partition_code:
"""
super().__init__(region_names, session)
super().__init__(region_names, session, partition_code)
self.threshold = threshold

def init_globalaws_limits_cache(self, region, services, options: LimitOptions):
Expand Down
5 changes: 3 additions & 2 deletions cloudiscovery/provider/security/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ def __init__(self, session, region: str, commands, options: SecurityOptions):


class Security(BaseAwsCommand):
def __init__(self, region_names, session, commands):
def __init__(self, region_names, session, commands, partition_code):
"""
All AWS resources
:param region_names:
:param session:
:param commands:
:param partition_code:
"""
super().__init__(region_names, session)
super().__init__(region_names, session, partition_code)
self.commands = commands

def run(
Expand Down
5 changes: 3 additions & 2 deletions cloudiscovery/provider/vpc/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ def vpc_digest(self):

class Vpc(BaseAwsCommand):
# pylint: disable=too-many-arguments
def __init__(self, vpc_id, region_names, session):
def __init__(self, vpc_id, region_names, session, partition_code):
"""
VPC command
:param vpc_id:
:param region_names:
:param session:
:param partition_code:
"""
super().__init__(region_names, session)
super().__init__(region_names, session, partition_code)
self.vpc_id = vpc_id

@staticmethod
Expand Down
82 changes: 55 additions & 27 deletions cloudiscovery/shared/common_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,52 @@ def account_number(self):


class GlobalParameters:
def __init__(self, session, region: str, path: str):
def __init__(self, session, region: str, path: str, partition_code: str):
self.region = region
self.session = session.client("ssm", region_name="us-east-1")
self.session = session
self.client = None
self.path = path
self.partition_code = partition_code
self.cache = ResourceCache()

def get_parameters_by_path(self, next_token=None):
def paths(self):

params = {"Path": self.path, "Recursive": True, "MaxResults": 10}
if next_token is not None:
params["NextToken"] = next_token
cache_key = "aws_paths_" + self.region
cache = self.cache.get_key(cache_key)

return self.session.get_parameters_by_path(**params)
if cache is not None:
return cache

paths_found = []
if self.partition_code == "aws":
message_handler(
"Fetching available resources in region {} to cache...".format(
self.region
),
"HEADER",
)
self.client = self.session.client("ssm", region_name="us-east-1")
paths = self.parameters()
for path in paths:
paths_found.append(path["Value"])
else:
message_handler(
"Loading available resources in region {} to cache...".format(
self.region
),
"HEADER",
)
# pylint: disable=protected-access
loader = self.session._session.get_component("data_loader")
endpoints = loader.load_data("endpoints")
for partition in endpoints["partitions"]:
for service, service_info in partition["services"].items():
for endpoint_region, _ in service_info["endpoints"].items():
if self.region == endpoint_region:
paths_found.append(service)

self.cache.set_key(key=cache_key, value=paths_found, expire=86400)
return paths_found

def parameters(self):
next_token = None
Expand All @@ -108,37 +141,27 @@ def parameters(self):
break
next_token = response["NextToken"]

def paths(self):

cache_key = "aws_paths_" + self.region
cache = self.cache.get_key(cache_key)

if cache is not None:
return cache
def get_parameters_by_path(self, next_token=None):

message_handler(
"Fetching available resources in region {} to cache...".format(self.region),
"HEADER",
)
paths_found = []
paths = self.parameters()
for path in paths:
paths_found.append(path["Value"])
params = {"Path": self.path, "Recursive": True, "MaxResults": 10}
if next_token is not None:
params["NextToken"] = next_token

self.cache.set_key(key=cache_key, value=paths_found, expire=86400)
return paths_found
return self.client.get_parameters_by_path(**params)


class BaseAwsCommand(BaseCommand):
def __init__(self, region_names, session):
def __init__(self, region_names, session, partition_code):
"""
Base class for discovery command
:param region_names:
:param session:
:param partition_code:
"""
self.region_names: List[str] = region_names
self.session: Session = session
self.partition_code: str = partition_code

def run(
self,
Expand All @@ -150,9 +173,14 @@ def run(
raise NotImplementedError()

def init_region_cache(self, region):
# Get and cache SSM services available in specific region
# Get and cache services available in specific region
path = "/aws/service/global-infrastructure/regions/" + region + "/services/"
GlobalParameters(session=self.session, region=region, path=path).paths()
GlobalParameters(
session=self.session,
region=region,
path=path,
partition_code=self.partition_code,
).paths()


def resource_tags(resource_data: dict) -> List[ResourceTag]:
Expand Down

0 comments on commit 219d1eb

Please sign in to comment.