diff --git a/motionblindsble/device.py b/motionblindsble/device.py index 0d5b951..4be2210 100644 --- a/motionblindsble/device.py +++ b/motionblindsble/device.py @@ -207,24 +207,17 @@ class ConnectionQueue: """Class used to ensure the first caller connects, but the last caller's command goes through after connection.""" - _ha_create_task: Callable[[Coroutine], Task] | None = None _connection_task: Task | Any | None = None _last_caller_cancel: Future | None = None - def set_ha_create_task( - self, ha_create_task: Callable[[Coroutine], Task] - ) -> None: - """Set the Home Assistant create_task function.""" - self._ha_create_task = ha_create_task - def _create_connection_task(self, device: MotionDevice) -> Task | Any: """Create a connection task.""" - if self._ha_create_task: + if device._ha_create_task: _LOGGER.debug( "(%s) Connecting using Home Assistant", device.ble_device.address, ) - return self._ha_create_task( + return device._ha_create_task( target=device.establish_connection() ) # type: ignore[call-arg] _LOGGER.debug("(%s) Connecting", device.ble_device.address) @@ -303,6 +296,7 @@ class MotionDevice: _disconnect_timer: TimerHandle | Callable | None # Callbacks that are used to interface with HA + _ha_create_task: Callable[[Coroutine], Task] | None = None _ha_call_later: Callable[[int, Coroutine], Callable] | None = None # Callbacks @@ -409,7 +403,11 @@ def set_ble_device( def set_custom_disconnect_time(self, timeout: float | None): """Set a custom disconnect time.""" _LOGGER.debug( - "(%s) Set custom disconnect time to %.2fs", + ( + "(%s) Set custom disconnect time to %.2fs" + if timeout is not None + else "(%s) Set custom disconnect time to %s" + ), self.ble_device.address, timeout, ) @@ -420,7 +418,9 @@ async def set_permanent_connection( ) -> None: """Enable or disable a permanent connection.""" self._permanent_connection = permanent_connection - if not permanent_connection: + if permanent_connection: + await self.connect() + else: await self.disconnect() @property @@ -432,7 +432,7 @@ def set_ha_create_task( self, ha_create_task: Callable[[Coroutine], Task] ) -> None: """Set the create_task function to use.""" - self._connection_queue.set_ha_create_task(ha_create_task) + self._ha_create_task = ha_create_task def set_ha_call_later( self, ha_call_later: Callable[[int, Coroutine], Callable] @@ -568,6 +568,19 @@ def _disconnect_callback(self, _: BleakClient) -> None: self.update_speed(None) self.update_connection(MotionConnectionType.DISCONNECTED) self._current_bleak_client = None + if self._permanent_connection: + if self._ha_create_task: + _LOGGER.debug( + "(%s) Automatically reconnecting using Home Assistant", + self.ble_device.address, + ) + self._ha_create_task( + target=self.connect() + ) # type: ignore[call-arg] + _LOGGER.debug( + "(%s) Automatically reconnecting", self.ble_device.address + ) + get_event_loop().create_task(self.connect()) async def connect( self, disable_callbacks: list[MotionCallback] | None = None diff --git a/tests/test_device.py b/tests/test_device.py index 72ecbfa..9d0ad48 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -174,8 +174,8 @@ def test_create_connection_task(self) -> None: # Test creation of connect task with Home Assistant mock_ha_create_task = Mock() - connection_queue.set_ha_create_task(mock_ha_create_task) - assert connection_queue._ha_create_task is not None + device.set_ha_create_task(mock_ha_create_task) + assert device._ha_create_task is not None with patch( "motionblindsble.device.MotionDevice.establish_connection", AsyncMock(return_value=True), @@ -393,6 +393,14 @@ def side_effect_disconnect(*args, **kwargs): assert not await device.disconnect() device._connection_queue.cancel.assert_called_once() + @patch( + "motionblindsble.device.establish_connection", + AsyncMock(), + ) + @patch( + "motionblindsble.crypt.MotionCrypt.get_time", + Mock(return_value=""), + ) @patch("motionblindsble.device.MotionDevice.disconnect") @patch("motionblindsble.device.time_ns") async def test_refresh_disconnect_timer( @@ -447,8 +455,8 @@ def call_later(delay: int, action): mock_disconnect.assert_called_once() # Test permanent connection, no disconnect timer - mock_time_ns.reset_mock() await device.set_permanent_connection(True) + mock_time_ns.reset_mock() device.refresh_disconnect_timer() assert mock_time_ns.call_count == 0 @@ -483,6 +491,26 @@ def call_later_condition(delay: int, action: Callable) -> None: permanent_connection_enabled.set() assert mock_disconnect.call_count == 1 + @patch("motionblindsble.device.MotionDevice.connect") + async def test_permanent_connection(self, mock_connect) -> None: + """Test the permanent connection function.""" + device = MotionDevice("00:11:22:33:44:55") + device._disconnect_callback(Mock()) + assert mock_connect.call_count == 0 + + # Test normal permanent connection + await device.set_permanent_connection(True) + assert mock_connect.call_count == 1 + device._disconnect_callback(Mock()) + assert mock_connect.call_count == 2 + + # Test permanent connection with Home Assistant + mock_ha_create_task = Mock() + device.set_ha_create_task(mock_ha_create_task) + await device.set_permanent_connection(True) + device._disconnect_callback(Mock()) + assert mock_ha_create_task.call_count == 1 + class TestDevice: """Test the Device in device.py module.""" @@ -527,7 +555,7 @@ def test_setters(self) -> None: mock2 = Mock() device.set_ha_create_task(mock) device.set_ha_call_later(mock2) - assert device._connection_queue._ha_create_task is mock + assert device._ha_create_task is mock assert device._ha_call_later is mock2 @patch("motionblindsble.device.MotionCrypt.decrypt", lambda x: x)