Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #2898 from matrix-org/erikj/split_push_rules_store
Browse files Browse the repository at this point in the history
Split PushRulesStore
  • Loading branch information
erikjohnston authored Feb 23, 2018
2 parents 199dba6 + 7e6cf89 commit d095775
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 45 deletions.
24 changes: 7 additions & 17 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,36 +16,25 @@

from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage.push_rule import PushRulesWorkerStore


class SlavedPushRuleStore(SlavedEventStore):
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)

get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
super(SlavedPushRuleStore, self).__init__(db_conn, hs)

def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)

def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
Expand Down
13 changes: 1 addition & 12 deletions synapse/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -169,18 +170,6 @@ def __init__(self, db_conn, hs):
prefilled_cache=presence_cache_prefill
)

push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)

self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)

max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
Expand Down
72 changes: 56 additions & 16 deletions synapse/storage/push_rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +16,12 @@

from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer

import abc
import logging
import simplejson as json

Expand Down Expand Up @@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map):
return rules


class PushRuleStore(SQLBaseStore):
class PushRulesWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""

# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta

def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)

push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)

self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)

@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()

@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
Expand Down Expand Up @@ -89,6 +124,24 @@ def get_push_rules_enabled_for_user(self, user_id):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})

def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)


class PushRuleStore(PushRulesWorkerStore):
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
Expand Down Expand Up @@ -526,21 +579,8 @@ def get_push_rules_stream_token(self):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()

def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
def get_max_push_rules_stream_id(self):
return self.get_push_rules_stream_token()[0]


class RuleNotFoundException(Exception):
Expand Down

0 comments on commit d095775

Please sign in to comment.