diff --git a/homeassistant/components/energyzero/__init__.py b/homeassistant/components/energyzero/__init__.py index 0eac874f1ed74b..8878a99e562a7d 100644 --- a/homeassistant/components/energyzero/__init__.py +++ b/homeassistant/components/energyzero/__init__.py @@ -5,12 +5,23 @@ from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.typing import ConfigType from .const import DOMAIN from .coordinator import EnergyZeroDataUpdateCoordinator -from .services import async_register_services +from .services import async_setup_services PLATFORMS = [Platform.SENSOR] +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up EnergyZero services.""" + + async_setup_services(hass) + + return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: @@ -27,8 +38,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - async_register_services(hass, coordinator) - return True diff --git a/homeassistant/components/energyzero/services.py b/homeassistant/components/energyzero/services.py index fb451c40401f6f..d8e548c22f8d6a 100644 --- a/homeassistant/components/energyzero/services.py +++ b/homeassistant/components/energyzero/services.py @@ -9,6 +9,7 @@ from energyzero import Electricity, Gas, VatOption import voluptuous as vol +from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.core import ( HomeAssistant, ServiceCall, @@ -17,11 +18,13 @@ callback, ) from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers import selector from homeassistant.util import dt as dt_util from .const import DOMAIN from .coordinator import EnergyZeroDataUpdateCoordinator +ATTR_CONFIG_ENTRY: Final = "config_entry" ATTR_START: Final = "start" ATTR_END: Final = "end" ATTR_INCL_VAT: Final = "incl_vat" @@ -30,6 +33,11 @@ ENERGY_SERVICE_NAME: Final = "get_energy_prices" SERVICE_SCHEMA: Final = vol.Schema( { + vol.Required(ATTR_CONFIG_ENTRY): selector.ConfigEntrySelector( + { + "integration": DOMAIN, + } + ), vol.Required(ATTR_INCL_VAT): bool, vol.Optional(ATTR_START): str, vol.Optional(ATTR_END): str, @@ -75,12 +83,43 @@ def __serialize_prices(prices: Electricity | Gas) -> ServiceResponse: } +def __get_coordinator( + hass: HomeAssistant, call: ServiceCall +) -> EnergyZeroDataUpdateCoordinator: + """Get the coordinator from the entry.""" + entry_id: str = call.data[ATTR_CONFIG_ENTRY] + entry: ConfigEntry | None = hass.config_entries.async_get_entry(entry_id) + + if not entry: + raise ServiceValidationError( + f"Invalid config entry: {entry_id}", + translation_domain=DOMAIN, + translation_key="invalid_config_entry", + translation_placeholders={ + "config_entry": entry_id, + }, + ) + if entry.state != ConfigEntryState.LOADED: + raise ServiceValidationError( + f"{entry.title} is not loaded", + translation_domain=DOMAIN, + translation_key="unloaded_config_entry", + translation_placeholders={ + "config_entry": entry.title, + }, + ) + + return hass.data[DOMAIN][entry_id] + + async def __get_prices( call: ServiceCall, *, - coordinator: EnergyZeroDataUpdateCoordinator, + hass: HomeAssistant, price_type: PriceType, ) -> ServiceResponse: + coordinator = __get_coordinator(hass, call) + start = __get_date(call.data.get(ATTR_START)) end = __get_date(call.data.get(ATTR_END)) @@ -108,22 +147,20 @@ async def __get_prices( @callback -def async_register_services( - hass: HomeAssistant, coordinator: EnergyZeroDataUpdateCoordinator -): +def async_setup_services(hass: HomeAssistant) -> None: """Set up EnergyZero services.""" hass.services.async_register( DOMAIN, GAS_SERVICE_NAME, - partial(__get_prices, coordinator=coordinator, price_type=PriceType.GAS), + partial(__get_prices, hass=hass, price_type=PriceType.GAS), schema=SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) hass.services.async_register( DOMAIN, ENERGY_SERVICE_NAME, - partial(__get_prices, coordinator=coordinator, price_type=PriceType.ENERGY), + partial(__get_prices, hass=hass, price_type=PriceType.ENERGY), schema=SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) diff --git a/homeassistant/components/energyzero/services.yaml b/homeassistant/components/energyzero/services.yaml index 1bcc5ae34bef01..dc8df9aa6d0c4c 100644 --- a/homeassistant/components/energyzero/services.yaml +++ b/homeassistant/components/energyzero/services.yaml @@ -1,5 +1,10 @@ get_gas_prices: fields: + config_entry: + required: true + selector: + config_entry: + integration: energyzero incl_vat: required: true default: true @@ -17,6 +22,11 @@ get_gas_prices: datetime: get_energy_prices: fields: + config_entry: + required: true + selector: + config_entry: + integration: energyzero incl_vat: required: true default: true diff --git a/homeassistant/components/energyzero/strings.json b/homeassistant/components/energyzero/strings.json index 81f54f4222ad62..9858838aff7f01 100644 --- a/homeassistant/components/energyzero/strings.json +++ b/homeassistant/components/energyzero/strings.json @@ -12,6 +12,12 @@ "exceptions": { "invalid_date": { "message": "Invalid date provided. Got {date}" + }, + "invalid_config_entry": { + "message": "Invalid config entry provided. Got {config_entry}" + }, + "unloaded_config_entry": { + "message": "Invalid config entry provided. {config_entry} is not loaded." } }, "entity": { @@ -50,6 +56,10 @@ "name": "Get gas prices", "description": "Request gas prices from EnergyZero.", "fields": { + "config_entry": { + "name": "Config Entry", + "description": "The config entry to use for this service." + }, "incl_vat": { "name": "Including VAT", "description": "Include VAT in the prices." @@ -68,6 +78,10 @@ "name": "Get energy prices", "description": "Request energy prices from EnergyZero.", "fields": { + "config_entry": { + "name": "[%key:component::energyzero::services::get_gas_prices::fields::config_entry::name%]", + "description": "[%key:component::energyzero::services::get_gas_prices::fields::config_entry::description%]" + }, "incl_vat": { "name": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::name%]", "description": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::description%]" diff --git a/tests/components/energyzero/test_services.py b/tests/components/energyzero/test_services.py index 7939b06ce8e0e3..c0b54729e03360 100644 --- a/tests/components/energyzero/test_services.py +++ b/tests/components/energyzero/test_services.py @@ -6,12 +6,15 @@ from homeassistant.components.energyzero.const import DOMAIN from homeassistant.components.energyzero.services import ( + ATTR_CONFIG_ENTRY, ENERGY_SERVICE_NAME, GAS_SERVICE_NAME, ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import ServiceValidationError +from tests.common import MockConfigEntry + @pytest.mark.usefixtures("init_integration") async def test_has_services( @@ -29,6 +32,7 @@ async def test_has_services( @pytest.mark.parametrize("end", [{"end": "2023-01-01 00:00:00"}, {}]) async def test_service( hass: HomeAssistant, + mock_config_entry: MockConfigEntry, snapshot: SnapshotAssertion, service: str, incl_vat: dict[str, bool], @@ -36,8 +40,9 @@ async def test_service( end: dict[str, str], ) -> None: """Test the EnergyZero Service.""" + entry = {ATTR_CONFIG_ENTRY: mock_config_entry.entry_id} - data = incl_vat | start | end + data = entry | incl_vat | start | end assert snapshot == await hass.services.async_call( DOMAIN, @@ -48,32 +53,72 @@ async def test_service( ) +@pytest.fixture +def config_entry_data( + mock_config_entry: MockConfigEntry, request: pytest.FixtureRequest +) -> dict[str, str]: + """Fixture for the config entry.""" + if "config_entry" in request.param and request.param["config_entry"] is True: + return {"config_entry": mock_config_entry.entry_id} + + return request.param + + @pytest.mark.usefixtures("init_integration") @pytest.mark.parametrize("service", [GAS_SERVICE_NAME, ENERGY_SERVICE_NAME]) @pytest.mark.parametrize( - ("service_data", "error", "error_message"), + ("config_entry_data", "service_data", "error", "error_message"), [ - ({}, vol.er.Error, "required key not provided .+"), + ({}, {}, vol.er.Error, "required key not provided .+"), ( + {"config_entry": True}, + {}, + vol.er.Error, + "required key not provided .+", + ), + ( + {}, + {"incl_vat": True}, + vol.er.Error, + "required key not provided .+", + ), + ( + {"config_entry": True}, {"incl_vat": "incorrect vat"}, vol.er.Error, "expected bool for dictionary value .+", ), ( - {"incl_vat": True, "start": "incorrect date"}, + {"config_entry": "incorrect entry"}, + {"incl_vat": True}, + ServiceValidationError, + "Invalid config entry.+", + ), + ( + {"config_entry": True}, + { + "incl_vat": True, + "start": "incorrect date", + }, ServiceValidationError, "Invalid datetime provided.", ), ( - {"incl_vat": True, "end": "incorrect date"}, + {"config_entry": True}, + { + "incl_vat": True, + "end": "incorrect date", + }, ServiceValidationError, "Invalid datetime provided.", ), ], + indirect=["config_entry_data"], ) async def test_service_validation( hass: HomeAssistant, service: str, + config_entry_data: dict[str, str], service_data: dict[str, str], error: type[Exception], error_message: str, @@ -84,7 +129,32 @@ async def test_service_validation( await hass.services.async_call( DOMAIN, service, - service_data, + config_entry_data | service_data, + blocking=True, + return_response=True, + ) + + +@pytest.mark.usefixtures("init_integration") +@pytest.mark.parametrize("service", [GAS_SERVICE_NAME, ENERGY_SERVICE_NAME]) +async def test_service_called_with_unloaded_entry( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + service: str, +) -> None: + """Test service calls with unloaded config entry.""" + + await mock_config_entry.async_unload(hass) + + data = {"config_entry": mock_config_entry.entry_id, "incl_vat": True} + + with pytest.raises( + ServiceValidationError, match=f"{mock_config_entry.title} is not loaded" + ): + await hass.services.async_call( + DOMAIN, + service, + data, blocking=True, return_response=True, )