Skip to content

Commit

Permalink
make sure service has a unique name
Browse files Browse the repository at this point in the history
  • Loading branch information
jbouwh committed Jan 24, 2022
1 parent 10b3864 commit 76343bc
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 94 deletions.
58 changes: 49 additions & 9 deletions homeassistant/components/mqtt/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
CONF_TARGETS = "targets"
CONF_TITLE = "title"

MQTT_EVENT_RELOADED = "event_{}_reloaded"
MQTT_NOTIFY_SERVICES_SETUP = "mqtt_notify_services_setup"

PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
{
Expand Down Expand Up @@ -87,7 +87,19 @@ async def _async_setup_notify(
):
"""Set up the MQTT notify service with auto discovery."""
config = DISCOVERY_SCHEMA(discovery_data[ATTR_DISCOVERY_PAYLOAD])
service_name = slugify(config.get(CONF_NAME) or DOMAIN)
services = hass.data.setdefault(MQTT_NOTIFY_SERVICES_SETUP, set())
has_services = hass.services.has_service(notify.DOMAIN, service_name)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
if service_name in services or has_services:
_LOGGER.error(
"Notify service '%s' already exists, cannot register service(s)",
service_name,
)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
clear_discovery_hash(hass, discovery_hash)
return
services.add(service_name)
device_id = None
if CONF_DEVICE in config:
await _update_device(hass, config_entry, config)
Expand All @@ -114,7 +126,6 @@ async def _async_setup_notify(
device_id=device_id,
config_entry=config_entry,
)
service_name = slugify(config.get(CONF_NAME, DOMAIN))
await service.async_setup(hass, service_name, service_name)
await service.async_register_services()

Expand All @@ -128,7 +139,16 @@ async def async_get_service(
name = config.get(CONF_NAME)
if CONF_NAME not in config:
config[CONF_NAME] = DOMAIN

service_name = slugify(name or DOMAIN)
services = hass.data.setdefault(MQTT_NOTIFY_SERVICES_SETUP, set())
has_services = hass.services.has_service(notify.DOMAIN, service_name)
if service_name in services or has_services:
_LOGGER.error(
"Notify service '%s' is not unique, cannot register service(s)",
service_name,
)
return None
services.add(service_name)
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
return MqttNotificationService(
hass,
Expand Down Expand Up @@ -164,7 +184,7 @@ async def async_discovery_update(
# update notify service through auto discovery
await service.async_update_service(discovery_payload)
_LOGGER.debug(
"Notify service %s has been updated",
"Notify service %s updated has been processed",
service.discovery_hash,
)

Expand All @@ -182,6 +202,9 @@ async def async_device_removed(event):

async def async_tear_down_service():
"""Handle the removal of the service."""
services = hass.data.setdefault(MQTT_NOTIFY_SERVICES_SETUP, set())
if self._service.service_name in services:
services.remove(self._service.service_name)
if not self._device_removed and service.device_id:
self._device_removed = True
await cleanup_device_registry(hass, service.device_id)
Expand Down Expand Up @@ -245,28 +268,42 @@ def __init__(
self._discovery_hash = discovery_hash
self._device_id = device_id
self._config_entry = config_entry
self._service_name = slugify(name or command_topic)
self._service_name = slugify(name or DOMAIN)

self._updater = (
MqttNotificationServiceUpdater(hass, self) if discovery_hash else None
)

@property
def device_id(self) -> str | None:
"""Return the device ID."""
return self._device_id

@property
def discovery_hash(self) -> tuple | None:
"""Return the discovery hash."""
return self._discovery_hash

@property
def device_id(self) -> str | None:
"""Return the device ID."""
return self._device_id
def service_name(self) -> str:
"""Return the service ma,e."""
return self._service_name

async def async_update_service(
self,
discovery_payload: DiscoveryInfoType,
) -> None:
"""Update the notify service through auto discovery."""
config = DISCOVERY_SCHEMA(discovery_payload)
new_service_name = slugify(config.get(CONF_NAME, DOMAIN))
if new_service_name != self._service_name and self.hass.services.has_service(
notify.DOMAIN, new_service_name
):
_LOGGER.error(
"Notify service '%s' already exists, cannot update the existing service(s)",
new_service_name,
)
return
self._command_topic = config[CONF_COMMAND_TOPIC]
self._command_template = MqttCommandTemplate(
config.get(CONF_COMMAND_TEMPLATE), hass=self.hass
Expand All @@ -276,15 +313,18 @@ async def async_update_service(
self._qos = config[CONF_QOS]
self._retain = config[CONF_RETAIN]
self._title = config[CONF_TITLE]
new_service_name = slugify(config.get(CONF_NAME, DOMAIN))
if (
new_service_name != self._service_name
or config[CONF_TARGETS] != self._targets
):
services = self.hass.data.setdefault(MQTT_NOTIFY_SERVICES_SETUP, set())
await self.async_unregister_services()
if self._service_name in services:
services.remove(self._service_name)
self._targets = config[CONF_TARGETS]
self._service_name = new_service_name
await self.async_register_services()
services.add(new_service_name)
if self.device_id:
await _update_device(self.hass, self._config_entry, config)

Expand Down
Loading

0 comments on commit 76343bc

Please sign in to comment.