Skip to content

Commit

Permalink
Add permanent connection reconnecting
Browse files Browse the repository at this point in the history
  • Loading branch information
LennP committed Mar 1, 2024
1 parent 92c616d commit e45a68a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
37 changes: 25 additions & 12 deletions motionblindsble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,17 @@ class ConnectionQueue:
"""Class used to ensure the first caller connects,
but the last caller's command goes through after connection."""

_ha_create_task: Callable[[Coroutine], Task] | None = None
_connection_task: Task | Any | None = None
_last_caller_cancel: Future | None = None

def set_ha_create_task(
self, ha_create_task: Callable[[Coroutine], Task]
) -> None:
"""Set the Home Assistant create_task function."""
self._ha_create_task = ha_create_task

def _create_connection_task(self, device: MotionDevice) -> Task | Any:
"""Create a connection task."""
if self._ha_create_task:
if device._ha_create_task:
_LOGGER.debug(
"(%s) Connecting using Home Assistant",
device.ble_device.address,
)
return self._ha_create_task(
return device._ha_create_task(
target=device.establish_connection()
) # type: ignore[call-arg]
_LOGGER.debug("(%s) Connecting", device.ble_device.address)
Expand Down Expand Up @@ -303,6 +296,7 @@ class MotionDevice:
_disconnect_timer: TimerHandle | Callable | None

# Callbacks that are used to interface with HA
_ha_create_task: Callable[[Coroutine], Task] | None = None
_ha_call_later: Callable[[int, Coroutine], Callable] | None = None

# Callbacks
Expand Down Expand Up @@ -409,7 +403,11 @@ def set_ble_device(
def set_custom_disconnect_time(self, timeout: float | None):
"""Set a custom disconnect time."""
_LOGGER.debug(
"(%s) Set custom disconnect time to %.2fs",
(
"(%s) Set custom disconnect time to %.2fs"
if timeout is not None
else "(%s) Set custom disconnect time to %s"
),
self.ble_device.address,
timeout,
)
Expand All @@ -420,7 +418,9 @@ async def set_permanent_connection(
) -> None:
"""Enable or disable a permanent connection."""
self._permanent_connection = permanent_connection
if not permanent_connection:
if permanent_connection:
await self.connect()
else:
await self.disconnect()

@property
Expand All @@ -432,7 +432,7 @@ def set_ha_create_task(
self, ha_create_task: Callable[[Coroutine], Task]
) -> None:
"""Set the create_task function to use."""
self._connection_queue.set_ha_create_task(ha_create_task)
self._ha_create_task = ha_create_task

def set_ha_call_later(
self, ha_call_later: Callable[[int, Coroutine], Callable]
Expand Down Expand Up @@ -568,6 +568,19 @@ def _disconnect_callback(self, _: BleakClient) -> None:
self.update_speed(None)
self.update_connection(MotionConnectionType.DISCONNECTED)
self._current_bleak_client = None
if self._permanent_connection:
if self._ha_create_task:
_LOGGER.debug(
"(%s) Automatically reconnecting using Home Assistant",
self.ble_device.address,
)
self._ha_create_task(
target=self.connect()
) # type: ignore[call-arg]
_LOGGER.debug(
"(%s) Automatically reconnecting", self.ble_device.address
)
get_event_loop().create_task(self.connect())

async def connect(
self, disable_callbacks: list[MotionCallback] | None = None
Expand Down
36 changes: 32 additions & 4 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def test_create_connection_task(self) -> None:

# Test creation of connect task with Home Assistant
mock_ha_create_task = Mock()
connection_queue.set_ha_create_task(mock_ha_create_task)
assert connection_queue._ha_create_task is not None
device.set_ha_create_task(mock_ha_create_task)
assert device._ha_create_task is not None
with patch(
"motionblindsble.device.MotionDevice.establish_connection",
AsyncMock(return_value=True),
Expand Down Expand Up @@ -393,6 +393,14 @@ def side_effect_disconnect(*args, **kwargs):
assert not await device.disconnect()
device._connection_queue.cancel.assert_called_once()

@patch(
"motionblindsble.device.establish_connection",
AsyncMock(),
)
@patch(
"motionblindsble.crypt.MotionCrypt.get_time",
Mock(return_value=""),
)
@patch("motionblindsble.device.MotionDevice.disconnect")
@patch("motionblindsble.device.time_ns")
async def test_refresh_disconnect_timer(
Expand Down Expand Up @@ -447,8 +455,8 @@ def call_later(delay: int, action):
mock_disconnect.assert_called_once()

# Test permanent connection, no disconnect timer
mock_time_ns.reset_mock()
await device.set_permanent_connection(True)
mock_time_ns.reset_mock()
device.refresh_disconnect_timer()
assert mock_time_ns.call_count == 0

Expand Down Expand Up @@ -483,6 +491,26 @@ def call_later_condition(delay: int, action: Callable) -> None:
permanent_connection_enabled.set()
assert mock_disconnect.call_count == 1

@patch("motionblindsble.device.MotionDevice.connect")
async def test_permanent_connection(self, mock_connect) -> None:
"""Test the permanent connection function."""
device = MotionDevice("00:11:22:33:44:55")
device._disconnect_callback(Mock())
assert mock_connect.call_count == 0

# Test normal permanent connection
await device.set_permanent_connection(True)
assert mock_connect.call_count == 1
device._disconnect_callback(Mock())
assert mock_connect.call_count == 2

# Test permanent connection with Home Assistant
mock_ha_create_task = Mock()
device.set_ha_create_task(mock_ha_create_task)
await device.set_permanent_connection(True)
device._disconnect_callback(Mock())
assert mock_ha_create_task.call_count == 1


class TestDevice:
"""Test the Device in device.py module."""
Expand Down Expand Up @@ -527,7 +555,7 @@ def test_setters(self) -> None:
mock2 = Mock()
device.set_ha_create_task(mock)
device.set_ha_call_later(mock2)
assert device._connection_queue._ha_create_task is mock
assert device._ha_create_task is mock
assert device._ha_call_later is mock2

@patch("motionblindsble.device.MotionCrypt.decrypt", lambda x: x)
Expand Down

0 comments on commit e45a68a

Please sign in to comment.