Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Wait for streams to catch up when processing HTTP replication. (#14820)
Browse files Browse the repository at this point in the history
This should hopefully mitigate a class of races where data gets out of
sync due a HTTP replication request racing with the replication streams.
  • Loading branch information
erikjohnston authored Jan 18, 2023
1 parent e8f2bf5 commit 9187fd9
Show file tree
Hide file tree
Showing 21 changed files with 226 additions and 145 deletions.
1 change: 1 addition & 0 deletions changelog.d/14820.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix rare races when using workers.
4 changes: 4 additions & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,10 @@ async def persist_events_and_notify(
event_and_contexts, backfilled=backfilled
)

# After persistence we always need to notify replication there may
# be new data.
self._notifier.notify_replication()

if self._ephemeral_messages_enabled:
for event in events:
# If there's an expiry timestamp on the event, schedule its expiry.
Expand Down
97 changes: 88 additions & 9 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import urllib.parse
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple

from prometheus_client import Counter, Gauge

Expand All @@ -27,6 +27,7 @@
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
Expand All @@ -53,6 +54,9 @@
)


_STREAM_POSITION_KEY = "_INT_STREAM_POS"


class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
Expand Down Expand Up @@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer.
WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
catch up before processing the request and/or response. Defaults to
True.
"""

NAME: str = abc.abstractproperty() # type: ignore
Expand All @@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)

WAIT_FOR_STREAMS: ClassVar[bool] = True

def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache(
Expand All @@ -126,6 +135,10 @@ def __init__(self, hs: "HomeServer"):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret

self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
self._replication = hs.get_replication_data_handler()
self._instance_name = hs.get_instance_name()

def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
Expand Down Expand Up @@ -160,7 +173,7 @@ async def _serialize_payload(**kwargs) -> JsonDict:

@abc.abstractmethod
async def _handle_request(
self, request: Request, **kwargs: Any
self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
Expand Down Expand Up @@ -201,6 +214,10 @@ def make_client(cls, hs: "HomeServer") -> Callable:

@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
# We have to pull these out here to avoid circular dependencies...
streams = hs.get_replication_command_handler().get_streams_to_replicate()
replication = hs.get_replication_data_handler()

with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
Expand All @@ -219,6 +236,24 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:

data = await cls._serialize_payload(**kwargs)

if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
# Include the current stream positions that we write to. We
# don't do this for GETs as they don't have a body, and we
# generally assume that a GET won't rely on data we have
# written.
if _STREAM_POSITION_KEY in data:
raise Exception(
"data to send contains %r key", _STREAM_POSITION_KEY
)

data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
for stream in streams
},
"instance_name": local_instance_name,
}

url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
Expand Down Expand Up @@ -308,6 +343,18 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
) from e

_outgoing_request_counter.labels(cls.NAME, 200).inc()

# Wait on any streams that the remote may have written to.
for stream_name, position in result.get(
_STREAM_POSITION_KEY, {}
).items():
await replication.wait_for_stream_position(
instance_name=instance_name,
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)

return result

return send_request
Expand Down Expand Up @@ -353,6 +400,23 @@ async def _check_auth_and_handle(
if self._replication_secret:
self._check_auth(request)

if self.METHOD == "GET":
# GET APIs always have an empty body.
content = {}
else:
content = parse_json_object_from_request(request)

# Wait on any streams that the remote may have written to.
for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
"streams"
].items():
await self._replication.wait_for_stream_position(
instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)

if self.CACHE:
txn_id = kwargs.pop("txn_id")

Expand All @@ -361,13 +425,28 @@ async def _check_auth_and_handle(
# correctly yet. In particular, there may be issues to do with logging
# context lifetimes.

return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
code, response = await self.response_cache.wrap(
txn_id, self._handle_request, request, content, **kwargs
)
else:
# The `@cancellable` decorator may be applied to `_handle_request`. But we
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
# so we have to set up the cancellable flag ourselves.
request.is_render_cancellable = is_function_cancellable(
self._handle_request
)

code, response = await self._handle_request(request, content, **kwargs)

# Return streams we may have written to in the course of processing this
# request.
if _STREAM_POSITION_KEY in response:
raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)

# The `@cancellable` decorator may be applied to `_handle_request`. But we
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
# so we have to set up the cancellable flag ourselves.
request.is_render_cancellable = is_function_cancellable(self._handle_request)
if self.WAIT_FOR_STREAMS:
response[_STREAM_POSITION_KEY] = {
stream.NAME: stream.current_token(self._instance_name)
for stream in self._streams
}

return await self._handle_request(request, **kwargs)
return code, response
29 changes: 16 additions & 13 deletions synapse/replication/http/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

Expand Down Expand Up @@ -61,10 +60,8 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
Expand Down Expand Up @@ -101,7 +98,7 @@ async def _serialize_payload( # type: ignore[override]
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_user(
user_id, account_data_type
Expand Down Expand Up @@ -143,10 +140,13 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
Expand Down Expand Up @@ -183,7 +183,12 @@ async def _serialize_payload( # type: ignore[override]
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
Expand Down Expand Up @@ -225,10 +230,8 @@ async def _serialize_payload( # type: ignore[override]
return payload

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
Expand Down Expand Up @@ -266,7 +269,7 @@ async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict:
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
Expand Down
10 changes: 3 additions & 7 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
Expand Down Expand Up @@ -78,7 +77,7 @@ async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

Expand Down Expand Up @@ -138,9 +137,8 @@ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[o
return {"user_ids": user_ids}

async def _handle_request( # type: ignore[override]
self, request: Request
self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
content = parse_json_object_from_request(request)
user_ids: List[str] = content["user_ids"]

logger.info("Resync for %r", user_ids)
Expand Down Expand Up @@ -205,10 +203,8 @@ async def _serialize_payload( # type: ignore[override]
}

async def _handle_request( # type: ignore[override]
self, request: Request
self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

user_id = content["user_id"]
device_id = content["device_id"]
keys = content["keys"]
Expand Down
28 changes: 9 additions & 19 deletions synapse/replication/http/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure
Expand Down Expand Up @@ -114,10 +113,8 @@ async def _serialize_payload( # type: ignore[override]

return payload

async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)

room_id = content["room_id"]
backfilled = content["backfilled"]

Expand Down Expand Up @@ -181,13 +178,10 @@ async def _serialize_payload( # type: ignore[override]
return {"origin": origin, "content": content}

async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str
self, request: Request, content: JsonDict, edu_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)

origin = content["origin"]
edu_content = content["content"]
origin = content["origin"]
edu_content = content["content"]

logger.info("Got %r edu from %s", edu_type, origin)

Expand Down Expand Up @@ -231,13 +225,10 @@ async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # ty
return {"args": args}

async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str
self, request: Request, content: JsonDict, query_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)

args = content["args"]
args["origin"] = content["origin"]
args = content["args"]
args["origin"] = content["origin"]

logger.info("Got %r query from %s", query_type, args["origin"])

Expand Down Expand Up @@ -274,7 +265,7 @@ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override
return {}

async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)

Expand Down Expand Up @@ -307,9 +298,8 @@ async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDic
return {"room_version": room_version.identifier}

async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
Expand Down
Loading

0 comments on commit 9187fd9

Please sign in to comment.