Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix stack overflow in _PerHostRatelimiter due to synchronous requests #14812

Merged
merged 6 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14812.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early.
1 change: 1 addition & 0 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
hs.get_reactor(),
hs.get_clock(),
FederationRatelimitSettings(
# Time window of 2s
Expand Down
1 change: 1 addition & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def get_replication_streams(self) -> Dict[str, Stream]:
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(
self.get_reactor(),
self.get_clock(),
config=self.config.ratelimiting.rc_federation,
metrics_name="federation_servlets",
Expand Down
34 changes: 25 additions & 9 deletions synapse/util/ratelimitutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from typing_extensions import ContextManager

from twisted.internet import defer
from twisted.internet.interfaces import IReactorFromThreads

from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRatelimitSettings
Expand Down Expand Up @@ -146,12 +147,14 @@ class FederationRateLimiter:

def __init__(
self,
reactor: IReactorFromThreads,
clock: Clock,
config: FederationRatelimitSettings,
metrics_name: Optional[str] = None,
):
"""
Args:
reactor
clock
config
metrics_name: The name of the rate limiter so we can differentiate it
Expand All @@ -163,7 +166,7 @@ def __init__(

def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(
clock=clock, config=config, metrics_name=metrics_name
reactor=reactor, clock=clock, config=config, metrics_name=metrics_name
)

self.ratelimiters: DefaultDict[
Expand Down Expand Up @@ -194,19 +197,22 @@ def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]
class _PerHostRatelimiter:
def __init__(
self,
reactor: IReactorFromThreads,
clock: Clock,
config: FederationRatelimitSettings,
metrics_name: Optional[str] = None,
):
"""
Args:
reactor
clock
config
metrics_name: The name of the rate limiter so we can differentiate it
from the rest in the metrics. If `None`, we don't track metrics
for this rate limiter.
from the rest in the metrics
"""
self.reactor = reactor
self.clock = clock
self.metrics_name = metrics_name

Expand Down Expand Up @@ -364,12 +370,22 @@ def on_both(r: object) -> object:

def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
_, deferred = self.ready_request_queue.popitem(last=False)

with PreserveLoggingContext():
deferred.callback(None)
except KeyError:
pass
# When requests complete synchronously, we will recursively start the next
# request in the queue. To avoid stack exhaustion, we defer starting the next
# request until the next reactor tick.

def start_next_request() -> None:
# We only remove the completed request from the list when we're about to
# start the next one, otherwise we can admit extra requests.
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
_, deferred = self.ready_request_queue.popitem(last=False)

with PreserveLoggingContext():
deferred.callback(None)
except KeyError:
pass

self.reactor.callFromThread(start_next_request)
3 changes: 2 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
IProtocol,
IPullProducer,
IPushProducer,
IReactorFromThreads,
IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple,
Expand Down Expand Up @@ -401,7 +402,7 @@ def make_request(
return channel


@implementer(IReactorPluggableNameResolver)
@implementer(IReactorPluggableNameResolver, IReactorFromThreads)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
Expand Down
45 changes: 42 additions & 3 deletions tests/util/test_ratelimitutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Optional

from twisted.internet import defer
from twisted.internet.defer import Deferred

from synapse.config.homeserver import HomeServerConfig
Expand All @@ -29,7 +30,7 @@ def test_ratelimit(self) -> None:
"""A simple test with the default values"""
reactor, clock = get_clock()
rc_config = build_rc_config()
ratelimiter = FederationRateLimiter(clock, rc_config)
ratelimiter = FederationRateLimiter(reactor, clock, rc_config)

with ratelimiter.ratelimit("testhost") as d1:
# shouldn't block
Expand All @@ -39,7 +40,7 @@ def test_concurrent_limit(self) -> None:
"""Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
ratelimiter = FederationRateLimiter(clock, rc_config)
ratelimiter = FederationRateLimiter(reactor, clock, rc_config)

with ratelimiter.ratelimit("testhost") as d1:
# shouldn't block
Expand All @@ -57,6 +58,7 @@ def test_concurrent_limit(self) -> None:

# ... until we complete an earlier request
cm2.__exit__(None, None, None)
reactor.advance(0.0)
self.successResultOf(d3)

def test_sleep_limit(self) -> None:
Expand All @@ -65,7 +67,7 @@ def test_sleep_limit(self) -> None:
rc_config = build_rc_config(
{"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}}
)
ratelimiter = FederationRateLimiter(clock, rc_config)
ratelimiter = FederationRateLimiter(reactor, clock, rc_config)

with ratelimiter.ratelimit("testhost") as d1:
# shouldn't block
Expand All @@ -81,6 +83,43 @@ def test_sleep_limit(self) -> None:
sleep_time = _await_resolution(reactor, d3)
self.assertAlmostEqual(sleep_time, 500, places=3)

def test_lots_of_queued_things(self) -> None:
"""Tests lots of synchronous things queued up behind a slow thing.

The stack should *not* explode when the slow thing completes.
"""
reactor, clock = get_clock()
rc_config = build_rc_config(
{
"rc_federation": {
"sleep_limit": 1000000000, # never sleep
"reject_limit": 1000000000, # never reject requests
"concurrent": 1,
}
}
)
ratelimiter = FederationRateLimiter(reactor, clock, rc_config)

with ratelimiter.ratelimit("testhost") as d:
# shouldn't block
self.successResultOf(d)

async def task() -> None:
with ratelimiter.ratelimit("testhost") as d:
await d

for _ in range(1, 100):
defer.ensureDeferred(task())

last_task = defer.ensureDeferred(task())

# Upon exiting the context manager, all the synchronous things will resume.
# If a stack overflow occurs, the final task will not complete.

# Wait for all the things to complete.
reactor.advance(0.0)
self.successResultOf(last_task)


def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
"""advance the clock until the deferred completes.
Expand Down