diff --git a/home_assistant_datasets/fixtures.py b/home_assistant_datasets/fixtures.py index d2f74bb3..c8848925 100644 --- a/home_assistant_datasets/fixtures.py +++ b/home_assistant_datasets/fixtures.py @@ -1,10 +1,10 @@ """Test fixtures for evaluations.""" -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator import logging import os import pathlib -from typing import Any, TextIO +from typing import Any, TextIO, cast from unittest.mock import patch, mock_open import pytest @@ -73,7 +73,7 @@ def mock_synthetic_home_content(synthetic_home_config: str | None) -> str | None @pytest.fixture(autouse=True, name="synthetic_home_config_entry") async def mock_synthetic_home( hass: HomeAssistant, synthetic_home_yaml: str | None -) -> Generator[MockConfigEntry | None]: +) -> AsyncGenerator[MockConfigEntry | None, None]: """Fixture for mock configuration entry.""" if synthetic_home_yaml is None: yield None @@ -163,19 +163,20 @@ async def async_process(self, hass: HomeAssistant, text: str) -> str: blocking=True, return_response=True, ) - response = service_response["response"] - return response["speech"]["plain"]["speech"] + assert service_response + response: dict[str, Any] = service_response["response"] + return str(response["speech"]["plain"]["speech"]) @pytest.fixture(name="conversation_agent_id") async def mock_conversation_agent_id( model_config: ModelConfig, - conversation_agent_config_entry: MockConfigEntry | None, + conversation_agent_config_entry: MockConfigEntry, ) -> str: """Return the id for the conversation agent under test.""" if model_config.domain == "homeassistant": return "conversation.home_assistant" - return conversation_agent_config_entry.entry_id + return cast(str, conversation_agent_config_entry.entry_id) @pytest.fixture(name="agent") @@ -200,6 +201,7 @@ def open(self) -> None: def write(self, record: dict[str, Any]) -> None: """Write an eval record.""" + assert self._fd self._fd.write(yaml.dump(record, sort_keys=False, explicit_start=True)) self._fd.flush() self._records += 1 diff --git a/home_assistant_datasets/model_client.py b/home_assistant_datasets/model_client.py index f47be35b..0813581b 100644 --- a/home_assistant_datasets/model_client.py +++ b/home_assistant_datasets/model_client.py @@ -16,14 +16,8 @@ def complete(self, prompt: str, user_message: str) -> str: response = self.client.chat.completions.create( model=self.model_id, messages=[ - { - "content": prompt, - "role": "system" - }, - { - "content": user_message, - "role": "user" - } + {"content": prompt, "role": "system"}, + {"content": user_message, "role": "user"}, ], ) - return "".join([choice.message.content for choice in response.choices]) + return "".join([choice.message.content or "" for choice in response.choices]) diff --git a/home_assistant_datasets/secrets.py b/home_assistant_datasets/secrets.py index 47aaef8a..ef46ee3a 100644 --- a/home_assistant_datasets/secrets.py +++ b/home_assistant_datasets/secrets.py @@ -35,4 +35,4 @@ def get_secret(secret_name: str) -> str: "Could not find secret_name %s in keys (%s)", secret_name, secrets.keys() ) raise KeyError(f"Could not find '{secret_name}' in secrets file {secrets_file}") - return secrets[secret_name] + return str(secrets[secret_name]) diff --git a/home_assistant_datasets/tool/assist/conftest.py b/home_assistant_datasets/tool/assist/conftest.py index 1926bd61..0c318f29 100644 --- a/home_assistant_datasets/tool/assist/conftest.py +++ b/home_assistant_datasets/tool/assist/conftest.py @@ -3,6 +3,8 @@ import pathlib import logging import datetime +from typing import Any +from collections.abc import Generator import pytest from homeassistant.util import dt as dt_util @@ -15,7 +17,7 @@ _LOGGER = logging.getLogger(__name__) -def pytest_addoption(parser): +def pytest_addoption(parser: Any) -> None: """Pytest arguments passed from the `collect` action to the test.""" parser.addoption("--dataset") parser.addoption("--models") @@ -23,7 +25,7 @@ def pytest_addoption(parser): parser.addoption("--categories") -def pytest_generate_tests(metafunc) -> None: +def pytest_generate_tests(metafunc: Any) -> None: """Generate test parameters for the evaluation from flags.""" # Parameterize tests by the models under development models = metafunc.config.getoption("models").split(",") @@ -47,7 +49,7 @@ def pytest_generate_tests(metafunc) -> None: raise ValueError(f"Could not find any dataset files in path: {dataset}") categories_str = metafunc.config.getoption("categories") - categories = set(categories_str.split(",")) if categories_str else {} + categories = set(categories_str.split(",") if categories_str else {}) dataset_path = pathlib.Path(dataset) output_path = pathlib.Path(output_dir) @@ -72,16 +74,15 @@ def pytest_generate_tests(metafunc) -> None: @pytest.fixture(autouse=True) -def restore_tz() -> None: +def restore_tz() -> Generator[None, None]: yield # Home Assistant teardown seems to run too soon and expects this so try to # patch it in first. dt_util.set_default_time_zone(datetime.UTC) - @pytest.fixture(name="eval_output_file") -def eval_output_file_fixture(model_id: str, eval_task: EvalTask) -> str: +def eval_output_file_fixture(model_id: str, eval_task: EvalTask) -> pathlib.Path: """Sets the output filename for the evaluation run. This output file needs to be unique across the test instances to avoid overwriting. For diff --git a/home_assistant_datasets/tool/assist/data_model.py b/home_assistant_datasets/tool/assist/data_model.py index 0b7cf3bd..8f3c6e27 100644 --- a/home_assistant_datasets/tool/assist/data_model.py +++ b/home_assistant_datasets/tool/assist/data_model.py @@ -7,7 +7,7 @@ import logging from typing import Any from dataclasses import dataclass, field -from collections.abc import AsyncGenerator +from collections.abc import Generator import pathlib from slugify import slugify @@ -117,7 +117,7 @@ def generate_tasks( dataset_path: pathlib.Path, output_dir: pathlib.Path, categories: set[str], -) -> AsyncGenerator[EvalTask, None]: +) -> Generator[EvalTask, None, None]: """Read and validate the dataset.""" # Generate the record id based on the file path relpath = record_path.relative_to(dataset_path) @@ -137,7 +137,7 @@ def generate_tasks( "Skipping record with category %s (not in %s)", record.category, categories ) return - for action in record.tests: + for action in record.tests or (): if not action.sentences: raise ValueError("No sentences defined for the action") if not action.expect_changes: diff --git a/home_assistant_datasets/tool/assist/eval.py b/home_assistant_datasets/tool/assist/eval.py index 964c1253..8d1cb961 100644 --- a/home_assistant_datasets/tool/assist/eval.py +++ b/home_assistant_datasets/tool/assist/eval.py @@ -34,7 +34,7 @@ def find_llm_call(trace_events: list[dict[str, Any]]) -> dict[str, Any] | None: event for event in trace_events if event["event_type"] - in (trace.ConversationTraceEventType.TOOL_CALL, "llm_tool_call") + in (trace.ConversationTraceEventType.TOOL_CALL, "llm_tool_call") # type: ignore[attr-defined] ), None, ) @@ -48,7 +48,7 @@ def find_llm_call(trace_events: list[dict[str, Any]]) -> dict[str, Any] | None: } -def yaml_decoder(data: EncodedData) -> dict[Any, Any]: +def yaml_decoder(data: EncodedData) -> Any: return yaml.load(data, yaml.UnsafeLoader) diff --git a/home_assistant_datasets/tool/assist/eval_output.py b/home_assistant_datasets/tool/assist/eval_output.py index a5eba297..8e2fa7c7 100644 --- a/home_assistant_datasets/tool/assist/eval_output.py +++ b/home_assistant_datasets/tool/assist/eval_output.py @@ -29,7 +29,7 @@ class OutputType(enum.StrEnum): class WriterBase: """Base class for eval output.""" - diff: dict | str = dict + diff: type[dict] | type[str] = dict def start(self) -> None: """Write the output start.""" diff --git a/home_assistant_datasets/tool/assist/test_collect.py b/home_assistant_datasets/tool/assist/test_collect.py index 32548553..c9c7838b 100644 --- a/home_assistant_datasets/tool/assist/test_collect.py +++ b/home_assistant_datasets/tool/assist/test_collect.py @@ -5,7 +5,7 @@ import logging import uuid import dataclasses -from typing import Any +from typing import Any, cast import enum import json @@ -30,6 +30,7 @@ _LOGGER = logging.getLogger(__name__) TIMEOUT = 25 + @pytest.fixture(name="get_state") def get_state_fixture( hass: HomeAssistant, @@ -45,6 +46,9 @@ def func() -> dict[str, EntityState]: results = {} for entity_entry in entity_entries: state = hass.states.get(entity_entry.entity_id) + assert state + assert state.state + assert state.attributes results[entity_entry.entity_id] = EntityState( state=state.state, attributes=dict(state.attributes) ) @@ -64,12 +68,12 @@ def compare_state(v: Any, other_v: Any) -> bool: if isinstance(v, tuple) or isinstance(other_v, tuple): v = list(v) other_v = list(v) - return v == other_v + return cast(bool, v == other_v) if isinstance(v, enum.StrEnum) or isinstance(other_v, enum.StrEnum): v = str(v) other_v = str(other_v) - return v == other_v + return cast(bool, v == other_v) if v == other_v: return True @@ -87,7 +91,7 @@ def compute_entity_diff( a = a_state.as_dict() b = b_state.as_dict() - diff_attributes = set({}) + diff_attributes = set([]) for k, v in a.items(): other_v = b.get(k) if not compare_state(other_v, v): @@ -95,7 +99,7 @@ def compute_entity_diff( for k in b: if k not in a and k: diff_attributes.add(k) - diff_attributes = [k for k in diff_attributes if k not in ignored] + diff_attributes = set({k for k in diff_attributes if k not in ignored}) if not diff_attributes: return None return { @@ -132,13 +136,13 @@ async def func( if states[entity_id].attributes is None: states[entity_id].attributes = {} states[entity_id].attributes = { - **states[entity_id].attributes, + **states[entity_id].attributes, # type: ignore[dict-item] **entity_state.attributes, } for entity_id in updated_states: if entity_id not in states: - return ValueError(f"Unexpected new entity found: {entity_id}") + raise ValueError(f"Unexpected new entity found: {entity_id}") diffs = {} for entity_id in states: @@ -156,7 +160,7 @@ async def func( return func -def dump_conversation_trace(trace: trace.ConversationTrace) -> dict[str, Any]: +def dump_conversation_trace(trace: trace.ConversationTrace) -> list[dict[str, Any]]: """Serialize the conversation trace for evaluation.""" trace_data = trace.as_dict() trace_events = trace_data["events"] @@ -216,6 +220,7 @@ async def test_assist_actions( _LOGGER.debug("Response: %s", response) updated_states = get_state() + unexpected_states: dict[str, Any] | str try: unexpected_states = await verify_state(eval_task, states, updated_states) except ValueError as err: diff --git a/home_assistant_datasets/yaml_loaders.py b/home_assistant_datasets/yaml_loaders.py index cc9530c5..203edbe2 100644 --- a/home_assistant_datasets/yaml_loaders.py +++ b/home_assistant_datasets/yaml_loaders.py @@ -13,7 +13,7 @@ _DEFAULT_LOADER = getattr(yaml, "CSafeLoader", yaml.SafeLoader) -class FastSafeLoader(_DEFAULT_LOADER): +class FastSafeLoader(_DEFAULT_LOADER): # type: ignore """The fastest available safe loader, either C or Python. This exists to support capturing the stream file name in the same way as the @@ -47,7 +47,7 @@ def yaml_decode(stream: Any, shape_type: Type[T] | Any) -> T: but accepts a stream rather than content string in order to implement custom tags based on the current filename. """ - return YAMLDecoder(shape_type, pre_decoder_func=_default_decoder).decode(stream) + return YAMLDecoder(shape_type, pre_decoder_func=_default_decoder).decode(stream) # type: ignore[no-any-return] def _include_tag_constructor( diff --git a/pyproject.toml b/pyproject.toml index b621b46b..7c1e731e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "google-generativeai>=0.5.4", "hass-client>=1.2.0", "synthetic_home>=4.3.1", + "hass-client>=1.2.0", ] requires-python = ">= 3.12" authors = [{ name = "Allen Porter", email = "allen.porter@gmail.com" }] @@ -29,7 +30,7 @@ collect-area-data = "home_assistant_datasets.tools.collect_area_data:main" home-assistant-datasets = "home_assistant_datasets.tool.__main__:main" [tool.mypy] -exclude = ["setup.py", "venv/", "pyproject.toml"] +exclude = ["setup.py", "venv/", "home_assistant_datasets/tool/archive_evals/"] platform = "linux" show_error_codes = true follow_imports = "normal" diff --git a/script/human_eval_metrics.py b/script/human_eval_metrics.py index 1e4e569e..21581b10 100644 --- a/script/human_eval_metrics.py +++ b/script/human_eval_metrics.py @@ -3,6 +3,7 @@ import argparse import logging import pathlib +from typing import Any import yaml import random @@ -47,7 +48,7 @@ def get_arguments() -> argparse.Namespace: return parser.parse_args() -def main(): +def main() -> None: args = get_arguments() logging.basicConfig(level=getattr(logging, args.log_level.upper())) @@ -67,8 +68,8 @@ def main(): all_label_values.add(label) all_label_values.add("Unavailable") - model_results = {} - model_samples = {} + model_results: dict[str, dict[str, Any]] = {} + model_samples: dict[str, dict[str, Any]] = {} output_files = model_outputs.glob("**/*.yaml") for output_file in output_files: diff --git a/script/import.py b/script/import.py index da1d04dc..acfc4f0f 100644 --- a/script/import.py +++ b/script/import.py @@ -5,7 +5,7 @@ import sys -def main(): +def main() -> None: for doc in yaml.safe_load_all(sys.stdin.read()): print(json.dumps({"text": json.dumps(doc, indent=2)}))