Skip to content

Commit

Permalink
list accounts and publish them
Browse files Browse the repository at this point in the history
  • Loading branch information
tmclaugh committed Jan 2, 2025
1 parent d08edf0 commit afa3238
Show file tree
Hide file tree
Showing 8 changed files with 1,610 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ common = {editable = true, path = "src/common"}
aws-lambda-powertools = "*"

[dev-packages]
boto3-stubs = { extras = [ "sns", ], version = "*"}
boto3-stubs = { extras = [ "sns", "organizations" ], version = "*"}
cfn-lint = "*"
flake8 = "*"
genson = "*"
jsonschema = "*"
json2python-models = "*"
moto = { extras = [ "sns", ], version = "*"}
moto = { extras = [ "sns", "organizations" ], version = "*"}
mypy = "*"
pylint = "*"
pytest = "*"
Expand Down
1,430 changes: 1,430 additions & 0 deletions Pipfile.lock

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions src/common/common/model/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from typing import List

from mypy_boto3_organizations.type_defs import AccountTypeDef

@dataclass
class AccountType(AccountTypeDef):
pass

@dataclass
class AccountTypeWithTags(AccountType):
Tags: List[dict]
9 changes: 9 additions & 0 deletions src/common/common/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import datetime
from json import JSONEncoder

class JSONDateTimeEncoder(JSONEncoder):
'''Encode JSON when datetime objects are present'''

def default(self, obj):
if isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()
Empty file added src/common/common/util/aws.py
Empty file.
3 changes: 2 additions & 1 deletion src/common/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
include_package_data=True,
install_requires=[
'aws_lambda_powertools',
'boto3'
'boto3',
'boto3-stubs[organizations]',
],
classifiers=[
'Environment :: Console',
Expand Down
64 changes: 58 additions & 6 deletions src/handlers/ListAccounts/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
'''List AWS accounts'''
import os
import boto3
import json
from typing import TYPE_CHECKING, List, Optional

from aws_lambda_powertools.logging import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext
Expand All @@ -10,20 +12,70 @@
EventBridgeEvent
)

if TYPE_CHECKING:
from mypy_boto3_sns.type_defs import PublishResponseTypeDef

from common.model.account import AccountType, AccountTypeWithTags
from common.util import JSONDateTimeEncoder

LOGGER = Logger(utc=True)

ORG_CLIENT = boto3.client('organizations')
SNS_CLIENT = boto3.client('sns')
SNS_TOPIC_ARN = os.environ.get('SNS_TOPIC_ARN', 'UNSET')


def _get_account_tags(accounts: List[AccountType]) -> List[AccountTypeWithTags]:
'''Get tags for accounts'''
accounts_with_tags = []
for account in accounts:
tags = ORG_CLIENT.list_tags_for_resource(
# Haven't seen a situation where Id is not present
ResourceId=account.get('Id', '')
)
account_with_tags = {**account, **tags}
accounts_with_tags.append(account_with_tags)
return accounts_with_tags

def _main(data) -> None:
'''Main work of function'''
# Transform data

# Send data to destination
def _list_all_accounts(NextToken: Optional[str] = None) -> List[AccountType]:
'''List AWS accounts'''
accounts = []
while True:
response = ORG_CLIENT.list_accounts(
**{ 'NextToken': NextToken } if NextToken else {}
)
if 'Accounts' in response:
accounts += response['Accounts']

return
if 'NextToken' in response:
next_page = _list_all_accounts(NextToken=response['NextToken'])
response['Accounts'] += next_page
else:
break
return accounts


def _publish_accounts(accounts: List[AccountTypeWithTags]) -> List['PublishResponseTypeDef']:
'''Publish account to SNS'''
responses = []
for account in accounts:
LOGGER.debug('Publishing {}'.format(account.get('Id')), extra={"message_object": account})
response = SNS_CLIENT.publish(
TopicArn = SNS_TOPIC_ARN,
Subject = 'AWS Account',
Message = json.dumps(account, cls=JSONDateTimeEncoder)
)
LOGGER.debug('SNS Response for {}'.format(account.get('Id')), extra={"message_object": response})
responses.append(response)
return responses


def _main() -> None:
'''List AWS accounts and publish to SNS'''
accounts = _list_all_accounts()
accounts_with_tags = _get_account_tags(accounts)
_publish_accounts(accounts_with_tags)


@LOGGER.inject_lambda_context
Expand All @@ -32,6 +84,6 @@ def handler(event: EventBridgeEvent, context: LambdaContext) -> None:
'''Event handler'''
LOGGER.debug('Event', extra={"message_object": event})

_main(event.detail)
_main()

return
107 changes: 97 additions & 10 deletions tests/unit/handlers/ListAccounts/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import jsonschema
import os
from types import ModuleType
from typing import Generator
from typing import Generator, List

import pytest
from pytest_mock import MockerFixture


import boto3
from mypy_boto3_organizations import OrganizationsClient
from mypy_boto3_organizations.type_defs import AccountTypeDef, TagTypeDef
from mypy_boto3_sns import SNSClient
from moto import mock_aws

Expand Down Expand Up @@ -59,18 +61,69 @@ def mocked_aws(aws_credentials):
with mock_aws():
yield

@pytest.fixture()
def mock_orgs_client(mocked_aws) -> Generator[OrganizationsClient, None, None]:
orgs_client = boto3.client('organizations')
yield orgs_client

@pytest.fixture()
def mock_sns_client(mocked_aws) -> Generator[SNSClient, None, None]:
sns_client = boto3.client('sns')
yield sns_client

@pytest.fixture()
def mock_sns_topic_name(mock_sns_client) -> str:
def mock_sns_topic_arn(mock_sns_client) -> str:
'''Create a mock resource'''
mock_topic_name = 'MockTopic'
mock_sns_client.create_topic(Name=mock_topic_name)
return mock_topic_name
r = mock_sns_client.create_topic(Name=mock_topic_name)
return r.get('TopicArn')

@pytest.fixture()
def mock_organization(mock_orgs_client) -> None:
'''Mock organization'''
mock_orgs_client.create_organization()


@pytest.fixture()
def mock_account(
mock_orgs_client: OrganizationsClient,
mock_organization
) -> AccountTypeDef:
'''Mock account'''
account_config = {
'Email': '[email protected]',
'AccountName': 'Mock Account',
'Tags': [
{
"Key": "org:system",
"Value": "mock_system"
},
{
"Key": "org:domain",
"Value": "mock_domain"
},
{
"Key": "org:owner",
"Value": "group:mock_group"
}
]
}

response = mock_orgs_client.create_account(**account_config)
account_id = response.get('CreateAccountStatus', {}).get('AccountId', '')
return mock_orgs_client.describe_account(AccountId=account_id).get('Account')


@pytest.fixture()
def mock_account_tags(
mock_orgs_client: OrganizationsClient,
mock_account: AccountTypeDef,
) -> List[TagTypeDef]:
'''Return account tags'''
return mock_orgs_client.list_tags_for_resource(
ResourceId=mock_account.get('Id', '')
).get('Tags', [])


# Function
@pytest.fixture()
Expand All @@ -80,7 +133,7 @@ def mock_context(function_name=FN_NAME):

@pytest.fixture()
def mock_fn(
mock_sns_topic_name: str,
mock_sns_topic_arn: str,
mocker: MockerFixture
) -> Generator[ModuleType, None, None]:
'''Return mocked function'''
Expand All @@ -89,7 +142,7 @@ def mock_fn(
# NOTE: use mocker to mock any top-level variables outside of the handler function.
mocker.patch(
'src.handlers.ListAccounts.function.SNS_TOPIC_ARN',
mock_sns_topic_name
mock_sns_topic_arn
)

yield fn
Expand All @@ -103,20 +156,54 @@ def test_validate_event(mock_event, event_schema):


### Code Tests
def test__get_account_tags(
mock_fn: ModuleType,
mock_account: AccountTypeDef,
mock_account_tags: List[TagTypeDef],
):
'''Test _get_account_tags function'''
account_with_tags = mock_fn._get_account_tags([mock_account])[0]
assert 'Tags' in account_with_tags
assert len(account_with_tags['Tags']) > 0
assert account_with_tags['Tags'] == mock_account_tags


def test__list_all_accounts(
mock_fn: ModuleType,
mock_orgs_client: OrganizationsClient,
mock_account: AccountTypeDef,
):
'''Test _list_all_accounts function'''
# Call the function
accounts = mock_fn._list_all_accounts()
account_ids = [account.get('Id', '') for account in accounts]

# Assertions
assert len(accounts) > 0
assert mock_account.get('Id') in account_ids


def test__publish_accounts(
mock_fn: ModuleType,
mock_account: AccountTypeDef,
):
'''Test _publish_accounts function'''
account_with_tags = mock_fn._get_account_tags([mock_account])[0]
response = mock_fn._publish_accounts([account_with_tags])
assert len(response) > 0


def test__main(
mock_fn: ModuleType,
mock_data
):
'''Test _main function'''
mock_fn._main(mock_data)
mock_fn._main()


def test_handler(
mock_fn: ModuleType,
mock_context,
mock_event: EventBridgeEvent,
mock_data
mock_sns_client: SNSClient,
):
'''Test calling handler'''
# Call the function
Expand Down

0 comments on commit afa3238

Please sign in to comment.