Skip to content

Commit

Permalink
Add more tests, only write credentials if Amazon object provided.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Ide committed Jan 10, 2018
1 parent 7f57028 commit 0008520
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
25 changes: 13 additions & 12 deletions aws_google_auth/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,19 @@ def write(self, amazon_object):
with open(self.config_file, 'w+') as f:
config_parser.write(f)

# Write to the credentials file
credentials_parser = configparser.RawConfigParser()
credentials_parser.read(self.credentials_file)
if not credentials_parser.has_section(self.profile):
credentials_parser.add_section(self.profile)
credentials_parser.set(self.profile, 'aws_access_key_id', amazon_object.access_key_id)
credentials_parser.set(self.profile, 'aws_secret_access_key', amazon_object.secret_access_key)
credentials_parser.set(self.profile, 'aws_security_token', amazon_object.session_token)
credentials_parser.set(self.profile, 'aws_session_expiration', amazon_object.expiration.strftime('%Y-%m-%dT%H:%M:%S%z'))
credentials_parser.set(self.profile, 'aws_session_token', amazon_object.session_token)
with open(self.credentials_file, 'w+') as f:
credentials_parser.write(f)
# Write to the credentials file (only if we have credentials)
if amazon_object is not None:
credentials_parser = configparser.RawConfigParser()
credentials_parser.read(self.credentials_file)
if not credentials_parser.has_section(self.profile):
credentials_parser.add_section(self.profile)
credentials_parser.set(self.profile, 'aws_access_key_id', amazon_object.access_key_id)
credentials_parser.set(self.profile, 'aws_secret_access_key', amazon_object.secret_access_key)
credentials_parser.set(self.profile, 'aws_security_token', amazon_object.session_token)
credentials_parser.set(self.profile, 'aws_session_expiration', amazon_object.expiration.strftime('%Y-%m-%dT%H:%M:%S%z'))
credentials_parser.set(self.profile, 'aws_session_token', amazon_object.session_token)
with open(self.credentials_file, 'w+') as f:
credentials_parser.write(f)

# Read from the configuration file and override ALL values currently stored
# in the configuration object. As this is potentially destructive, it's
Expand Down
79 changes: 79 additions & 0 deletions aws_google_auth/tests/test_configuration_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python

import configparser
import unittest
from aws_google_auth import configuration
from random import randint


class TestConfigurationPersistence(unittest.TestCase):

def setUp(self):
self.c = configuration.Configuration()

# Pick a profile name that is clear it's for testing. We'll delete it
# after, but in case something goes wrong we don't want to use
# something that could clobber user input.
self.c.profile = "aws_google_auth_test_{}".format(randint(100, 999))

# Pick a string to do password leakage tests.
self.c.password = "aws_google_auth_test_password_{}".format(randint(100, 999))

self.c.region = "us-east-1"
self.c.ask_role = False
self.c.duration = 1234
self.c.idp_id = "sample_idp_id"
self.c.role_arn = "arn:aws:iam::sample_arn"
self.c.sp_id = "sample_sp_id"
self.c.u2f_disabled = False
self.c.username = "sample_username"
self.c.raise_if_invalid()
self.c.write(None)

self.config_parser = configparser.RawConfigParser()
self.config_parser.read(self.c.config_file)

def tearDown(self):
self.config_parser.remove_section(self.c.profile)
with open(self.c.config_file, 'w') as config_file:
self.config_parser.write(config_file)

def test_creating_new_profile(self):
self.assertTrue(self.config_parser.has_section(self.c.profile))
self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_idp_id'), self.c.idp_id)
self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_role_arn'), self.c.role_arn)
self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_sp_id'), self.c.sp_id)
self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_username'), self.c.username)
self.assertEqual(self.config_parser[self.c.profile].get('region'), self.c.region)
self.assertEqual(self.config_parser[self.c.profile].getboolean('aws_google_auth_ask_role'), self.c.ask_role)
self.assertEqual(self.config_parser[self.c.profile].getboolean('aws_google_auth_u2f_disabled'), self.c.u2f_disabled)
self.assertEqual(self.config_parser[self.c.profile].getint('aws_google_auth_duration'), self.c.duration)

def test_password_not_written(self):
self.assertIsNone(self.config_parser[self.c.profile].get('aws_google_auth_password', None))
self.assertIsNone(self.config_parser[self.c.profile].get('password', None))

# Check for password leakage (It didn't get written in an odd way)
with open(self.c.config_file, 'r') as config_file:
for line in config_file:
self.assertFalse(self.c.password in line)

def test_can_read_all_values(self):
test_configuration = configuration.Configuration()
test_configuration.read(self.c.profile)

# Reading won't get password, so we need to set for the configuration
# to be considered valid
test_configuration.password = "test_password"

test_configuration.raise_if_invalid()

self.assertEqual(test_configuration.profile, self.c.profile)
self.assertEqual(test_configuration.idp_id, self.c.idp_id)
self.assertEqual(test_configuration.role_arn, self.c.role_arn)
self.assertEqual(test_configuration.sp_id, self.c.sp_id)
self.assertEqual(test_configuration.username, self.c.username)
self.assertEqual(test_configuration.region, self.c.region)
self.assertEqual(test_configuration.ask_role, self.c.ask_role)
self.assertEqual(test_configuration.u2f_disabled, self.c.u2f_disabled)
self.assertEqual(test_configuration.duration, self.c.duration)
13 changes: 13 additions & 0 deletions aws_google_auth/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python

import unittest
from aws_google_auth import util


class TestUtilMethods(unittest.TestCase):

def test_default_if_none(self):
value = "non_none_value"
self.assertEqual(util.Util.default_if_none(value, None), value)
self.assertEqual(util.Util.default_if_none(None, value), value)
self.assertEqual(util.Util.default_if_none(None, None), None)

0 comments on commit 0008520

Please sign in to comment.