Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert some util functions to async #8035

Merged
merged 4 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8035.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
39 changes: 21 additions & 18 deletions synapse/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from functools import wraps

from prometheus_client import Counter

from twisted.internet import defer

from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge

Expand Down Expand Up @@ -62,25 +59,31 @@


def measure_func(name=None):
def wrapper(func):
block_name = func.__name__ if name is None else name
"""
Used to decorate an async function with a `Measure` context manager.

Usage:

if inspect.iscoroutinefunction(func):
@measure_func()
async def foo(...):
...

@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
Which is analogous to:

else:
async def foo(...):
with Measure(...):
...

"""

def wrapper(func):
block_name = func.__name__ if name is None else name

@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r

return measured_func

Expand Down
16 changes: 6 additions & 10 deletions synapse/util/retryutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import logging
import random

from twisted.internet import defer

import synapse.logging.context
from synapse.api.errors import CodeMessageException

Expand Down Expand Up @@ -54,8 +52,7 @@ def __init__(self, retry_last_ts, retry_interval, destination):
self.destination = destination


@defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
Expand All @@ -73,17 +70,17 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
Example usage:

try:
limiter = yield get_retry_limiter(destination, clock, store)
limiter = await get_retry_limiter(destination, clock, store)
with limiter:
response = yield do_request()
response = await do_request()
except NotRetryingDestination:
# We aren't ready to retry that destination.
raise
"""
failure_ts = None
retry_last_ts, retry_interval = (0, 0)

retry_timings = yield store.get_destination_retry_timings(destination)
retry_timings = await store.get_destination_retry_timings(destination)

if retry_timings:
failure_ts = retry_timings["failure_ts"]
Expand Down Expand Up @@ -222,10 +219,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if self.failure_ts is None:
self.failure_ts = retry_last_ts

@defer.inlineCallbacks
def store_retry_timings():
async def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
await self.store.set_destination_retry_timings(
self.destination,
self.failure_ts,
retry_last_ts,
Expand Down
44 changes: 11 additions & 33 deletions tests/util/test_retryutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

# advance the clock a bit before making the request
self.pump(1)

with limiter:
pass

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)

def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()

d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
try:
Expand All @@ -58,29 +52,22 @@ def test_limiter(self):
except AssertionError:
pass

# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)

# now if we try again we should get a failure
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
self.failureResultOf(d, NotRetryingDestination)
self.get_failure(
get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
)

#
# advance the clock and try again
#

self.pump(MIN_RETRY_INTERVAL)
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
try:
Expand All @@ -91,12 +78,7 @@ def test_limiter(self):
except AssertionError:
pass

# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
Expand All @@ -110,9 +92,7 @@ def test_limiter(self):
# one more go, with success
#
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
with limiter:
Expand All @@ -121,7 +101,5 @@ def test_limiter(self):
# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)