Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Change device list streams to have one row per ID #7010

Merged
merged 8 commits into from
Mar 19, 2020
Merged
1 change: 1 addition & 0 deletions changelog.d/7010.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change device list streams to have one row per ID.
10 changes: 7 additions & 3 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,9 @@ async def process_and_notify(self, stream_name, token, rows):
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
await self.presence_handler.process_replication_rows(token, rows)
Expand Down Expand Up @@ -774,7 +775,10 @@ def process_replication_rows(self, stream_name, token, rows):

# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
hosts = {row.destination for row in rows}
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
for host in hosts:
self.federation_sender.send_device_messages(host)

Expand Down
36 changes: 23 additions & 13 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def __init__(self, database: Database, db_conn, hs):
self.hs = hs

self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id"
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
Expand All @@ -55,23 +61,27 @@ def stream_positions(self):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)

def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token)

if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))

self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)
13 changes: 9 additions & 4 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,13 @@ class CachesStreamRow:
"network_id", # str, optional
),
)
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
)


@attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)


ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
Expand Down Expand Up @@ -363,7 +367,8 @@ def __init__(self, hs):


class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""Either a user has updated their devices or a remote server needs to be
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
told about a device update.
"""

NAME = "device_lists"
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/data_stores/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def __init__(self, database: Database, db_conn, hs):
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[("user_signature_stream", "stream_id")],
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
Expand Down
132 changes: 68 additions & 64 deletions synapse/storage/data_stores/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple

from six import iteritems

Expand All @@ -31,7 +32,7 @@
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
Expand Down Expand Up @@ -112,23 +113,13 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
if not has_changed:
return now_stream_id, []

# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
limit + 1,
limit,
)

# Return an empty list if there are no updates
Expand Down Expand Up @@ -166,14 +157,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"device_id": verify_key.version,
}

# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None

# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
Expand All @@ -192,10 +175,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break

if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
Expand All @@ -218,17 +197,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []

results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
Expand Down Expand Up @@ -607,21 +575,26 @@ def get_users_whose_signatures_changed(self, user_id, from_key):
else:
return set()

def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.

# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self.db.execute(

return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)

Expand Down Expand Up @@ -1017,29 +990,49 @@ def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
if not device_ids:
return

with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
device_ids,
stream_ids,
)

if not hosts:
return stream_ids[-1]

context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
yield self.db.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
stream_id,
stream_ids,
context,
)
return stream_id

def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
return stream_ids[-1]

def _add_device_change_to_stream_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
stream_ids: List[str],
):
txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_id
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_id,
)

min_stream_id = stream_ids[0]

# Delete older entries in the table, as we really only care about
# when the latest change happened.
Expand All @@ -1048,27 +1041,38 @@ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
[(user_id, device_id, stream_id) for device_id in device_ids],
[(user_id, device_id, min_stream_id) for device_id in device_ids],
)

self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids
for stream_id, device_id in zip(stream_ids, device_ids)
],
)

context = get_active_span_text_map()
def _add_device_outbound_poke_to_stream_txn(
self, txn, user_id, device_ids, hosts, stream_ids, context,
):
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)

now = self._clock.time_msec()
next_stream_id = iter(stream_ids)

self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
Expand Down
Loading