Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework MQTT config merging and adding defaults #90529

Merged
merged 7 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 22 additions & 89 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SERVICE_RELOAD,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.exceptions import ConfigEntryError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
Expand All @@ -45,11 +45,7 @@
publish,
subscribe,
)
from .config_integration import (
CONFIG_SCHEMA_ENTRY,
DEFAULT_VALUES,
PLATFORM_CONFIG_SCHEMA_BASE,
)
from .config_integration import CONFIG_SCHEMA_ENTRY, PLATFORM_CONFIG_SCHEMA_BASE
from .const import ( # noqa: F401
ATTR_PAYLOAD,
ATTR_QOS,
Expand Down Expand Up @@ -83,6 +79,7 @@
)
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
Expand All @@ -102,8 +99,6 @@
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"

MANDATORY_DEFAULT_VALUES = (CONF_PORT, CONF_DISCOVERY_PREFIX)

ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"

Expand Down Expand Up @@ -193,50 +188,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True


def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove unknown keys from config entry data.

Extra keys may have been added when importing MQTT yaml configuration.
"""
filtered_data = {
k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
}
if entry.data.keys() != filtered_data.keys():
_LOGGER.warning(
(
"The following unsupported configuration options were removed from the "
"MQTT config entry: %s"
),
entry.data.keys() - filtered_data.keys(),
)
hass.config_entries.async_update_entry(entry, data=filtered_data)


async def _async_auto_mend_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Mends config fetched from config entry and adds missing values.

This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True

if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)


def _merge_extended_config(entry: ConfigEntry, conf: ConfigType) -> dict[str, Any]:
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}


async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated.

Expand All @@ -245,45 +196,29 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
await hass.config_entries.async_reload(entry.entry_id)


async def async_fetch_config(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any] | None:
"""Fetch fresh MQTT yaml config from the hass config."""
mqtt_data = get_mqtt_data(hass)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))

# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)

# Add missing defaults to migrate older config entries
await _async_auto_mend_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return None

# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))

# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
return conf


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data = get_mqtt_data(hass, True)
# validate entry config
try:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
except vol.MultipleInvalid as ex:
raise ConfigEntryError(
f"The MQTT config entry is invalid, please correct it: {ex}"
) from ex
jbouwh marked this conversation as resolved.
Show resolved Hide resolved

# Fetch configuration and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out
return False
# Fetch configuration and add default values
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)

await async_create_certificate_temp_files(hass, dict(entry.data))
mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
Expand Down Expand Up @@ -349,7 +284,7 @@ async def async_publish_service(call: ServiceCall) -> None:
)
return

assert mqtt_data.client is not None and msg_topic is not None
assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)

hass.services.async_register(
Expand Down Expand Up @@ -585,7 +520,6 @@ def unsubscribe() -> None:
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
return mqtt_data.client.connected


Expand All @@ -603,7 +537,6 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client

# Unload publish and dump services.
Expand Down
28 changes: 16 additions & 12 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .models import (
AsyncMessageCallbackType,
MessageCallbackType,
MqttData,
PublishMessage,
PublishPayloadType,
ReceiveMessage,
Expand Down Expand Up @@ -111,11 +112,11 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
Expand Down Expand Up @@ -161,11 +162,11 @@ async def async_subscribe(

Call the return value to unsubscribe.
"""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
# Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0
non_default = 0
Expand Down Expand Up @@ -377,19 +378,16 @@ class MQTT:

_mqttc: mqtt.Client
_last_subscribe: float
_mqtt_data: MqttData

def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
conf: ConfigType,
self, hass: HomeAssistant, config_entry: ConfigEntry, conf: ConfigType
) -> None:
"""Initialize Home Assistant MQTT client."""
self._mqtt_data = get_mqtt_data(hass)

self.hass = hass
self.config_entry = config_entry
self.conf = conf

self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = []
self.connected = False
Expand All @@ -415,8 +413,6 @@ def ha_started(_: Event) -> None:

self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)

self.init_client()

async def async_stop_mqtt(_event: Event) -> None:
"""Stop MQTT component."""
await self.async_disconnect()
Expand All @@ -425,6 +421,14 @@ async def async_stop_mqtt(_event: Event) -> None:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
)

def start(
self,
mqtt_data: MqttData,
) -> None:
"""Start Home Assistant MQTT client."""
self._mqtt_data = mqtt_data
self.init_client()

@property
def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions."""
Expand Down
59 changes: 14 additions & 45 deletions homeassistant/components/mqtt/config_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@

DEFAULT_TLS_PROTOCOL = "auto"

DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX,
CONF_PORT: DEFAULT_PORT,
CONF_PROTOCOL: DEFAULT_PROTOCOL,
CONF_TRANSPORT: DEFAULT_TRANSPORT,
CONF_WILL_MESSAGE: DEFAULT_WILL,
CONF_KEEPALIVE: DEFAULT_KEEPALIVE,
}

PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{
Platform.ALARM_CONTROL_PANEL.value: vol.All(
Expand Down Expand Up @@ -169,9 +158,11 @@
CONFIG_SCHEMA_ENTRY = vol.Schema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up, I think we can drop the schema? There's no reason to validate config entry options / data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema adds default values. If we want to drop the schema we need to change the config flow to save basic broker settings with broker CONFIG_SCHEMA_ENTRY. So I don't see an option to drop it yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, but there's no need to add default values via a schema, just use dict.get instead and pass the wanted default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's implement this in a follow up PR

{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Coerce(int), vol.Range(min=15)
),
vol.Required(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): str,
Expand All @@ -180,13 +171,17 @@
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): str,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
cv.string, vol.In(SUPPORTED_PROTOCOLS)
),
vol.Optional(CONF_WILL_MESSAGE, default=DEFAULT_WILL): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE, default=DEFAULT_BIRTH): valid_birth_will,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All(
cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS])
),
Expand All @@ -195,32 +190,6 @@
}
)

CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
}
)

DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE,
CONF_BROKER,
Expand Down
1 change: 0 additions & 1 deletion homeassistant/components/mqtt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ async def async_will_remove_from_hass(self) -> None:
def available(self) -> bool:
"""Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/mqtt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def write_state_request(self, entity: Entity) -> None:
class MqttData:
"""Keep the MQTT entry data."""

client: MQTT | None = None
config: ConfigType | None = None
client: MQTT
config: ConfigType
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict
Expand Down
5 changes: 1 addition & 4 deletions homeassistant/components/mqtt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,9 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config


def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData
if ensure_exists:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just removing the check, maybe raise instead:

if DATA_MQTT not in hass.data:
    raise HomeAssistantError

Or maybe the KeyError is good enough?

Copy link
Contributor Author

@jbouwh jbouwh Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeyError should be enough I think

mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data

Expand Down
Loading