Skip to content

Commit

Permalink
Improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
allenporter committed Aug 3, 2024
1 parent 9919bb5 commit e026591
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 44 deletions.
16 changes: 9 additions & 7 deletions home_assistant_datasets/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test fixtures for evaluations."""

from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
import logging
import os
import pathlib
from typing import Any, TextIO
from typing import Any, TextIO, cast
from unittest.mock import patch, mock_open

import pytest
Expand Down Expand Up @@ -73,7 +73,7 @@ def mock_synthetic_home_content(synthetic_home_config: str | None) -> str | None
@pytest.fixture(autouse=True, name="synthetic_home_config_entry")
async def mock_synthetic_home(
hass: HomeAssistant, synthetic_home_yaml: str | None
) -> Generator[MockConfigEntry | None]:
) -> AsyncGenerator[MockConfigEntry | None, None]:
"""Fixture for mock configuration entry."""
if synthetic_home_yaml is None:
yield None
Expand Down Expand Up @@ -163,19 +163,20 @@ async def async_process(self, hass: HomeAssistant, text: str) -> str:
blocking=True,
return_response=True,
)
response = service_response["response"]
return response["speech"]["plain"]["speech"]
assert service_response
response: dict[str, Any] = service_response["response"]
return str(response["speech"]["plain"]["speech"])


@pytest.fixture(name="conversation_agent_id")
async def mock_conversation_agent_id(
model_config: ModelConfig,
conversation_agent_config_entry: MockConfigEntry | None,
conversation_agent_config_entry: MockConfigEntry,
) -> str:
"""Return the id for the conversation agent under test."""
if model_config.domain == "homeassistant":
return "conversation.home_assistant"
return conversation_agent_config_entry.entry_id
return cast(str, conversation_agent_config_entry.entry_id)


@pytest.fixture(name="agent")
Expand All @@ -200,6 +201,7 @@ def open(self) -> None:

def write(self, record: dict[str, Any]) -> None:
"""Write an eval record."""
assert self._fd
self._fd.write(yaml.dump(record, sort_keys=False, explicit_start=True))
self._fd.flush()
self._records += 1
Expand Down
12 changes: 3 additions & 9 deletions home_assistant_datasets/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,8 @@ def complete(self, prompt: str, user_message: str) -> str:
response = self.client.chat.completions.create(
model=self.model_id,
messages=[
{
"content": prompt,
"role": "system"
},
{
"content": user_message,
"role": "user"
}
{"content": prompt, "role": "system"},
{"content": user_message, "role": "user"},
],
)
return "".join([choice.message.content for choice in response.choices])
return "".join([choice.message.content or "" for choice in response.choices])
2 changes: 1 addition & 1 deletion home_assistant_datasets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def get_secret(secret_name: str) -> str:
"Could not find secret_name %s in keys (%s)", secret_name, secrets.keys()
)
raise KeyError(f"Could not find '{secret_name}' in secrets file {secrets_file}")
return secrets[secret_name]
return str(secrets[secret_name])
13 changes: 7 additions & 6 deletions home_assistant_datasets/tool/assist/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pathlib
import logging
import datetime
from typing import Any
from collections.abc import Generator

import pytest
from homeassistant.util import dt as dt_util
Expand All @@ -15,15 +17,15 @@
_LOGGER = logging.getLogger(__name__)


def pytest_addoption(parser):
def pytest_addoption(parser: Any) -> None:
"""Pytest arguments passed from the `collect` action to the test."""
parser.addoption("--dataset")
parser.addoption("--models")
parser.addoption("--model_output_dir")
parser.addoption("--categories")


def pytest_generate_tests(metafunc) -> None:
def pytest_generate_tests(metafunc: Any) -> None:
"""Generate test parameters for the evaluation from flags."""
# Parameterize tests by the models under development
models = metafunc.config.getoption("models").split(",")
Expand All @@ -47,7 +49,7 @@ def pytest_generate_tests(metafunc) -> None:
raise ValueError(f"Could not find any dataset files in path: {dataset}")

categories_str = metafunc.config.getoption("categories")
categories = set(categories_str.split(",")) if categories_str else {}
categories = set(categories_str.split(",") if categories_str else {})

dataset_path = pathlib.Path(dataset)
output_path = pathlib.Path(output_dir)
Expand All @@ -72,16 +74,15 @@ def pytest_generate_tests(metafunc) -> None:


@pytest.fixture(autouse=True)
def restore_tz() -> None:
def restore_tz() -> Generator[None, None]:
yield
# Home Assistant teardown seems to run too soon and expects this so try to
# patch it in first.
dt_util.set_default_time_zone(datetime.UTC)



@pytest.fixture(name="eval_output_file")
def eval_output_file_fixture(model_id: str, eval_task: EvalTask) -> str:
def eval_output_file_fixture(model_id: str, eval_task: EvalTask) -> pathlib.Path:
"""Sets the output filename for the evaluation run.
This output file needs to be unique across the test instances to avoid overwriting. For
Expand Down
6 changes: 3 additions & 3 deletions home_assistant_datasets/tool/assist/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
from typing import Any
from dataclasses import dataclass, field
from collections.abc import AsyncGenerator
from collections.abc import Generator
import pathlib
from slugify import slugify

Expand Down Expand Up @@ -117,7 +117,7 @@ def generate_tasks(
dataset_path: pathlib.Path,
output_dir: pathlib.Path,
categories: set[str],
) -> AsyncGenerator[EvalTask, None]:
) -> Generator[EvalTask, None, None]:
"""Read and validate the dataset."""
# Generate the record id based on the file path
relpath = record_path.relative_to(dataset_path)
Expand All @@ -137,7 +137,7 @@ def generate_tasks(
"Skipping record with category %s (not in %s)", record.category, categories
)
return
for action in record.tests:
for action in record.tests or ():
if not action.sentences:
raise ValueError("No sentences defined for the action")
if not action.expect_changes:
Expand Down
4 changes: 2 additions & 2 deletions home_assistant_datasets/tool/assist/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def find_llm_call(trace_events: list[dict[str, Any]]) -> dict[str, Any] | None:
event
for event in trace_events
if event["event_type"]
in (trace.ConversationTraceEventType.TOOL_CALL, "llm_tool_call")
in (trace.ConversationTraceEventType.TOOL_CALL, "llm_tool_call") # type: ignore[attr-defined]
),
None,
)
Expand All @@ -48,7 +48,7 @@ def find_llm_call(trace_events: list[dict[str, Any]]) -> dict[str, Any] | None:
}


def yaml_decoder(data: EncodedData) -> dict[Any, Any]:
def yaml_decoder(data: EncodedData) -> Any:
return yaml.load(data, yaml.UnsafeLoader)


Expand Down
2 changes: 1 addition & 1 deletion home_assistant_datasets/tool/assist/eval_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OutputType(enum.StrEnum):
class WriterBase:
"""Base class for eval output."""

diff: dict | str = dict
diff: type[dict] | type[str] = dict

def start(self) -> None:
"""Write the output start."""
Expand Down
21 changes: 13 additions & 8 deletions home_assistant_datasets/tool/assist/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid
import dataclasses
from typing import Any
from typing import Any, cast
import enum
import json

Expand All @@ -30,6 +30,7 @@
_LOGGER = logging.getLogger(__name__)
TIMEOUT = 25


@pytest.fixture(name="get_state")
def get_state_fixture(
hass: HomeAssistant,
Expand All @@ -45,6 +46,9 @@ def func() -> dict[str, EntityState]:
results = {}
for entity_entry in entity_entries:
state = hass.states.get(entity_entry.entity_id)
assert state
assert state.state
assert state.attributes
results[entity_entry.entity_id] = EntityState(
state=state.state, attributes=dict(state.attributes)
)
Expand All @@ -64,12 +68,12 @@ def compare_state(v: Any, other_v: Any) -> bool:
if isinstance(v, tuple) or isinstance(other_v, tuple):
v = list(v)
other_v = list(v)
return v == other_v
return cast(bool, v == other_v)

if isinstance(v, enum.StrEnum) or isinstance(other_v, enum.StrEnum):
v = str(v)
other_v = str(other_v)
return v == other_v
return cast(bool, v == other_v)

if v == other_v:
return True
Expand All @@ -87,15 +91,15 @@ def compute_entity_diff(
a = a_state.as_dict()
b = b_state.as_dict()

diff_attributes = set({})
diff_attributes = set([])
for k, v in a.items():
other_v = b.get(k)
if not compare_state(other_v, v):
diff_attributes.add(k)
for k in b:
if k not in a and k:
diff_attributes.add(k)
diff_attributes = [k for k in diff_attributes if k not in ignored]
diff_attributes = set({k for k in diff_attributes if k not in ignored})
if not diff_attributes:
return None
return {
Expand Down Expand Up @@ -132,13 +136,13 @@ async def func(
if states[entity_id].attributes is None:
states[entity_id].attributes = {}
states[entity_id].attributes = {
**states[entity_id].attributes,
**states[entity_id].attributes, # type: ignore[dict-item]
**entity_state.attributes,
}

for entity_id in updated_states:
if entity_id not in states:
return ValueError(f"Unexpected new entity found: {entity_id}")
raise ValueError(f"Unexpected new entity found: {entity_id}")

diffs = {}
for entity_id in states:
Expand All @@ -156,7 +160,7 @@ async def func(
return func


def dump_conversation_trace(trace: trace.ConversationTrace) -> dict[str, Any]:
def dump_conversation_trace(trace: trace.ConversationTrace) -> list[dict[str, Any]]:
"""Serialize the conversation trace for evaluation."""
trace_data = trace.as_dict()
trace_events = trace_data["events"]
Expand Down Expand Up @@ -216,6 +220,7 @@ async def test_assist_actions(
_LOGGER.debug("Response: %s", response)

updated_states = get_state()
unexpected_states: dict[str, Any] | str
try:
unexpected_states = await verify_state(eval_task, states, updated_states)
except ValueError as err:
Expand Down
4 changes: 2 additions & 2 deletions home_assistant_datasets/yaml_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_DEFAULT_LOADER = getattr(yaml, "CSafeLoader", yaml.SafeLoader)


class FastSafeLoader(_DEFAULT_LOADER):
class FastSafeLoader(_DEFAULT_LOADER): # type: ignore
"""The fastest available safe loader, either C or Python.
This exists to support capturing the stream file name in the same way as the
Expand Down Expand Up @@ -47,7 +47,7 @@ def yaml_decode(stream: Any, shape_type: Type[T] | Any) -> T:
but accepts a stream rather than content string in order to implement
custom tags based on the current filename.
"""
return YAMLDecoder(shape_type, pre_decoder_func=_default_decoder).decode(stream)
return YAMLDecoder(shape_type, pre_decoder_func=_default_decoder).decode(stream) # type: ignore[no-any-return]


def _include_tag_constructor(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"google-generativeai>=0.5.4",
"hass-client>=1.2.0",
"synthetic_home>=4.3.1",
"hass-client>=1.2.0",
]
requires-python = ">= 3.12"
authors = [{ name = "Allen Porter", email = "[email protected]" }]
Expand All @@ -29,7 +30,7 @@ collect-area-data = "home_assistant_datasets.tools.collect_area_data:main"
home-assistant-datasets = "home_assistant_datasets.tool.__main__:main"

[tool.mypy]
exclude = ["setup.py", "venv/", "pyproject.toml"]
exclude = ["setup.py", "venv/", "home_assistant_datasets/tool/archive_evals/"]
platform = "linux"
show_error_codes = true
follow_imports = "normal"
Expand Down
7 changes: 4 additions & 3 deletions script/human_eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import logging
import pathlib
from typing import Any
import yaml
import random

Expand Down Expand Up @@ -47,7 +48,7 @@ def get_arguments() -> argparse.Namespace:
return parser.parse_args()


def main():
def main() -> None:
args = get_arguments()
logging.basicConfig(level=getattr(logging, args.log_level.upper()))

Expand All @@ -67,8 +68,8 @@ def main():
all_label_values.add(label)
all_label_values.add("Unavailable")

model_results = {}
model_samples = {}
model_results: dict[str, dict[str, Any]] = {}
model_samples: dict[str, dict[str, Any]] = {}

output_files = model_outputs.glob("**/*.yaml")
for output_file in output_files:
Expand Down
2 changes: 1 addition & 1 deletion script/import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys


def main():
def main() -> None:
for doc in yaml.safe_load_all(sys.stdin.read()):
print(json.dumps({"text": json.dumps(doc, indent=2)}))

Expand Down

0 comments on commit e026591

Please sign in to comment.