Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework MQTT config merging and adding defaults #90529

Merged
merged 7 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 18 additions & 85 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SERVICE_RELOAD,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.exceptions import ConfigEntryError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
Expand All @@ -45,11 +45,7 @@
publish,
subscribe,
)
from .config_integration import (
CONFIG_SCHEMA_ENTRY,
DEFAULT_VALUES,
PLATFORM_CONFIG_SCHEMA_BASE,
)
from .config_integration import CONFIG_SCHEMA_ENTRY, PLATFORM_CONFIG_SCHEMA_BASE
from .const import ( # noqa: F401
ATTR_PAYLOAD,
ATTR_QOS,
Expand Down Expand Up @@ -83,6 +79,7 @@
)
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
Expand All @@ -102,8 +99,6 @@
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"

MANDATORY_DEFAULT_VALUES = (CONF_PORT, CONF_DISCOVERY_PREFIX)

ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"

Expand Down Expand Up @@ -193,50 +188,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True


def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove unknown keys from config entry data.

Extra keys may have been added when importing MQTT yaml configuration.
"""
filtered_data = {
k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
}
if entry.data.keys() != filtered_data.keys():
_LOGGER.warning(
(
"The following unsupported configuration options were removed from the "
"MQTT config entry: %s"
),
entry.data.keys() - filtered_data.keys(),
)
hass.config_entries.async_update_entry(entry, data=filtered_data)


async def _async_auto_mend_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Mends config fetched from config entry and adds missing values.

This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True

if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)


def _merge_extended_config(entry: ConfigEntry, conf: ConfigType) -> dict[str, Any]:
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}


async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated.

Expand All @@ -245,42 +196,24 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
await hass.config_entries.async_reload(entry.entry_id)


async def async_fetch_config(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any] | None:
"""Fetch fresh MQTT yaml config from the hass config."""
mqtt_data = get_mqtt_data(hass)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))

# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)

# Add missing defaults to migrate older config entries
await _async_auto_mend_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return None

# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))

# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
return conf


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data = get_mqtt_data(hass, True)
# validate entry config
try:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
except vol.MultipleInvalid as ex:
raise ConfigEntryError(
f"The MQTT config entry is invalid, please correct it: {ex}"
) from ex
jbouwh marked this conversation as resolved.
Show resolved Hide resolved

# Fetch configuration and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out
return False
# Fetch configuration and add default values
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
else:
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml)
jbouwh marked this conversation as resolved.
Show resolved Hide resolved

await async_create_certificate_temp_files(hass, dict(entry.data))
mqtt_data.client = MQTT(hass, entry, conf)
Expand Down
10 changes: 6 additions & 4 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
Expand Down Expand Up @@ -161,11 +162,12 @@ async def async_subscribe(

Call the return value to unsubscribe.
"""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
jbouwh marked this conversation as resolved.
Show resolved Hide resolved
# Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0
non_default = 0
Expand Down
59 changes: 14 additions & 45 deletions homeassistant/components/mqtt/config_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@

DEFAULT_TLS_PROTOCOL = "auto"

DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX,
CONF_PORT: DEFAULT_PORT,
CONF_PROTOCOL: DEFAULT_PROTOCOL,
CONF_TRANSPORT: DEFAULT_TRANSPORT,
CONF_WILL_MESSAGE: DEFAULT_WILL,
CONF_KEEPALIVE: DEFAULT_KEEPALIVE,
}

PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{
Platform.ALARM_CONTROL_PANEL.value: vol.All(
Expand Down Expand Up @@ -169,9 +158,11 @@
CONFIG_SCHEMA_ENTRY = vol.Schema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up, I think we can drop the schema? There's no reason to validate config entry options / data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema adds default values. If we want to drop the schema we need to change the config flow to save basic broker settings with broker CONFIG_SCHEMA_ENTRY. So I don't see an option to drop it yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, but there's no need to add default values via a schema, just use dict.get instead and pass the wanted default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's implement this in a follow up PR

{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Coerce(int), vol.Range(min=15)
),
vol.Required(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): str,
Expand All @@ -180,13 +171,17 @@
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): str,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
cv.string, vol.In(SUPPORTED_PROTOCOLS)
),
vol.Optional(CONF_WILL_MESSAGE, default=DEFAULT_WILL): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE, default=DEFAULT_BIRTH): valid_birth_will,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All(
cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS])
),
Expand All @@ -195,32 +190,6 @@
}
)

CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
}
)

DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE,
CONF_BROKER,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/mqtt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def write_state_request(self, entity: Entity) -> None:
class MqttData:
"""Keep the MQTT entry data."""

config: ConfigType
client: MQTT | None = None
config: ConfigType | None = None
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict
Expand Down
5 changes: 1 addition & 4 deletions homeassistant/components/mqtt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,9 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config


def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData
if ensure_exists:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just removing the check, maybe raise instead:

if DATA_MQTT not in hass.data:
    raise HomeAssistantError

Or maybe the KeyError is good enough?

Copy link
Contributor Author

@jbouwh jbouwh Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeyError should be enough I think

mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data

Expand Down
8 changes: 2 additions & 6 deletions tests/components/mqtt/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,14 @@ async def test_user_connection_works(
assert result["type"] == "form"

result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1", "advanced_options": False}
Copy link
Contributor Author

@jbouwh jbouwh Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"advanced_options" was filtered out, that was why tests pass, it should not be a setting in the entry, and in real live this setting was never added to the MQTT config entry data.

result["flow_id"], {"broker": "127.0.0.1"}
)

assert result["type"] == "create_entry"
assert result["result"].data == {
"broker": "127.0.0.1",
"port": 1883,
"discovery": True,
"discovery_prefix": "homeassistant",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default settings are not added to the entry by default, they are added during entry load by applying the schema.

}
# Check we tried the connection
assert len(mock_try_connection.mock_calls) == 1
Expand Down Expand Up @@ -231,7 +230,6 @@ async def test_user_v5_connection_works(
assert result["result"].data == {
"broker": "another-broker",
"discovery": True,
"discovery_prefix": "homeassistant",
"port": 2345,
"protocol": "5",
}
Expand Down Expand Up @@ -283,15 +281,14 @@ async def test_manual_config_set(
assert result["type"] == "form"

result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1"}
result["flow_id"], {"broker": "127.0.0.1", "port": "1883"}
)

assert result["type"] == "create_entry"
assert result["result"].data == {
"broker": "127.0.0.1",
"port": 1883,
"discovery": True,
"discovery_prefix": "homeassistant",
}
# Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with(
Expand Down Expand Up @@ -395,7 +392,6 @@ async def test_hassio_confirm(
"username": "mock-user",
"password": "mock-pass",
"discovery": True,
"discovery_prefix": "homeassistant",
}
# Check we tried the connection
assert len(mock_try_connection_success.mock_calls)
Expand Down
3 changes: 3 additions & 0 deletions tests/components/mqtt/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"retain": False,
"topic": "homeassistant/status",
},
"ws_headers": {},
"ws_path": "/",
}


Expand Down Expand Up @@ -265,6 +267,7 @@ async def test_redact_diagnostics(
"name_by_user": None,
}

await get_diagnostics_for_config_entry(hass, hass_client, config_entry)
assert await get_diagnostics_for_config_entry(hass, hass_client, config_entry) == {
"connected": True,
"devices": [expected_device],
Expand Down
Loading