Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STS cross-account access #64

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions awslimitchecker/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@

class AwsLimitChecker(object):

def __init__(self, warning_threshold=80, critical_threshold=99):
def __init__(self, warning_threshold=80, critical_threshold=99,
account_id=None, account_role=None, region=None):
"""
Main AwsLimitChecker class - this should be the only externally-used
portion of awslimitchecker.
Expand All @@ -65,6 +66,10 @@ def __init__(self, warning_threshold=80, critical_threshold=99):
integer percentage, for any limits without a specifically-set
threshold.
:type critical_threshold: int
:param account_id: connect via STS to this AWS account
:type account_id: str
:param account_role: connect via STS as this IAM role
:type account_role: str
"""
# ###### IMPORTANT license notice ##########
# Pursuant to Sections 5(b) and 13 of the GNU Affero General Public
Expand All @@ -91,10 +96,14 @@ def __init__(self, warning_threshold=80, critical_threshold=99):
)
self.warning_threshold = warning_threshold
self.critical_threshold = critical_threshold
self.account_id = account_id
self.account_role = account_role
self.region = region
self.services = {}
self.ta = TrustedAdvisor()
self.ta = TrustedAdvisor(account_id=self.account_id, account_role=self.account_role, region=self.region)
for sname, cls in _services.items():
self.services[sname] = cls(warning_threshold, critical_threshold)
self.services[sname] = cls(warning_threshold, critical_threshold,
account_id, account_role, region)

def get_version(self):
"""
Expand Down
14 changes: 13 additions & 1 deletion awslimitchecker/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def parse_args(self, argv):
type=int, default=99,
help='default critical threshold (percentage of '
'limit); default: 99')
p.add_argument('-A', '--sts-account-id', action='store',
type=str, default=None,
help='the AWS account to control')
p.add_argument('-R', '--sts-account-role', action='store',
type=str, default=None,
help='the IAM role to assume')
p.add_argument('-r', '--region', action='store',
type=str, default=None,
help='connect to this AWS region; required for STS')
p.add_argument('--skip-ta', action='store_true', default=False,
help='do not attempt to pull *any* information on limits'
' from Trusted Advisor')
Expand Down Expand Up @@ -281,7 +290,10 @@ def console_entry_point(self):
# the rest of these actually use the checker
self.checker = AwsLimitChecker(
warning_threshold=args.warning_threshold,
critical_threshold=args.critical_threshold
critical_threshold=args.critical_threshold,
account_id=args.sts_account_id,
account_role=args.sts_account_role,
region=args.region
)

if args.version:
Expand Down
11 changes: 7 additions & 4 deletions awslimitchecker/services/autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.ec2.autoscale
import logging

from .base import _AwsService
Expand All @@ -52,11 +53,13 @@ class _AutoscalingService(_AwsService):
service_name = 'AutoScaling'

def connect(self):
"""connect to API if not already connected; set self.conn"""
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.autoscale.connect_to_region)
else:
self.conn = boto.connect_autoscale()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
44 changes: 43 additions & 1 deletion awslimitchecker/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

import abc
import logging
import boto.sts

logger = logging.getLogger(__name__)


Expand All @@ -47,7 +49,8 @@ class _AwsService(object):

service_name = 'baseclass'

def __init__(self, warning_threshold, critical_threshold):
def __init__(self, warning_threshold, critical_threshold, account_id=None,
account_role=None, region=None):
"""
Describes an AWS service and its limits, and provides methods to
query current utilization.
Expand All @@ -65,9 +68,17 @@ def __init__(self, warning_threshold, critical_threshold):
integer percentage, for any limits without a specifically-set
threshold.
:type critical_threshold: int
:param account_id: connect via STS to this AWS account
:type account_id: str
:param account_role: connect via STS as this IAM role
:type account_role: str
"""
self.warning_threshold = warning_threshold
self.critical_threshold = critical_threshold
self.account_id = account_id
self.account_role = account_role
self.region = region

self.limits = {}
self.limits = self.get_limits()
self.conn = None
Expand Down Expand Up @@ -137,6 +148,37 @@ def required_iam_permissions(self):
"""
raise NotImplementedError('abstract base class')

def connect_via(self, driver):
"""
Connect to API if not already connected; set self.conn
Use STS to assume a role as another user if self.account_id has been set

:param driver: the connect_to_region() function of the boto
submodule to use to create this connection
:type driver: :py:func:
"""
if(self.account_id):
logger.debug("Connecting to %s for account %s", self.service_name,
self.account_id)
self.credentials = self._get_sts_token()
conn = driver(
self.region,
aws_access_key_id=self.credentials.access_key,
aws_secret_access_key=self.credentials.secret_key,
security_token=self.credentials.session_token)
else:
logger.debug("Connecting to %s", self.service_name)
conn = driver.connect_to_region(self.region)
logger.info("Connected to %s", self.service_name)
return conn

def _get_sts_token(self):
"""Attempt to get STS token, exit if fail."""
sts = boto.sts.connect_to_region(self.region)
arn = "arn:aws:iam::%s:role/%s" % (self.account_id, self.account_role)
role = sts.assume_role(arn, "awslimitchecker")
return role.credentials

def set_limit_override(self, limit_name, value, override_ta=True):
"""
Set a new limit ``value`` for the specified limit, overriding
Expand Down
11 changes: 7 additions & 4 deletions awslimitchecker/services/ebs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.ec2
import logging
from .base import _AwsService
from ..limit import AwsLimit
Expand All @@ -50,11 +51,13 @@ class _EbsService(_AwsService):
service_name = 'EBS'

def connect(self):
"""connect to API if not already connected; set self.conn"""
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.connect_to_region)
else:
self.conn = boto.connect_ec2()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
11 changes: 7 additions & 4 deletions awslimitchecker/services/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.ec2
import logging
from collections import defaultdict
from copy import deepcopy
Expand All @@ -52,11 +53,13 @@ class _Ec2Service(_AwsService):
service_name = 'EC2'

def connect(self):
"""connect to API if not already connected; set self.conn"""
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.connect_to_region)
else:
self.conn = boto.connect_ec2()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
10 changes: 7 additions & 3 deletions awslimitchecker/services/elasticache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"""

import abc # noqa
import boto.elasticache
from boto.elasticache.layer1 import ElastiCacheConnection
from boto.exception import BotoServerError
import logging
Expand All @@ -53,10 +54,13 @@ class _ElastiCacheService(_AwsService):
service_name = 'ElastiCache'

def connect(self):
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.elasticache.connect_to_region)
else:
self.conn = ElastiCacheConnection()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
10 changes: 7 additions & 3 deletions awslimitchecker/services/elb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.ec2.elb
import logging

from .base import _AwsService
Expand All @@ -52,10 +53,13 @@ class _ElbService(_AwsService):
service_name = 'ELB'

def connect(self):
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.elb.connect_to_region)
else:
self.conn = boto.connect_elb()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
12 changes: 8 additions & 4 deletions awslimitchecker/services/newservice.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ class _XXNewServiceXXService(_AwsService):
service_name = 'XXNewServiceXX'

def connect(self):
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
# TODO: set this to the correct connection method:
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
# TODO: set this to the correct connection methods:
elif self.region:
import boto.XXnewserviceXX
self.conn = self.connect_via(boto.XXnewserviceXX.connect_to_region)
else:
self.conn = boto.connect_XXnewserviceXX()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
11 changes: 7 additions & 4 deletions awslimitchecker/services/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.rds2
import logging

from .base import _AwsService
Expand All @@ -52,11 +53,13 @@ class _RDSService(_AwsService):
service_name = 'RDS'

def connect(self):
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
# TODO: set this to the correct connection method:
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.rds2.connect_to_region)
else:
self.conn = boto.connect_rds2()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
11 changes: 7 additions & 4 deletions awslimitchecker/services/vpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.vpc
import logging
from collections import defaultdict

Expand All @@ -53,11 +54,13 @@ class _VpcService(_AwsService):
service_name = 'VPC'

def connect(self):
"""connect to API if not already connected; set self.conn"""
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.vpc.connect_to_region)
else:
self.conn = boto.connect_vpc()
logger.info("Connected to %s", self.service_name)

def find_usage(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions awslimitchecker/tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_init(self):
'SvcBar': self.mock_svc2
}
# _AwsService instances should exist, but have no other calls
assert self.mock_foo.mock_calls == [call(80, 99)]
assert self.mock_bar.mock_calls == [call(80, 99)]
assert self.mock_foo.mock_calls == [call(80, 99, None, None, None)]
assert self.mock_bar.mock_calls == [call(80, 99, None, None, None)]
assert self.mock_svc1.mock_calls == []
assert self.mock_svc2.mock_calls == []
assert self.cls.ta == self.mock_ta
Expand Down Expand Up @@ -151,8 +151,8 @@ def test_init_thresholds(self):
'SvcBar': mock_svc2
}
# _AwsService instances should exist, but have no other calls
assert mock_foo.mock_calls == [call(5, 22)]
assert mock_bar.mock_calls == [call(5, 22)]
assert mock_foo.mock_calls == [call(5, 22, None, None, None)]
assert mock_bar.mock_calls == [call(5, 22, None, None, None)]
assert mock_svc1.mock_calls == []
assert mock_svc2.mock_calls == []
assert self.mock_version.mock_calls == [call()]
Expand Down
Loading