diff --git a/homeassistant/components/aws/__init__.py b/homeassistant/components/aws/__init__.py new file mode 100644 index 0000000000000..bd1f6b550909a --- /dev/null +++ b/homeassistant/components/aws/__init__.py @@ -0,0 +1,147 @@ +"""Support for Amazon Web Services (AWS).""" +import asyncio +import logging +from collections import OrderedDict + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.const import ATTR_CREDENTIALS, CONF_NAME, CONF_PROFILE_NAME +from homeassistant.helpers import config_validation as cv, discovery + +# Loading the config flow file will register the flow +from . import config_flow # noqa +from .const import ( + CONF_ACCESS_KEY_ID, + CONF_SECRET_ACCESS_KEY, + DATA_CONFIG, + DATA_HASS_CONFIG, + DATA_SESSIONS, + DOMAIN, + CONF_NOTIFY, +) +from .notify import PLATFORM_SCHEMA as NOTIFY_PLATFORM_SCHEMA + +REQUIREMENTS = ["aiobotocore==0.10.2"] + +_LOGGER = logging.getLogger(__name__) + +AWS_CREDENTIAL_SCHEMA = vol.Schema( + { + vol.Required(CONF_NAME): cv.string, + vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, + vol.Inclusive(CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS): cv.string, + vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string, + } +) + +DEFAULT_CREDENTIAL = [{CONF_NAME: "default", CONF_PROFILE_NAME: "default"}] + +CONFIG_SCHEMA = vol.Schema( + { + DOMAIN: vol.Schema( + { + vol.Optional( + ATTR_CREDENTIALS, default=DEFAULT_CREDENTIAL + ): vol.All(cv.ensure_list, [AWS_CREDENTIAL_SCHEMA]), + vol.Optional(CONF_NOTIFY): vol.All( + cv.ensure_list, [NOTIFY_PLATFORM_SCHEMA] + ), + } + ) + }, + extra=vol.ALLOW_EXTRA, +) + + +async def async_setup(hass, config): + """Set up AWS component.""" + hass.data[DATA_HASS_CONFIG] = config + + conf = config.get(DOMAIN) + if conf is None: + # create a default conf using default profile + conf = CONFIG_SCHEMA({ATTR_CREDENTIALS: DEFAULT_CREDENTIAL}) + + hass.data[DATA_CONFIG] = conf + hass.data[DATA_SESSIONS] = OrderedDict() + + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_IMPORT}, data=conf + ) + ) + + return True + + +async def async_setup_entry(hass, entry): + """Load a config entry. + + Validate and save sessions per aws credential. + """ + config = hass.data.get(DATA_HASS_CONFIG) + conf = hass.data.get(DATA_CONFIG) + + if entry.source == config_entries.SOURCE_IMPORT: + if conf is None: + # user removed config from configuration.yaml, abort setup + hass.async_create_task( + hass.config_entries.async_remove(entry.entry_id) + ) + return False + + if conf != entry.data: + # user changed config from configuration.yaml, use conf to setup + hass.config_entries.async_update_entry(entry, data=conf) + + if conf is None: + conf = CONFIG_SCHEMA({DOMAIN: entry.data})[DOMAIN] + + validation = True + tasks = [] + for cred in conf.get(ATTR_CREDENTIALS): + tasks.append(_validate_aws_credentials(hass, cred)) + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + for index, result in enumerate(results): + name = conf[ATTR_CREDENTIALS][index][CONF_NAME] + if isinstance(result, Exception): + _LOGGER.error( + "Validating credential [%s] failed: %s", + name, result, exc_info=result + ) + validation = False + else: + hass.data[DATA_SESSIONS][name] = result + + # No entry support for notify component yet + for notify_config in conf.get(CONF_NOTIFY, []): + discovery.load_platform(hass, "notify", DOMAIN, notify_config, config) + + return validation + + +async def _validate_aws_credentials(hass, credential): + """Validate AWS credential config.""" + import aiobotocore + + aws_config = credential.copy() + del aws_config[CONF_NAME] + + profile = aws_config.get(CONF_PROFILE_NAME) + + if profile is not None: + session = aiobotocore.AioSession(profile=profile, loop=hass.loop) + del aws_config[CONF_PROFILE_NAME] + if CONF_ACCESS_KEY_ID in aws_config: + del aws_config[CONF_ACCESS_KEY_ID] + if CONF_SECRET_ACCESS_KEY in aws_config: + del aws_config[CONF_SECRET_ACCESS_KEY] + else: + session = aiobotocore.AioSession(loop=hass.loop) + + async with session.create_client("iam", **aws_config) as client: + await client.get_user() + + return session diff --git a/homeassistant/components/aws/config_flow.py b/homeassistant/components/aws/config_flow.py new file mode 100644 index 0000000000000..c21f2a94137f6 --- /dev/null +++ b/homeassistant/components/aws/config_flow.py @@ -0,0 +1,22 @@ +"""Config flow for AWS component.""" + +from homeassistant import config_entries + +from .const import DOMAIN + + +@config_entries.HANDLERS.register(DOMAIN) +class AWSFlowHandler(config_entries.ConfigFlow): + """Handle a config flow.""" + + VERSION = 1 + CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH + + async def async_step_import(self, user_input): + """Import a config entry.""" + if self._async_current_entries(): + return self.async_abort(reason="single_instance_allowed") + + return self.async_create_entry( + title="configuration.yaml", data=user_input + ) diff --git a/homeassistant/components/aws/const.py b/homeassistant/components/aws/const.py new file mode 100644 index 0000000000000..c8b0eed8b6bbe --- /dev/null +++ b/homeassistant/components/aws/const.py @@ -0,0 +1,13 @@ +"""Constant for AWS component.""" +DOMAIN = "aws" +DATA_KEY = DOMAIN +DATA_CONFIG = "aws_config" +DATA_HASS_CONFIG = "aws_hass_config" +DATA_SESSIONS = "aws_sessions" + +CONF_REGION = "region_name" +CONF_ACCESS_KEY_ID = "aws_access_key_id" +CONF_SECRET_ACCESS_KEY = "aws_secret_access_key" +CONF_PROFILE_NAME = "profile_name" +CONF_CREDENTIAL_NAME = "credential_name" +CONF_NOTIFY = "notify" diff --git a/homeassistant/components/aws/notify.py b/homeassistant/components/aws/notify.py new file mode 100644 index 0000000000000..020d92200b98e --- /dev/null +++ b/homeassistant/components/aws/notify.py @@ -0,0 +1,278 @@ +"""AWS platform for notify component.""" +import asyncio +import logging +import json +import base64 + +import voluptuous as vol + +import homeassistant.helpers.config_validation as cv +from homeassistant.const import CONF_PLATFORM, CONF_NAME, ATTR_CREDENTIALS +from homeassistant.components.notify import ( + ATTR_TARGET, + ATTR_TITLE, + ATTR_TITLE_DEFAULT, + BaseNotificationService, + PLATFORM_SCHEMA, +) +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.json import JSONEncoder + +from .const import ( + CONF_ACCESS_KEY_ID, + CONF_CREDENTIAL_NAME, + CONF_PROFILE_NAME, + CONF_REGION, + CONF_SECRET_ACCESS_KEY, + DATA_SESSIONS, +) + +DEPENDENCIES = ["aws"] + +_LOGGER = logging.getLogger(__name__) + +CONF_CONTEXT = "context" +CONF_SERVICE = "service" + +SUPPORTED_SERVICES = ["lambda", "sns", "sqs"] + + +def _in_avilable_region(config): + """Check if region is available.""" + import aiobotocore + + session = aiobotocore.get_session() + available_regions = session.get_available_regions(config[CONF_SERVICE]) + if config[CONF_REGION] not in available_regions: + raise vol.Invalid( + "Region {} is not available for {} service, mustin {}".format( + config[CONF_REGION], config[CONF_SERVICE], available_regions + ) + ) + return config + + +PLATFORM_SCHEMA = vol.Schema( + vol.All( + PLATFORM_SCHEMA.extend( + { + # override notify.PLATFORM_SCHEMA.CONF_PLATFORM to Optional + # we don't need this field when we use discovery + vol.Optional(CONF_PLATFORM): cv.string, + vol.Required(CONF_SERVICE): vol.All( + cv.string, vol.Lower, vol.In(SUPPORTED_SERVICES) + ), + vol.Required(CONF_REGION): vol.All(cv.string, vol.Lower), + vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, + vol.Inclusive( + CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS + ): cv.string, + vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string, + vol.Exclusive( + CONF_CREDENTIAL_NAME, ATTR_CREDENTIALS + ): cv.string, + vol.Optional(CONF_CONTEXT): vol.Coerce(dict), + }, + extra=vol.PREVENT_EXTRA, + ), + _in_avilable_region, + ) +) + + +async def async_get_service(hass, config, discovery_info=None): + """Get the AWS notification service.""" + import aiobotocore + + session = None + + if discovery_info is not None: + conf = discovery_info + else: + conf = config + + service = conf[CONF_SERVICE] + region_name = conf[CONF_REGION] + + aws_config = conf.copy() + + del aws_config[CONF_SERVICE] + del aws_config[CONF_REGION] + if CONF_PLATFORM in aws_config: + del aws_config[CONF_PLATFORM] + if CONF_NAME in aws_config: + del aws_config[CONF_NAME] + if CONF_CONTEXT in aws_config: + del aws_config[CONF_CONTEXT] + + if not aws_config: + # no platform config, use aws component config instead + if hass.data[DATA_SESSIONS]: + session = list(hass.data[DATA_SESSIONS].values())[0] + else: + raise ValueError( + "No available aws session for {}".format(config[CONF_NAME]) + ) + + if session is None: + credential_name = aws_config.get(CONF_CREDENTIAL_NAME) + if credential_name is not None: + session = hass.data[DATA_SESSIONS].get(credential_name) + if session is None: + _LOGGER.warning( + "No available aws session for %s", credential_name + ) + del aws_config[CONF_CREDENTIAL_NAME] + + if session is None: + profile = aws_config.get(CONF_PROFILE_NAME) + if profile is not None: + session = aiobotocore.AioSession(profile=profile, loop=hass.loop) + del aws_config[CONF_PROFILE_NAME] + else: + session = aiobotocore.AioSession(loop=hass.loop) + + aws_config[CONF_REGION] = region_name + + if service == "lambda": + context_str = json.dumps( + {"custom": conf.get(CONF_CONTEXT, {})}, cls=JSONEncoder + ) + context_b64 = base64.b64encode(context_str.encode("utf-8")) + context = context_b64.decode("utf-8") + return AWSLambda(session, aws_config, context) + + if service == "sns": + return AWSSNS(session, aws_config) + + if service == "sqs": + return AWSSQS(session, aws_config) + + raise ValueError("Unsupported service {}".format(service)) + + +class AWSNotify(BaseNotificationService): + """Implement the notification service for the AWS service.""" + + def __init__(self, session, aws_config): + """Initialize the service.""" + self.session = session + self.aws_config = aws_config + + def send_message(self, message, **kwargs): + """Send notification.""" + raise NotImplementedError("Please call async_send_message()") + + async def async_send_message(self, message="", **kwargs): + """Send notification.""" + targets = kwargs.get(ATTR_TARGET) + + if not targets: + raise HomeAssistantError("At least one target is required") + + +class AWSLambda(AWSNotify): + """Implement the notification service for the AWS Lambda service.""" + + service = "lambda" + + def __init__(self, session, aws_config, context): + """Initialize the service.""" + super().__init__(session, aws_config) + self.context = context + + async def async_send_message(self, message="", **kwargs): + """Send notification to specified LAMBDA ARN.""" + await super().async_send_message(message, **kwargs) + + cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) + payload = {"message": message} + payload.update(cleaned_kwargs) + json_payload = json.dumps(payload) + + async with self.session.create_client( + self.service, **self.aws_config + ) as client: + tasks = [] + for target in kwargs.get(ATTR_TARGET, []): + tasks.append( + client.invoke( + FunctionName=target, + Payload=json_payload, + ClientContext=self.context, + ) + ) + + if tasks: + await asyncio.gather(*tasks) + + +class AWSSNS(AWSNotify): + """Implement the notification service for the AWS SNS service.""" + + service = "sns" + + async def async_send_message(self, message="", **kwargs): + """Send notification to specified SNS ARN.""" + await super().async_send_message(message, **kwargs) + + message_attributes = { + k: {"StringValue": json.dumps(v), "DataType": "String"} + for k, v in kwargs.items() + if v + } + subject = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT) + + async with self.session.create_client( + self.service, **self.aws_config + ) as client: + tasks = [] + for target in kwargs.get(ATTR_TARGET, []): + tasks.append( + client.publish( + TargetArn=target, + Message=message, + Subject=subject, + MessageAttributes=message_attributes, + ) + ) + + if tasks: + await asyncio.gather(*tasks) + + +class AWSSQS(AWSNotify): + """Implement the notification service for the AWS SQS service.""" + + service = "sqs" + + async def async_send_message(self, message="", **kwargs): + """Send notification to specified SQS ARN.""" + await super().async_send_message(message, **kwargs) + + cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) + message_body = {"message": message} + message_body.update(cleaned_kwargs) + json_body = json.dumps(message_body) + message_attributes = {} + for key, val in cleaned_kwargs.items(): + message_attributes[key] = { + "StringValue": json.dumps(val), + "DataType": "String", + } + + async with self.session.create_client( + self.service, **self.aws_config + ) as client: + tasks = [] + for target in kwargs.get(ATTR_TARGET, []): + tasks.append( + client.send_message( + QueueUrl=target, + MessageBody=json_body, + MessageAttributes=message_attributes, + ) + ) + + if tasks: + await asyncio.gather(*tasks) diff --git a/homeassistant/components/notify/aws_lambda.py b/homeassistant/components/notify/aws_lambda.py index e605f82c3f15a..8f639a653c3bf 100644 --- a/homeassistant/components/notify/aws_lambda.py +++ b/homeassistant/components/notify/aws_lambda.py @@ -38,6 +38,12 @@ def get_service(hass, config, discovery_info=None): """Get the AWS Lambda notification service.""" + _LOGGER.warning( + "aws_lambda notify platform is deprecated, please replace it" + " with aws component. This config will become invalid in version 0.92." + " See https://www.home-assistant.io/components/aws/ for details." + ) + context_str = json.dumps({'custom': config[CONF_CONTEXT]}, cls=JSONEncoder) context_b64 = base64.b64encode(context_str.encode('utf-8')) context = context_b64.decode('utf-8') diff --git a/homeassistant/components/notify/aws_sns.py b/homeassistant/components/notify/aws_sns.py index 9363576fc1ace..7fa0e25b32a21 100644 --- a/homeassistant/components/notify/aws_sns.py +++ b/homeassistant/components/notify/aws_sns.py @@ -35,6 +35,12 @@ def get_service(hass, config, discovery_info=None): """Get the AWS SNS notification service.""" + _LOGGER.warning( + "aws_sns notify platform is deprecated, please replace it" + " with aws component. This config will become invalid in version 0.92." + " See https://www.home-assistant.io/components/aws/ for details." + ) + import boto3 aws_config = config.copy() diff --git a/homeassistant/components/notify/aws_sqs.py b/homeassistant/components/notify/aws_sqs.py index ed22147cfedc3..927824299398b 100644 --- a/homeassistant/components/notify/aws_sqs.py +++ b/homeassistant/components/notify/aws_sqs.py @@ -33,6 +33,12 @@ def get_service(hass, config, discovery_info=None): """Get the AWS SQS notification service.""" + _LOGGER.warning( + "aws_sqs notify platform is deprecated, please replace it" + " with aws component. This config will become invalid in version 0.92." + " See https://www.home-assistant.io/components/aws/ for details." + ) + import boto3 aws_config = config.copy() diff --git a/requirements_all.txt b/requirements_all.txt index b82d38d409d7f..053c1225a02a4 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -105,6 +105,9 @@ aioasuswrt==1.1.21 # homeassistant.components.automatic.device_tracker aioautomatic==0.6.5 +# homeassistant.components.aws +aiobotocore==0.10.2 + # homeassistant.components.dnsip.sensor aiodns==1.1.1 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index adcd2db3c4310..51b49b003ab91 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -40,6 +40,9 @@ aioambient==0.1.3 # homeassistant.components.automatic.device_tracker aioautomatic==0.6.5 +# homeassistant.components.aws +aiobotocore==0.10.2 + # homeassistant.components.emulated_hue # homeassistant.components.http aiohttp_cors==0.7.0 diff --git a/script/gen_requirements_all.py b/script/gen_requirements_all.py index 6912c83f77082..f71791e8f5f3f 100755 --- a/script/gen_requirements_all.py +++ b/script/gen_requirements_all.py @@ -41,6 +41,7 @@ TEST_REQUIREMENTS = ( 'aioambient', 'aioautomatic', + 'aiobotocore', 'aiohttp_cors', 'aiohue', 'aiounifi', diff --git a/tests/components/aws/__init__.py b/tests/components/aws/__init__.py new file mode 100644 index 0000000000000..270922b1e1ed9 --- /dev/null +++ b/tests/components/aws/__init__.py @@ -0,0 +1 @@ +"""Tests for the aws component.""" diff --git a/tests/components/aws/test_init.py b/tests/components/aws/test_init.py new file mode 100644 index 0000000000000..89dd9deaa0ab8 --- /dev/null +++ b/tests/components/aws/test_init.py @@ -0,0 +1,199 @@ +"""Tests for the aws component config and setup.""" +from asynctest import patch as async_patch, MagicMock, CoroutineMock + +from homeassistant.components import aws +from homeassistant.setup import async_setup_component + + +class MockAioSession: + """Mock AioSession.""" + + def __init__(self, *args, **kwargs): + """Init a mock session.""" + + def create_client(self, *args, **kwargs): # pylint: disable=no-self-use + """Create a mocked client.""" + return MagicMock( + __aenter__=CoroutineMock(return_value=CoroutineMock( + get_user=CoroutineMock(), # iam + invoke=CoroutineMock(), # lambda + publish=CoroutineMock(), # sns + send_message=CoroutineMock(), # sqs + )), + __aexit__=CoroutineMock() + ) + + +async def test_empty_config(hass): + """Test a default config will be create for empty config.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': {} + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 1 + assert isinstance(sessions.get('default'), MockAioSession) + + +async def test_empty_credential(hass): + """Test a default config will be create for empty credential section.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': { + 'notify': [{ + 'service': 'lambda', + 'name': 'New Lambda Test', + 'region_name': 'us-east-1', + }] + } + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 1 + assert isinstance(sessions.get('default'), MockAioSession) + + assert hass.services.has_service('notify', 'new_lambda_test') is True + await hass.services.async_call( + 'notify', + 'new_lambda_test', + {'message': 'test', 'target': 'ARN'}, + blocking=True + ) + + +async def test_profile_credential(hass): + """Test credentials with profile name.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': { + 'credentials': { + 'name': 'test', + 'profile_name': 'test-profile', + }, + 'notify': [{ + 'service': 'sns', + 'credential_name': 'test', + 'name': 'SNS Test', + 'region_name': 'us-east-1', + }] + } + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 1 + assert isinstance(sessions.get('test'), MockAioSession) + + assert hass.services.has_service('notify', 'sns_test') is True + await hass.services.async_call( + 'notify', + 'sns_test', + {'title': 'test', 'message': 'test', 'target': 'ARN'}, + blocking=True + ) + + +async def test_access_key_credential(hass): + """Test credentials with access key.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': { + 'credentials': [ + { + 'name': 'test', + 'profile_name': 'test-profile', + }, + { + 'name': 'key', + 'aws_access_key_id': 'test-key', + 'aws_secret_access_key': 'test-secret', + }, + ], + 'notify': [{ + 'service': 'sns', + 'credential_name': 'key', + 'name': 'SNS Test', + 'region_name': 'us-east-1', + }] + } + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 2 + assert isinstance(sessions.get('key'), MockAioSession) + + assert hass.services.has_service('notify', 'sns_test') is True + await hass.services.async_call( + 'notify', + 'sns_test', + {'title': 'test', 'message': 'test', 'target': 'ARN'}, + blocking=True + ) + + +async def test_notify_credential(hass): + """Test notify service can use access key directly.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': { + 'notify': [{ + 'service': 'sqs', + 'credential_name': 'test', + 'name': 'SQS Test', + 'region_name': 'us-east-1', + 'aws_access_key_id': 'some-key', + 'aws_secret_access_key': 'some-secret', + }] + } + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 1 + assert isinstance(sessions.get('default'), MockAioSession) + + assert hass.services.has_service('notify', 'sqs_test') is True + await hass.services.async_call( + 'notify', + 'sqs_test', + {'message': 'test', 'target': 'ARN'}, + blocking=True + ) + + +async def test_notify_credential_profile(hass): + """Test notify service can use profile directly.""" + with async_patch('aiobotocore.AioSession', new=MockAioSession): + await async_setup_component(hass, 'aws', { + 'aws': { + 'notify': [{ + 'service': 'sqs', + 'name': 'SQS Test', + 'region_name': 'us-east-1', + 'profile_name': 'test', + }] + } + }) + await hass.async_block_till_done() + + sessions = hass.data[aws.DATA_SESSIONS] + assert sessions is not None + assert len(sessions) == 1 + assert isinstance(sessions.get('default'), MockAioSession) + + assert hass.services.has_service('notify', 'sqs_test') is True + await hass.services.async_call( + 'notify', + 'sqs_test', + {'message': 'test', 'target': 'ARN'}, + blocking=True + )