Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix race in MultiWriterIdGenerator #11045

Merged
merged 7 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11045.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race.
65 changes: 55 additions & 10 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)

import attr
from sortedcontainers import SortedSet
from sortedcontainers import SortedList, SortedSet

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import (
Expand Down Expand Up @@ -265,6 +265,15 @@ def __init__(
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()

# We also need to track when we've requested some new stream IDs but
# they haven't yet been added to the `_unfinished_ids` set. Every time
# we request a new stream ID we add the current max stream ID to the
# list, and remove it once we've added the newly allocated IDs to the
# `_unfinished_ids` set. This means that we *may* be allocated stream
# IDs above those in the list, and so we can't advance the local current
# position beyond the minimum stream ID in this list.
self._in_flight_fetches: SortedList[int] = SortedList()

# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
Expand All @@ -290,6 +299,9 @@ def __init__(
)
self._known_persisted_positions: List[int] = []

# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

# We check that the table and sequence haven't diverged.
Expand All @@ -305,6 +317,10 @@ def __init__(
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)

self._max_seen_allocated_stream_id = (
max(self._current_positions.values()) if self._current_positions else 1
)

def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
Expand Down Expand Up @@ -411,10 +427,32 @@ def _load_current_ids(
cur.close()

def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn)
stream_ids = self._load_next_mult_id_txn(txn, 1)
return stream_ids[0]

def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
# We need to track that we've requested some more stream IDs, and what
# the current max allocated stream ID is. This is to prevent a race
# where we've been allocated stream IDs but they have not yet been added
# to the `_unfinished_ids` set, allowing the current position to advance
# past them.
with self._lock:
current_max = self._max_seen_allocated_stream_id
self._in_flight_fetches.add(current_max)

try:
stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)

with self._lock:
self._unfinished_ids.update(stream_ids)
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
)
finally:
with self._lock:
self._in_flight_fetches.remove(current_max)

return stream_ids

def get_next(self) -> AsyncContextManager[int]:
"""
Expand Down Expand Up @@ -463,9 +501,6 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

next_id = self._load_next_id_txn(txn)

with self._lock:
self._unfinished_ids.add(next_id)

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)

Expand Down Expand Up @@ -524,6 +559,11 @@ def _mark_id_as_finished(self, next_id: int) -> None:
self._finished_ids.clear()

if new_cur:
# If we are currently fetching new stream IDs we need to ensure
# that we don't advance past where we
if self._in_flight_fetches:
new_cur = min(new_cur, self._in_flight_fetches[0])

curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)

Expand Down Expand Up @@ -575,6 +615,10 @@ def advance(self, instance_name: str, new_id: int) -> None:
new_id, self._current_positions.get(instance_name, 0)
)

self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, new_id
)

self._add_persisted_position(new_id)

def get_persisted_upto_position(self) -> int:
Expand Down Expand Up @@ -605,7 +649,11 @@ def _add_persisted_position(self, new_id: int) -> None:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
if our_current_position and not self._unfinished_ids:
if (
our_current_position
and not self._unfinished_ids
and not self._in_flight_fetches
):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
Expand Down Expand Up @@ -697,9 +745,6 @@ async def __aenter__(self) -> Union[int, List[int]]:
db_autocommit=True,
)

with self.id_gen._lock:
self.id_gen._unfinished_ids.update(self.stream_ids)

if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
Expand Down