Skip to content

Commit

Permalink
Store state
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Aug 22, 2024
1 parent 69ca857 commit 5f78552
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 103 deletions.
2 changes: 2 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
Expand Down Expand Up @@ -159,6 +160,7 @@ class GenericWorkerStore(
SessionStore,
TaskSchedulerWorkerStore,
ExperimentalFeaturesStore,
SlidingSyncStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sliding_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(self, hs: "HomeServer"):
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
self.is_mine_id = hs.is_mine_id

self.connection_store = SlidingSyncConnectionStore()
self.connection_store = SlidingSyncConnectionStore(self.store)
self.extensions = SlidingSyncExtensionHandler(hs)

async def wait_for_sync_for_user(
Expand Down
129 changes: 30 additions & 99 deletions synapse/handlers/sliding_sync/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
#

import logging
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Optional

import attr

from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.logging.opentracing import trace
from synapse.storage.databases.main import DataStore
from synapse.types import SlidingSyncStreamToken
from synapse.types.handlers.sliding_sync import (
MutablePerConnectionState,
Expand Down Expand Up @@ -61,20 +61,7 @@ class SlidingSyncConnectionStore:
to mapping of room ID to `HaveSentRoom`.
"""

# `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
_connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
dict
)

async def is_valid_token(
self, sync_config: SlidingSyncConfig, connection_token: int
) -> bool:
"""Return whether the connection token is valid/recognized"""
if connection_token == 0:
return True

conn_key = self._get_connection_key(sync_config)
return connection_token in self._connections.get(conn_key, {})
store: "DataStore"

async def get_per_connection_state(
self,
Expand All @@ -86,23 +73,20 @@ async def get_per_connection_state(
Raises:
SlidingSyncUnknownPosition if the connection_token is unknown
"""
if from_token is None:
return PerConnectionState()

connection_position = from_token.connection_position
if connection_position == 0:
# Initial sync (request without a `from_token`) starts at `0` so
# there is no existing per-connection state
if from_token is None or from_token.connection_position == 0:
return PerConnectionState()

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.get(conn_key, {})
connection_state = sync_statuses.get(connection_position)
conn_id = sync_config.conn_id or ""

if connection_state is None:
raise SlidingSyncUnknownPosition()
device_id = sync_config.requester.device_id
assert device_id is not None

return connection_state
return await self.store.get_per_connection_state(
sync_config.user.to_string(),
device_id,
conn_id,
from_token.connection_position,
)

@trace
async def record_new_state(
Expand All @@ -116,85 +100,32 @@ async def record_new_state(
If there are no changes to the state this may return the same token as
the existing per-connection state.
"""
prev_connection_token = 0
if from_token is not None:
prev_connection_token = from_token.connection_position

if not new_connection_state.has_updates():
return prev_connection_token
if from_token is not None:
return from_token.connection_position
else:
return 0

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.setdefault(conn_key, {})
if from_token is not None and from_token.connection_position == 0:
from_token = None

# Generate a new token, removing any existing entries in that token
# (which can happen if requests get resent).
new_store_token = prev_connection_token + 1
sync_statuses.pop(new_store_token, None)
conn_id = sync_config.conn_id or ""

# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
# don't grow forever.
sync_statuses[new_store_token] = new_connection_state.copy()
device_id = sync_config.requester.device_id
assert device_id is not None

return new_store_token
return await self.store.persist_per_connection_state(
sync_config.user.to_string(),
device_id,
conn_id,
from_token.connection_position if from_token else None,
new_connection_state,
)

@trace
async def mark_token_seen(
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
) -> None:
"""We have received a request with the given token, so we can clear out
any other tokens associated with the connection.
If there is no from token then we have started afresh, and so we delete
all tokens associated with the device.
"""
# Clear out any tokens for the connection that doesn't match the one
# from the request.

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.pop(conn_key, {})
if from_token is None:
return

sync_statuses = {
connection_token: room_statuses
for connection_token, room_statuses in sync_statuses.items()
if connection_token == from_token.connection_position
}
if sync_statuses:
self._connections[conn_key] = sync_statuses

@staticmethod
def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
"""Return a unique identifier for this connection.
The first part is simply the user ID.
The second part is generally a combination of device ID and conn_id.
However, both these two are optional (e.g. puppet access tokens don't
have device IDs), so this handles those edge cases.
We use this over the raw `conn_id` to avoid clashes between different
clients that use the same `conn_id`. Imagine a user uses a web client
that uses `conn_id: main_sync_loop` and an Android client that also has
a `conn_id: main_sync_loop`.
"""

user_id = sync_config.user.to_string()

# Only one sliding sync connection is allowed per given conn_id (empty
# or not).
conn_id = sync_config.conn_id or ""

if sync_config.requester.device_id:
return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")

if sync_config.requester.access_token_id:
# If we don't have a device, then the access token ID should be a
# stable ID.
return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")

# If we have neither then its likely an AS or some weird token. Either
# way we can just fail here.
raise Exception("Cannot use sliding sync with access token type")
pass
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
Expand Down Expand Up @@ -156,6 +157,7 @@ class DataStore(
LockStore,
SessionStore,
TaskSchedulerWorkerStore,
SlidingSyncStore,
):
def __init__(
self,
Expand Down
Loading

0 comments on commit 5f78552

Please sign in to comment.