Skip to content

Commit

Permalink
Use mashumaro features for Update models
Browse files Browse the repository at this point in the history
  • Loading branch information
DCSBL committed Jan 1, 2025
1 parent 100ede2 commit dad8c48
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 32 deletions.
43 changes: 16 additions & 27 deletions homewizard_energy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
Expand All @@ -25,6 +25,19 @@ class Config(BaseConfig):
omit_none = True


class UpdateBaseModel(BaseModel):
"""Base model for all 'update' models."""

def __post_serialize__(self, d: dict, context: dict | None = None):
"""Post serialize hook for UpdateBaseModel object."""
_ = context # Unused

if not d:
raise ValueError("No values to update")

return d


def get_verification_hostname(model: str, serial_number: str) -> str:
"""Helper method to convert device model and serial to identifier
Expand Down Expand Up @@ -421,25 +434,13 @@ class DeviceType(Enum):


@dataclass(kw_only=True)
class StateUpdate(BaseModel):
class StateUpdate(UpdateBaseModel):
"""Represent State update config."""

power_on: bool | None = field(default=None)
switch_lock: bool | None = field(default=None)
brightness: int | None = field(default=None)

def as_dict(self) -> dict[str, bool | int]:
"""Return StateUpdate object as dict.
Only include values that are not None.
"""
_dict = {k: v for k, v in asdict(self).items() if v is not None}

if not _dict:
raise ValueError("No values to update")

return _dict


@dataclass(kw_only=True)
class State(BaseModel):
Expand All @@ -457,25 +458,13 @@ class State(BaseModel):


@dataclass
class SystemUpdate(BaseModel):
class SystemUpdate(UpdateBaseModel):
"""Represent System update config."""

cloud_enabled: bool | None = field(default=None)
status_led_brightness_pct: int | None = field(default=None)
api_v1_enabled: bool | None = field(default=None)

def as_dict(self) -> dict[str, bool | int]:
"""Return SystemUpdate object as dict.
Only include values that are not None.
"""
_dict = {k: v for k, v in asdict(self).items() if v is not None}

if not _dict:
raise ValueError("No values to update")

return _dict


@dataclass(kw_only=True)
class System(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions homewizard_energy/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def data(self) -> Measurement:
async def state(self, update: StateUpdate | None = None) -> State:
"""Return the state object."""
if update is not None:
data = update.as_dict()
data = update.to_dict()
status, response = await self._request(
"api/v1/state", method=METH_PUT, data=data
)
Expand All @@ -119,7 +119,7 @@ async def system(self, update: SystemUpdate | None = None) -> System:
"Setting status_led_brightness_pct and api_v1_enabled is not supported in v1"
)

data = update.as_dict()
data = update.to_dict()
status, response = await self._request(
"api/v1/system", method=METH_PUT, data=data
)
Expand Down
2 changes: 1 addition & 1 deletion homewizard_energy/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def system(
"""Return the system object."""

if update is not None:
data = update.as_dict()
data = update.to_dict()
status, response = await self._request(
"/api/system", method=METH_PUT, data=data
)
Expand Down
4 changes: 2 additions & 2 deletions tests/v2/test_v2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def test_system_update(
status_led_brightness_pct=status_led_brightness_pct,
api_v1_enabled=api_v1_enabled,
)
assert snapshot == data.as_dict()
assert snapshot == data.to_dict()


async def test_system_update_raises_when_none_set():
Expand All @@ -109,4 +109,4 @@ async def test_system_update_raises_when_none_set():
)

with pytest.raises(ValueError):
update.as_dict()
update.to_dict()

0 comments on commit dad8c48

Please sign in to comment.