Skip to content

Commit

Permalink
PR #64 - major refactor of copy/paste connect() methods out of servic…
Browse files Browse the repository at this point in the history
…e classes
  • Loading branch information
jantman committed Oct 2, 2015
1 parent 9a603a8 commit 47a0602
Show file tree
Hide file tree
Showing 20 changed files with 120 additions and 520 deletions.
11 changes: 2 additions & 9 deletions awslimitchecker/services/autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,8 @@
class _AutoscalingService(_AwsService):

service_name = 'AutoScaling'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.autoscale.connect_to_region)
else:
self.conn = boto.connect_autoscale()
connect_function = boto.connect_autoscale
region_connect_function = boto.ec2.autoscale.connect_to_region

def find_usage(self):
"""
Expand Down
71 changes: 39 additions & 32 deletions awslimitchecker/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,6 @@ def __init__(self, warning_threshold, critical_threshold, account_id=None,
self.conn = None
self._have_usage = False

@abc.abstractmethod
def connect(self):
"""
If not already done, establish a connection to the relevant AWS service
and save as ``self.conn``.
"""
"""
if self.conn is None:
logger.debug("Connecting to %s", self.service_name)
# self.conn = boto.<connect to something>
logger.info("Connected to %s", self.service_name)
"""
raise NotImplementedError('abstract base class')

@abc.abstractmethod
def find_usage(self):
"""
Expand Down Expand Up @@ -161,28 +147,49 @@ def required_iam_permissions(self):
"""
raise NotImplementedError('abstract base class')

def connect_via(self, driver):
def connect(self):
"""
Connect to API if not already connected; set self.conn
Use STS to assume a role as another user if self.account_id has been set
Connect to AWS API for this service, if not already connected; set
``self.conn`` to the connection object.
:param driver: the connect_to_region() function of the boto
submodule to use to create this connection
:type driver: :py:obj:`function`
If ``self.region`` is None, connect by calling
``self.connect_function()`` and setting ``self.conn`` to that value.
If ``self.region`` is not None, connect by calling
``self.region_connect_function()``. Arguments passed are ``self.region``
and, if ``self.account_id`` is not None, ``self.region`` and the STS
credentials returned by :py:meth:`~._get_sts_token` (to assume a role
using STS).
"""
if(self.account_id):
logger.debug("Connecting to %s for account %s (STS; %s)",
self.service_name, self.account_id, self.region)
self.credentials = self._get_sts_token()
conn = driver(
self.region,
aws_access_key_id=self.credentials.access_key,
aws_secret_access_key=self.credentials.secret_key,
security_token=self.credentials.session_token)
else:
logger.debug("Connecting to %s (%s)",
if self.conn is not None:
# already connected
return self.conn

if self.region is None:
# use regionless self.connect_function
logger.debug("Connecting to %s", self.service_name)
self.conn = self.connect_function()
logger.info("Connected to %s", self.service_name)
return self.conn

if self.account_id is None:
# region but no account_id
logger.debug("Connecting to %s (region %s)",
self.service_name, self.region)
conn = driver(self.region)
conn = self.region_connect_function(self.region)
logger.info("Connected to %s", self.service_name)
return conn

# else we have self.account_id set; use STS
logger.debug("Connecting to %s for account %s (STS; %s)",
self.service_name, self.account_id, self.region)
self.credentials = self._get_sts_token()
conn = self.region_connect_function(
self.region,
aws_access_key_id=self.credentials.access_key,
aws_secret_access_key=self.credentials.secret_key,
security_token=self.credentials.session_token
)
logger.info("Connected to %s", self.service_name)
return conn

Expand Down
11 changes: 2 additions & 9 deletions awslimitchecker/services/ebs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,8 @@
class _EbsService(_AwsService):

service_name = 'EBS'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.connect_to_region)
else:
self.conn = boto.connect_ec2()
connect_function = boto.connect_ec2
region_connect_function = boto.ec2.connect_to_region

def find_usage(self):
"""
Expand Down
12 changes: 3 additions & 9 deletions awslimitchecker/services/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import abc # noqa
import boto
import boto.ec2
import logging
from collections import defaultdict
from copy import deepcopy
Expand All @@ -50,15 +51,8 @@
class _Ec2Service(_AwsService):

service_name = 'EC2'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.connect_to_region)
else:
self.conn = boto.connect_ec2()
connect_function = boto.connect_ec2
region_connect_function = boto.ec2.connect_to_region

def find_usage(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions awslimitchecker/services/elasticache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,8 @@
class _ElastiCacheService(_AwsService):

service_name = 'ElastiCache'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.elasticache.connect_to_region)
else:
self.conn = ElastiCacheConnection()
connect_function = ElastiCacheConnection
region_connect_function = boto.elasticache.connect_to_region

def find_usage(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions awslimitchecker/services/elb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,8 @@
class _ElbService(_AwsService):

service_name = 'ELB'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.ec2.elb.connect_to_region)
else:
self.conn = boto.connect_elb()
connect_function = boto.connect_elb
region_connect_function = boto.ec2.elb.connect_to_region

def find_usage(self):
"""
Expand Down
17 changes: 5 additions & 12 deletions awslimitchecker/services/newservice.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,11 @@ class _XXNewServiceXXService(_AwsService):

service_name = 'XXNewServiceXX'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
# TODO: set this to the correct connection methods:
elif self.region:
self.conn = self.connect_via(boto.XXnewserviceXX.connect_to_region)
else:
logger.debug("Connecting to %s (no region specified)",
self.service_name)
self.conn = boto.connect_XXnewserviceXX()
logger.info("Connected to %s", self.service_name)
# TODO: ensure these point to the correct functions
# boto function used to connect to service with no region or STS
connect_function = boto.connect_XXnewserviceXX()
# boto function used to connect to service with region or STS
region_connect_function = boto.XXnewserviceXX.connect_to_region

def find_usage(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions awslimitchecker/services/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,8 @@
class _RDSService(_AwsService):

service_name = 'RDS'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.rds2.connect_to_region)
else:
self.conn = boto.connect_rds2()
connect_function = boto.connect_rds2
region_connect_function = boto.rds2.connect_to_region

def find_usage(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions awslimitchecker/services/vpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,8 @@
class _VpcService(_AwsService):

service_name = 'VPC'

def connect(self):
"""Connect to API if not already connected; set self.conn."""
if self.conn is not None:
return
elif self.region:
self.conn = self.connect_via(boto.vpc.connect_to_region)
else:
self.conn = boto.connect_vpc()
connect_function = boto.connect_vpc
region_connect_function = boto.vpc.connect_to_region

def find_usage(self):
"""
Expand Down
44 changes: 0 additions & 44 deletions awslimitchecker/tests/services/test_autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,50 +66,6 @@ def test_init(self):
assert cls.warning_threshold == 21
assert cls.critical_threshold == 43

def test_connect(self):
"""test connect()"""
mock_conn = Mock()
mock_conn_via = Mock()
cls = _AutoscalingService(21, 43)
with patch('%s.boto.connect_autoscale' % self.pbm) as mock_autoscaling:
with patch('%s.connect_via' % self.pb) as mock_connect_via:
mock_autoscaling.return_value = mock_conn
mock_connect_via.return_value = mock_conn_via
cls.connect()
assert mock_autoscaling.mock_calls == [call()]
assert mock_connect_via.mock_calls == []
assert mock_conn.mock_calls == []
assert cls.conn == mock_conn

def test_connect_region(self):
"""test connect()"""
mock_conn = Mock()
mock_conn_via = Mock()
cls = _AutoscalingService(21, 43, region='myreg')
with patch('%s.boto.connect_autoscale' % self.pbm) as mock_autoscaling:
with patch('%s.connect_via' % self.pb) as mock_connect_via:
mock_autoscaling.return_value = mock_conn
mock_connect_via.return_value = mock_conn_via
cls.connect()
assert mock_autoscaling.mock_calls == []
assert mock_connect_via.mock_calls == [
call(connect_to_region)
]
assert mock_conn.mock_calls == []
assert cls.conn == mock_conn_via

def test_connect_again(self):
"""make sure we re-use the connection"""
mock_conn = Mock()
cls = _AutoscalingService(21, 43)
cls.conn = mock_conn
with patch('awslimitchecker.services.autoscaling.boto.connect_'
'autoscale') as mock_autoscaling:
mock_autoscaling.return_value = mock_conn
cls.connect()
assert mock_autoscaling.mock_calls == []
assert mock_conn.mock_calls == []

def test_get_limits(self):
cls = _AutoscalingService(21, 43)
cls.limits = {}
Expand Down
Loading

0 comments on commit 47a0602

Please sign in to comment.