diff --git a/aws_google_auth/configuration.py b/aws_google_auth/configuration.py index 819905c..ba1c081 100644 --- a/aws_google_auth/configuration.py +++ b/aws_google_auth/configuration.py @@ -107,18 +107,19 @@ def write(self, amazon_object): with open(self.config_file, 'w+') as f: config_parser.write(f) - # Write to the credentials file - credentials_parser = configparser.RawConfigParser() - credentials_parser.read(self.credentials_file) - if not credentials_parser.has_section(self.profile): - credentials_parser.add_section(self.profile) - credentials_parser.set(self.profile, 'aws_access_key_id', amazon_object.access_key_id) - credentials_parser.set(self.profile, 'aws_secret_access_key', amazon_object.secret_access_key) - credentials_parser.set(self.profile, 'aws_security_token', amazon_object.session_token) - credentials_parser.set(self.profile, 'aws_session_expiration', amazon_object.expiration.strftime('%Y-%m-%dT%H:%M:%S%z')) - credentials_parser.set(self.profile, 'aws_session_token', amazon_object.session_token) - with open(self.credentials_file, 'w+') as f: - credentials_parser.write(f) + # Write to the credentials file (only if we have credentials) + if amazon_object is not None: + credentials_parser = configparser.RawConfigParser() + credentials_parser.read(self.credentials_file) + if not credentials_parser.has_section(self.profile): + credentials_parser.add_section(self.profile) + credentials_parser.set(self.profile, 'aws_access_key_id', amazon_object.access_key_id) + credentials_parser.set(self.profile, 'aws_secret_access_key', amazon_object.secret_access_key) + credentials_parser.set(self.profile, 'aws_security_token', amazon_object.session_token) + credentials_parser.set(self.profile, 'aws_session_expiration', amazon_object.expiration.strftime('%Y-%m-%dT%H:%M:%S%z')) + credentials_parser.set(self.profile, 'aws_session_token', amazon_object.session_token) + with open(self.credentials_file, 'w+') as f: + credentials_parser.write(f) # Read from the configuration file and override ALL values currently stored # in the configuration object. As this is potentially destructive, it's diff --git a/aws_google_auth/tests/test_configuration_persistence.py b/aws_google_auth/tests/test_configuration_persistence.py new file mode 100644 index 0000000..2dd5ad5 --- /dev/null +++ b/aws_google_auth/tests/test_configuration_persistence.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python + +import configparser +import unittest +from aws_google_auth import configuration +from random import randint + + +class TestConfigurationPersistence(unittest.TestCase): + + def setUp(self): + self.c = configuration.Configuration() + + # Pick a profile name that is clear it's for testing. We'll delete it + # after, but in case something goes wrong we don't want to use + # something that could clobber user input. + self.c.profile = "aws_google_auth_test_{}".format(randint(100, 999)) + + # Pick a string to do password leakage tests. + self.c.password = "aws_google_auth_test_password_{}".format(randint(100, 999)) + + self.c.region = "us-east-1" + self.c.ask_role = False + self.c.duration = 1234 + self.c.idp_id = "sample_idp_id" + self.c.role_arn = "arn:aws:iam::sample_arn" + self.c.sp_id = "sample_sp_id" + self.c.u2f_disabled = False + self.c.username = "sample_username" + self.c.raise_if_invalid() + self.c.write(None) + + self.config_parser = configparser.RawConfigParser() + self.config_parser.read(self.c.config_file) + + def tearDown(self): + self.config_parser.remove_section(self.c.profile) + with open(self.c.config_file, 'w') as config_file: + self.config_parser.write(config_file) + + def test_creating_new_profile(self): + self.assertTrue(self.config_parser.has_section(self.c.profile)) + self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_idp_id'), self.c.idp_id) + self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_role_arn'), self.c.role_arn) + self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_sp_id'), self.c.sp_id) + self.assertEqual(self.config_parser[self.c.profile].get('aws_google_auth_username'), self.c.username) + self.assertEqual(self.config_parser[self.c.profile].get('region'), self.c.region) + self.assertEqual(self.config_parser[self.c.profile].getboolean('aws_google_auth_ask_role'), self.c.ask_role) + self.assertEqual(self.config_parser[self.c.profile].getboolean('aws_google_auth_u2f_disabled'), self.c.u2f_disabled) + self.assertEqual(self.config_parser[self.c.profile].getint('aws_google_auth_duration'), self.c.duration) + + def test_password_not_written(self): + self.assertIsNone(self.config_parser[self.c.profile].get('aws_google_auth_password', None)) + self.assertIsNone(self.config_parser[self.c.profile].get('password', None)) + + # Check for password leakage (It didn't get written in an odd way) + with open(self.c.config_file, 'r') as config_file: + for line in config_file: + self.assertFalse(self.c.password in line) + + def test_can_read_all_values(self): + test_configuration = configuration.Configuration() + test_configuration.read(self.c.profile) + + # Reading won't get password, so we need to set for the configuration + # to be considered valid + test_configuration.password = "test_password" + + test_configuration.raise_if_invalid() + + self.assertEqual(test_configuration.profile, self.c.profile) + self.assertEqual(test_configuration.idp_id, self.c.idp_id) + self.assertEqual(test_configuration.role_arn, self.c.role_arn) + self.assertEqual(test_configuration.sp_id, self.c.sp_id) + self.assertEqual(test_configuration.username, self.c.username) + self.assertEqual(test_configuration.region, self.c.region) + self.assertEqual(test_configuration.ask_role, self.c.ask_role) + self.assertEqual(test_configuration.u2f_disabled, self.c.u2f_disabled) + self.assertEqual(test_configuration.duration, self.c.duration) diff --git a/aws_google_auth/tests/test_util.py b/aws_google_auth/tests/test_util.py new file mode 100644 index 0000000..7a8aff6 --- /dev/null +++ b/aws_google_auth/tests/test_util.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python + +import unittest +from aws_google_auth import util + + +class TestUtilMethods(unittest.TestCase): + + def test_default_if_none(self): + value = "non_none_value" + self.assertEqual(util.Util.default_if_none(value, None), value) + self.assertEqual(util.Util.default_if_none(None, value), value) + self.assertEqual(util.Util.default_if_none(None, None), None)