Skip to content

Commit

Permalink
Add awareness features to handle server state (#170)
Browse files Browse the repository at this point in the history
* Move the Awareness from pycrdt_websocket to pycrdt project, and add some features to it

* Add tests on awareness

* use google style docstring

* Generate the message in test for clarity

* Add docstring and tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove the unused logger

* Remove typing from test

* Apply suggestions from code review

Co-authored-by: David Brochart <[email protected]>

* Add missing docstring

* Apply suggestions from code review

Co-authored-by: David Brochart <[email protected]>

* Remove the default user in the awareness

* Remove totally the conept of user in the awareness

* Add subscription id

* update docstring according to review

* Remove on_change callback

* Observe both local and remote changes, and add a function to encode the changes

* mypy

* Mimic awareness.js

* Check if state is set before deleting

* Add create_awareness_message() and write_message()

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David Brochart <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent bbf4b1d commit fd74268
Show file tree
Hide file tree
Showing 5 changed files with 574 additions and 31 deletions.
4 changes: 4 additions & 0 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
- BaseType
- Array
- ArrayEvent
- Awareness
- Decoder
- Doc
- Encoder
- Map
- MapEvent
- NewTransaction
Expand All @@ -24,11 +26,13 @@
- UndoManager
- YMessageType
- YSyncMessageType
- create_awareness_message
- create_sync_message
- create_update_message
- handle_sync_message
- get_state
- get_update
- merge_updates
- read_message
- write_message
- write_var_uint
3 changes: 3 additions & 0 deletions python/pycrdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from ._pycrdt import Subscription as Subscription
from ._pycrdt import TransactionEvent as TransactionEvent
from ._sync import Decoder as Decoder
from ._sync import Encoder as Encoder
from ._sync import YMessageType as YMessageType
from ._sync import YSyncMessageType as YSyncMessageType
from ._sync import create_awareness_message as create_awareness_message
from ._sync import create_sync_message as create_sync_message
from ._sync import create_update_message as create_update_message
from ._sync import handle_sync_message as handle_sync_message
from ._sync import read_message as read_message
from ._sync import write_message as write_message
from ._sync import write_var_uint as write_var_uint
from ._text import Text as Text
from ._text import TextEvent as TextEvent
Expand Down
190 changes: 163 additions & 27 deletions python/pycrdt/_awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,161 @@

import json
import time
from typing import Any
from typing import Any, Callable, cast
from uuid import uuid4

from ._doc import Doc
from ._sync import Decoder, read_message
from ._sync import Decoder, Encoder


class Awareness: # pragma: no cover
class Awareness:
client_id: int
_meta: dict[int, dict[str, Any]]
_states: dict[int, dict[str, Any]]
_subscriptions: dict[str, Callable[[str, tuple[dict[str, Any], Any]], None]]

def __init__(self, ydoc: Doc):
"""
Args:
ydoc: The [Doc][pycrdt.Doc] to associate the awareness with.
"""
self.client_id = ydoc.client_id
self.meta: dict[int, dict[str, Any]] = {}
self.states: dict[int, dict[str, Any]] = {}
self._meta = {}
self._states = {}
self._subscriptions = {}
self.set_local_state({})

@property
def meta(self) -> dict[int, dict[str, Any]]:
"""The clients' metadata."""
return self._meta

@property
def states(self) -> dict[int, dict[str, Any]]:
"""The client states."""
return self._states

def get_local_state(self) -> dict[str, Any] | None:
"""
Returns:
The local state, if any.
"""
return self._states.get(self.client_id)

def set_local_state(self, state: dict[str, Any] | None) -> None:
"""
Updates the local state and meta, and sends the changes to subscribers.
Args:
state: The new local state, if any.
"""
client_id = self.client_id
curr_local_meta = self._meta.get(client_id)
clock = 0 if curr_local_meta is None else curr_local_meta["clock"] + 1
prev_state = self._states.get(client_id)
if state is None:
if client_id in self._states:
del self._states[client_id]
else:
self._states[client_id] = state
timestamp = int(time.time() * 1000)
self._meta[client_id] = {"clock": clock, "lastUpdated": timestamp}
added = []
updated = []
filtered_updated = []
removed = []
if state is None:
removed.append(client_id)
elif prev_state is None:
if state is not None:
added.append(client_id)
else:
updated.append(client_id)
if prev_state != state:
filtered_updated.append(client_id)
if added or filtered_updated or removed:
for callback in self._subscriptions.values():
callback(
"change",
({"added": added, "updated": filtered_updated, "removed": removed}, "local"),
)
for callback in self._subscriptions.values():
callback("update", ({"added": added, "updated": updated, "removed": removed}, "local"))

def get_changes(self, message: bytes) -> dict[str, Any]:
message = read_message(message)
decoder = Decoder(message)
def set_local_state_field(self, field: str, value: Any) -> None:
"""
Sets a local state field.
Args:
field: The field of the local state to set.
value: The value associated with the field.
"""
state = self.get_local_state()
if state is not None:
state[field] = value
self.set_local_state(state)

def encode_awareness_update(self, client_ids: list[int]) -> bytes:
"""
Creates an encoded awareness update of the clients given by their IDs.
Args:
client_ids: The list of client IDs for which to create an update.
Returns:
The encoded awareness update.
"""
encoder = Encoder()
encoder.write_var_uint(len(client_ids))
for client_id in client_ids:
state = self._states.get(client_id)
clock = cast(int, self._meta.get(client_id, {}).get("clock"))
encoder.write_var_uint(client_id)
encoder.write_var_uint(clock)
encoder.write_var_string(json.dumps(state, separators=(",", ":")))
return encoder.to_bytes()

def apply_awareness_update(self, update: bytes, origin: Any) -> None:
"""
Applies the binary update and notifies subscribers with changes.
Args:
update: The binary update.
origin: The origin of the update.
"""
decoder = Decoder(update)
timestamp = int(time.time() * 1000)
added = []
updated = []
filtered_updated = []
removed = []
states = []
length = decoder.read_var_uint()
for _ in range(length):
client_id = decoder.read_var_uint()
clock = decoder.read_var_uint()
state_str = decoder.read_var_string()
state = None if not state_str else json.loads(state_str)
if state is not None:
states.append(state)
client_meta = self.meta.get(client_id)
prev_state = self.states.get(client_id)
client_meta = self._meta.get(client_id)
prev_state = self._states.get(client_id)
curr_clock = 0 if client_meta is None else client_meta["clock"]
if curr_clock < clock or (
curr_clock == clock and state is None and client_id in self.states
curr_clock == clock and state is None and client_id in self._states
):
if state is None:
if client_id == self.client_id and self.states.get(client_id) is not None:
# Never let a remote client remove this local state.
if client_id == self.client_id and self.get_local_state() is not None:
# Remote client removed the local state. Do not remove state.
# Broadcast a message indicating that this client still exists by increasing
# the clock.
clock += 1
else:
if client_id in self.states:
del self.states[client_id]
if client_id in self._states:
del self._states[client_id]
else:
self.states[client_id] = state
self.meta[client_id] = {
self._states[client_id] = state
self._meta[client_id] = {
"clock": clock,
"last_updated": timestamp,
"lastUpdated": timestamp,
}
if client_meta is None and state is not None:
added.append(client_id)
Expand All @@ -57,10 +166,37 @@ def get_changes(self, message: bytes) -> dict[str, Any]:
if state != prev_state:
filtered_updated.append(client_id)
updated.append(client_id)
return {
"added": added,
"updated": updated,
"filtered_updated": filtered_updated,
"removed": removed,
"states": states,
}
if added or filtered_updated or removed:
for callback in self._subscriptions.values():
callback(
"change",
({"added": added, "updated": filtered_updated, "removed": removed}, origin),
)
if added or updated or removed:
for callback in self._subscriptions.values():
callback(
"update", ({"added": added, "updated": updated, "removed": removed}, origin)
)

def observe(self, callback: Callable[[str, tuple[dict[str, Any], Any]], None]) -> str:
"""
Registers the given callback to awareness changes.
Args:
callback: The callback to call with the awareness changes.
Returns:
The subscription ID that can be used to unobserve.
"""
id = str(uuid4())
self._subscriptions[id] = callback
return id

def unobserve(self, id: str) -> None:
"""
Unregisters the given subscription ID from awareness changes.
Args:
id: The subscription ID to unregister.
"""
del self._subscriptions[id]
71 changes: 67 additions & 4 deletions python/pycrdt/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,31 @@ def write_var_uint(num: int) -> bytes:
return bytes(res)


def create_awareness_message(data: bytes) -> bytes:
"""
Creates an [AWARENESS][pycrdt.YMessageType] message.
Args:
data: The data to send in the message.
Returns:
The [AWARENESS][pycrdt.YMessageType] message.
"""
return bytes([YMessageType.AWARENESS]) + write_message(data)


def create_message(data: bytes, msg_type: int) -> bytes:
"""
Creates a binary Y message.
Creates a SYNC message.
Args:
data: The data to send in the message.
msg_type: The [message type][pycrdt.YSyncMessageType].
msg_type: The [SYNC message type][pycrdt.YSyncMessageType].
Returns:
The binary Y message.
The SYNC message.
"""
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data
return bytes([YMessageType.SYNC, msg_type]) + write_message(data)


def create_sync_step1_message(data: bytes) -> bytes:
Expand Down Expand Up @@ -110,6 +123,43 @@ def create_update_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_UPDATE)


class Encoder:
"""
An encoder capable of writing messages to a binary stream.
"""

stream: list[bytes]

def __init__(self) -> None:
self.stream = []

def write_var_uint(self, num: int) -> None:
"""
Encodes a number.
Args:
num: The number to encode.
"""
self.stream.append(write_var_uint(num))

def write_var_string(self, text: str) -> None:
"""
Encodes a string.
Args:
text: The string to encode.
"""
self.stream.append(write_var_uint(len(text)))
self.stream.append(text.encode())

def to_bytes(self) -> bytes:
"""
Returns:
The binary stream.
"""
return b"".join(self.stream)


class Decoder:
"""
A decoder capable of reading messages from a byte stream.
Expand Down Expand Up @@ -205,6 +255,19 @@ def read_message(stream: bytes) -> bytes:
return message


def write_message(stream: bytes) -> bytes:
"""
Writes a stream in a message.
Args:
stream: The byte stream to write in a message.
Returns:
The message containing the stream.
"""
return write_var_uint(len(stream)) + stream


def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None:
"""
Processes a [synchronization message][pycrdt.YSyncMessageType] on a document.
Expand Down
Loading

0 comments on commit fd74268

Please sign in to comment.