Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #209 from matrix-org/erikj/cached_keyword_args
Browse files Browse the repository at this point in the history
Add support for using keyword arguments with cached functions
  • Loading branch information
erikjohnston committed Aug 6, 2015
2 parents 1e62a3d + 953dbd2 commit 8049c9a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 23 deletions.
40 changes: 34 additions & 6 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from collections import namedtuple, OrderedDict

import functools
import inspect
import sys
import time
import threading
Expand Down Expand Up @@ -141,13 +142,28 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True):
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig

if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig

self.max_entries = max_entries
self.num_args = num_args
self.lru = lru

self.arg_names = inspect.getargspec(orig).args[1:num_args+1]

if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)

def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
Expand All @@ -158,11 +174,13 @@ def __get__(self, obj, objtype=None):

@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try:
cached_result = cache.get(*keyargs[:self.num_args])
cached_result = cache.get(*keyargs)
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
Expand All @@ -177,9 +195,9 @@ def wrapped(*keyargs):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence

ret = yield self.orig(obj, *keyargs)
ret = yield self.function_to_call(obj, *args, **kwargs)

cache.update(sequence, *keyargs[:self.num_args] + (ret,))
cache.update(sequence, *(keyargs + [ret]))

defer.returnValue(ret)

Expand All @@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=True):
)


def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)


class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from _base import SQLBaseStore, cached
from _base import SQLBaseStore, cachedInlineCallbacks

from twisted.internet import defer

Expand Down Expand Up @@ -71,8 +71,7 @@ def store_server_certificate(self, server_name, from_server, time_now_ms,
desc="store_server_certificate",
)

@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
Expand Down
8 changes: 3 additions & 5 deletions synapse/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer

import logging
Expand All @@ -23,8 +23,7 @@


class PushRuleStore(SQLBaseStore):
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
Expand All @@ -41,8 +40,7 @@ def get_push_rules_for_user(self, user_name):

defer.returnValue(rows)

@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks

from twisted.internet import defer

Expand Down Expand Up @@ -128,8 +128,7 @@ def f(txn):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)

@cached
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from synapse.api.errors import StoreError

from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks

import collections
import logging
Expand Down Expand Up @@ -186,8 +186,7 @@ def _store_room_name_txn(self, txn, event):
}
)

@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cached, cachedInlineCallbacks

from twisted.internet import defer

Expand Down Expand Up @@ -189,8 +189,7 @@ def f(txn):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)

@cached(num_args=3)
@defer.inlineCallbacks
@cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
Expand Down

0 comments on commit 8049c9a

Please sign in to comment.