Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Optimise queueing of inbound replication commands
Browse files Browse the repository at this point in the history
When we get behind on replication, we tend to stack up background processes
behind a linearizer. Bg processes are heavy (particularly with respect to
prometheus metrics) and linearizers aren't terribly efficient once the queue
gets long either.

A better approach is to maintain a queue of requests to be processed, and
nominate a single process to work its way through the queue.

Fixes: #7444
  • Loading branch information
richvdh committed Jul 15, 2020
1 parent 8c7d0f1 commit 1bc4e93
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 116 deletions.
1 change: 1 addition & 0 deletions changelog.d/7861.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimise queueing of inbound replication commands.
328 changes: 212 additions & 116 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from typing import (
Any,
Deque,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

from prometheus_client import Counter

Expand Down Expand Up @@ -43,7 +55,6 @@
FederationStream,
Stream,
)
from synapse.util.async_helpers import Linearizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,10 +118,6 @@ def __init__(self, hs):

self._streams_to_replicate.append(stream)

self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)

# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
Expand All @@ -122,17 +129,42 @@ def __init__(self, hs):
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]

# For each connection, the incoming stream names that are coming from
# that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self._connections),
)

# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.

# the streams which are currently being processed by _unsafe_process_stream
self._processing_streams = set() # type: Set[str]

# for each stream, a queue of commands that are awaiting processing, and the
# connection that they arrived on.
self._command_queues_by_stream = {
stream_name: Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
]()
for stream_name in self._streams
}

# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]

LaterGauge(
"synapse_replication_tcp_commmand_queue",
"Number of inbound RDATA/POSITION commands queued for processing",
["stream_name"],
lambda: {
stream_name: (len(queue),)
for stream_name, queue in self._command_queues_by_stream.items()
},
)

self._is_master = hs.config.worker_app is None

self._federation_sender = None
Expand All @@ -143,6 +175,64 @@ def __init__(self, hs):
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()

async def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Adds the given command to the per-stream queue, and processes the queue if
necessary
"""
stream_name = cmd.stream_name
queue = self._command_queues_by_stream.get(stream_name)
if queue is None:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return

# if we're already processing this stream, stick the new command in the
# queue, and we're done.
if stream_name in self._processing_streams:
queue.append(cmd, conn)
return

# otherwise, process the new command.

# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.

self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)

# now process any other commands that have built up while we were
# dealing with that one.
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)

finally:
self._processing_streams.discard(stream_name)

async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
await self._process_position(stream_name, conn, cmd)
elif isinstance(cmd, RdataCommand):
await self._process_rdata(stream_name, conn, cmd)
else:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)

def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
Expand Down Expand Up @@ -276,63 +366,71 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()

try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise

# We linearize here for two reasons:
# We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return

if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)

stream = self._streams.get(stream_name)
if not stream:
logger.error("Got RDATA for unknown stream: %s", stream_name)
return

# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)

# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)

await self._add_command_to_stream_queue(conn, cmd)

async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
Called after the command has been popped off the queue of inbound commands
"""
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception as e:
raise Exception(
"Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
) from e

# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return

if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
return

# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)

stream = self._streams[stream_name]

# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)

# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)

async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
Expand All @@ -358,67 +456,65 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):

logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())

stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", stream_name)
return
await self._add_command_to_stream_queue(conn, cmd)

# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)

# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])

# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)

# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
# TODO: add some tests for this
Called after the command has been popped off the queue of inbound commands
"""
stream = self._streams[stream_name]

# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)

for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])

logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)

# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)

self._streams_by_connection.setdefault(conn, set()).add(stream_name)
# TODO: add some tests for this

# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.

for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)

logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)

# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)

self._streams_by_connection.setdefault(conn, set()).add(stream_name)

async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
Expand Down

0 comments on commit 1bc4e93

Please sign in to comment.