From 8d87a3964519eac19bd40a38a1b9a0767d0a006a Mon Sep 17 00:00:00 2001 From: Duco Sebel <74970928+DCSBL@users.noreply.github.com> Date: Wed, 1 Jan 2025 20:38:59 +0100 Subject: [PATCH] Add HomeWizardEnergy base class which is implemented in v1 and v2 --- homewizard_energy/homewizard_energy.py | 109 +++++++++++++++---------- homewizard_energy/v1/__init__.py | 57 +------------ homewizard_energy/v2/__init__.py | 54 +++--------- tests/test_homewizard_energy.py | 25 ++++++ tests/v1/test_v1_homewizard_energy.py | 12 +-- tests/v2/test_v2_homewizard_energy.py | 6 +- 6 files changed, 116 insertions(+), 147 deletions(-) create mode 100644 tests/test_homewizard_energy.py diff --git a/homewizard_energy/homewizard_energy.py b/homewizard_energy/homewizard_energy.py index b1fa40c..fe25fe8 100644 --- a/homewizard_energy/homewizard_energy.py +++ b/homewizard_energy/homewizard_energy.py @@ -2,40 +2,40 @@ from __future__ import annotations -import asyncio -import json -import ssl -from collections.abc import Callable, Coroutine -from http import HTTPStatus -from typing import Any, TypeVar - -import async_timeout -import backoff -from aiohttp.client import ( - ClientError, - ClientResponseError, - ClientSession, - ClientTimeout, - TCPConnector, -) -from aiohttp.hdrs import METH_DELETE, METH_GET, METH_POST, METH_PUT -from mashumaro.exceptions import InvalidFieldValue, MissingField -from ..models import Device, Measurement, System, SystemUpdate, Token, State - -from homewizard_energy.errors import ( - DisabledError, - RequestError, - ResponseError, - UnauthorizedError, -) +from typing import Any + +from aiohttp.client import ClientSession + +from .const import LOGGER +from .models import Device, Measurement, State, StateUpdate, System, SystemUpdate + class HomeWizardEnergy: """Base class for HomeWizard Energy API.""" - _clientsession: ClientSession | None = None + _session: ClientSession | None = None + _close_session: bool = False _request_timeout: int = 10 _host: str - + + def __init__( + self, + host: str, + clientsession: ClientSession = None, + timeout: int = 10, + ): + """Create a HomeWizard Energy object. + + Args: + host: IP or URL for device. + clientsession: The clientsession. + timeout: Request timeout in seconds. + """ + self._host = host + self._session = clientsession + self._close_session = clientsession is None + self._request_timeout = timeout + @property def host(self) -> str: """Return the hostname of the device. @@ -45,34 +45,59 @@ def host(self) -> str: """ return self._host - + async def device(self) -> Device: """Get the device information.""" raise NotImplementedError - + async def measurement(self) -> Measurement: """Get the current measurement.""" raise NotImplementedError - + async def system( self, update: SystemUpdate | None = None, ) -> System: - """Get the system information.""" + """Get/set the system.""" raise NotImplementedError - - async def state(self) -> State: - """Get the current state.""" + + async def state( + self, + update: StateUpdate | None = None, + ) -> State: + """Get/set the state.""" raise NotImplementedError - + async def identify( self, ) -> None: """Identify the device.""" raise NotImplementedError - - # async def reboot( - # self, - # ) -> None: - # """Reboot the device.""" - # raise NotImplementedError + + async def reboot( + self, + ) -> None: + """Reboot the device.""" + raise NotImplementedError + + async def close(self) -> None: + """Close client session.""" + LOGGER.debug("Closing clientsession") + if self._session and self._close_session: + await self._session.close() + + async def __aenter__(self) -> HomeWizardEnergy: + """Async enter. + + Returns: + The HomeWizardEnergy object. + """ + return self + + async def __aexit__(self, *_exc_info: Any) -> None: + """Async exit. + + Args: + _exc_info: Exec type. + """ + await self.close() diff --git a/homewizard_energy/v1/__init__.py b/homewizard_energy/v1/__init__.py index 3763fba..af032b2 100644 --- a/homewizard_energy/v1/__init__.py +++ b/homewizard_energy/v1/__init__.py @@ -20,6 +20,7 @@ ) from ..const import LOGGER +from ..homewizard_energy import HomeWizardEnergy from ..models import Device, Measurement, State, StateUpdate, System, SystemUpdate from .const import SUPPORTED_API_VERSION @@ -40,38 +41,10 @@ async def wrapper(self, *args, **kwargs) -> T: return wrapper -class HomeWizardEnergyV1: +# pylint: disable=abstract-method +class HomeWizardEnergyV1(HomeWizardEnergy): """Communicate with a HomeWizard Energy device.""" - _session: ClientSession | None - _close_session: bool = False - _request_timeout: int = 10 - - def __init__( - self, host: str, clientsession: ClientSession = None, timeout: int = 10 - ): - """Create a HomeWizard Energy object. - - Args: - host: IP or URL for device. - clientsession: The clientsession. - timeout: Request timeout in seconds. - """ - - self._host = host - self._session = clientsession - self._request_timeout = timeout - - @property - def host(self) -> str: - """Return the hostname of the device. - - Returns: - host: The used host - - """ - return self._host - async def device(self) -> Device: """Return the device object.""" _, response = await self._request("api") @@ -84,7 +57,7 @@ async def device(self) -> Device: return device - async def data(self) -> Measurement: + async def measurement(self) -> Measurement: """Return the data object.""" _, response = await self._request("api/v1/data") return Measurement.from_json(response) @@ -193,25 +166,3 @@ async def _request( raise RequestError(f"API request error ({resp.status})") return (resp.status, await resp.text()) - - async def close(self) -> None: - """Close client session.""" - LOGGER.debug("Closing clientsession") - if self._session and self._close_session: - await self._session.close() - - async def __aenter__(self) -> HomeWizardEnergyV1: - """Async enter. - - Returns: - The HomeWizardEnergyV1 object. - """ - return self - - async def __aexit__(self, *_exc_info: Any) -> None: - """Async exit. - - Args: - _exc_info: Exec type. - """ - await self.close() diff --git a/homewizard_energy/v2/__init__.py b/homewizard_energy/v2/__init__.py index 88807da..f8fb88a 100644 --- a/homewizard_energy/v2/__init__.py +++ b/homewizard_energy/v2/__init__.py @@ -29,6 +29,7 @@ ) from ..const import LOGGER +from ..homewizard_energy import HomeWizardEnergy from ..models import Device, Measurement, System, SystemUpdate, Token from .cacert import CACERT @@ -50,17 +51,18 @@ async def wrapper(self, *args, **kwargs) -> T: return wrapper -class HomeWizardEnergyV2: +# pylint: disable=abstract-method +class HomeWizardEnergyV2(HomeWizardEnergy): """Communicate with a HomeWizard Energy device.""" - _clientsession: ClientSession | None = None - _request_timeout: int = 10 - + # pylint: disable=too-many-arguments + # pylint: disable=too-many-positional-arguments def __init__( self, host: str, identifier: str | None = None, token: str | None = None, + clientsession: ClientSession = None, timeout: int = 10, ): """Create a HomeWizard Energy object. @@ -71,21 +73,9 @@ def __init__( token: Token for device. timeout: Request timeout in seconds. """ - - self._host = host + super().__init__(host, clientsession, timeout) self._identifier = identifier self._token = token - self._request_timeout = timeout - - @property - def host(self) -> str: - """Return the hostname of the device. - - Returns: - host: The used host - - """ - return self._host @authorized_method async def device(self) -> Device: @@ -192,7 +182,7 @@ async def delete_token( if name is None: self._token = None - async def _get_clientsession(self) -> ClientSession: + async def _get_session(self) -> ClientSession: """ Get a clientsession that is tuned for communication with the HomeWizard Energy Device """ @@ -227,8 +217,8 @@ async def _request( ) -> tuple[HTTPStatus, dict[str, Any] | None]: """Make a request to the API.""" - if self._clientsession is None: - self._clientsession = await self._get_clientsession() + if self._session is None: + self._session = await self._get_session() # Construct request url = f"https://{self.host}{path}" @@ -242,7 +232,7 @@ async def _request( try: async with async_timeout.timeout(self._request_timeout): - resp = await self._clientsession.request( + resp = await self._session.request( method, url, json=data, @@ -271,25 +261,3 @@ async def _request( pass return (resp.status, await resp.text()) - - async def close(self) -> None: - """Close client session.""" - LOGGER.debug("Closing clientsession") - if self._clientsession is not None: - await self._clientsession.close() - - async def __aenter__(self) -> HomeWizardEnergyV2: - """Async enter. - - Returns: - The HomeWizardEnergyV2 object. - """ - return self - - async def __aexit__(self, *_exc_info: Any) -> None: - """Async exit. - - Args: - _exc_info: Exec type. - """ - await self.close() diff --git a/tests/test_homewizard_energy.py b/tests/test_homewizard_energy.py new file mode 100644 index 0000000..4de5d17 --- /dev/null +++ b/tests/test_homewizard_energy.py @@ -0,0 +1,25 @@ +"""Test the base class.""" + +import pytest + +from homewizard_energy.homewizard_energy import HomeWizardEnergy + +pytestmark = [pytest.mark.asyncio] + + +@pytest.mark.parametrize( + ("function"), + [ + ("device"), + ("measurement"), + ("system"), + ("state"), + ("identify"), + ("reboot"), + ], +) +async def test_base_class_raises_notimplementederror(function: str): + """Test the base class raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + async with HomeWizardEnergy("host") as api: + await getattr(api, function)() diff --git a/tests/v1/test_v1_homewizard_energy.py b/tests/v1/test_v1_homewizard_energy.py index bc3c9a8..7b612c0 100644 --- a/tests/v1/test_v1_homewizard_energy.py +++ b/tests/v1/test_v1_homewizard_energy.py @@ -230,9 +230,9 @@ async def test_get_data_object( async with aiohttp.ClientSession() as session: api = HomeWizardEnergyV1("example.com", clientsession=session) - data = await api.data() - assert data is not None - assert data == snapshot + measurement = await api.measurement() + assert measurement is not None + assert measurement == snapshot await api.close() @@ -275,9 +275,9 @@ async def test_get_data_object_with_known_device( # pylint: disable=protected-access api._detected_api_version = "v1" - data = await api.data() - assert data is not None - assert data == snapshot + measurement = await api.measurement() + assert measurement is not None + assert measurement == snapshot await api.close() diff --git a/tests/v2/test_v2_homewizard_energy.py b/tests/v2/test_v2_homewizard_energy.py index 8846bf2..bb0c2a7 100644 --- a/tests/v2/test_v2_homewizard_energy.py +++ b/tests/v2/test_v2_homewizard_energy.py @@ -519,13 +519,13 @@ async def test_delete_token_returns_unexpected_response_code(aresponses): async def test_request_handles_timeout(): """Test request times out when request takes too long.""" async with HomeWizardEnergyV2("example.com", token="token") as api: - api._clientsession = AsyncMock() - api._clientsession.request = AsyncMock(side_effect=asyncio.TimeoutError()) + api._session = AsyncMock() + api._session.request = AsyncMock(side_effect=asyncio.TimeoutError()) with pytest.raises(RequestError): await api.device() - assert api._clientsession.request.call_count == 5 + assert api._session.request.call_count == 5 async def test_request_with_identifier_sets_common_name(aresponses):