Skip to content

Commit

Permalink
Add periodic awareness updates (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Oct 11, 2024
1 parent 632769f commit b1cf6b8
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 21 deletions.
116 changes: 95 additions & 21 deletions python/pycrdt/_awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import copy
import json
import time
from typing import Any, Callable, cast
from time import time
from typing import Any, Callable, Literal, cast
from uuid import uuid4

from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from anyio.abc import TaskGroup, TaskStatus

from ._doc import Doc
from ._sync import Decoder, Encoder

Expand All @@ -15,16 +18,20 @@ class Awareness:
_meta: dict[int, dict[str, Any]]
_states: dict[int, dict[str, Any]]
_subscriptions: dict[str, Callable[[str, tuple[dict[str, Any], Any]], None]]
_task_group: TaskGroup | None

def __init__(self, ydoc: Doc):
def __init__(self, ydoc: Doc, *, outdated_timeout: int = 30000) -> None:
"""
Args:
ydoc: The [Doc][pycrdt.Doc] to associate the awareness with.
outdated_timeout: The timeout (in milliseconds) to consider a client gone.
"""
self.client_id = ydoc.client_id
self._outdated_timeout = outdated_timeout
self._meta = {}
self._states = {}
self._subscriptions = {}
self._task_group = None
self.set_local_state({})

@property
Expand All @@ -37,6 +44,62 @@ def states(self) -> dict[int, dict[str, Any]]:
"""The client states."""
return self._states

def _emit(
self,
topic: Literal["change", "update"],
added: list[int],
updated: list[int],
removed: list[int],
origin: Any,
):
for callback in self._subscriptions.values():
callback(topic, ({"added": added, "updated": updated, "removed": removed}, origin))

def _get_time(self) -> int:
return int(time() * 1000)

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
"""
Starts updating the awareness periodically.
"""
if self._task_group is not None:
raise RuntimeError("Awareness already started")

async with create_task_group() as tg:
self._task_group = tg
task_status.started()
tg.start_soon(self._start)

async def _start(self) -> None:
while True:
await sleep(self._outdated_timeout / 1000 / 10)
now = self._get_time()
if (
self.get_local_state() is not None
and self._outdated_timeout / 2 <= now - self._meta[self.client_id]["lastUpdated"]
):
# renew local clock
self.set_local_state(self.get_local_state())
remove: list[int] = []
for client_id, meta in self._meta.items():
if (
client_id != self.client_id
and self._outdated_timeout <= now - meta["lastUpdated"]
and client_id in self._states
):
remove.append(client_id)
if remove:
self.remove_awareness_states(remove, "timeout")

async def stop(self) -> None:
"""
Stops updating the awareness periodically.
"""
if self._task_group is None:
raise RuntimeError("Awareness not started")
self._task_group.cancel_scope.cancel()
self._task_group = None

def get_local_state(self) -> dict[str, Any] | None:
"""
Returns:
Expand All @@ -62,7 +125,7 @@ def set_local_state(self, state: dict[str, Any] | None) -> None:
del self._states[client_id]
else:
self._states[client_id] = state
timestamp = int(time.time() * 1000)
timestamp = self._get_time()
self._meta[client_id] = {"clock": clock, "lastUpdated": timestamp}
added = []
updated = []
Expand All @@ -78,13 +141,8 @@ def set_local_state(self, state: dict[str, Any] | None) -> None:
if prev_state != state:
filtered_updated.append(client_id)
if added or filtered_updated or removed:
for callback in self._subscriptions.values():
callback(
"change",
({"added": added, "updated": filtered_updated, "removed": removed}, "local"),
)
for callback in self._subscriptions.values():
callback("update", ({"added": added, "updated": updated, "removed": removed}, "local"))
self._emit("change", added, filtered_updated, removed, "local")
self._emit("update", added, updated, removed, "local")

def set_local_state_field(self, field: str, value: Any) -> None:
"""
Expand All @@ -100,6 +158,29 @@ def set_local_state_field(self, field: str, value: Any) -> None:
state[field] = value
self.set_local_state(state)

def remove_awareness_states(self, client_ids: list[int], origin: Any) -> None:
"""
Removes awareness states for clients given by their IDs.
Args:
client_ids: The list of client IDs for which to remove the awareness states.
origin: The origin of the update.
"""
removed = []
for client_id in client_ids:
if client_id in self._states:
del self._states[client_id]
if client_id == self.client_id:
cur_meta = self._meta[client_id]
self._meta[client_id] = {
"clock": cur_meta["clock"] + 1,
"lastUpdted": self._get_time(),
}
removed.append(client_id)
if removed:
self._emit("change", [], [], removed, origin)
self._emit("update", [], [], removed, origin)

def encode_awareness_update(self, client_ids: list[int]) -> bytes:
"""
Creates an encoded awareness update of the clients given by their IDs.
Expand Down Expand Up @@ -129,7 +210,7 @@ def apply_awareness_update(self, update: bytes, origin: Any) -> None:
origin: The origin of the update.
"""
decoder = Decoder(update)
timestamp = int(time.time() * 1000)
timestamp = self._get_time()
added = []
updated = []
filtered_updated = []
Expand Down Expand Up @@ -171,16 +252,9 @@ def apply_awareness_update(self, update: bytes, origin: Any) -> None:
filtered_updated.append(client_id)
updated.append(client_id)
if added or filtered_updated or removed:
for callback in self._subscriptions.values():
callback(
"change",
({"added": added, "updated": filtered_updated, "removed": removed}, origin),
)
self._emit("change", added, filtered_updated, removed, origin)
if added or updated or removed:
for callback in self._subscriptions.values():
callback(
"update", ({"added": added, "updated": updated, "removed": removed}, origin)
)
self._emit("update", added, updated, removed, origin)

def observe(self, callback: Callable[[str, tuple[dict[str, Any], Any]], None]) -> str:
"""
Expand Down
87 changes: 87 additions & 0 deletions tests/test_awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from uuid import uuid4

import pytest
from anyio import create_task_group, sleep
from pycrdt import Awareness, Doc, Encoder, YMessageType, create_awareness_message, read_message

pytestmark = pytest.mark.anyio

TEST_USER = {"username": str(uuid4()), "name": "Test user"}
REMOTE_CLIENT_ID = 853790970
REMOTE_USER = {
Expand Down Expand Up @@ -346,3 +349,87 @@ def test_awareness_encode_wrong_id():
awareness = Awareness(ydoc)
with pytest.raises(TypeError):
awareness.encode_awareness_update([10])


async def test_awareness_periodic_updates():
ydoc = Doc()
outdated_timeout = 200
awareness = Awareness(ydoc, outdated_timeout=outdated_timeout)
remote_client_id = 0
awareness._meta[remote_client_id] = {"clock": 0, "lastUpdated": 0}
awareness._states[remote_client_id] = {}
changes = []

def callback(topic, value):
changes.append((topic, value))

awareness.observe(callback)
async with create_task_group() as tg:
await tg.start(awareness.start)
with pytest.raises(RuntimeError) as excinfo:
await tg.start(awareness.start)
assert str(excinfo.value) == "Awareness already started"
await sleep((outdated_timeout - outdated_timeout / 10) / 1000)
awareness.remove_awareness_states([awareness.client_id], "local")
await sleep(outdated_timeout / 1000)
await awareness.stop()
with pytest.raises(RuntimeError) as excinfo:
await awareness.stop()
assert str(excinfo.value) == "Awareness not started"

assert len(changes) == 5
assert changes[0] == (
"change",
(
{
"added": [],
"removed": [remote_client_id],
"updated": [],
},
"timeout",
),
)
assert changes[1] == (
"update",
(
{
"added": [],
"removed": [remote_client_id],
"updated": [],
},
"timeout",
),
)
assert changes[2] == (
"update",
(
{
"added": [],
"removed": [],
"updated": [awareness.client_id],
},
"local",
),
)
assert changes[3] == (
"change",
(
{
"added": [],
"removed": [awareness.client_id],
"updated": [],
},
"local",
),
)
assert changes[4] == (
"update",
(
{
"added": [],
"removed": [awareness.client_id],
"updated": [],
},
"local",
),
)

0 comments on commit b1cf6b8

Please sign in to comment.