diff --git a/docs/api_reference.md b/docs/api_reference.md index b65c05b..2050c17 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -8,8 +8,10 @@ - BaseType - Array - ArrayEvent + - Awareness - Decoder - Doc + - Encoder - Map - MapEvent - NewTransaction @@ -24,6 +26,7 @@ - UndoManager - YMessageType - YSyncMessageType + - create_awareness_message - create_sync_message - create_update_message - handle_sync_message @@ -31,4 +34,5 @@ - get_update - merge_updates - read_message + - write_message - write_var_uint diff --git a/python/pycrdt/__init__.py b/python/pycrdt/__init__.py index 21c8e09..e3139e3 100644 --- a/python/pycrdt/__init__.py +++ b/python/pycrdt/__init__.py @@ -9,12 +9,15 @@ from ._pycrdt import Subscription as Subscription from ._pycrdt import TransactionEvent as TransactionEvent from ._sync import Decoder as Decoder +from ._sync import Encoder as Encoder from ._sync import YMessageType as YMessageType from ._sync import YSyncMessageType as YSyncMessageType +from ._sync import create_awareness_message as create_awareness_message from ._sync import create_sync_message as create_sync_message from ._sync import create_update_message as create_update_message from ._sync import handle_sync_message as handle_sync_message from ._sync import read_message as read_message +from ._sync import write_message as write_message from ._sync import write_var_uint as write_var_uint from ._text import Text as Text from ._text import TextEvent as TextEvent diff --git a/python/pycrdt/_awareness.py b/python/pycrdt/_awareness.py index 43d412a..f336d5a 100644 --- a/python/pycrdt/_awareness.py +++ b/python/pycrdt/_awareness.py @@ -2,52 +2,161 @@ import json import time -from typing import Any +from typing import Any, Callable, cast +from uuid import uuid4 from ._doc import Doc -from ._sync import Decoder, read_message +from ._sync import Decoder, Encoder -class Awareness: # pragma: no cover +class Awareness: + client_id: int + _meta: dict[int, dict[str, Any]] + _states: dict[int, dict[str, Any]] + _subscriptions: dict[str, Callable[[str, tuple[dict[str, Any], Any]], None]] + def __init__(self, ydoc: Doc): + """ + Args: + ydoc: The [Doc][pycrdt.Doc] to associate the awareness with. + """ self.client_id = ydoc.client_id - self.meta: dict[int, dict[str, Any]] = {} - self.states: dict[int, dict[str, Any]] = {} + self._meta = {} + self._states = {} + self._subscriptions = {} + self.set_local_state({}) + + @property + def meta(self) -> dict[int, dict[str, Any]]: + """The clients' metadata.""" + return self._meta + + @property + def states(self) -> dict[int, dict[str, Any]]: + """The client states.""" + return self._states + + def get_local_state(self) -> dict[str, Any] | None: + """ + Returns: + The local state, if any. + """ + return self._states.get(self.client_id) + + def set_local_state(self, state: dict[str, Any] | None) -> None: + """ + Updates the local state and meta, and sends the changes to subscribers. + + Args: + state: The new local state, if any. + """ + client_id = self.client_id + curr_local_meta = self._meta.get(client_id) + clock = 0 if curr_local_meta is None else curr_local_meta["clock"] + 1 + prev_state = self._states.get(client_id) + if state is None: + if client_id in self._states: + del self._states[client_id] + else: + self._states[client_id] = state + timestamp = int(time.time() * 1000) + self._meta[client_id] = {"clock": clock, "lastUpdated": timestamp} + added = [] + updated = [] + filtered_updated = [] + removed = [] + if state is None: + removed.append(client_id) + elif prev_state is None: + if state is not None: + added.append(client_id) + else: + updated.append(client_id) + if prev_state != state: + filtered_updated.append(client_id) + if added or filtered_updated or removed: + for callback in self._subscriptions.values(): + callback( + "change", + ({"added": added, "updated": filtered_updated, "removed": removed}, "local"), + ) + for callback in self._subscriptions.values(): + callback("update", ({"added": added, "updated": updated, "removed": removed}, "local")) - def get_changes(self, message: bytes) -> dict[str, Any]: - message = read_message(message) - decoder = Decoder(message) + def set_local_state_field(self, field: str, value: Any) -> None: + """ + Sets a local state field. + + Args: + field: The field of the local state to set. + value: The value associated with the field. + """ + state = self.get_local_state() + if state is not None: + state[field] = value + self.set_local_state(state) + + def encode_awareness_update(self, client_ids: list[int]) -> bytes: + """ + Creates an encoded awareness update of the clients given by their IDs. + + Args: + client_ids: The list of client IDs for which to create an update. + + Returns: + The encoded awareness update. + """ + encoder = Encoder() + encoder.write_var_uint(len(client_ids)) + for client_id in client_ids: + state = self._states.get(client_id) + clock = cast(int, self._meta.get(client_id, {}).get("clock")) + encoder.write_var_uint(client_id) + encoder.write_var_uint(clock) + encoder.write_var_string(json.dumps(state, separators=(",", ":"))) + return encoder.to_bytes() + + def apply_awareness_update(self, update: bytes, origin: Any) -> None: + """ + Applies the binary update and notifies subscribers with changes. + + Args: + update: The binary update. + origin: The origin of the update. + """ + decoder = Decoder(update) timestamp = int(time.time() * 1000) added = [] updated = [] filtered_updated = [] removed = [] - states = [] length = decoder.read_var_uint() for _ in range(length): client_id = decoder.read_var_uint() clock = decoder.read_var_uint() state_str = decoder.read_var_string() state = None if not state_str else json.loads(state_str) - if state is not None: - states.append(state) - client_meta = self.meta.get(client_id) - prev_state = self.states.get(client_id) + client_meta = self._meta.get(client_id) + prev_state = self._states.get(client_id) curr_clock = 0 if client_meta is None else client_meta["clock"] if curr_clock < clock or ( - curr_clock == clock and state is None and client_id in self.states + curr_clock == clock and state is None and client_id in self._states ): if state is None: - if client_id == self.client_id and self.states.get(client_id) is not None: + # Never let a remote client remove this local state. + if client_id == self.client_id and self.get_local_state() is not None: + # Remote client removed the local state. Do not remove state. + # Broadcast a message indicating that this client still exists by increasing + # the clock. clock += 1 else: - if client_id in self.states: - del self.states[client_id] + if client_id in self._states: + del self._states[client_id] else: - self.states[client_id] = state - self.meta[client_id] = { + self._states[client_id] = state + self._meta[client_id] = { "clock": clock, - "last_updated": timestamp, + "lastUpdated": timestamp, } if client_meta is None and state is not None: added.append(client_id) @@ -57,10 +166,37 @@ def get_changes(self, message: bytes) -> dict[str, Any]: if state != prev_state: filtered_updated.append(client_id) updated.append(client_id) - return { - "added": added, - "updated": updated, - "filtered_updated": filtered_updated, - "removed": removed, - "states": states, - } + if added or filtered_updated or removed: + for callback in self._subscriptions.values(): + callback( + "change", + ({"added": added, "updated": filtered_updated, "removed": removed}, origin), + ) + if added or updated or removed: + for callback in self._subscriptions.values(): + callback( + "update", ({"added": added, "updated": updated, "removed": removed}, origin) + ) + + def observe(self, callback: Callable[[str, tuple[dict[str, Any], Any]], None]) -> str: + """ + Registers the given callback to awareness changes. + + Args: + callback: The callback to call with the awareness changes. + + Returns: + The subscription ID that can be used to unobserve. + """ + id = str(uuid4()) + self._subscriptions[id] = callback + return id + + def unobserve(self, id: str) -> None: + """ + Unregisters the given subscription ID from awareness changes. + + Args: + id: The subscription ID to unregister. + """ + del self._subscriptions[id] diff --git a/python/pycrdt/_sync.py b/python/pycrdt/_sync.py index c9b78fe..79939ef 100644 --- a/python/pycrdt/_sync.py +++ b/python/pycrdt/_sync.py @@ -54,18 +54,31 @@ def write_var_uint(num: int) -> bytes: return bytes(res) +def create_awareness_message(data: bytes) -> bytes: + """ + Creates an [AWARENESS][pycrdt.YMessageType] message. + + Args: + data: The data to send in the message. + + Returns: + The [AWARENESS][pycrdt.YMessageType] message. + """ + return bytes([YMessageType.AWARENESS]) + write_message(data) + + def create_message(data: bytes, msg_type: int) -> bytes: """ - Creates a binary Y message. + Creates a SYNC message. Args: data: The data to send in the message. - msg_type: The [message type][pycrdt.YSyncMessageType]. + msg_type: The [SYNC message type][pycrdt.YSyncMessageType]. Returns: - The binary Y message. + The SYNC message. """ - return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data + return bytes([YMessageType.SYNC, msg_type]) + write_message(data) def create_sync_step1_message(data: bytes) -> bytes: @@ -110,6 +123,43 @@ def create_update_message(data: bytes) -> bytes: return create_message(data, YSyncMessageType.SYNC_UPDATE) +class Encoder: + """ + An encoder capable of writing messages to a binary stream. + """ + + stream: list[bytes] + + def __init__(self) -> None: + self.stream = [] + + def write_var_uint(self, num: int) -> None: + """ + Encodes a number. + + Args: + num: The number to encode. + """ + self.stream.append(write_var_uint(num)) + + def write_var_string(self, text: str) -> None: + """ + Encodes a string. + + Args: + text: The string to encode. + """ + self.stream.append(write_var_uint(len(text))) + self.stream.append(text.encode()) + + def to_bytes(self) -> bytes: + """ + Returns: + The binary stream. + """ + return b"".join(self.stream) + + class Decoder: """ A decoder capable of reading messages from a byte stream. @@ -205,6 +255,19 @@ def read_message(stream: bytes) -> bytes: return message +def write_message(stream: bytes) -> bytes: + """ + Writes a stream in a message. + + Args: + stream: The byte stream to write in a message. + + Returns: + The message containing the stream. + """ + return write_var_uint(len(stream)) + stream + + def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None: """ Processes a [synchronization message][pycrdt.YSyncMessageType] on a document. diff --git a/tests/test_awareness.py b/tests/test_awareness.py new file mode 100644 index 0000000..c52aeae --- /dev/null +++ b/tests/test_awareness.py @@ -0,0 +1,337 @@ +import json +from copy import deepcopy +from uuid import uuid4 + +import pytest +from pycrdt import Awareness, Doc, Encoder, YMessageType, create_awareness_message, read_message + +TEST_USER = {"username": str(uuid4()), "name": "Test user"} +REMOTE_CLIENT_ID = 853790970 +REMOTE_USER = { + "user": { + "username": "2460ab00fd28415b87e49ec5aa2d482d", + "name": "Anonymous Ersa", + "display_name": "Anonymous Ersa", + "initials": "AE", + "avatar_url": None, + "color": "var(--jp-collaborator-color7)", + } +} + + +def create_awareness_update(client_id, user, clock=1) -> bytes: + if isinstance(user, str): + new_user_str = user + else: + new_user_str = json.dumps(user, separators=(",", ":")) + encoder = Encoder() + encoder.write_var_uint(1) + encoder.write_var_uint(client_id) + encoder.write_var_uint(clock) + encoder.write_var_string(new_user_str) + return encoder.to_bytes() + + +def test_awareness_get_local_state(): + ydoc = Doc() + awareness = Awareness(ydoc) + + assert awareness.get_local_state() == {} + + +def test_awareness_set_local_state(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + + awareness.set_local_state({"foo": "bar"}) + assert awareness.get_local_state() == {"foo": "bar"} + + awareness.set_local_state(None) + assert awareness.get_local_state() is None + + assert changes == [ + ("change", ({"added": [], "updated": [awareness.client_id], "removed": []}, "local")), + ("update", ({"added": [], "updated": [awareness.client_id], "removed": []}, "local")), + ("change", ({"added": [], "updated": [], "removed": [awareness.client_id]}, "local")), + ("update", ({"added": [], "updated": [], "removed": [awareness.client_id]}, "local")), + ] + + +def test_awareness_set_local_state_field(): + ydoc = Doc() + awareness = Awareness(ydoc) + + awareness.set_local_state_field("new_field", "new_value") + assert awareness.get_local_state() == {"new_field": "new_value"} + + +def test_awareness_add_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "removed": [], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [REMOTE_CLIENT_ID], + "updated": [], + "removed": [], + }, + "custom_origin", + ), + ) + assert awareness.states == { + REMOTE_CLIENT_ID: REMOTE_USER, + ydoc.client_id: {}, + } + + +def test_awareness_update_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + + # Update it. + remote_user = deepcopy(REMOTE_USER) + remote_user["user"]["name"] = "New user name" + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, remote_user, 2), + "custom_origin", + ) + + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "removed": [], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [REMOTE_CLIENT_ID], + "removed": [], + }, + "custom_origin", + ), + ) + assert awareness.states == { + REMOTE_CLIENT_ID: remote_user, + ydoc.client_id: {}, + } + + +def test_awareness_remove_user(): + ydoc = Doc() + awareness = Awareness(ydoc) + + # Add a remote user. + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + + # Remove it. + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, "null", 2), + "custom_origin", + ) + + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [], + "removed": [REMOTE_CLIENT_ID], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [], + "removed": [REMOTE_CLIENT_ID], + }, + "custom_origin", + ), + ) + assert awareness.states == {ydoc.client_id: {}} + + +def test_awareness_increment_clock(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, event): + changes.append((topic, event)) + + awareness.observe(callback) + awareness.apply_awareness_update( + create_awareness_update(awareness.client_id, "null"), + "custom_origin", + ) + assert len(changes) == 2 + assert changes[0] == ( + "change", + ( + { + "added": [], + "updated": [], + "removed": [awareness.client_id], + }, + "custom_origin", + ), + ) + assert changes[1] == ( + "update", + ( + { + "added": [], + "updated": [], + "removed": [awareness.client_id], + }, + "custom_origin", + ), + ) + assert awareness.meta.get(awareness.client_id, {}).get("clock") == 2 + + +def test_awareness_observes(): + ydoc = Doc() + awareness = Awareness(ydoc) + + changes1 = [] + changes2 = [] + + def callback1(topic, value): + changes1.append((topic, value)) + + def callback2(topic, value): + changes2.append((topic, value)) + + awareness.observe(callback1) + sub2 = awareness.observe(callback2) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, REMOTE_USER), + "custom_origin", + ) + assert len(changes1) == 2 + assert len(changes2) == 2 + + changes1.clear() + changes2.clear() + + awareness.unobserve(sub2) + awareness.apply_awareness_update( + create_awareness_update(REMOTE_CLIENT_ID, "null"), + "custom_origin", + ) + assert len(changes1) == 2 + assert len(changes2) == 0 + + +def test_awareness_observes_local_change(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, value): + changes.append((topic, value)) + + awareness.observe(callback) + awareness.set_local_state_field("new_field", "new_value") + assert len(changes) == 1 + assert changes[0] == ( + "update", + ( + { + "added": [], + "removed": [], + "updated": [ydoc.client_id], + }, + "local", + ), + ) + + +def test_awareness_encode(): + ydoc = Doc() + awareness = Awareness(ydoc) + changes = [] + + def callback(topic, value): + changes.append((topic, value)) + + awareness.observe(callback) + awareness.set_local_state_field("new_field", "new_value") + awareness_update = awareness.encode_awareness_update(changes[0][1][0]["updated"]) + assert awareness_update == create_awareness_update( + awareness.client_id, awareness.get_local_state() + ) + message = create_awareness_message(awareness_update) + assert message[0] == YMessageType.AWARENESS + assert read_message(message[1:]) == awareness_update + + +def test_awareness_encode_wrong_id(): + ydoc = Doc() + awareness = Awareness(ydoc) + with pytest.raises(TypeError): + awareness.encode_awareness_update([10])