-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Significate additions for cloud and local models. Added language and removed scheduler.
- Loading branch information
1 parent
6655ee3
commit 5884036
Showing
12 changed files
with
1,533 additions
and
394 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,121 @@ | ||
"""The AI Automation Suggester integration.""" | ||
# custom_components/ai_automation_suggester/__init__.py | ||
|
||
"""The AI Automation Suggester integration.""" | ||
import logging | ||
from homeassistant.config_entries import ConfigEntry | ||
from homeassistant.core import HomeAssistant | ||
from .const import DOMAIN, PLATFORMS | ||
from homeassistant.core import HomeAssistant, ServiceCall | ||
from homeassistant.exceptions import ConfigEntryNotReady, ServiceValidationError | ||
from homeassistant.helpers.typing import ConfigType | ||
|
||
from .const import ( | ||
DOMAIN, | ||
PLATFORMS, | ||
CONF_PROVIDER, | ||
SERVICE_GENERATE_SUGGESTIONS, | ||
ATTR_PROVIDER_CONFIG, | ||
ATTR_CUSTOM_PROMPT, | ||
) | ||
from .coordinator import AIAutomationCoordinator | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
async def async_setup(hass: HomeAssistant, config: dict): | ||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: | ||
"""Set up the AI Automation Suggester component.""" | ||
hass.data.setdefault(DOMAIN, {}) | ||
return True | ||
|
||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): | ||
"""Set up AI Automation Suggester from a config entry.""" | ||
coordinator = AIAutomationCoordinator(hass, entry) | ||
hass.data[DOMAIN][entry.entry_id] = coordinator | ||
|
||
await coordinator.async_config_entry_first_refresh() | ||
|
||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) | ||
|
||
async def handle_generate_suggestions(call): | ||
"""Handle the service call to generate suggestions.""" | ||
await coordinator.async_request_refresh() | ||
|
||
hass.services.async_register(DOMAIN, "generate_suggestions", handle_generate_suggestions) | ||
async def handle_generate_suggestions(call: ServiceCall) -> None: | ||
"""Handle the generate_suggestions service call.""" | ||
provider_config = call.data.get(ATTR_PROVIDER_CONFIG) | ||
custom_prompt = call.data.get(ATTR_CUSTOM_PROMPT) | ||
|
||
try: | ||
coordinator = None | ||
if provider_config: | ||
coordinator = hass.data[DOMAIN][provider_config] | ||
else: | ||
for entry_id, coord in hass.data[DOMAIN].items(): | ||
if isinstance(coord, AIAutomationCoordinator): | ||
coordinator = coord | ||
break | ||
|
||
if coordinator is None: | ||
raise ServiceValidationError("No AI Automation Suggester provider configured") | ||
|
||
if custom_prompt: | ||
original_prompt = coordinator.SYSTEM_PROMPT | ||
try: | ||
coordinator.SYSTEM_PROMPT = custom_prompt | ||
await coordinator.async_request_refresh() | ||
finally: | ||
coordinator.SYSTEM_PROMPT = original_prompt | ||
else: | ||
await coordinator.async_request_refresh() | ||
|
||
except KeyError: | ||
raise ServiceValidationError(f"Provider configuration not found") | ||
except Exception as err: | ||
raise ServiceValidationError(f"Failed to generate suggestions: {err}") | ||
|
||
# Register the service | ||
hass.services.async_register( | ||
DOMAIN, | ||
SERVICE_GENERATE_SUGGESTIONS, | ||
handle_generate_suggestions | ||
) | ||
|
||
return True | ||
|
||
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry): | ||
"""Migrate old entry.""" | ||
_LOGGER.debug(f"Starting migration for entry version {entry.version}") | ||
|
||
if entry.version == 1: | ||
# Example: If moving from version 1 to 2, make changes to the data | ||
new_data = {**entry.data} | ||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
"""Set up AI Automation Suggester from a config entry.""" | ||
try: | ||
# Ensure required config values are present | ||
if CONF_PROVIDER not in entry.data: | ||
raise ConfigEntryNotReady("Provider not specified in config") | ||
|
||
# Create and store coordinator | ||
coordinator = AIAutomationCoordinator(hass, entry) | ||
hass.data[DOMAIN][entry.entry_id] = coordinator | ||
|
||
# Set up platforms | ||
for platform in PLATFORMS: | ||
try: | ||
await hass.config_entries.async_forward_entry_setup(entry, platform) | ||
except Exception as err: | ||
_LOGGER.error("Failed to setup platform %s: %s", platform, err) | ||
raise ConfigEntryNotReady from err | ||
|
||
_LOGGER.debug( | ||
"Setup complete for %s with provider %s", | ||
entry.title, | ||
entry.data.get(CONF_PROVIDER) | ||
) | ||
|
||
entry.async_on_unload(entry.add_update_listener(async_reload_entry)) | ||
|
||
# Handle any changes in your schema or structure | ||
if 'scan_frequency' not in new_data: | ||
new_data['scan_frequency'] = 24 # Set a default scan frequency if it doesn't exist | ||
return True | ||
|
||
if 'initial_lag_time' not in new_data: | ||
new_data['initial_lag_time'] = 10 # Add default lag time if missing | ||
except Exception as err: | ||
_LOGGER.error("Failed to setup integration: %s", err) | ||
raise ConfigEntryNotReady from err | ||
|
||
# Update the entry data | ||
entry.version = 2 | ||
hass.config_entries.async_update_entry(entry, data=new_data) | ||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
"""Unload a config entry.""" | ||
try: | ||
# Unload platforms | ||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) | ||
|
||
if unload_ok: | ||
# Clean up coordinator | ||
coordinator = hass.data[DOMAIN].pop(entry.entry_id) | ||
await coordinator.async_shutdown() | ||
|
||
_LOGGER.info(f"Migration to version {entry.version} successful") | ||
return unload_ok | ||
|
||
return True | ||
except Exception as err: | ||
_LOGGER.error("Error unloading entry: %s", err) | ||
return False | ||
|
||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): | ||
"""Unload a config entry.""" | ||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) | ||
hass.data[DOMAIN].pop(entry.entry_id) | ||
return unload_ok | ||
async def async_reload_entry(hass: HomeAssistant, entry: ConfigEntry) -> None: | ||
"""Reload config entry.""" | ||
await async_unload_entry(hass, entry) | ||
await async_setup_entry(hass, entry) |
Oops, something went wrong.