From 36c0dfb37f771e723f1f8b944f45ac5991919a5d Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 11:34:57 +0200 Subject: [PATCH 01/21] Move the Awareness from pycrdt_websocket to pycrdt project, and add some features to it --- python/pycrdt/_awareness.py | 127 ++++++++++++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 13 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 43d412a..428b699 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -2,17 +2,58 @@ import json import time -from typing import Any +from logging import Logger, getLogger +from typing import Any, Callable, Coroutine, Never +from uuid import uuid4 from ._doc import Doc -from ._sync import Decoder, read_message +from ._sync import Decoder, YMessageType, read_message, write_var_uint -class Awareness: # pragma: no cover - def __init__(self, ydoc: Doc): +DEFAULT_USER = {"username": str(uuid4()), "name": "Jupyter server"} + + +class Awareness: + client_id: int + log: Logger + meta: dict[int, dict[str, Any]] + _states: dict[int, dict[str, Any]] + _subscriptions: list[Callable[[dict[str, Any]], None]] + _user: dict[str, str] | None + + def __init__( + self, + ydoc: Doc, + log: Logger | None = None, + on_change: Callable[[bytes], Coroutine[Any, Any, Never]] | None = None, + user: dict[str, str] | None = None, + ): self.client_id = ydoc.client_id - self.meta: dict[int, dict[str, Any]] = {} - self.states: dict[int, dict[str, Any]] = {} + self.log = log or getLogger(__name__) + self.meta = {} + self._states = {} + + if user is not None: + self.user = user + else: + self._user = DEFAULT_USER + self._states[self.client_id] = {"user": DEFAULT_USER} + self.on_change = on_change + + self._subscriptions = [] + + @property + def states(self) -> dict[int, dict[str, Any]]: + return self._states + + @property + def user(self) -> dict[str, str] | None: + return self._user + + @user.setter + def user(self, user: dict[str, str]): + self._user = user + self.set_local_state_field("user", self._user) def get_changes(self, message: bytes) -> dict[str, Any]: message = read_message(message) @@ -32,19 +73,19 @@ def get_changes(self, message: bytes) -> dict[str, Any]: if state is not None: states.append(state) client_meta = self.meta.get(client_id) - prev_state = self.states.get(client_id) + prev_state = self._states.get(client_id) curr_clock = 0 if client_meta is None else client_meta["clock"] if curr_clock < clock or ( - curr_clock == clock and state is None and client_id in self.states + curr_clock == clock and state is None and client_id in self._states ): if state is None: - if client_id == self.client_id and self.states.get(client_id) is not None: + if client_id == self.client_id and self._states.get(client_id) is not None: clock += 1 else: - if client_id in self.states: - del self.states[client_id] + if client_id in self._states: + del self._states[client_id] else: - self.states[client_id] = state + self._states[client_id] = state self.meta[client_id] = { "clock": clock, "last_updated": timestamp, @@ -57,10 +98,70 @@ def get_changes(self, message: bytes) -> dict[str, Any]: if state != prev_state: filtered_updated.append(client_id) updated.append(client_id) - return { + + changes = { "added": added, "updated": updated, "filtered_updated": filtered_updated, "removed": removed, "states": states, } + + # Do not trigger the callbacks if it is only a keep alive update + if added or filtered_updated or removed: + for callback in self._subscriptions: + callback(changes) + + return changes + + def get_local_state(self) -> dict[str, Any]: + return self._states.get(self.client_id, {}) + + def set_local_state(self, state: dict[str, Any]) -> None: + self.log('SET LOCAL CHANGE') + # Update the state and the meta. + timestamp = int(time.time() * 1000) + clock = self.meta.get(self.client_id, {}).get("clock", -1) + 1 + self._states[self.client_id] = state + self.meta[self.client_id] = {"clock": clock, "last_updated": timestamp} + # Build the message to broadcast, with the following information: + # - message type + # - length in bytes of the updates + # - number of updates + # - for each update + # - client_id + # - clock + # - length in bytes of the update + # - encoded update + msg = json.dumps(state, separators=(",", ":")).encode("utf-8") + msg = write_var_uint(len(msg)) + msg + msg = write_var_uint(clock) + msg + msg = write_var_uint(self.client_id) + msg + msg = write_var_uint(1) + msg + msg = write_var_uint(len(msg)) + msg + msg = write_var_uint(YMessageType.AWARENESS) + msg + + if self.on_change: + self.on_change(msg) + + def set_local_state_field(self, field: str, value: Any) -> None: + current_state = self.get_local_state() + current_state[field] = value + self.set_local_state(current_state) + + def observe(self, callback: Callable[[dict[str, Any]], None]) -> None: + """ + Subscribes to awareness changes. + + :param callback: Callback that will be called when the document changes. + :type callback: Callable[[str, Any], None] + """ + self._subscriptions.append(callback) + + def unobserve(self) -> None: + """ + Unsubscribes to awareness changes. + + This method removes all the callbacks. + """ + self._subscriptions = [] From 0df6b20fafb81deca5a151c40d5be0069b12c9ef Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 12:02:15 +0200 Subject: [PATCH 02/21] Add tests on awareness --- pyproject.toml | 1 + python/pycrdt/_awareness.py | 8 +-- tests/test_awareness.py | 108 ++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 tests/test_awareness.py diff --git a/pyproject.toml b/pyproject.toml index bf26c74..30268f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ "mypy", "coverage[toml] >=7", "exceptiongroup; python_version<'3.11'", + "dirty_equals", ] docs = [ "mkdocs", diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 428b699..7532ea2 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -3,13 +3,12 @@ import json import time from logging import Logger, getLogger -from typing import Any, Callable, Coroutine, Never +from typing import Any, Callable from uuid import uuid4 from ._doc import Doc from ._sync import Decoder, YMessageType, read_message, write_var_uint - DEFAULT_USER = {"username": str(uuid4()), "name": "Jupyter server"} @@ -25,7 +24,7 @@ def __init__( self, ydoc: Doc, log: Logger | None = None, - on_change: Callable[[bytes], Coroutine[Any, Any, Never]] | None = None, + on_change: Callable[[bytes], None] | None = None, user: dict[str, str] | None = None, ): self.client_id = ydoc.client_id @@ -118,7 +117,6 @@ def get_local_state(self) -> dict[str, Any]: return self._states.get(self.client_id, {}) def set_local_state(self, state: dict[str, Any]) -> None: - self.log('SET LOCAL CHANGE') # Update the state and the meta. timestamp = int(time.time() * 1000) clock = self.meta.get(self.client_id, {}).get("clock", -1) + 1 @@ -142,7 +140,7 @@ def set_local_state(self, state: dict[str, Any]) -> None: msg = write_var_uint(YMessageType.AWARENESS) + msg if self.on_change: - self.on_change(msg) + self.on_change(msg) def set_local_state_field(self, field: str, value: Any) -> None: current_state = self.get_local_state() diff --git a/tests/test_awareness.py b/tests/test_awareness.py new file mode 100644 index 0000000..6853c38 --- /dev/null +++ b/tests/test_awareness.py @@ -0,0 +1,108 @@ +import json + +from dirty_equals import IsStr +from pycrdt import Awareness, Doc + +DEFAULT_USER = {"username": IsStr(), "name": "Jupyter server"} + + +def test_awareness_default_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + assert awareness.user == DEFAULT_USER + + +def test_awareness_set_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + user = {"username": "test_username", "name": "test_name"} + awareness.user = user + assert awareness.user == user + + +def test_awareness_get_local_state(): + ydoc = Doc() + awareness = Awareness(ydoc) + + assert awareness.get_local_state() == {"user": DEFAULT_USER} + + +def test_awareness_set_local_state_field(): + ydoc = Doc() + awareness = Awareness(ydoc) + + awareness.set_local_state_field("new_field", "new_value") + assert awareness.get_local_state() == {"user": DEFAULT_USER, "new_field": "new_value"} + + +def test_awareness_get_changes(): + ydoc = Doc() + awareness = Awareness(ydoc) + + new_user = { + "user": { + "username": "2460ab00fd28415b87e49ec5aa2d482d", + "name": "Anonymous Ersa", + "display_name": "Anonymous Ersa", + "initials": "AE", + "avatar_url": None, + "color": "var(--jp-collaborator-color7)", + } + } + new_user_bytes = json.dumps(new_user, separators=(",", ":")).encode("utf-8") + new_user_message = b"\xc3\x01\x01\xfa\xa1\x8f\x97\x03\x03\xba\x01" + new_user_bytes + changes = awareness.get_changes(new_user_message) + assert changes == { + "added": [853790970], + "updated": [], + "filtered_updated": [], + "removed": [], + "states": [new_user], + } + assert awareness.states == {awareness.client_id: {"user": DEFAULT_USER}, 853790970: new_user} + + +def test_awareness_observes(): + ydoc = Doc() + awareness = Awareness(ydoc) + + called = {} + + def callback(value): + called.update(value) + + awareness.observe(callback) + + new_user = { + "user": { + "username": "2460ab00fd28415b87e49ec5aa2d482d", + "name": "Anonymous Ersa", + "display_name": "Anonymous Ersa", + "initials": "AE", + "avatar_url": None, + "color": "var(--jp-collaborator-color7)", + } + } + new_user_bytes = json.dumps(new_user, separators=(",", ":")).encode("utf-8") + new_user_message = b"\xc3\x01\x01\xfa\xa1\x8f\x97\x03\x03\xba\x01" + new_user_bytes + changes = awareness.get_changes(new_user_message) + + assert called == changes + + +def test_awareness_on_change(): + ydoc = Doc() + + changes = [] + + def callback(value): + changes.append(value) + + awareness = Awareness(ydoc, on_change=callback) + + awareness.set_local_state_field("new_field", "new_value") + + assert len(changes) == 1 + + assert type(changes[0]) is bytes From ac39dae209a2e0e50d24c2b11c27d9c9c942845c Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 15:05:20 +0200 Subject: [PATCH 03/21] use google style docstring --- python/pycrdt/_awareness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 7532ea2..a81fcf2 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -151,8 +151,8 @@ def observe(self, callback: Callable[[dict[str, Any]], None]) -> None: """ Subscribes to awareness changes. - :param callback: Callback that will be called when the document changes. - :type callback: Callable[[str, Any], None] + Args: + callback: Callback that will be called when the document changes. """ self._subscriptions.append(callback) From 8c3d21b18871f1f6a2b900bf02e05e2447cead8a Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 15:18:36 +0200 Subject: [PATCH 04/21] Generate the message in test for clarity --- tests/test_awareness.py | 61 +++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 6853c38..b13ef48 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -1,10 +1,32 @@ import json from dirty_equals import IsStr -from pycrdt import Awareness, Doc +from pycrdt import Awareness, Doc, write_var_uint DEFAULT_USER = {"username": IsStr(), "name": "Jupyter server"} +TEST_CLIENT_ID = 853790970 +TEST_USER = { + "user": { + "username": "2460ab00fd28415b87e49ec5aa2d482d", + "name": "Anonymous Ersa", + "display_name": "Anonymous Ersa", + "initials": "AE", + "avatar_url": None, + "color": "var(--jp-collaborator-color7)", + } +} + + +def create_bytes_message(client_id: int, user: dict[str, dict[str, str | None]]) -> bytes: + new_user_bytes = json.dumps(user, separators=(",", ":")).encode("utf-8") + msg = write_var_uint(len(new_user_bytes)) + new_user_bytes + msg = write_var_uint(1) + msg + msg = write_var_uint(client_id) + msg + msg = write_var_uint(1) + msg + msg = write_var_uint(len(msg)) + msg + return msg + def test_awareness_default_user(): ydoc = Doc() @@ -40,27 +62,18 @@ def test_awareness_get_changes(): ydoc = Doc() awareness = Awareness(ydoc) - new_user = { - "user": { - "username": "2460ab00fd28415b87e49ec5aa2d482d", - "name": "Anonymous Ersa", - "display_name": "Anonymous Ersa", - "initials": "AE", - "avatar_url": None, - "color": "var(--jp-collaborator-color7)", - } - } - new_user_bytes = json.dumps(new_user, separators=(",", ":")).encode("utf-8") - new_user_message = b"\xc3\x01\x01\xfa\xa1\x8f\x97\x03\x03\xba\x01" + new_user_bytes - changes = awareness.get_changes(new_user_message) + changes = awareness.get_changes(create_bytes_message(TEST_CLIENT_ID, TEST_USER)) assert changes == { - "added": [853790970], + "added": [TEST_CLIENT_ID], "updated": [], "filtered_updated": [], "removed": [], - "states": [new_user], + "states": [TEST_USER], + } + assert awareness.states == { + awareness.client_id: {"user": DEFAULT_USER}, + TEST_CLIENT_ID: TEST_USER, } - assert awareness.states == {awareness.client_id: {"user": DEFAULT_USER}, 853790970: new_user} def test_awareness_observes(): @@ -74,19 +87,7 @@ def callback(value): awareness.observe(callback) - new_user = { - "user": { - "username": "2460ab00fd28415b87e49ec5aa2d482d", - "name": "Anonymous Ersa", - "display_name": "Anonymous Ersa", - "initials": "AE", - "avatar_url": None, - "color": "var(--jp-collaborator-color7)", - } - } - new_user_bytes = json.dumps(new_user, separators=(",", ":")).encode("utf-8") - new_user_message = b"\xc3\x01\x01\xfa\xa1\x8f\x97\x03\x03\xba\x01" + new_user_bytes - changes = awareness.get_changes(new_user_message) + changes = awareness.get_changes(create_bytes_message(TEST_CLIENT_ID, TEST_USER)) assert called == changes From a183971d30193af4685a19adac6de0a3cd6f585b Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 17:00:33 +0200 Subject: [PATCH 05/21] Add docstring and tests --- python/pycrdt/_awareness.py | 25 ++++++++- tests/test_awareness.py | 107 +++++++++++++++++++++++++++++++----- 2 files changed, 116 insertions(+), 16 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index a81fcf2..11e02de 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -31,13 +31,13 @@ def __init__( self.log = log or getLogger(__name__) self.meta = {} self._states = {} + self.on_change = on_change if user is not None: self.user = user else: self._user = DEFAULT_USER self._states[self.client_id] = {"user": DEFAULT_USER} - self.on_change = on_change self._subscriptions = [] @@ -55,6 +55,13 @@ def user(self, user: dict[str, str]): self.set_local_state_field("user", self._user) def get_changes(self, message: bytes) -> dict[str, Any]: + """ + Update the states with a user state. + This function send the changes to subscribers. + + Args: + msg: Bytes representing the user state. + """ message = read_message(message) decoder = Decoder(message) timestamp = int(time.time() * 1000) @@ -117,7 +124,14 @@ def get_local_state(self) -> dict[str, Any]: return self._states.get(self.client_id, {}) def set_local_state(self, state: dict[str, Any]) -> None: - # Update the state and the meta. + """ + Update the local state and meta. + This function call the on_change() callback (if provided), with the states + formatted (bytes) as argument. + + Args: + state: The dictionary representing the state. + """ timestamp = int(time.time() * 1000) clock = self.meta.get(self.client_id, {}).get("clock", -1) + 1 self._states[self.client_id] = state @@ -143,6 +157,13 @@ def set_local_state(self, state: dict[str, Any]) -> None: self.on_change(msg) def set_local_state_field(self, field: str, value: Any) -> None: + """ + Set a local state field. + + Args: + field: The field to set (str) + value: the value of the field + """ current_state = self.get_local_state() current_state[field] = value self.set_local_state(current_state) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index b13ef48..ce55a7f 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -1,12 +1,14 @@ +from copy import deepcopy import json +from uuid import uuid4 from dirty_equals import IsStr from pycrdt import Awareness, Doc, write_var_uint DEFAULT_USER = {"username": IsStr(), "name": "Jupyter server"} - -TEST_CLIENT_ID = 853790970 -TEST_USER = { +TEST_USER = {"username": str(uuid4()), "name": "Test user"} +REMOTE_CLIENT_ID = 853790970 +REMOTE_USER = { "user": { "username": "2460ab00fd28415b87e49ec5aa2d482d", "name": "Anonymous Ersa", @@ -18,10 +20,17 @@ } -def create_bytes_message(client_id: int, user: dict[str, dict[str, str | None]]) -> bytes: - new_user_bytes = json.dumps(user, separators=(",", ":")).encode("utf-8") +def create_bytes_message( + client_id: int, + user: dict[str, dict[str, str | None]] | str, + clock: int = 1 +) -> bytes: + if type(user) is str: + new_user_bytes = user.encode("utf-8") + else: + new_user_bytes = json.dumps(user, separators=(",", ":")).encode("utf-8") msg = write_var_uint(len(new_user_bytes)) + new_user_bytes - msg = write_var_uint(1) + msg + msg = write_var_uint(clock) + msg msg = write_var_uint(client_id) + msg msg = write_var_uint(1) + msg msg = write_var_uint(len(msg)) + msg @@ -35,6 +44,13 @@ def test_awareness_default_user(): assert awareness.user == DEFAULT_USER +def test_awareness_with_user(): + ydoc = Doc() + awareness = Awareness(ydoc, user=TEST_USER) + + assert awareness.user == TEST_USER + + def test_awareness_set_user(): ydoc = Doc() awareness = Awareness(ydoc) @@ -58,24 +74,83 @@ def test_awareness_set_local_state_field(): assert awareness.get_local_state() == {"user": DEFAULT_USER, "new_field": "new_value"} -def test_awareness_get_changes(): +def test_awareness_add_user(): ydoc = Doc() awareness = Awareness(ydoc) - changes = awareness.get_changes(create_bytes_message(TEST_CLIENT_ID, TEST_USER)) + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) assert changes == { - "added": [TEST_CLIENT_ID], + "added": [REMOTE_CLIENT_ID], "updated": [], "filtered_updated": [], "removed": [], - "states": [TEST_USER], + "states": [REMOTE_USER], + } + assert awareness.states == { + awareness.client_id: {"user": DEFAULT_USER}, + REMOTE_CLIENT_ID: REMOTE_USER, + } + + +def test_awareness_update_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + + # Update it + remote_user = deepcopy(REMOTE_USER) + remote_user["user"]["name"] = "New user name" + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2)) + + assert changes == { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "filtered_updated": [REMOTE_CLIENT_ID], + "removed": [], + "states": [remote_user], } assert awareness.states == { awareness.client_id: {"user": DEFAULT_USER}, - TEST_CLIENT_ID: TEST_USER, + REMOTE_CLIENT_ID: remote_user, } +def test_awareness_remove_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + + # Remove it + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null", 2)) + + assert changes == { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [REMOTE_CLIENT_ID], + "states": [], + } + assert awareness.states == {awareness.client_id: {"user": DEFAULT_USER}} + + +def test_awareness_increment_clock(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = awareness.get_changes(create_bytes_message(awareness.client_id, "null")) + assert changes == { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [], + "states": [], + } + assert awareness.meta.get(awareness.client_id, {}).get("clock", 0) == 2 + + def test_awareness_observes(): ydoc = Doc() awareness = Awareness(ydoc) @@ -86,11 +161,15 @@ def callback(value): called.update(value) awareness.observe(callback) - - changes = awareness.get_changes(create_bytes_message(TEST_CLIENT_ID, TEST_USER)) - + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) assert called == changes + called = {} + awareness.unobserve() + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + assert called != changes + assert called == {} + def test_awareness_on_change(): ydoc = Doc() From 888dca995fc4d483d7ec754a4e0f81fbdc9bc744 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:00:43 +0000 Subject: [PATCH 06/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_awareness.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index ce55a7f..1d82bb7 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -1,5 +1,5 @@ -from copy import deepcopy import json +from copy import deepcopy from uuid import uuid4 from dirty_equals import IsStr @@ -21,9 +21,7 @@ def create_bytes_message( - client_id: int, - user: dict[str, dict[str, str | None]] | str, - clock: int = 1 + client_id: int, user: dict[str, dict[str, str | None]] | str, clock: int = 1 ) -> bytes: if type(user) is str: new_user_bytes = user.encode("utf-8") From 879b9ad1af4d70a84f8c4313e320b6a838175559 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 17:02:20 +0200 Subject: [PATCH 07/21] remove the unused logger --- python/pycrdt/_awareness.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 11e02de..fe7eb00 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -2,7 +2,6 @@ import json import time -from logging import Logger, getLogger from typing import Any, Callable from uuid import uuid4 @@ -14,7 +13,6 @@ class Awareness: client_id: int - log: Logger meta: dict[int, dict[str, Any]] _states: dict[int, dict[str, Any]] _subscriptions: list[Callable[[dict[str, Any]], None]] @@ -23,12 +21,10 @@ class Awareness: def __init__( self, ydoc: Doc, - log: Logger | None = None, on_change: Callable[[bytes], None] | None = None, user: dict[str, str] | None = None, ): self.client_id = ydoc.client_id - self.log = log or getLogger(__name__) self.meta = {} self._states = {} self.on_change = on_change From bf148d2aeff81fe58bd8e17f7af9de56ed05b139 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 3 Oct 2024 17:07:05 +0200 Subject: [PATCH 08/21] Remove typing from test --- tests/test_awareness.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 1d82bb7..805842d 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -20,9 +20,7 @@ } -def create_bytes_message( - client_id: int, user: dict[str, dict[str, str | None]] | str, clock: int = 1 -) -> bytes: +def create_bytes_message(client_id, user, clock=1) -> bytes: if type(user) is str: new_user_bytes = user.encode("utf-8") else: From 6d023688946c7d3e4cccb6e02ec5c516effa48e7 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:51:26 +0200 Subject: [PATCH 09/21] Apply suggestions from code review Co-authored-by: David Brochart --- python/pycrdt/_awareness.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index fe7eb00..0538f07 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -52,11 +52,11 @@ def user(self, user: dict[str, str]): def get_changes(self, message: bytes) -> dict[str, Any]: """ - Update the states with a user state. - This function send the changes to subscribers. + Updates the states with a user state. + This function sends the changes to subscribers. Args: - msg: Bytes representing the user state. + message: Bytes representing the user state. """ message = read_message(message) decoder = Decoder(message) @@ -121,9 +121,9 @@ def get_local_state(self) -> dict[str, Any]: def set_local_state(self, state: dict[str, Any]) -> None: """ - Update the local state and meta. - This function call the on_change() callback (if provided), with the states - formatted (bytes) as argument. + Updates the local state and meta. + This function calls the `on_change()` callback (if provided), with the serialized states + as argument. Args: state: The dictionary representing the state. @@ -175,8 +175,6 @@ def observe(self, callback: Callable[[dict[str, Any]], None]) -> None: def unobserve(self) -> None: """ - Unsubscribes to awareness changes. - - This method removes all the callbacks. + Unsubscribes to awareness changes. This method removes all the callbacks. """ self._subscriptions = [] From 14ac84f03597885f8b5013dc227afb7a75bdd380 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Fri, 4 Oct 2024 07:23:44 +0200 Subject: [PATCH 10/21] Add missing docstring --- python/pycrdt/_awareness.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 0538f07..460eb51 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -117,6 +117,9 @@ def get_changes(self, message: bytes) -> dict[str, Any]: return changes def get_local_state(self) -> dict[str, Any]: + """ + Returns the local state (the state of the current awareness client). + """ return self._states.get(self.client_id, {}) def set_local_state(self, state: dict[str, Any]) -> None: From 4fc8c936c3cca21c80d427004397b07c7b2af6be Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:03:13 +0200 Subject: [PATCH 11/21] Apply suggestions from code review Co-authored-by: David Brochart --- python/pycrdt/_awareness.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 460eb51..e5e8acb 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -118,7 +118,8 @@ def get_changes(self, message: bytes) -> dict[str, Any]: def get_local_state(self) -> dict[str, Any]: """ - Returns the local state (the state of the current awareness client). + Returns: + The local state (the state of the current awareness client). """ return self._states.get(self.client_id, {}) @@ -157,11 +158,11 @@ def set_local_state(self, state: dict[str, Any]) -> None: def set_local_state_field(self, field: str, value: Any) -> None: """ - Set a local state field. + Sets a local state field. Args: - field: The field to set (str) - value: the value of the field + field: The field to set. + value: The value of the field. """ current_state = self.get_local_state() current_state[field] = value From 5bde016ba87742f9ac443cd3455ac380bc0dbb41 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Fri, 4 Oct 2024 11:18:32 +0200 Subject: [PATCH 12/21] Remove the default user in the awareness --- pyproject.toml | 1 - python/pycrdt/_awareness.py | 6 ------ tests/test_awareness.py | 22 ++++++---------------- 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 30268f4..bf26c74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ test = [ "mypy", "coverage[toml] >=7", "exceptiongroup; python_version<'3.11'", - "dirty_equals", ] docs = [ "mkdocs", diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index e5e8acb..6da43f6 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -3,13 +3,10 @@ import json import time from typing import Any, Callable -from uuid import uuid4 from ._doc import Doc from ._sync import Decoder, YMessageType, read_message, write_var_uint -DEFAULT_USER = {"username": str(uuid4()), "name": "Jupyter server"} - class Awareness: client_id: int @@ -31,9 +28,6 @@ def __init__( if user is not None: self.user = user - else: - self._user = DEFAULT_USER - self._states[self.client_id] = {"user": DEFAULT_USER} self._subscriptions = [] diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 805842d..49756b9 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -2,10 +2,8 @@ from copy import deepcopy from uuid import uuid4 -from dirty_equals import IsStr from pycrdt import Awareness, Doc, write_var_uint -DEFAULT_USER = {"username": IsStr(), "name": "Jupyter server"} TEST_USER = {"username": str(uuid4()), "name": "Test user"} REMOTE_CLIENT_ID = 853790970 REMOTE_USER = { @@ -33,11 +31,11 @@ def create_bytes_message(client_id, user, clock=1) -> bytes: return msg -def test_awareness_default_user(): +def test_awareness_get_local_state(): ydoc = Doc() awareness = Awareness(ydoc) - assert awareness.user == DEFAULT_USER + assert awareness.get_local_state() == {} def test_awareness_with_user(): @@ -45,6 +43,7 @@ def test_awareness_with_user(): awareness = Awareness(ydoc, user=TEST_USER) assert awareness.user == TEST_USER + assert awareness.get_local_state() == {"user": TEST_USER} def test_awareness_set_user(): @@ -55,19 +54,12 @@ def test_awareness_set_user(): assert awareness.user == user -def test_awareness_get_local_state(): - ydoc = Doc() - awareness = Awareness(ydoc) - - assert awareness.get_local_state() == {"user": DEFAULT_USER} - - def test_awareness_set_local_state_field(): ydoc = Doc() awareness = Awareness(ydoc) awareness.set_local_state_field("new_field", "new_value") - assert awareness.get_local_state() == {"user": DEFAULT_USER, "new_field": "new_value"} + assert awareness.get_local_state() == {"new_field": "new_value"} def test_awareness_add_user(): @@ -83,7 +75,6 @@ def test_awareness_add_user(): "states": [REMOTE_USER], } assert awareness.states == { - awareness.client_id: {"user": DEFAULT_USER}, REMOTE_CLIENT_ID: REMOTE_USER, } @@ -108,7 +99,6 @@ def test_awareness_update_user(): "states": [remote_user], } assert awareness.states == { - awareness.client_id: {"user": DEFAULT_USER}, REMOTE_CLIENT_ID: remote_user, } @@ -130,7 +120,7 @@ def test_awareness_remove_user(): "removed": [REMOTE_CLIENT_ID], "states": [], } - assert awareness.states == {awareness.client_id: {"user": DEFAULT_USER}} + assert awareness.states == {} def test_awareness_increment_clock(): @@ -144,7 +134,7 @@ def test_awareness_increment_clock(): "removed": [], "states": [], } - assert awareness.meta.get(awareness.client_id, {}).get("clock", 0) == 2 + assert awareness.meta.get(awareness.client_id, {}).get("clock", 0) == 1 def test_awareness_observes(): From f523e3dc03265af37df0483f5e0716a7c9e44c39 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Fri, 4 Oct 2024 11:33:06 +0200 Subject: [PATCH 13/21] Remove totally the conept of user in the awareness --- python/pycrdt/_awareness.py | 20 ++++---------------- tests/test_awareness.py | 28 ++++++++++------------------ 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 6da43f6..c0f5dc4 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -13,44 +13,32 @@ class Awareness: meta: dict[int, dict[str, Any]] _states: dict[int, dict[str, Any]] _subscriptions: list[Callable[[dict[str, Any]], None]] - _user: dict[str, str] | None def __init__( self, ydoc: Doc, on_change: Callable[[bytes], None] | None = None, - user: dict[str, str] | None = None, ): self.client_id = ydoc.client_id self.meta = {} self._states = {} self.on_change = on_change - if user is not None: - self.user = user - self._subscriptions = [] @property def states(self) -> dict[int, dict[str, Any]]: return self._states - @property - def user(self) -> dict[str, str] | None: - return self._user - - @user.setter - def user(self, user: dict[str, str]): - self._user = user - self.set_local_state_field("user", self._user) - def get_changes(self, message: bytes) -> dict[str, Any]: """ - Updates the states with a user state. + Updates the states with a client state. This function sends the changes to subscribers. Args: - message: Bytes representing the user state. + message: Bytes representing the client state. + Returns: + A dictionary summarizing the changes. """ message = read_message(message) decoder = Decoder(message) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 49756b9..a39dc85 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -38,22 +38,6 @@ def test_awareness_get_local_state(): assert awareness.get_local_state() == {} -def test_awareness_with_user(): - ydoc = Doc() - awareness = Awareness(ydoc, user=TEST_USER) - - assert awareness.user == TEST_USER - assert awareness.get_local_state() == {"user": TEST_USER} - - -def test_awareness_set_user(): - ydoc = Doc() - awareness = Awareness(ydoc) - user = {"username": "test_username", "name": "test_name"} - awareness.user = user - assert awareness.user == user - - def test_awareness_set_local_state_field(): ydoc = Doc() awareness = Awareness(ydoc) @@ -123,7 +107,7 @@ def test_awareness_remove_user(): assert awareness.states == {} -def test_awareness_increment_clock(): +def test_awareness_do_not_increment_clock(): ydoc = Doc() awareness = Awareness(ydoc) changes = awareness.get_changes(create_bytes_message(awareness.client_id, "null")) @@ -134,7 +118,15 @@ def test_awareness_increment_clock(): "removed": [], "states": [], } - assert awareness.meta.get(awareness.client_id, {}).get("clock", 0) == 1 + assert awareness.meta.get(awareness.client_id, {}).get("clock") == 1 + + +def test_awareness_increment_clock(): + ydoc = Doc() + awareness = Awareness(ydoc) + awareness.set_local_state_field("new_field", "new_value") + awareness.get_changes(create_bytes_message(awareness.client_id, "null")) + assert awareness.meta.get(awareness.client_id, {}).get("clock") == 2 def test_awareness_observes(): From 6b06af1a2ba5bc9d26f3dafaf8220550f8b2b2a3 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Fri, 4 Oct 2024 12:03:25 +0200 Subject: [PATCH 14/21] Add subscription id --- python/pycrdt/_awareness.py | 21 ++++++++++++--------- tests/test_awareness.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index c0f5dc4..411ccaa 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -3,6 +3,7 @@ import json import time from typing import Any, Callable +from uuid import uuid4 from ._doc import Doc from ._sync import Decoder, YMessageType, read_message, write_var_uint @@ -12,7 +13,7 @@ class Awareness: client_id: int meta: dict[int, dict[str, Any]] _states: dict[int, dict[str, Any]] - _subscriptions: list[Callable[[dict[str, Any]], None]] + _subscriptions: dict[str, Callable[[dict[str, Any]], None]] def __init__( self, @@ -24,7 +25,7 @@ def __init__( self._states = {} self.on_change = on_change - self._subscriptions = [] + self._subscriptions = {} @property def states(self) -> dict[int, dict[str, Any]]: @@ -32,8 +33,7 @@ def states(self) -> dict[int, dict[str, Any]]: def get_changes(self, message: bytes) -> dict[str, Any]: """ - Updates the states with a client state. - This function sends the changes to subscribers. + Updates the states and sends the changes to subscribers. Args: message: Bytes representing the client state. @@ -93,7 +93,7 @@ def get_changes(self, message: bytes) -> dict[str, Any]: # Do not trigger the callbacks if it is only a keep alive update if added or filtered_updated or removed: - for callback in self._subscriptions: + for callback in self._subscriptions.values(): callback(changes) return changes @@ -150,17 +150,20 @@ def set_local_state_field(self, field: str, value: Any) -> None: current_state[field] = value self.set_local_state(current_state) - def observe(self, callback: Callable[[dict[str, Any]], None]) -> None: + def observe(self, callback: Callable[[dict[str, Any]], None]) -> str: """ Subscribes to awareness changes. Args: callback: Callback that will be called when the document changes. """ - self._subscriptions.append(callback) + id = str(uuid4()) + self._subscriptions[id] = callback + return id - def unobserve(self) -> None: + def unobserve(self, id: str) -> None: """ Unsubscribes to awareness changes. This method removes all the callbacks. """ - self._subscriptions = [] + if id in self._subscriptions.keys(): + del self._subscriptions[id] diff --git a/tests/test_awareness.py b/tests/test_awareness.py index a39dc85..6522e08 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -133,20 +133,34 @@ def test_awareness_observes(): ydoc = Doc() awareness = Awareness(ydoc) - called = {} + called_1 = {} + called_2 = {} - def callback(value): - called.update(value) + def callback_1(value): + called_1.update(value) - awareness.observe(callback) - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) - assert called == changes + def callback_2(value): + called_2.update(value) - called = {} - awareness.unobserve() + awareness.observe(callback_1) + sub_2 = awareness.observe(callback_2) changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) - assert called != changes - assert called == {} + assert called_1 == changes + assert called_2 == changes + + keys = list(called_1.keys()) + for k in keys: + del called_1[k] + + keys = list(called_2.keys()) + for k in keys: + del called_2[k] + + awareness.unobserve(sub_2) + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null")) + assert called_1 == changes + assert called_2 != changes + assert called_2 == {} def test_awareness_on_change(): From edb66a59ff4a170a7b2cbe856efd376f0ca7f2a5 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Fri, 4 Oct 2024 12:06:25 +0200 Subject: [PATCH 15/21] update docstring according to review --- python/pycrdt/_awareness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 411ccaa..3a2f246 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -36,7 +36,7 @@ def get_changes(self, message: bytes) -> dict[str, Any]: Updates the states and sends the changes to subscribers. Args: - message: Bytes representing the client state. + message: The binary changes. Returns: A dictionary summarizing the changes. """ From 98540a57c78b63c6cfb6fc070a97df4b6d8d1692 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 4 Oct 2024 16:51:01 +0200 Subject: [PATCH 16/21] Remove on_change callback --- docs/api_reference.md | 1 + python/pycrdt/_awareness.py | 109 ++++++++++++++++++++---------------- tests/test_awareness.py | 22 ++------ 3 files changed, 68 insertions(+), 64 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index b65c05b..debb7e1 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -8,6 +8,7 @@ - BaseType - Array - ArrayEvent + - Awareness - Decoder - Doc - Map diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 3a2f246..a936f70 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -11,24 +11,29 @@ class Awareness: client_id: int - meta: dict[int, dict[str, Any]] + _meta: dict[int, dict[str, Any]] _states: dict[int, dict[str, Any]] _subscriptions: dict[str, Callable[[dict[str, Any]], None]] - def __init__( - self, - ydoc: Doc, - on_change: Callable[[bytes], None] | None = None, - ): + def __init__(self, ydoc: Doc): + """ + Args: + ydoc: The [Doc][pycrdt.Doc] to associate the awareness with. + """ self.client_id = ydoc.client_id - self.meta = {} + self._meta = {} self._states = {} - self.on_change = on_change self._subscriptions = {} + @property + def meta(self) -> dict[int, dict[str, Any]]: + """The clients' metadata.""" + return self._meta + @property def states(self) -> dict[int, dict[str, Any]]: + """The client states.""" return self._states def get_changes(self, message: bytes) -> dict[str, Any]: @@ -37,6 +42,7 @@ def get_changes(self, message: bytes) -> dict[str, Any]: Args: message: The binary changes. + Returns: A dictionary summarizing the changes. """ @@ -56,7 +62,7 @@ def get_changes(self, message: bytes) -> dict[str, Any]: state = None if not state_str else json.loads(state_str) if state is not None: states.append(state) - client_meta = self.meta.get(client_id) + client_meta = self._meta.get(client_id) prev_state = self._states.get(client_id) curr_clock = 0 if client_meta is None else client_meta["clock"] if curr_clock < clock or ( @@ -70,7 +76,7 @@ def get_changes(self, message: bytes) -> dict[str, Any]: del self._states[client_id] else: self._states[client_id] = state - self.meta[client_id] = { + self._meta[client_id] = { "clock": clock, "last_updated": timestamp, } @@ -101,61 +107,66 @@ def get_changes(self, message: bytes) -> dict[str, Any]: def get_local_state(self) -> dict[str, Any]: """ Returns: - The local state (the state of the current awareness client). + The local state. """ return self._states.get(self.client_id, {}) - def set_local_state(self, state: dict[str, Any]) -> None: + def set_local_state(self, state: dict[str, Any], encode: bool = True) -> bytes | None: """ - Updates the local state and meta. - This function calls the `on_change()` callback (if provided), with the serialized states - as argument. + Updates the local state and meta, and optionally returns the encoded new state. Args: - state: The dictionary representing the state. + state: The new local state. + encode: Whether to encode the new state and return it. + + Returns: + The encoded new state, if `encode==True`. """ timestamp = int(time.time() * 1000) - clock = self.meta.get(self.client_id, {}).get("clock", -1) + 1 + clock = self._meta.get(self.client_id, {}).get("clock", -1) + 1 self._states[self.client_id] = state - self.meta[self.client_id] = {"clock": clock, "last_updated": timestamp} - # Build the message to broadcast, with the following information: - # - message type - # - length in bytes of the updates - # - number of updates - # - for each update - # - client_id - # - clock - # - length in bytes of the update - # - encoded update - msg = json.dumps(state, separators=(",", ":")).encode("utf-8") - msg = write_var_uint(len(msg)) + msg - msg = write_var_uint(clock) + msg - msg = write_var_uint(self.client_id) + msg - msg = write_var_uint(1) + msg - msg = write_var_uint(len(msg)) + msg - msg = write_var_uint(YMessageType.AWARENESS) + msg - - if self.on_change: - self.on_change(msg) - - def set_local_state_field(self, field: str, value: Any) -> None: + self._meta[self.client_id] = {"clock": clock, "last_updated": timestamp} + if encode: + update = json.dumps(state, separators=(",", ":")).encode() + message0 = [update] + message0.insert(0, write_var_uint(len(update))) + message0.insert(0, write_var_uint(clock)) + message0.insert(0, write_var_uint(self.client_id)) + message0.insert(0, bytes(1)) + message0_bytes = b"".join(message0) + message1 = [ + bytes(YMessageType.AWARENESS), + write_var_uint(len(message0_bytes)), + message0_bytes, + ] + message = b"".join(message1) + return message + return None + + def set_local_state_field(self, field: str, value: Any, encode: bool = True) -> bytes | None: """ - Sets a local state field. + Sets a local state field, and optionally returns the encoded new state. Args: - field: The field to set. - value: The value of the field. + field: The field of the local state to set. + value: The value associated with the field. + + Returns: + The encoded new state, if `encode==True`. """ current_state = self.get_local_state() current_state[field] = value - self.set_local_state(current_state) + return self.set_local_state(current_state, encode) def observe(self, callback: Callable[[dict[str, Any]], None]) -> str: """ - Subscribes to awareness changes. + Registers the given callback to awareness changes. Args: - callback: Callback that will be called when the document changes. + callback: The callback to call with the awareness changes. + + Returns: + The subscription ID that can be used to unobserve. """ id = str(uuid4()) self._subscriptions[id] = callback @@ -163,7 +174,9 @@ def observe(self, callback: Callable[[dict[str, Any]], None]) -> str: def unobserve(self, id: str) -> None: """ - Unsubscribes to awareness changes. This method removes all the callbacks. + Unregisters the given subscription ID from awareness changes. + + Args: + id: The subscription ID to unregister. """ - if id in self._subscriptions.keys(): - del self._subscriptions[id] + del self._subscriptions[id] diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 6522e08..5e1f570 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -42,7 +42,7 @@ def test_awareness_set_local_state_field(): ydoc = Doc() awareness = Awareness(ydoc) - awareness.set_local_state_field("new_field", "new_value") + awareness.set_local_state_field("new_field", "new_value", encode=False) assert awareness.get_local_state() == {"new_field": "new_value"} @@ -124,7 +124,7 @@ def test_awareness_do_not_increment_clock(): def test_awareness_increment_clock(): ydoc = Doc() awareness = Awareness(ydoc) - awareness.set_local_state_field("new_field", "new_value") + awareness.set_local_state_field("new_field", "new_value", encode=False) awareness.get_changes(create_bytes_message(awareness.client_id, "null")) assert awareness.meta.get(awareness.client_id, {}).get("clock") == 2 @@ -163,18 +163,8 @@ def callback_2(value): assert called_2 == {} -def test_awareness_on_change(): +def test_awareness_encode(): ydoc = Doc() - - changes = [] - - def callback(value): - changes.append(value) - - awareness = Awareness(ydoc, on_change=callback) - - awareness.set_local_state_field("new_field", "new_value") - - assert len(changes) == 1 - - assert type(changes[0]) is bytes + awareness = Awareness(ydoc) + encoded_state = awareness.set_local_state_field("new_field", "new_value", encode=True) + assert encoded_state.endswith(b'{"new_field":"new_value"}') From 3d4beafd1561d21ff571ebadd839bb1242cbb6a1 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Mon, 7 Oct 2024 16:00:39 +0200 Subject: [PATCH 17/21] Observe both local and remote changes, and add a function to encode the changes --- python/pycrdt/_awareness.py | 213 ++++++++++++++++++++++++------------ tests/test_awareness.py | 195 +++++++++++++++++++++++++++------ 2 files changed, 307 insertions(+), 101 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index a936f70..f25cbdd 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -5,6 +5,8 @@ from typing import Any, Callable from uuid import uuid4 +from typing_extensions import deprecated + from ._doc import Doc from ._sync import Decoder, YMessageType, read_message, write_var_uint @@ -36,9 +38,10 @@ def states(self) -> dict[int, dict[str, Any]]: """The client states.""" return self._states + @deprecated("Use `apply_awareness_update()` instead") def get_changes(self, message: bytes) -> dict[str, Any]: """ - Updates the states and sends the changes to subscribers. + Apply states update and sends the changes to subscribers. Args: message: The binary changes. @@ -46,15 +49,35 @@ def get_changes(self, message: bytes) -> dict[str, Any]: Returns: A dictionary summarizing the changes. """ - message = read_message(message) - decoder = Decoder(message) - timestamp = int(time.time() * 1000) - added = [] - updated = [] - filtered_updated = [] - removed = [] + changes = self.apply_awareness_update(message, "remote") + states_changes = changes["changes"] + client_ids = [*states_changes["added"], *states_changes["filtered_updated"]] + states = [self._states[client_id] for client_id in client_ids] + states_changes["states"] = states + return states_changes + + def apply_awareness_update(self, update: bytes, origin: str) -> dict[str, Any]: + """ + Apply states update and sends the changes to subscribers. + + Args: + message: The binary changes. + origin: The origin of the change. + + Returns: + A dictionary with the changes and the origin. + """ + update = read_message(update) + decoder = Decoder(update) states = [] length = decoder.read_var_uint() + states_changes = { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [], + } + for _ in range(length): client_id = decoder.read_var_uint() clock = decoder.read_var_uint() @@ -62,43 +85,19 @@ def get_changes(self, message: bytes) -> dict[str, Any]: state = None if not state_str else json.loads(state_str) if state is not None: states.append(state) - client_meta = self._meta.get(client_id) - prev_state = self._states.get(client_id) - curr_clock = 0 if client_meta is None else client_meta["clock"] - if curr_clock < clock or ( - curr_clock == clock and state is None and client_id in self._states - ): - if state is None: - if client_id == self.client_id and self._states.get(client_id) is not None: - clock += 1 - else: - if client_id in self._states: - del self._states[client_id] - else: - self._states[client_id] = state - self._meta[client_id] = { - "clock": clock, - "last_updated": timestamp, - } - if client_meta is None and state is not None: - added.append(client_id) - elif client_meta is not None and state is None: - removed.append(client_id) - elif state is not None: - if state != prev_state: - filtered_updated.append(client_id) - updated.append(client_id) + self._update_states(client_id, clock, state, states_changes) changes = { - "added": added, - "updated": updated, - "filtered_updated": filtered_updated, - "removed": removed, - "states": states, + "changes": states_changes, + "origin": origin, } # Do not trigger the callbacks if it is only a keep alive update - if added or filtered_updated or removed: + if ( + states_changes["added"] + or states_changes["filtered_updated"] + or states_changes["removed"] + ): for callback in self._subscriptions.values(): callback(changes) @@ -111,39 +110,41 @@ def get_local_state(self) -> dict[str, Any]: """ return self._states.get(self.client_id, {}) - def set_local_state(self, state: dict[str, Any], encode: bool = True) -> bytes | None: + def set_local_state(self, state: dict[str, Any]) -> dict[str, Any]: """ - Updates the local state and meta, and optionally returns the encoded new state. + Updates the local state and meta, and sends the changes to subscribers. Args: state: The new local state. - encode: Whether to encode the new state and return it. Returns: - The encoded new state, if `encode==True`. + A dictionary with the changes and the origin (="local"). """ - timestamp = int(time.time() * 1000) - clock = self._meta.get(self.client_id, {}).get("clock", -1) + 1 - self._states[self.client_id] = state - self._meta[self.client_id] = {"clock": clock, "last_updated": timestamp} - if encode: - update = json.dumps(state, separators=(",", ":")).encode() - message0 = [update] - message0.insert(0, write_var_uint(len(update))) - message0.insert(0, write_var_uint(clock)) - message0.insert(0, write_var_uint(self.client_id)) - message0.insert(0, bytes(1)) - message0_bytes = b"".join(message0) - message1 = [ - bytes(YMessageType.AWARENESS), - write_var_uint(len(message0_bytes)), - message0_bytes, - ] - message = b"".join(message1) - return message - return None - - def set_local_state_field(self, field: str, value: Any, encode: bool = True) -> bytes | None: + clock = self._meta.get(self.client_id, {}).get("clock", 0) + 1 + states_changes = { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [], + } + self._update_states(self.client_id, clock, state, states_changes) + + changes = { + "changes": states_changes, + "origin": "local", + } + + if ( + states_changes["added"] + or states_changes["filtered_updated"] + or states_changes["removed"] + ): + for callback in self._subscriptions.values(): + callback(changes) + + return changes + + def set_local_state_field(self, field: str, value: Any) -> dict[str, Any]: """ Sets a local state field, and optionally returns the encoded new state. @@ -152,11 +153,47 @@ def set_local_state_field(self, field: str, value: Any, encode: bool = True) -> value: The value associated with the field. Returns: - The encoded new state, if `encode==True`. + A dictionary with the changes and the origin (="local"). """ current_state = self.get_local_state() current_state[field] = value - return self.set_local_state(current_state, encode) + return self.set_local_state(current_state) + + def encode_awareness_update(self, client_ids: list[int]) -> bytes | None: + """ + Encode the states of the client ids. + + Args: + client_ids: The list of clients' state to update. + + Returns: + The encoded clients' state. + """ + messages = [] + for client_id in client_ids: + if client_id not in self._states: + continue + state = self._states[client_id] + meta = self._meta[client_id] + update = json.dumps(state, separators=(",", ":")).encode() + client_msg = [update] + client_msg.insert(0, write_var_uint(len(update))) + client_msg.insert(0, write_var_uint(meta.get("clock", 0))) + client_msg.insert(0, write_var_uint(client_id)) + messages.append(b"".join(client_msg)) + + if not messages: + return + + messages.insert(0, write_var_uint(len(client_ids))) + encoded_messages = b"".join(messages) + + message = [ + write_var_uint(YMessageType.AWARENESS), + write_var_uint(len(encoded_messages)), + encoded_messages, + ] + return b"".join(message) def observe(self, callback: Callable[[dict[str, Any]], None]) -> str: """ @@ -180,3 +217,43 @@ def unobserve(self, id: str) -> None: id: The subscription ID to unregister. """ del self._subscriptions[id] + + def _update_states( + self, client_id: int, clock: int, state: Any, states_changes: dict[str, list[str]] + ) -> None: + """ + Update the states of the clients, and the states_changes dictionary. + + Args: + client_id: The client's state to update. + clock: The clock of this client. + state: The updated state. + states_changes: The changes to updates. + """ + timestamp = int(time.time() * 1000) + client_meta = self._meta.get(client_id) + prev_state = self._states.get(client_id) + curr_clock = 0 if client_meta is None else client_meta["clock"] + if curr_clock < clock or ( + curr_clock == clock and state is None and client_id in self._states + ): + if state is None: + if client_id == self.client_id and self._states.get(client_id) is not None: + clock += 1 + else: + if client_id in self._states: + del self._states[client_id] + else: + self._states[client_id] = state + self._meta[client_id] = { + "clock": clock, + "last_updated": timestamp, + } + if client_meta is None and state is not None: + states_changes["added"].append(client_id) + elif client_meta is not None and state is None: + states_changes["removed"].append(client_id) + elif state is not None: + if state != prev_state: + states_changes["filtered_updated"].append(client_id) + states_changes["updated"].append(client_id) diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 5e1f570..d56e068 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -42,7 +42,7 @@ def test_awareness_set_local_state_field(): ydoc = Doc() awareness = Awareness(ydoc) - awareness.set_local_state_field("new_field", "new_value", encode=False) + awareness.set_local_state_field("new_field", "new_value") assert awareness.get_local_state() == {"new_field": "new_value"} @@ -50,13 +50,18 @@ def test_awareness_add_user(): ydoc = Doc() awareness = Awareness(ydoc) - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + changes = awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) assert changes == { - "added": [REMOTE_CLIENT_ID], - "updated": [], - "filtered_updated": [], - "removed": [], - "states": [REMOTE_USER], + "changes": { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "filtered_updated": [], + "removed": [], + }, + "origin": "custom_origin", } assert awareness.states == { REMOTE_CLIENT_ID: REMOTE_USER, @@ -68,19 +73,27 @@ def test_awareness_update_user(): awareness = Awareness(ydoc) # Add a remote user. - awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) # Update it remote_user = deepcopy(REMOTE_USER) remote_user["user"]["name"] = "New user name" - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2)) + changes = awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2), + "custom_origin", + ) assert changes == { - "added": [], - "updated": [REMOTE_CLIENT_ID], - "filtered_updated": [REMOTE_CLIENT_ID], - "removed": [], - "states": [remote_user], + "changes": { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "filtered_updated": [REMOTE_CLIENT_ID], + "removed": [], + }, + "origin": "custom_origin", } assert awareness.states == { REMOTE_CLIENT_ID: remote_user, @@ -92,17 +105,25 @@ def test_awareness_remove_user(): awareness = Awareness(ydoc) # Add a remote user. - awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) # Remove it - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null", 2)) + changes = awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, "null", 2), + "custom_origin", + ) assert changes == { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [REMOTE_CLIENT_ID], - "states": [], + "changes": { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [REMOTE_CLIENT_ID], + }, + "origin": "custom_origin", } assert awareness.states == {} @@ -110,13 +131,18 @@ def test_awareness_remove_user(): def test_awareness_do_not_increment_clock(): ydoc = Doc() awareness = Awareness(ydoc) - changes = awareness.get_changes(create_bytes_message(awareness.client_id, "null")) + changes = awareness.apply_awareness_update( + create_bytes_message(awareness.client_id, "null"), + "custom_origin", + ) assert changes == { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [], - "states": [], + "changes": { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [], + }, + "origin": "custom_origin", } assert awareness.meta.get(awareness.client_id, {}).get("clock") == 1 @@ -124,8 +150,11 @@ def test_awareness_do_not_increment_clock(): def test_awareness_increment_clock(): ydoc = Doc() awareness = Awareness(ydoc) - awareness.set_local_state_field("new_field", "new_value", encode=False) - awareness.get_changes(create_bytes_message(awareness.client_id, "null")) + awareness.set_local_state_field("new_field", "new_value") + awareness.apply_awareness_update( + create_bytes_message(awareness.client_id, "null"), + "custom_origin", + ) assert awareness.meta.get(awareness.client_id, {}).get("clock") == 2 @@ -144,7 +173,10 @@ def callback_2(value): awareness.observe(callback_1) sub_2 = awareness.observe(callback_2) - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + changes = awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) assert called_1 == changes assert called_2 == changes @@ -157,14 +189,111 @@ def callback_2(value): del called_2[k] awareness.unobserve(sub_2) - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null")) + changes = awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, "null"), + "custom_origin", + ) assert called_1 == changes assert called_2 != changes assert called_2 == {} +def test_awareness_observes_local_change(): + ydoc = Doc() + awareness = Awareness(ydoc) + + called_1 = {} + + def callback_1(value): + called_1.update(value) + + awareness.observe(callback_1) + changes = awareness.set_local_state_field("new_field", "new_value") + assert changes["origin"] == "local" + assert called_1 == changes + + def test_awareness_encode(): ydoc = Doc() awareness = Awareness(ydoc) - encoded_state = awareness.set_local_state_field("new_field", "new_value", encode=True) - assert encoded_state.endswith(b'{"new_field":"new_value"}') + + changes = awareness.set_local_state_field("new_field", "new_value") + states_bytes = awareness.encode_awareness_update(changes["changes"]["added"]) + assert states_bytes[1:] == create_bytes_message( + awareness.client_id, awareness.get_local_state() + ) + + +def test_awareness_encode_wrong_id(): + ydoc = Doc() + awareness = Awareness(ydoc) + + states_bytes = awareness.encode_awareness_update([10]) + assert states_bytes is None + + +def test_awareness_deprecated_add_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) + assert changes == { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "filtered_updated": [], + "removed": [], + "states": [REMOTE_USER], + } + assert awareness.states == { + REMOTE_CLIENT_ID: REMOTE_USER, + } + + +def test_awareness_deprecated_update_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + + # Update it + remote_user = deepcopy(REMOTE_USER) + remote_user["user"]["name"] = "New user name" + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2)) + + assert changes == { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "filtered_updated": [REMOTE_CLIENT_ID], + "removed": [], + "states": [remote_user], + } + assert awareness.states == { + REMOTE_CLIENT_ID: remote_user, + } + + +def test_awareness_deprecated_remove_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.apply_awareness_update( + create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + + # Remove it + changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null", 2)) + + assert changes == { + "added": [], + "updated": [], + "filtered_updated": [], + "removed": [REMOTE_CLIENT_ID], + "states": [], + } + assert awareness.states == {} From 0a37d1732c30d05ba3698a46002239a565b1b307 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Mon, 7 Oct 2024 16:06:22 +0200 Subject: [PATCH 18/21] mypy --- python/pycrdt/_awareness.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index f25cbdd..1b776da 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -71,7 +71,7 @@ def apply_awareness_update(self, update: bytes, origin: str) -> dict[str, Any]: decoder = Decoder(update) states = [] length = decoder.read_var_uint() - states_changes = { + states_changes: dict[str, list[int]] = { "added": [], "updated": [], "filtered_updated": [], @@ -121,7 +121,7 @@ def set_local_state(self, state: dict[str, Any]) -> dict[str, Any]: A dictionary with the changes and the origin (="local"). """ clock = self._meta.get(self.client_id, {}).get("clock", 0) + 1 - states_changes = { + states_changes: dict[str, list[int]] = { "added": [], "updated": [], "filtered_updated": [], @@ -183,7 +183,7 @@ def encode_awareness_update(self, client_ids: list[int]) -> bytes | None: messages.append(b"".join(client_msg)) if not messages: - return + return None messages.insert(0, write_var_uint(len(client_ids))) encoded_messages = b"".join(messages) @@ -219,7 +219,7 @@ def unobserve(self, id: str) -> None: del self._subscriptions[id] def _update_states( - self, client_id: int, clock: int, state: Any, states_changes: dict[str, list[str]] + self, client_id: int, clock: int, state: Any, states_changes: dict[str, list[int]] ) -> None: """ Update the states of the clients, and the states_changes dictionary. From c75c8f95697bc0c7b9a55d98f76d99ec576c33d0 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Tue, 8 Oct 2024 10:05:37 +0200 Subject: [PATCH 19/21] Mimic awareness.js --- docs/api_reference.md | 1 + python/pycrdt/__init__.py | 1 + python/pycrdt/_awareness.py | 294 +++++++++++---------------- python/pycrdt/_sync.py | 37 ++++ tests/test_awareness.py | 384 ++++++++++++++++++++---------------- 5 files changed, 367 insertions(+), 350 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index debb7e1..d570aff 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -11,6 +11,7 @@ - Awareness - Decoder - Doc + - Encoder - Map - MapEvent - NewTransaction diff --git a/python/pycrdt/__init__.py b/python/pycrdt/__init__.py index 21c8e09..fdfb790 100644 --- a/python/pycrdt/__init__.py +++ b/python/pycrdt/__init__.py @@ -9,6 +9,7 @@ from ._pycrdt import Subscription as Subscription from ._pycrdt import TransactionEvent as TransactionEvent from ._sync import Decoder as Decoder +from ._sync import Encoder as Encoder from ._sync import YMessageType as YMessageType from ._sync import YSyncMessageType as YSyncMessageType from ._sync import create_sync_message as create_sync_message diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 1b776da..121da73 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -2,20 +2,18 @@ import json import time -from typing import Any, Callable +from typing import Any, Callable, cast from uuid import uuid4 -from typing_extensions import deprecated - from ._doc import Doc -from ._sync import Decoder, YMessageType, read_message, write_var_uint +from ._sync import Decoder, Encoder class Awareness: client_id: int _meta: dict[int, dict[str, Any]] _states: dict[int, dict[str, Any]] - _subscriptions: dict[str, Callable[[dict[str, Any]], None]] + _subscriptions: dict[str, Callable[[str, tuple[dict[str, Any], Any]], None]] def __init__(self, ydoc: Doc): """ @@ -25,8 +23,8 @@ def __init__(self, ydoc: Doc): self.client_id = ydoc.client_id self._meta = {} self._states = {} - self._subscriptions = {} + self.set_local_state({}) @property def meta(self) -> dict[int, dict[str, Any]]: @@ -38,164 +36,148 @@ def states(self) -> dict[int, dict[str, Any]]: """The client states.""" return self._states - @deprecated("Use `apply_awareness_update()` instead") - def get_changes(self, message: bytes) -> dict[str, Any]: + def get_local_state(self) -> dict[str, Any] | None: """ - Apply states update and sends the changes to subscribers. - - Args: - message: The binary changes. - Returns: - A dictionary summarizing the changes. - """ - changes = self.apply_awareness_update(message, "remote") - states_changes = changes["changes"] - client_ids = [*states_changes["added"], *states_changes["filtered_updated"]] - states = [self._states[client_id] for client_id in client_ids] - states_changes["states"] = states - return states_changes - - def apply_awareness_update(self, update: bytes, origin: str) -> dict[str, Any]: + The local state, if any. """ - Apply states update and sends the changes to subscribers. - - Args: - message: The binary changes. - origin: The origin of the change. + return self._states.get(self.client_id) - Returns: - A dictionary with the changes and the origin. - """ - update = read_message(update) - decoder = Decoder(update) - states = [] - length = decoder.read_var_uint() - states_changes: dict[str, list[int]] = { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [], - } - - for _ in range(length): - client_id = decoder.read_var_uint() - clock = decoder.read_var_uint() - state_str = decoder.read_var_string() - state = None if not state_str else json.loads(state_str) - if state is not None: - states.append(state) - self._update_states(client_id, clock, state, states_changes) - - changes = { - "changes": states_changes, - "origin": origin, - } - - # Do not trigger the callbacks if it is only a keep alive update - if ( - states_changes["added"] - or states_changes["filtered_updated"] - or states_changes["removed"] - ): - for callback in self._subscriptions.values(): - callback(changes) - - return changes - - def get_local_state(self) -> dict[str, Any]: - """ - Returns: - The local state. - """ - return self._states.get(self.client_id, {}) - - def set_local_state(self, state: dict[str, Any]) -> dict[str, Any]: + def set_local_state(self, state: dict[str, Any] | None) -> None: """ Updates the local state and meta, and sends the changes to subscribers. Args: - state: The new local state. - - Returns: - A dictionary with the changes and the origin (="local"). + state: The new local state, if any. """ - clock = self._meta.get(self.client_id, {}).get("clock", 0) + 1 - states_changes: dict[str, list[int]] = { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [], - } - self._update_states(self.client_id, clock, state, states_changes) - - changes = { - "changes": states_changes, - "origin": "local", - } - - if ( - states_changes["added"] - or states_changes["filtered_updated"] - or states_changes["removed"] - ): + client_id = self.client_id + curr_local_meta = self._meta.get(client_id) + clock = 0 if curr_local_meta is None else curr_local_meta["clock"] + 1 + prev_state = self._states.get(client_id) + if state is None: + del self._states[client_id] + else: + self._states[client_id] = state + timestamp = int(time.time() * 1000) + self._meta[client_id] = {"clock": clock, "lastUpdated": timestamp} + added = [] + updated = [] + filtered_updated = [] + removed = [] + if state is None: + removed.append(client_id) + elif prev_state is None: + if state is not None: + added.append(client_id) + else: + updated.append(client_id) + if prev_state != state: + filtered_updated.append(client_id) + if added or filtered_updated or removed: for callback in self._subscriptions.values(): - callback(changes) - - return changes + callback( + "change", + ({"added": added, "updated": filtered_updated, "removed": removed}, "local"), + ) + for callback in self._subscriptions.values(): + callback("update", ({"added": added, "updated": updated, "removed": removed}, "local")) - def set_local_state_field(self, field: str, value: Any) -> dict[str, Any]: + def set_local_state_field(self, field: str, value: Any) -> None: """ - Sets a local state field, and optionally returns the encoded new state. + Sets a local state field. Args: field: The field of the local state to set. value: The value associated with the field. - - Returns: - A dictionary with the changes and the origin (="local"). """ - current_state = self.get_local_state() - current_state[field] = value - return self.set_local_state(current_state) + state = self.get_local_state() + if state is not None: + state[field] = value + self.set_local_state(state) - def encode_awareness_update(self, client_ids: list[int]) -> bytes | None: + def encode_awareness_update(self, client_ids: list[int]) -> bytes: """ - Encode the states of the client ids. + Creates an encoded awareness update of the clients given by their IDs. Args: - client_ids: The list of clients' state to update. + client_ids: The list of client IDs for which to create an update. Returns: - The encoded clients' state. + The encoded awareness update. """ - messages = [] + encoder = Encoder() + encoder.write_var_uint(len(client_ids)) for client_id in client_ids: - if client_id not in self._states: - continue - state = self._states[client_id] - meta = self._meta[client_id] - update = json.dumps(state, separators=(",", ":")).encode() - client_msg = [update] - client_msg.insert(0, write_var_uint(len(update))) - client_msg.insert(0, write_var_uint(meta.get("clock", 0))) - client_msg.insert(0, write_var_uint(client_id)) - messages.append(b"".join(client_msg)) + state = self._states.get(client_id) + clock = cast(int, self._meta.get(client_id, {}).get("clock")) + encoder.write_var_uint(client_id) + encoder.write_var_uint(clock) + encoder.write_var_string(json.dumps(state, separators=(",", ":"))) + return encoder.to_bytes() - if not messages: - return None - - messages.insert(0, write_var_uint(len(client_ids))) - encoded_messages = b"".join(messages) + def apply_awareness_update(self, update: bytes, origin: Any) -> None: + """ + Applies the binary update and notifies subscribers with changes. - message = [ - write_var_uint(YMessageType.AWARENESS), - write_var_uint(len(encoded_messages)), - encoded_messages, - ] - return b"".join(message) + Args: + update: The binary update. + origin: The origin of the update. + """ + decoder = Decoder(update) + timestamp = int(time.time() * 1000) + added = [] + updated = [] + filtered_updated = [] + removed = [] + length = decoder.read_var_uint() + for _ in range(length): + client_id = decoder.read_var_uint() + clock = decoder.read_var_uint() + state_str = decoder.read_var_string() + state = None if not state_str else json.loads(state_str) + client_meta = self._meta.get(client_id) + prev_state = self._states.get(client_id) + curr_clock = 0 if client_meta is None else client_meta["clock"] + if curr_clock < clock or ( + curr_clock == clock and state is None and client_id in self._states + ): + if state is None: + # Never let a remote client remove this local state. + if client_id == self.client_id and self.get_local_state() is not None: + # Remote client removed the local state. Do not remove state. + # Broadcast a message indicating that this client still exists by increasing + # the clock. + clock += 1 + else: + if client_id in self._states: + del self._states[client_id] + else: + self._states[client_id] = state + self._meta[client_id] = { + "clock": clock, + "lastUpdated": timestamp, + } + if client_meta is None and state is not None: + added.append(client_id) + elif client_meta is not None and state is None: + removed.append(client_id) + elif state is not None: + if state != prev_state: + filtered_updated.append(client_id) + updated.append(client_id) + if added or filtered_updated or removed: + for callback in self._subscriptions.values(): + callback( + "change", + ({"added": added, "updated": filtered_updated, "removed": removed}, origin), + ) + if added or updated or removed: + for callback in self._subscriptions.values(): + callback( + "update", ({"added": added, "updated": updated, "removed": removed}, origin) + ) - def observe(self, callback: Callable[[dict[str, Any]], None]) -> str: + def observe(self, callback: Callable[[str, tuple[dict[str, Any], Any]], None]) -> str: """ Registers the given callback to awareness changes. @@ -217,43 +199,3 @@ def unobserve(self, id: str) -> None: id: The subscription ID to unregister. """ del self._subscriptions[id] - - def _update_states( - self, client_id: int, clock: int, state: Any, states_changes: dict[str, list[int]] - ) -> None: - """ - Update the states of the clients, and the states_changes dictionary. - - Args: - client_id: The client's state to update. - clock: The clock of this client. - state: The updated state. - states_changes: The changes to updates. - """ - timestamp = int(time.time() * 1000) - client_meta = self._meta.get(client_id) - prev_state = self._states.get(client_id) - curr_clock = 0 if client_meta is None else client_meta["clock"] - if curr_clock < clock or ( - curr_clock == clock and state is None and client_id in self._states - ): - if state is None: - if client_id == self.client_id and self._states.get(client_id) is not None: - clock += 1 - else: - if client_id in self._states: - del self._states[client_id] - else: - self._states[client_id] = state - self._meta[client_id] = { - "clock": clock, - "last_updated": timestamp, - } - if client_meta is None and state is not None: - states_changes["added"].append(client_id) - elif client_meta is not None and state is None: - states_changes["removed"].append(client_id) - elif state is not None: - if state != prev_state: - states_changes["filtered_updated"].append(client_id) - states_changes["updated"].append(client_id) diff --git a/python/pycrdt/_sync.py b/python/pycrdt/_sync.py index c9b78fe..837703a 100644 --- a/python/pycrdt/_sync.py +++ b/python/pycrdt/_sync.py @@ -110,6 +110,43 @@ def create_update_message(data: bytes) -> bytes: return create_message(data, YSyncMessageType.SYNC_UPDATE) +class Encoder: + """ + An encoder capable of writing messages to a binary stream. + """ + + stream: list[bytes] + + def __init__(self) -> None: + self.stream = [] + + def write_var_uint(self, num: int) -> None: + """ + Encodes a number. + + Args: + num: The number to encode. + """ + self.stream.append(write_var_uint(num)) + + def write_var_string(self, text: str) -> None: + """ + Encodes a string. + + Args: + text: The string to encode. + """ + self.stream.append(write_var_uint(len(text))) + self.stream.append(text.encode()) + + def to_bytes(self) -> bytes: + """ + Returns: + The binary stream. + """ + return b"".join(self.stream) + + class Decoder: """ A decoder capable of reading messages from a byte stream. diff --git a/tests/test_awareness.py b/tests/test_awareness.py index d56e068..5a5c1b5 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -2,7 +2,8 @@ from copy import deepcopy from uuid import uuid4 -from pycrdt import Awareness, Doc, write_var_uint +import pytest +from pycrdt import Awareness, Doc, Encoder TEST_USER = {"username": str(uuid4()), "name": "Test user"} REMOTE_CLIENT_ID = 853790970 @@ -18,17 +19,17 @@ } -def create_bytes_message(client_id, user, clock=1) -> bytes: - if type(user) is str: - new_user_bytes = user.encode("utf-8") +def create_awareness_update(client_id, user, clock=1) -> bytes: + if isinstance(user, str): + new_user_str = user else: - new_user_bytes = json.dumps(user, separators=(",", ":")).encode("utf-8") - msg = write_var_uint(len(new_user_bytes)) + new_user_bytes - msg = write_var_uint(clock) + msg - msg = write_var_uint(client_id) + msg - msg = write_var_uint(1) + msg - msg = write_var_uint(len(msg)) + msg - return msg + new_user_str = json.dumps(user, separators=(",", ":")) + encoder = Encoder() + encoder.write_var_uint(1) + encoder.write_var_uint(client_id) + encoder.write_var_uint(clock) + encoder.write_var_string(new_user_str) + return encoder.to_bytes() def test_awareness_get_local_state(): @@ -38,6 +39,30 @@ def test_awareness_get_local_state(): assert awareness.get_local_state() == {} +def test_awareness_set_local_state(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + + awareness.set_local_state({"foo": "bar"}) + assert awareness.get_local_state() == {"foo": "bar"} + + awareness.set_local_state(None) + assert awareness.get_local_state() is None + + assert changes == [ + ("change", ({"added": [], "updated": [awareness.client_id], "removed": []}, "local")), + ("update", ({"added": [], "updated": [awareness.client_id], "removed": []}, "local")), + ("change", ({"added": [], "updated": [], "removed": [awareness.client_id]}, "local")), + ("update", ({"added": [], "updated": [], "removed": [awareness.client_id]}, "local")), + ] + + def test_awareness_set_local_state_field(): ydoc = Doc() awareness = Awareness(ydoc) @@ -49,22 +74,42 @@ def test_awareness_set_local_state_field(): def test_awareness_add_user(): ydoc = Doc() awareness = Awareness(ydoc) + changes = [] - changes = awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), "custom_origin", ) - assert changes == { - "changes": { - "added": [REMOTE_CLIENT_ID], - "updated": [], - "filtered_updated": [], - "removed": [], - }, - "origin": "custom_origin", - } + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "removed": [], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "removed": [], + }, + "custom_origin", + ), + ) assert awareness.states == { REMOTE_CLIENT_ID: REMOTE_USER, + ydoc.client_id: {}, } @@ -74,29 +119,50 @@ def test_awareness_update_user(): # Add a remote user. awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), "custom_origin", ) - # Update it + # Update it. remote_user = deepcopy(REMOTE_USER) remote_user["user"]["name"] = "New user name" - changes = awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2), + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, remote_user, 2), "custom_origin", ) - assert changes == { - "changes": { - "added": [], - "updated": [REMOTE_CLIENT_ID], - "filtered_updated": [REMOTE_CLIENT_ID], - "removed": [], - }, - "origin": "custom_origin", - } + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "removed": [], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "removed": [], + }, + "custom_origin", + ), + ) assert awareness.states == { REMOTE_CLIENT_ID: remote_user, + ydoc.client_id: {}, } @@ -106,55 +172,84 @@ def test_awareness_remove_user(): # Add a remote user. awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), "custom_origin", ) - # Remove it - changes = awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, "null", 2), - "custom_origin", - ) - - assert changes == { - "changes": { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [REMOTE_CLIENT_ID], - }, - "origin": "custom_origin", - } - assert awareness.states == {} + # Remove it. + changes = [] + def callback(topic, event): + changes.append((topic, event)) -def test_awareness_do_not_increment_clock(): - ydoc = Doc() - awareness = Awareness(ydoc) - changes = awareness.apply_awareness_update( - create_bytes_message(awareness.client_id, "null"), + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, "null", 2), "custom_origin", ) - assert changes == { - "changes": { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [], - }, - "origin": "custom_origin", - } - assert awareness.meta.get(awareness.client_id, {}).get("clock") == 1 + + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [], + "removed": [REMOTE_CLIENT_ID], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [], + "removed": [REMOTE_CLIENT_ID], + }, + "custom_origin", + ), + ) + assert awareness.states == {ydoc.client_id: {}} def test_awareness_increment_clock(): ydoc = Doc() awareness = Awareness(ydoc) - awareness.set_local_state_field("new_field", "new_value") + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) awareness.apply_awareness_update( - create_bytes_message(awareness.client_id, "null"), + create_awareness_update(awareness.client_id, "null"), "custom_origin", ) + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [], + "removed": [awareness.client_id], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [], + "removed": [awareness.client_id], + }, + "custom_origin", + ), + ) assert awareness.meta.get(awareness.client_id, {}).get("clock") == 2 @@ -162,64 +257,72 @@ def test_awareness_observes(): ydoc = Doc() awareness = Awareness(ydoc) - called_1 = {} - called_2 = {} + changes1 = [] + changes2 = [] - def callback_1(value): - called_1.update(value) + def callback1(topic, value): + changes1.append((topic, value)) - def callback_2(value): - called_2.update(value) + def callback2(topic, value): + changes2.append((topic, value)) - awareness.observe(callback_1) - sub_2 = awareness.observe(callback_2) - changes = awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), + awareness.observe(callback1) + sub2 = awareness.observe(callback2) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), "custom_origin", ) - assert called_1 == changes - assert called_2 == changes - - keys = list(called_1.keys()) - for k in keys: - del called_1[k] + assert len(changes1) == 2 + assert len(changes2) == 2 - keys = list(called_2.keys()) - for k in keys: - del called_2[k] + changes1.clear() + changes2.clear() - awareness.unobserve(sub_2) - changes = awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, "null"), + awareness.unobserve(sub2) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, "null"), "custom_origin", ) - assert called_1 == changes - assert called_2 != changes - assert called_2 == {} + assert len(changes1) == 2 + assert len(changes2) == 0 def test_awareness_observes_local_change(): ydoc = Doc() awareness = Awareness(ydoc) + changes = [] - called_1 = {} + def callback(topic, value): + changes.append((topic, value)) - def callback_1(value): - called_1.update(value) - - awareness.observe(callback_1) - changes = awareness.set_local_state_field("new_field", "new_value") - assert changes["origin"] == "local" - assert called_1 == changes + awareness.observe(callback) + awareness.set_local_state_field("new_field", "new_value") + assert len(changes) == 1 + assert changes[0] == ( + "update", + ( + { + "added": [], + "removed": [], + "updated": [ydoc.client_id], + }, + "local", + ), + ) def test_awareness_encode(): ydoc = Doc() awareness = Awareness(ydoc) + changes = [] + + def callback(topic, value): + changes.append((topic, value)) - changes = awareness.set_local_state_field("new_field", "new_value") - states_bytes = awareness.encode_awareness_update(changes["changes"]["added"]) - assert states_bytes[1:] == create_bytes_message( + awareness.observe(callback) + awareness.set_local_state_field("new_field", "new_value") + awareness_update = awareness.encode_awareness_update(changes[0][1][0]["updated"]) + assert awareness_update == create_awareness_update( awareness.client_id, awareness.get_local_state() ) @@ -228,72 +331,5 @@ def test_awareness_encode_wrong_id(): ydoc = Doc() awareness = Awareness(ydoc) - states_bytes = awareness.encode_awareness_update([10]) - assert states_bytes is None - - -def test_awareness_deprecated_add_user(): - ydoc = Doc() - awareness = Awareness(ydoc) - - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER)) - assert changes == { - "added": [REMOTE_CLIENT_ID], - "updated": [], - "filtered_updated": [], - "removed": [], - "states": [REMOTE_USER], - } - assert awareness.states == { - REMOTE_CLIENT_ID: REMOTE_USER, - } - - -def test_awareness_deprecated_update_user(): - ydoc = Doc() - awareness = Awareness(ydoc) - - # Add a remote user. - awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), - "custom_origin", - ) - - # Update it - remote_user = deepcopy(REMOTE_USER) - remote_user["user"]["name"] = "New user name" - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, remote_user, 2)) - - assert changes == { - "added": [], - "updated": [REMOTE_CLIENT_ID], - "filtered_updated": [REMOTE_CLIENT_ID], - "removed": [], - "states": [remote_user], - } - assert awareness.states == { - REMOTE_CLIENT_ID: remote_user, - } - - -def test_awareness_deprecated_remove_user(): - ydoc = Doc() - awareness = Awareness(ydoc) - - # Add a remote user. - awareness.apply_awareness_update( - create_bytes_message(REMOTE_CLIENT_ID, REMOTE_USER), - "custom_origin", - ) - - # Remove it - changes = awareness.get_changes(create_bytes_message(REMOTE_CLIENT_ID, "null", 2)) - - assert changes == { - "added": [], - "updated": [], - "filtered_updated": [], - "removed": [REMOTE_CLIENT_ID], - "states": [], - } - assert awareness.states == {} + with pytest.raises(TypeError): + awareness.encode_awareness_update([10]) From dc4468dcb1c9e10283858a7b96712fd8a175c045 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Tue, 8 Oct 2024 15:36:25 +0200 Subject: [PATCH 20/21] Check if state is set before deleting --- python/pycrdt/_awareness.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 121da73..f336d5a 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -55,7 +55,8 @@ def set_local_state(self, state: dict[str, Any] | None) -> None: clock = 0 if curr_local_meta is None else curr_local_meta["clock"] + 1 prev_state = self._states.get(client_id) if state is None: - del self._states[client_id] + if client_id in self._states: + del self._states[client_id] else: self._states[client_id] = state timestamp = int(time.time() * 1000) From b02eb3bef1fa35b9b0c01aedf720194cd297bd6b Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 9 Oct 2024 09:49:02 +0200 Subject: [PATCH 21/21] Add create_awareness_message() and write_message() --- docs/api_reference.md | 2 ++ python/pycrdt/__init__.py | 2 ++ python/pycrdt/_sync.py | 34 ++++++++++++++++++++++++++++++---- tests/test_awareness.py | 6 ++++-- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index d570aff..2050c17 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -26,6 +26,7 @@ - UndoManager - YMessageType - YSyncMessageType + - create_awareness_message - create_sync_message - create_update_message - handle_sync_message @@ -33,4 +34,5 @@ - get_update - merge_updates - read_message + - write_message - write_var_uint diff --git a/python/pycrdt/__init__.py b/python/pycrdt/__init__.py index fdfb790..e3139e3 100644 --- a/python/pycrdt/__init__.py +++ b/python/pycrdt/__init__.py @@ -12,10 +12,12 @@ from ._sync import Encoder as Encoder from ._sync import YMessageType as YMessageType from ._sync import YSyncMessageType as YSyncMessageType +from ._sync import create_awareness_message as create_awareness_message from ._sync import create_sync_message as create_sync_message from ._sync import create_update_message as create_update_message from ._sync import handle_sync_message as handle_sync_message from ._sync import read_message as read_message +from ._sync import write_message as write_message from ._sync import write_var_uint as write_var_uint from ._text import Text as Text from ._text import TextEvent as TextEvent diff --git a/python/pycrdt/_sync.py b/python/pycrdt/_sync.py index 837703a..79939ef 100644 --- a/python/pycrdt/_sync.py +++ b/python/pycrdt/_sync.py @@ -54,18 +54,31 @@ def write_var_uint(num: int) -> bytes: return bytes(res) +def create_awareness_message(data: bytes) -> bytes: + """ + Creates an [AWARENESS][pycrdt.YMessageType] message. + + Args: + data: The data to send in the message. + + Returns: + The [AWARENESS][pycrdt.YMessageType] message. + """ + return bytes([YMessageType.AWARENESS]) + write_message(data) + + def create_message(data: bytes, msg_type: int) -> bytes: """ - Creates a binary Y message. + Creates a SYNC message. Args: data: The data to send in the message. - msg_type: The [message type][pycrdt.YSyncMessageType]. + msg_type: The [SYNC message type][pycrdt.YSyncMessageType]. Returns: - The binary Y message. + The SYNC message. """ - return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data + return bytes([YMessageType.SYNC, msg_type]) + write_message(data) def create_sync_step1_message(data: bytes) -> bytes: @@ -242,6 +255,19 @@ def read_message(stream: bytes) -> bytes: return message +def write_message(stream: bytes) -> bytes: + """ + Writes a stream in a message. + + Args: + stream: The byte stream to write in a message. + + Returns: + The message containing the stream. + """ + return write_var_uint(len(stream)) + stream + + def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None: """ Processes a [synchronization message][pycrdt.YSyncMessageType] on a document. diff --git a/tests/test_awareness.py b/tests/test_awareness.py index 5a5c1b5..c52aeae 100644 --- a/tests/test_awareness.py +++ b/tests/test_awareness.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest -from pycrdt import Awareness, Doc, Encoder +from pycrdt import Awareness, Doc, Encoder, YMessageType, create_awareness_message, read_message TEST_USER = {"username": str(uuid4()), "name": "Test user"} REMOTE_CLIENT_ID = 853790970 @@ -325,11 +325,13 @@ def callback(topic, value): assert awareness_update == create_awareness_update( awareness.client_id, awareness.get_local_state() ) + message = create_awareness_message(awareness_update) + assert message[0] == YMessageType.AWARENESS + assert read_message(message[1:]) == awareness_update def test_awareness_encode_wrong_id(): ydoc = Doc() awareness = Awareness(ydoc) - with pytest.raises(TypeError): awareness.encode_awareness_update([10])