diff --git a/docs/admin_api/purge_history_api.rst b/docs/admin_api/purge_history_api.rst index a3a17e9f9fe9..acf1bc574948 100644 --- a/docs/admin_api/purge_history_api.rst +++ b/docs/admin_api/purge_history_api.rst @@ -8,9 +8,9 @@ Depending on the amount of history being purged a call to the API may take several minutes or longer. During this period users will not be able to paginate further back in the room from the point being purged from. -The API is simply: +The API is: -``POST /_matrix/client/r0/admin/purge_history//`` +``POST /_matrix/client/r0/admin/purge_history/[/]`` including an ``access_token`` of a server admin. @@ -25,3 +25,10 @@ To delete local events as well, set ``delete_local_events`` in the body: { "delete_local_events": true } + +The caller must specify the point in the room to purge up to. This can be +specified by including an event_id in the URI, or by setting a +``purge_up_to_event_id`` or ``purge_up_to_ts`` in the request body. If an event +id is given, that event (and others at the same graph depth) will be retained. +If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch, +in milliseconds. diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index b2ce399258fa..fc0b9e8c04e1 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -27,10 +27,14 @@ from synapse.http.site import SynapseSite from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler @@ -48,6 +52,10 @@ class EventCreatorSlavedStore( + SlavedAccountDataStore, + SlavedPusherStore, + SlavedReceiptsStore, + SlavedPushRuleStore, SlavedDeviceStore, SlavedClientIpStore, SlavedApplicationServiceStore, diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 32ccea3f1376..98a4a7c62c9b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -32,7 +32,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import create_engine -from synapse.storage.roommember import RoomMemberStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.manhole import manhole @@ -75,10 +74,6 @@ class PusherSlaveStore( DataStore.get_profile_displayname.__func__ ) - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - class PusherServer(HomeServer): def setup(self): diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index f87531f1b64e..abe91dcfbddb 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -62,8 +62,6 @@ class SynchrotronSlavedStore( - SlavedPushRuleStore, - SlavedEventStore, SlavedReceiptsStore, SlavedAccountDataStore, SlavedApplicationServiceStore, @@ -73,14 +71,12 @@ class SynchrotronSlavedStore( SlavedGroupServerStore, SlavedDeviceInboxStore, SlavedDeviceStore, + SlavedPushRuleStore, + SlavedEventStore, SlavedClientIpStore, RoomStore, BaseSlavedStore, ): - who_forgot_in_room = ( - RoomMemberStore.__dict__["who_forgot_in_room"] - ) - did_forget = ( RoomMemberStore.__dict__["did_forget"] ) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 53213cdccf85..8f8fd82eb0b3 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,7 +17,6 @@ from .room import ( RoomCreationHandler, RoomContextHandler, ) -from .room_member import RoomMemberHandler from .message import MessageHandler from .federation import FederationHandler from .directory import DirectoryHandler @@ -49,7 +48,6 @@ def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) - self.room_member_handler = RoomMemberHandler(hs) self.federation_handler = FederationHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index faa5609c0c92..e089e66fde16 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -158,7 +158,7 @@ def kick_guest_users(self, current_state): # homeserver. requester = synapse.types.create_requester( target_user, is_guest=True) - handler = self.hs.get_handlers().room_member_handler + handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, target_user, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 258cc345dc6b..a5365c4fe450 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -863,8 +863,10 @@ def validate_hash(self, password, stored_hash): """ def _do_validate_hash(): - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - stored_hash.encode('utf8')) == stored_hash + return bcrypt.checkpw( + password.encode('utf8') + self.hs.config.password_pepper, + stored_hash.encode('utf8') + ) if stored_hash: return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 46bcf8b08130..520612683ee6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1447,16 +1447,24 @@ def _handle_new_event(self, origin, event, state=None, auth_events=None, auth_events=auth_events, ) - if not event.internal_metadata.is_outlier() and not backfilled: - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + try: + if not event.internal_metadata.is_outlier() and not backfilled: + yield self.action_generator.handle_push_actions_for_event( + event, context + ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - ) + event_stream_id, max_stream_id = yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + ) + except: # noqa: E722, as we reraise the exception this is fine. + # Ensure that we actually remove the entries in the push actions + # staging area + logcontext.preserve_fn( + self.store.remove_push_actions_from_staging + )(event.event_id) + raise if not backfilled: # this intentionally does not yield: we don't care about the result @@ -2145,7 +2153,7 @@ def exchange_third_party_invite( raise e yield self._check_signature(event, context) - member_handler = self.hs.get_handlers().room_member_handler + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) @@ -2189,7 +2197,7 @@ def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) - member_handler = self.hs.get_handlers().room_member_handler + member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d99d8049b3f2..dd00d8a86cb7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -52,16 +52,12 @@ def __init__(self, hs): self.pagination_lock = ReadWriteLock() @defer.inlineCallbacks - def purge_history(self, room_id, event_id, delete_local_events=False): - event = yield self.store.get_event(event_id) - - if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") - - depth = event.depth - + def purge_history(self, room_id, topological_ordering, + delete_local_events=False): with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history(room_id, depth, delete_local_events) + yield self.store.purge_history( + room_id, topological_ordering, delete_local_events, + ) @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, @@ -553,24 +549,21 @@ def handle_new_client_event( event, context, ratelimit=True, - extra_users=[] + extra_users=[], ): - # We now need to go and hit out to wherever we need to hit out to. - - # If we're a worker we need to hit out to the master. - if self.config.worker_app: - yield send_event_to_master( - self.http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, - requester=requester, - event=event, - context=context, - ) - return + """Processes a new event. This includes checking auth, persisting it, + notifying users, sending to remote servers, etc. - if ratelimit: - yield self.base_handler.ratelimit(requester) + If called from a worker will hit out to the master process for final + processing. + + Args: + requester (Requester) + event (FrozenEvent) + context (EventContext) + ratelimit (bool) + extra_users (list(str)): Any extra users to notify about event + """ try: yield self.auth.check_from_context(event, context) @@ -586,6 +579,57 @@ def handle_new_client_event( logger.exception("Failed to encode content: %r", event.content) raise + yield self.action_generator.handle_push_actions_for_event( + event, context + ) + + try: + # If we're a worker we need to hit out to the master. + if self.config.worker_app: + yield send_event_to_master( + self.http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + requester=requester, + event=event, + context=context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + return + + yield self.persist_and_notify_client_event( + requester, + event, + context, + ratelimit=ratelimit, + extra_users=extra_users, + ) + except: # noqa: E722, as we reraise the exception this is fine. + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id) + raise + + @defer.inlineCallbacks + def persist_and_notify_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[], + ): + """Called when we have fully built the event, have already + calculated the push actions for the event, and checked auth. + + This should only be run on master. + """ + assert not self.config.worker_app + + if ratelimit: + yield self.base_handler.ratelimit(requester) + yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: @@ -679,20 +723,10 @@ def is_inviter_member_event(e): "Changing the room create event is forbidden", ) - yield self.action_generator.handle_push_actions_for_event( - event, context + (event_stream_id, max_stream_id) = yield self.store.persist_event( + event, context=context ) - try: - (event_stream_id, max_stream_id) = yield self.store.persist_event( - event, context=context - ) - except: # noqa: E722, as we reraise the exception this is fine. - # Ensure that we actually remove the entries in the push actions - # staging area - preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id) - raise - # this intentionally does not yield: we don't care about the result # and don't need to wait for it. preserve_fn(self.pusher_pool.on_new_notifications)( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 9800e244533e..c9c28790385e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -233,7 +233,7 @@ def _update_join_states(self, requester, target_user): ) for room_id in room_ids: - handler = self.hs.get_handlers().room_member_handler + handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index b5b0303d54e0..5142ae153da4 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -41,9 +41,9 @@ def received_client_read_marker(self, room_id, user_id, event_id): """ with (yield self.read_marker_linearizer.queue((room_id, user_id))): - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - - existing_read_marker = account_data.get("m.fully_read", None) + existing_read_marker = yield self.store.get_account_data_for_room_and_type( + user_id, room_id, "m.fully_read", + ) should_update = True diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6ab020bf4160..8df8fcbbadc6 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -165,7 +165,7 @@ def create_room(self, requester, config, ratelimit=True): creation_content = config.get("creation_content", {}) - room_member_handler = self.hs.get_handlers().room_member_handler + room_member_handler = self.hs.get_room_member_handler() yield self._send_events_for_new_room( requester, @@ -224,7 +224,7 @@ def create_room(self, requester, config, ratelimit=True): id_server = invite_3pid["id_server"] address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_handlers().room_member_handler.do_3pid_invite( + yield self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -475,12 +475,9 @@ def get_new_events( user.to_string() ) if app_service: - events, end_key = yield self.store.get_appservice_room_stream( - service=app_service, - from_key=from_key, - to_key=to_key, - limit=limit, - ) + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() else: room_events = yield self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 37dc5e99ab90..ed3b97730d88 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -30,24 +30,32 @@ from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room -from ._base import BaseHandler logger = logging.getLogger(__name__) id_server_scheme = "https://" -class RoomMemberHandler(BaseHandler): +class RoomMemberHandler(object): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) - + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() + self.config = hs.config + self.simple_http_client = hs.get_simple_http_client() + + self.federation_handler = hs.get_handlers().federation_handler + self.directory_handler = hs.get_handlers().directory_handler + self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() self.event_creation_hander = hs.get_event_creation_handler() + self.replication_layer = hs.get_replication_layer() self.member_linearizer = Linearizer(name="member") @@ -138,7 +146,7 @@ def remote_join(self, remote_room_hosts, room_id, user, content): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield self.hs.get_handlers().federation_handler.do_invite_join( + yield self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), @@ -204,8 +212,7 @@ def _update_membership( # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( + yield self.replication_layer.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -226,7 +233,7 @@ def _update_membership( requester.user, ) if not is_requester_admin: - if self.hs.config.block_non_admin_invites: + if self.config.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" @@ -321,7 +328,7 @@ def _update_membership( else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - fed_handler = self.hs.get_handlers().federation_handler + fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( remote_room_hosts, @@ -477,7 +484,7 @@ def lookup_room_alias(self, room_alias): Raises: SynapseError if room alias could not be found. """ - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: @@ -508,7 +515,7 @@ def do_3pid_invite( requester, txn_id ): - if self.hs.config.block_non_admin_invites: + if self.config.block_non_admin_invites: is_requester_admin = yield self.auth.is_server_admin( requester.user, ) @@ -555,7 +562,7 @@ def _lookup_3pid(self, id_server, medium, address): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = yield self.hs.get_simple_http_client().get_json( + data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), { "medium": medium, @@ -566,7 +573,7 @@ def _lookup_3pid(self, id_server, medium, address): if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) + yield self.verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: @@ -578,7 +585,7 @@ def verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( + key_data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_scheme, server_hostname, key_name,), ) @@ -603,7 +610,7 @@ def _make_and_store_3pid_invite( user, txn_id ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) + room_state = yield self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -727,15 +734,15 @@ def _ask_id_server_for_third_party_invite( "sender_avatar_url": inviter_avatar_url, } - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler + if self.config.invite_3pid_guest: + registration_handler = self.registration_handler guest_access_token = yield registration_handler.guest_access_token_for( medium=medium, address=address, inviter_user_id=inviter_user_id, ) - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( + guest_user_info = yield self.auth.get_user_by_access_token( guest_access_token ) @@ -744,7 +751,7 @@ def _ask_id_server_for_third_party_invite( "guest_user_id": guest_user_info["user"].to_string(), }) - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + data = yield self.simple_http_client.post_urlencoded_get_json( is_url, invite_config ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b12988f3c99f..56b86356f2bc 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -998,8 +998,9 @@ def _have_rooms_changed(self, sync_result_builder): app_service = self.store.get_app_service_by_user_id(user_id) if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() else: joined_room_ids = yield self.store.get_rooms_for_user(user_id) @@ -1030,8 +1031,9 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): app_service = self.store.get_app_service_by_user_id(user_id) if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() else: joined_room_ids = yield self.store.get_rooms_for_user(user_id) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bf4f1c5836d4..7c680659b6f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -144,6 +144,7 @@ def action_for_event_by_user(self, event, context): Deferred """ rules_by_user = yield self._get_rules_for_event(event, context) + actions_by_user = {} room_members = yield self.store.get_joined_users_from_context( event, context @@ -189,14 +190,17 @@ def action_for_event_by_user(self, event, context): if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] if actions and 'notify' in actions: - # Push rules say we should notify the user of this event, - # so we mark it in the DB in the staging area. (This - # will then get handled when we persist the event) - yield self.store.add_push_actions_to_staging( - event.event_id, uid, actions, - ) + # Push rules say we should notify the user of this event + actions_by_user[uid] = actions break + # Mark in the DB staging area the push actions for users who should be + # notified for this event. (This will then get handled when we persist + # the event) + yield self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, + ) + def _condition_checker(evaluator, conditions, uid, display_name, cache): for cond in conditions: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 906d84373916..3c4f4d351c83 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -24,14 +24,14 @@ "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], + "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], "daemonize": ["daemonize"], - "bcrypt": ["bcrypt"], + "bcrypt": ["bcrypt>=3.1.0"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], "ujson": ["ujson"], diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 468f4b68f434..70f2fe456a3a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -15,10 +15,15 @@ from twisted.internet import defer -from synapse.api.errors import SynapseError, MatrixCodeMessageException +from synapse.api.errors import ( + SynapseError, MatrixCodeMessageException, CodeMessageException, +) from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.util.async import sleep +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure from synapse.types import Requester @@ -29,7 +34,8 @@ @defer.inlineCallbacks -def send_event_to_master(client, host, port, requester, event, context): +def send_event_to_master(client, host, port, requester, event, context, + ratelimit, extra_users): """Send event to be handled on the master Args: @@ -39,8 +45,12 @@ def send_event_to_master(client, host, port, requester, event, context): requester (Requester) event (FrozenEvent) context (EventContext) + ratelimit (bool) + extra_users (list(str)): Any extra users to notify about event """ - uri = "http://%s:%s/_synapse/replication/send_event" % (host, port,) + uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( + host, port, event.event_id, + ) payload = { "event": event.get_pdu_json(), @@ -48,10 +58,27 @@ def send_event_to_master(client, host, port, requester, event, context): "rejected_reason": event.rejected_reason, "context": context.serialize(event), "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": extra_users, } try: - result = yield client.post_json_get_json(uri, payload) + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed on + # the master, and so whether we should clean up or not. + while True: + try: + result = yield client.put_json(uri, payload) + break + except CodeMessageException as e: + if e.code != 504: + raise + + logger.warn("send_event request timed out") + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + yield sleep(1) except MatrixCodeMessageException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And @@ -66,7 +93,7 @@ class ReplicationSendEventRestServlet(RestServlet): The API looks like: - POST /_synapse/replication/send_event + POST /_synapse/replication/send_event/:event_id { "event": { .. serialized event .. }, @@ -74,9 +101,11 @@ class ReplicationSendEventRestServlet(RestServlet): "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, "requester": { .. serialized requester .. }, + "ratelimit": true, + "extra_users": [], } """ - PATTERNS = [re.compile("^/_synapse/replication/send_event$")] + PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P[^/]+)$")] def __init__(self, hs): super(ReplicationSendEventRestServlet, self).__init__() @@ -85,8 +114,23 @@ def __init__(self, hs): self.store = hs.get_datastore() self.clock = hs.get_clock() + # The responses are tiny, so we may as well cache them for a while + self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000) + + def on_PUT(self, request, event_id): + result = self.response_cache.get(event_id) + if not result: + result = self.response_cache.set( + event_id, + self._handle_request(request) + ) + else: + logger.warn("Returning cached response") + return make_deferred_yieldable(result) + + @preserve_fn @defer.inlineCallbacks - def on_POST(self, request): + def _handle_request(self, request): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -98,6 +142,9 @@ def on_POST(self, request): requester = Requester.deserialize(self.store, content["requester"]) context = yield EventContext.deserialize(self.store, content["context"]) + ratelimit = content["ratelimit"] + extra_users = content["extra_users"] + if requester.user: request.authenticated_entity = requester.user.to_string() @@ -106,8 +153,10 @@ def on_POST(self, request): event.event_id, event.room_id, ) - yield self.event_creation_handler.handle_new_client_event( + yield self.event_creation_handler.persist_and_notify_client_event( requester, event, context, + ratelimit=ratelimit, + extra_users=extra_users, ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index efbd87918ec6..d9ba6d69b10f 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,50 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.account_data import AccountDataStore -from synapse.storage.tags import TagsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.tags import TagsWorkerStore -class SlavedAccountDataStore(BaseSlavedStore): +class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_current_token(), - ) - - get_account_data_for_user = ( - AccountDataStore.__dict__["get_account_data_for_user"] - ) - - get_global_account_data_by_type_for_users = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] - ) - get_global_account_data_by_type_for_user = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] - ) - - get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] - get_tags_for_room = ( - DataStore.get_tags_for_room.__func__ - ) - get_account_data_for_room = ( - DataStore.get_account_data_for_room.__func__ - ) - - get_updated_tags = DataStore.get_updated_tags.__func__ - get_updated_account_data_for_user = ( - DataStore.get_updated_account_data_for_user.__func__ - ) + super(SlavedAccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() @@ -85,6 +56,10 @@ def process_replication_rows(self, stream_name, token, rows): (row.data_type, row.user_id,) ) self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type,), + ) self._account_data_stream_cache.entity_has_changed( row.user_id, token ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index 0d3f31a50c24..8cae3076f418 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.config.appservice import load_appservices -from synapse.storage.appservice import _make_exclusive_regex +from synapse.storage.appservice import ( + ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore, +) -class SlavedApplicationServiceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedApplicationServiceStore, self).__init__(db_conn, hs) - self.services_cache = load_appservices( - hs.config.server_name, - hs.config.app_service_config_files - ) - self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - - get_app_service_by_token = DataStore.get_app_service_by_token.__func__ - get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ - get_app_services = DataStore.get_app_services.__func__ - get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ - create_appservice_txn = DataStore.create_appservice_txn.__func__ - get_appservices_by_state = DataStore.get_appservices_by_state.__func__ - get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ - _get_last_txn = DataStore._get_last_txn.__func__ - complete_appservice_txn = DataStore.complete_appservice_txn.__func__ - get_appservice_state = DataStore.get_appservice_state.__func__ - set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ - set_appservice_state = DataStore.set_appservice_state.__func__ - get_if_app_services_interested_in_user = ( - DataStore.get_if_app_services_interested_in_user.__func__ - ) +class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, + ApplicationServiceWorkerStore): + pass diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 7301d885f26f..6deecd396352 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -14,10 +14,8 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage.directory import DirectoryStore +from synapse.storage.directory import DirectoryWorkerStore -class DirectoryStore(BaseSlavedStore): - get_aliases_for_room = DirectoryStore.__dict__[ - "get_aliases_for_room" - ] +class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index f8c164b48b70..b1f64ef0d8ed 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,14 +16,13 @@ import logging from synapse.api.constants import EventTypes -from synapse.storage import DataStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationWorkerStore +from synapse.storage.event_push_actions import EventPushActionsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore from synapse.storage.state import StateGroupWorkerStore -from synapse.storage.stream import StreamStore -from synapse.storage.signatures import SignatureStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.stream import StreamWorkerStore +from synapse.storage.signatures import SignatureWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -38,157 +38,33 @@ # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore): +class SlavedEventStore(EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - self.stream_ordering_month_ago = 0 - self._stream_order_on_start = self.get_room_max_stream_ordering() + super(SlavedEventStore, self).__init__(db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. - get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] - get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] - get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"] - get_users_who_share_room_with_user = ( - RoomMemberStore.__dict__["get_users_who_share_room_with_user"] - ) - get_latest_event_ids_in_room = EventFederationStore.__dict__[ - "get_latest_event_ids_in_room" - ] - get_invited_rooms_for_user = RoomMemberStore.__dict__[ - "get_invited_rooms_for_user" - ] - get_unread_event_push_actions_by_room_for_user = ( - EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] - ) - _get_unread_counts_by_receipt_txn = ( - DataStore._get_unread_counts_by_receipt_txn.__func__ - ) - _get_unread_counts_by_pos_txn = ( - DataStore._get_unread_counts_by_pos_txn.__func__ - ) - get_recent_event_ids_for_room = ( - StreamStore.__dict__["get_recent_event_ids_for_room"] - ) - _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] - has_room_changed_since = DataStore.has_room_changed_since.__func__ - - get_unread_push_actions_for_user_in_range_for_http = ( - DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ - ) - get_unread_push_actions_for_user_in_range_for_email = ( - DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ - ) - get_push_action_users_in_range = ( - DataStore.get_push_action_users_in_range.__func__ - ) - get_event = DataStore.get_event.__func__ - get_events = DataStore.get_events.__func__ - get_rooms_for_user_where_membership_is = ( - DataStore.get_rooms_for_user_where_membership_is.__func__ - ) - get_membership_changes_for_user = ( - DataStore.get_membership_changes_for_user.__func__ - ) - get_room_events_max_id = DataStore.get_room_events_max_id.__func__ - get_room_events_stream_for_room = ( - DataStore.get_room_events_stream_for_room.__func__ - ) - get_events_around = DataStore.get_events_around.__func__ - get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ - get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ - _get_joined_users_from_context = ( - RoomMemberStore.__dict__["_get_joined_users_from_context"] - ) - - get_joined_hosts = DataStore.get_joined_hosts.__func__ - _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] - - get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ - get_room_events_stream_for_rooms = ( - DataStore.get_room_events_stream_for_rooms.__func__ - ) - is_host_joined = RoomMemberStore.__dict__["is_host_joined"] - get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ - - _set_before_and_after = staticmethod(DataStore._set_before_and_after) - - _get_events = DataStore._get_events.__func__ - _get_events_from_cache = DataStore._get_events_from_cache.__func__ - - _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ - _enqueue_events = DataStore._enqueue_events.__func__ - _do_fetch = DataStore._do_fetch.__func__ - _fetch_event_rows = DataStore._fetch_event_rows.__func__ - _get_event_from_row = DataStore._get_event_from_row.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) - _get_events_around_txn = DataStore._get_events_around_txn.__func__ - - get_backfill_events = DataStore.get_backfill_events.__func__ - _get_backfill_events = DataStore._get_backfill_events.__func__ - get_missing_events = DataStore.get_missing_events.__func__ - _get_missing_events = DataStore._get_missing_events.__func__ - - get_auth_chain = DataStore.get_auth_chain.__func__ - get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ - _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ - - get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ - - get_forward_extremeties_for_room = ( - DataStore.get_forward_extremeties_for_room.__func__ - ) - _get_forward_extremeties_for_room = ( - EventFederationStore.__dict__["_get_forward_extremeties_for_room"] - ) - - get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ - - get_federation_out_pos = DataStore.get_federation_out_pos.__func__ - update_federation_out_pos = DataStore.update_federation_out_pos.__func__ - - get_latest_event_ids_and_hashes_in_room = ( - DataStore.get_latest_event_ids_and_hashes_in_room.__func__ - ) - _get_latest_event_ids_and_hashes_in_room = ( - DataStore._get_latest_event_ids_and_hashes_in_room.__func__ - ) - _get_event_reference_hashes_txn = ( - DataStore._get_event_reference_hashes_txn.__func__ - ) - add_event_hashes = ( - DataStore.add_event_hashes.__func__ - ) - get_event_reference_hashes = ( - SignatureStore.__dict__["get_event_reference_hashes"] - ) - get_event_reference_hash = ( - SignatureStore.__dict__["get_event_reference_hash"] - ) + + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 83e880fdd2a2..bb2c40b6e3e9 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,29 +16,15 @@ from .events import SlavedEventStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.push_rule import PushRulesWorkerStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,6 +32,9 @@ def get_push_rules_stream_token(self): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 4e8d68ece9dc..a7cd5a729111 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +17,10 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore +from synapse.storage.pusher import PusherWorkerStore -class SlavedPusherStore(BaseSlavedStore): +class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) @@ -28,13 +29,6 @@ def __init__(self, db_conn, hs): extra_tables=[("deleted_pushers", "stream_id")], ) - get_all_pushers = DataStore.get_all_pushers.__func__ - get_pushers_by = DataStore.get_pushers_by.__func__ - get_pushers_by_app_id_and_pushkey = ( - DataStore.get_pushers_by_app_id_and_pushkey.__func__ - ) - _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ - def stream_positions(self): result = super(SlavedPusherStore, self).stream_positions() result["pushers"] = self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index b371574ece56..1647072f659a 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +17,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.receipts import ReceiptsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.receipts import ReceiptsWorkerStore # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -29,36 +28,19 @@ # the method descriptor on the DataStore and chuck them into our class. -class SlavedReceiptsStore(BaseSlavedStore): +class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedReceiptsStore, self).__init__(db_conn, hs) - + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() - ) - - get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] - get_linearized_receipts_for_room = ( - ReceiptsStore.__dict__["get_linearized_receipts_for_room"] - ) - _get_linearized_receipts_for_rooms = ( - ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] - ) - get_last_receipt_event_id_for_user = ( - ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] - ) - - get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ - get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + super(SlavedReceiptsStore, self).__init__(db_conn, hs) - get_linearized_receipts_for_rooms = ( - DataStore.get_linearized_receipts_for_rooms.__func__ - ) + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() @@ -71,6 +53,8 @@ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, token, rows): if stream_name == "receipts": diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index e27c7332d214..7323bf0f1ed4 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -14,20 +14,8 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.registration import RegistrationStore +from synapse.storage.registration import RegistrationWorkerStore -class SlavedRegistrationStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedRegistrationStore, self).__init__(db_conn, hs) - - # TODO: use the cached version and invalidate deleted tokens - get_user_by_access_token = RegistrationStore.__dict__[ - "get_user_by_access_token" - ] - - _query_for_auth = DataStore._query_for_auth.__func__ - get_user_by_id = RegistrationStore.__dict__[ - "get_user_by_id" - ] +class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f5103840333f..5ae167015745 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,32 +14,19 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.room import RoomStore +from synapse.storage.room import RoomWorkerStore from ._slaved_id_tracker import SlavedIdTracker -class RoomStore(BaseSlavedStore): +class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(RoomStore, self).__init__(db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) - get_public_room_ids = DataStore.get_public_room_ids.__func__ - get_current_public_room_stream_id = ( - DataStore.get_current_public_room_stream_id.__func__ - ) - get_public_room_ids_at_stream_id = ( - RoomStore.__dict__["get_public_room_ids_at_stream_id"] - ) - get_public_room_ids_at_stream_id_txn = ( - DataStore.get_public_room_ids_at_stream_id_txn.__func__ - ) - get_published_at_stream_id_txn = ( - DataStore.get_published_at_stream_id_txn.__func__ - ) - get_public_room_changes = DataStore.get_public_room_changes.__func__ + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() def stream_positions(self): result = super(RoomStore, self).stream_positions() diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 6073cc6fa2ea..dcf6215dadda 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import UserID, create_requester from synapse.http.servlet import parse_json_object_from_request @@ -114,12 +114,18 @@ def on_POST(self, request): class PurgeHistoryRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns( - "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" + "/admin/purge_history/(?P[^/]*)(/(?P[^/]+))?" ) def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ super(PurgeHistoryRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() + self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): @@ -133,8 +139,54 @@ def on_POST(self, request, room_id, event_id): delete_local_events = bool(body.get("delete_local_events", False)) + # establish the topological ordering we should keep events from. The + # user can provide an event_id in the URL or the request body, or can + # provide a timestamp in the request body. + if event_id is None: + event_id = body.get('purge_up_to_event_id') + + if event_id is not None: + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + depth = event.depth + logger.info( + "[purge] purging up to depth %i (event_id %s)", + depth, event_id, + ) + elif 'purge_up_to_ts' in body: + ts = body['purge_up_to_ts'] + if not isinstance(ts, int): + raise SynapseError( + 400, "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, + ) + + stream_ordering = ( + yield self.store.find_first_stream_ordering_after_ts(ts) + ) + + (_, depth, _) = ( + yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + ) + logger.info( + "[purge] purging up to depth %i (received_ts %i => " + "stream_ordering %i)", + depth, ts, stream_ordering, + ) + else: + raise SynapseError( + 400, + "must specify purge_up_to_event_id or purge_up_to_ts", + errcode=Codes.BAD_JSON, + ) + yield self.handlers.message_handler.purge_history( - room_id, event_id, + room_id, depth, delete_local_events=delete_local_events, ) @@ -180,6 +232,7 @@ def __init__(self, hs): self.handlers = hs.get_handlers() self.state = hs.get_state_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() @defer.inlineCallbacks def on_POST(self, request, room_id): @@ -238,7 +291,7 @@ def on_POST(self, request, room_id): logger.info("Kicking %r from %r...", user_id, room_id) target_requester = create_requester(user_id) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -247,9 +300,9 @@ def on_POST(self, request, room_id): ratelimit=False ) - yield self.handlers.room_member_handler.forget(target_requester.user, room_id) + yield self.room_member_handler.forget(target_requester.user, room_id) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=new_room_id, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 817fd4784255..9d745174c746 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -84,6 +84,7 @@ def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.event_creation_hander = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /room/$roomid/state/$eventtype @@ -156,7 +157,7 @@ def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.handlers.room_member_handler.update_membership( + event = yield self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -229,7 +230,7 @@ def on_PUT(self, request, room_id, event_type, txn_id): class JoinRoomAliasServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinRoomAliasServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -257,7 +258,7 @@ def on_POST(self, request, room_identifier, txn_id=None): except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): - handler = self.handlers.room_member_handler + handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() @@ -266,7 +267,7 @@ def on_POST(self, request, room_identifier, txn_id=None): room_identifier, )) - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -562,7 +563,7 @@ def on_GET(self, request, room_id, event_id): class RoomForgetRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomForgetRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/forget") @@ -575,7 +576,7 @@ def on_POST(self, request, room_id, txn_id=None): allow_guest=False, ) - yield self.handlers.room_member_handler.forget( + yield self.room_member_handler.forget( user=requester.user, room_id=room_id, ) @@ -593,7 +594,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.room_member_handler = hs.get_room_member_handler() def register(self, http_server): # /rooms/$roomid/[invite|join|leave] @@ -622,7 +623,7 @@ def on_POST(self, request, room_id, membership_action, txn_id=None): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.handlers.room_member_handler.do_3pid_invite( + yield self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -644,7 +645,7 @@ def on_POST(self, request, room_id, membership_action, txn_id=None): if 'reason' in content and membership_action in ['kick', 'ban']: event_content = {'reason': content['reason']} - yield self.handlers.room_member_handler.update_membership( + yield self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c6f4680a76bb..0ba62bddc1b7 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -183,7 +183,7 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler - self.room_member_handler = hs.get_handlers().room_member_handler + self.room_member_handler = hs.get_room_member_handler() self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 3f8d4b9c2253..83471b3173a5 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -58,23 +58,13 @@ def store_file(self, source, file_info): Returns: Deferred[str]: the file path written to in the primary media store """ - path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) - dirname = os.path.dirname(fname) - if not os.path.exists(dirname): - os.makedirs(dirname) - - # Write to the main repository - yield make_deferred_yieldable(threads.deferToThread( - _write_file_synchronously, source, fname, - )) - - # Tell the storage providers about the new file. They'll decide - # if they should upload it and whether to do so synchronously - # or not. - for provider in self.storage_providers: - yield provider.store_file(path, file_info) + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + _write_file_synchronously, source, f, + )) + yield finish_cb() defer.returnValue(fname) @@ -240,21 +230,16 @@ def _file_info_to_path(self, file_info): ) -def _write_file_synchronously(source, fname): - """Write `source` to the path `fname` synchronously. Should be called +def _write_file_synchronously(source, dest): + """Write `source` to the file like `dest` synchronously. Should be called from a thread. Args: - source: A file like object to be written - fname (str): Path to write to + source: A file like object that's to be written + dest: A file like object to be written to """ - dirname = os.path.dirname(fname) - if not os.path.exists(dirname): - os.makedirs(dirname) - source.seek(0) # Ensure we read from the start of the file - with open(fname, "wb") as f: - shutil.copyfileobj(source, f) + shutil.copyfileobj(source, dest) class FileResponder(Responder): diff --git a/synapse/server.py b/synapse/server.py index fbd602d40ebb..5b6effbe31d3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -45,6 +45,7 @@ from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.room_list import RoomListHandler +from synapse.handlers.room_member import RoomMemberHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler @@ -145,6 +146,7 @@ def build_DEPENDENCY(self) 'groups_attestation_signing', 'groups_attestation_renewer', 'spam_checker', + 'room_member_handler', ] def __init__(self, hostname, **kwargs): @@ -382,6 +384,9 @@ def build_groups_attestation_renewer(self): def build_spam_checker(self): return SpamChecker(self) + def build_room_member_handler(self): + return RoomMemberHandler(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f8fbd02ceb5a..de00cae44750 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +20,6 @@ from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) -from ._base import LoggingTransaction from .directory import DirectoryStore from .events import EventsStore from .presence import PresenceStore, UserPresenceState @@ -104,12 +104,6 @@ def __init__(self, db_conn, hs): db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -146,27 +140,6 @@ def __init__(self, db_conn, hs): else: self._cache_id_gen = None - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - - account_max = self._account_data_id_gen.get_current_token() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, - ) - self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( @@ -180,18 +153,6 @@ def __init__(self, db_conn, hs): prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", @@ -226,6 +187,7 @@ def __init__(self, db_conn, hs): "DeviceListFederationStreamChangeCache", device_list_max, ) + events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -250,20 +212,6 @@ def __init__(self, db_conn, hs): prefilled_cache=_group_updates_prefill, ) - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - after_callbacks=[], - final_callbacks=[], - ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 - ) - self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 68125006eb9d..2fbebd4907cc 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -48,16 +48,16 @@ class LoggingTransaction(object): passed to the constructor. Adds logging and metrics to the .execute() method.""" __slots__ = [ - "txn", "name", "database_engine", "after_callbacks", "final_callbacks", + "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", ] def __init__(self, txn, name, database_engine, after_callbacks, - final_callbacks): + exception_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "final_callbacks", final_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the @@ -66,8 +66,8 @@ def call_after(self, callback, *args, **kwargs): """ self.after_callbacks.append((callback, args, kwargs)) - def call_finally(self, callback, *args, **kwargs): - self.final_callbacks.append((callback, args, kwargs)) + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -215,7 +215,7 @@ def loop(): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, final_callbacks, + def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, logging_context, func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID @@ -236,7 +236,7 @@ def _new_transaction(self, conn, desc, after_callbacks, final_callbacks, txn = conn.cursor() txn = LoggingTransaction( txn, name, self.database_engine, after_callbacks, - final_callbacks, + exception_callbacks, ) r = func(txn, *args, **kwargs) conn.commit() @@ -308,11 +308,11 @@ def runInteraction(self, desc, func, *args, **kwargs): current_context = LoggingContext.current_context() after_callbacks = [] - final_callbacks = [] + exception_callbacks = [] def inner_func(conn, *args, **kwargs): return self._new_transaction( - conn, desc, after_callbacks, final_callbacks, current_context, + conn, desc, after_callbacks, exception_callbacks, current_context, func, *args, **kwargs ) @@ -321,9 +321,10 @@ def inner_func(conn, *args, **kwargs): for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) - finally: - for after_callback, after_args, after_kwargs in final_callbacks: + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: after_callback(*after_args, **after_kwargs) + raise defer.returnValue(result) @@ -1000,7 +1001,8 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys): # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() - txn.call_finally(ctx.__exit__, None, None, None) + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) self._simple_insert_txn( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 56a0bde5495c..e70c9423e375 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator + +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +import abc import ujson as json import logging logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -104,6 +133,7 @@ def get_global_account_data_by_type_for_users(self, data_type, user_ids): for row in rows }) + @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -127,6 +157,38 @@ def get_account_data_for_room_txn(txn): "get_account_data_for_room", get_account_data_for_room_txn ) + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", + get_account_data_for_room_and_type_txn, + ) + def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server @@ -209,6 +271,36 @@ def get_updated_account_data_for_user_txn(txn): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -251,6 +343,10 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type,), content, + ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @@ -321,16 +417,3 @@ def _update(txn): "update_account_data_max_stream_id", _update, ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", ignorer_user_id, - on_invalidate=cache_context.invalidate, - ) - if not ignored_account_data: - defer.returnValue(False) - - defer.returnValue( - ignored_user_id in ignored_account_data.get("ignored_users", {}) - ) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 79673b427353..12ea8a158c33 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +18,9 @@ import simplejson as json from twisted.internet import defer -from synapse.api.constants import Membership from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.roommember import RoomsForUser +from synapse.storage.events import EventsWorkerStore from ._base import SQLBaseStore @@ -46,17 +46,16 @@ def _make_exclusive_regex(services_cache): return exclusive_user_regex -class ApplicationServiceStore(SQLBaseStore): - +class ApplicationServiceWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(ApplicationServiceStore, self).__init__(db_conn, hs) - self.hostname = hs.hostname self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + def get_app_services(self): return self.services_cache @@ -112,83 +111,17 @@ def get_app_service_by_id(self, as_id): return service return None - def get_app_service_rooms(self, service): - """Get a list of RoomsForUser for this application service. - - Application services may be "interested" in lots of rooms depending on - the room ID, the room aliases, or the members in the room. This function - takes all of these into account and returns a list of RoomsForUser which - represent the entire list of room IDs that this application service - wants to know about. - - Args: - service: The application service to get a room list for. - Returns: - A list of RoomsForUser. - """ - return self.runInteraction( - "get_app_service_rooms", - self._get_app_service_rooms_txn, - service, - ) - - def _get_app_service_rooms_txn(self, txn, service): - # get all rooms matching the room ID regex. - room_entries = self._simple_select_list_txn( - txn=txn, table="rooms", keyvalues=None, retcols=["room_id"] - ) - matching_room_list = set([ - r["room_id"] for r in room_entries if - service.is_interested_in_room(r["room_id"]) - ]) - - # resolve room IDs for matching room alias regex. - room_alias_mappings = self._simple_select_list_txn( - txn=txn, table="room_aliases", keyvalues=None, - retcols=["room_id", "room_alias"] - ) - matching_room_list |= set([ - r["room_id"] for r in room_alias_mappings if - service.is_interested_in_alias(r["room_alias"]) - ]) - - # get all rooms for every user for this AS. This is scoped to users on - # this HS only. - user_list = self._simple_select_list_txn( - txn=txn, table="users", keyvalues=None, retcols=["name"] - ) - user_list = [ - u["name"] for u in user_list if - service.is_interested_in_user(u["name"]) - ] - rooms_for_user_matching_user_id = set() # RoomsForUser list - for user_id in user_list: - # FIXME: This assumes this store is linked with RoomMemberStore :( - rooms_for_user = self._get_rooms_for_user_where_membership_is_txn( - txn=txn, - user_id=user_id, - membership_list=[Membership.JOIN] - ) - rooms_for_user_matching_user_id |= set(rooms_for_user) - - # make RoomsForUser tuples for room ids and aliases which are not in the - # main rooms_for_user_list - e.g. they are rooms which do not have AS - # registered users in it. - known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id] - missing_rooms_for_user = [ - RoomsForUser(r, service.sender, "join") for r in - matching_room_list if r not in known_room_ids - ] - rooms_for_user_matching_user_id |= set(missing_rooms_for_user) - return rooms_for_user_matching_user_id +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass -class ApplicationServiceTransactionStore(SQLBaseStore): - - def __init__(self, db_conn, hs): - super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs) - +class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, + EventsWorkerStore): @defer.inlineCallbacks def get_appservices_by_state(self, state): """Get a list of application services based on their state. @@ -433,3 +366,11 @@ def get_new_events_for_appservice_txn(txn): events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 79e7c540ad88..d0c0059757ce 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -29,8 +29,7 @@ ) -class DirectoryStore(SQLBaseStore): - +class DirectoryWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): """ Get's the room_id and server list for a given room_alias @@ -69,6 +68,28 @@ def get_association_from_room_alias(self, room_alias): RoomAliasMapping(room_id, room_alias.to_string(), servers) ) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers @@ -116,17 +137,6 @@ def alias_txn(txn): ) defer.returnValue(ret) - def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( - table="room_aliases", - keyvalues={ - "room_alias": room_alias, - }, - retcol="creator", - desc="get_room_alias_creator", - allow_none=True - ) - @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( @@ -135,7 +145,6 @@ def delete_room_alias(self, room_alias): room_alias, ) - self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): @@ -160,17 +169,12 @@ def _delete_room_alias_txn(self, txn, room_alias): (room_alias.to_string(),) ) - return room_id - - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) ) + return room_id + def update_aliases_for_room(self, old_room_id, new_room_id, creator): def _update_aliases_for_room_txn(txn): sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 55a05c59d502..00ee82d3006a 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,10 @@ from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.storage.signatures import SignatureWorkerStore + from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 @@ -27,30 +30,8 @@ logger = logging.getLogger(__name__) -class EventFederationStore(SQLBaseStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. - - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ - - EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) - - self.register_background_update_handler( - self.EVENT_AUTH_STATE_ONLY, - self._background_delete_non_state_event_auth, - ) - - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, + SQLBaseStore): def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -228,88 +209,6 @@ def _get_min_depth_interaction(self, txn, room_id): return int(min_depth) if min_depth is not None else None - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth and depth >= min_depth: - return - - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id, _ in ev.prev_events - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany(query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ]) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -371,28 +270,6 @@ def get_forward_extremeties_for_room_txn(txn): get_forward_extremeties_for_room_txn ) - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = (""" - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """) - txn.execute( - sql, - (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) - ) - return self.runInteraction( - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn - ) - def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -522,6 +399,135 @@ def _get_missing_events(self, txn, room_id, earliest_events, latest_events, return event_results + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + if min_depth and depth >= min_depth: + return + + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={ + "room_id": room_id, + }, + values={ + "min_depth": depth, + }, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id, _ in ev.prev_events + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) for ev in events + if not ev.internal_metadata.is_outlier() + ] + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def clean_room_for_join(self, room_id): return self.runInteraction( "clean_room_for_join", diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index f787431b7a7f..01f833982525 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, LoggingTransaction from twisted.internet import defer from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -62,77 +63,28 @@ def _deserialize_action(actions, is_highlight): return DEFAULT_NOTIF_ACTION -class EventPushActionsStore(SQLBaseStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - +class EventPushActionsWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 60 * 1000 - ) - - def _set_push_actions_for_event_and_users_txn(self, txn, event): - """ - Args: - event: the event set actions for - tuples: list of tuples of (user_id, actions) - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ + super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) - txn.execute(sql, ( - event.room_id, event.internal_metadata.stream_ordering, - event.depth, event.event_id, - )) + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None - user_ids = self._simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={ - "event_id": event.event_id, - }, - retcol="user_id", + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + after_callbacks=[], + exception_callbacks=[], ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() - self._simple_delete_txn( - txn, - table="event_push_actions_staging", - keyvalues={ - "event_id": event.event_id, - }, + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 ) - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid,) - ) - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id @@ -449,6 +401,280 @@ def get_no_receipt(txn): # Now return the first `limit` defer.returnValue(notifs[:limit]) + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany(sql, ( + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.iteritems() + )) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + def remove_push_actions_from_staging(self, event_id): + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + + Args: + event_id (str) + """ + + return self._simple_delete( + table="event_push_actions_staging", + keyvalues={ + "event_id": event_id, + }, + desc="remove_push_actions_from_staging", + ) + + @defer.inlineCallbacks + def _find_stream_orderings_for_times(self): + yield self.runInteraction( + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", + self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, , 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, + all_events_and_contexts): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany(sql, ( + ( + event.room_id, event.internal_metadata.stream_ordering, + event.depth, event.event_id, + ) + for event, _ in events_and_contexts + )) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + }, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid,) + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ( + (event.event_id,) + for event, _ in all_events_and_contexts + ) + ) + @defer.inlineCallbacks def get_push_actions_for_user(self, user_id, before=None, limit=50, only_highlight=False): @@ -567,69 +793,6 @@ def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? """, (room_id, user_id, stream_ordering)) - @defer.inlineCallbacks - def _find_stream_orderings_for_times(self): - yield self.runInteraction( - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn - ) - - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", - self.stream_ordering_month_ago - ) - logger.info("Searching for stream ordering 1 day ago") - self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 day ago: it's %d", - self.stream_ordering_day_ago - ) - - def _find_first_stream_ordering_after_ts_txn(self, txn, ts): - """ - Find the stream_ordering of the first event that was received after - a given timestamp. This is relatively slow as there is no index on - received_ts but we can then use this to delete push actions before - this. - - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering - """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - - if max_stream_ordering is None: - return 0 - - range_start = 0 - range_end = max_stream_ordering - - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering > ?" - " ORDER BY stream_ordering" - " LIMIT 1" - ) - - while range_end - range_start > 1: - middle = int((range_end + range_start) / 2) - txn.execute(sql, (middle,)) - middle_ts = txn.fetchone()[0] - if ts > middle_ts: - range_start = middle - else: - range_end = middle - - return range_end - @defer.inlineCallbacks def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: @@ -755,50 +918,6 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): (rotate_to_stream_ordering,) ) - def add_push_actions_to_staging(self, event_id, user_id, actions): - """Add the push actions for the user and event to the push - action staging area. - - Args: - event_id (str) - user_id (str) - actions (list[dict|str]): An action can either be a string or - dict. - - Returns: - Deferred - """ - - is_highlight = 1 if _action_has_highlight(actions) else 0 - - return self._simple_insert( - table="event_push_actions_staging", - values={ - "event_id": event_id, - "user_id": user_id, - "actions": _serialize_action(actions, is_highlight), - "notif": 1, - "highlight": is_highlight, - }, - desc="add_push_actions_to_staging", - ) - - def remove_push_actions_from_staging(self, event_id): - """Called if we failed to persist the event to ensure that stale push - actions don't build up in the DB - - Args: - event_id (str) - """ - - return self._simple_delete( - table="event_push_actions_staging", - keyvalues={ - "event_id": event_id, - }, - desc="remove_push_actions_from_staging", - ) - def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 73177e0bc23a..057b1be4d592 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer, reactor +from synapse.storage.events_worker import EventsWorkerStore -from synapse.events import FrozenEvent, USE_FROZEN_DICTS -from synapse.events.utils import prune_event +from twisted.internet import defer + +from synapse.events import USE_FROZEN_DICTS from synapse.util.async import ObservableDeferred from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, make_deferred_yieldable + PreserveLoggingContext, make_deferred_yieldable ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.types import get_domain_from_id from canonicaljson import encode_canonical_json @@ -61,16 +62,6 @@ def encode_json(json_object): return json.dumps(json_object, ensure_ascii=False) -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events - - class _EventPeristenceQueue(object): """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. @@ -199,13 +190,12 @@ def f(self, *args, **kwargs): return f -class EventsStore(SQLBaseStore): +class EventsStore(EventsWorkerStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self._clock = hs.get_clock() self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) @@ -609,62 +599,6 @@ def _calculate_state_delta(self, room_id, current_state): defer.returnValue((to_delete, to_insert)) - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. - - Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. - - Returns: - Deferred : A FrozenEvent. - """ - events = yield self._get_events( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - defer.returnValue(events[0] if events else None) - - @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - """Get events from the database - - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - defer.returnValue({e.event_id: e for e in events}) - @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, delete_existing=False, state_delta_for_room={}, @@ -693,6 +627,8 @@ def _persist_events_txn(self, txn, events_and_contexts, backfilled, list of the event ids which are the forward extremities. """ + all_events_and_contexts = events_and_contexts + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) @@ -755,6 +691,7 @@ def _persist_events_txn(self, txn, events_and_contexts, backfilled, self._update_metadata_tables_txn( txn, events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, backfilled=backfilled, ) @@ -1152,26 +1089,33 @@ def _store_rejected_events_txn(self, txn, events_and_contexts): ec for ec in events_and_contexts if ec[0] not in to_remove ] - def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): + def _update_metadata_tables_txn(self, txn, events_and_contexts, + all_events_and_contexts, backfilled): """Update all the miscellaneous tables for new events Args: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. backfilled (bool): True if the events were backfilled """ + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + if not events_and_contexts: # nothing to do here return for event, context in events_and_contexts: - # Insert all the push actions into the event_push_actions table. - self._set_push_actions_for_event_and_users_txn( - txn, event, - ) - if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. @@ -1375,292 +1319,6 @@ def f(txn): "have_events", f, ) - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - defer.returnValue([]) - - event_id_list = event_ids - event_ids = set(event_ids) - - event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - allow_rejected=allow_rejected, - ) - - event_entry_map.update(missing_events) - - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches - - Args: - events (list(str)): list of event_ids to fetch - allow_rejected (bool): Whether to teturn events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics - - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` - """ - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get( - (event_id,), None, - update_metrics=update_metrics, - ) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - event_list = [] - i = 0 - while True: - try: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - event_id_lists = zip(*event_list)[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] - - rows = self._new_transaction( - conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids - ) - - row_dict = { - r["event_id"]: r - for r in rows - } - - # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) - except Exception: - logger.exception("Failed to callback") - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(e) - - if event_list: - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) - - @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - """ - if not events: - defer.returnValue({}) - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - with PreserveLoggingContext(): - self.runWithConnection( - self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = yield make_deferred_yieldable(defer.gatherResults( - [ - preserve_fn(self._get_event_from_row)( - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - ) - for row in rows - ], - consumeErrors=True - )) - - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) - - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i * N:(i + 1) * N] - if not evs: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " e.internal_metadata," - " e.json," - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - - return rows - - @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - rejected_reason=None): - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - - original_ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - defer.returnValue(cache_entry) - @defer.inlineCallbacks def count_daily_messages(self): """ @@ -2375,7 +2033,7 @@ def is_event_after(self, event_id1, event_id2): to_2, so_2 = yield self._get_event_ordering(event_id2) defer.returnValue((to_1, so_1) > (to_2, so_2)) - @defer.inlineCallbacks + @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): res = yield self._simple_select_one( table="events", diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py new file mode 100644 index 000000000000..86c3b48ad40a --- /dev/null +++ b/synapse/storage/events_worker.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + +from twisted.internet import defer, reactor + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import ( + preserve_fn, PreserveLoggingContext, make_deferred_yieldable +) +from synapse.util.metrics import Measure +from synapse.api.errors import SynapseError + +from collections import namedtuple + +import logging +import ujson as json + +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_event(self, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False, + allow_none=False): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw an exception. + + Returns: + Deferred : A FrozenEvent. + """ + events = yield self._get_events( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + if not events and not allow_none: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue(events[0] if events else None) + + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_id_list = event_ids + event_ids = set(event_ids) + + event_entry_map = self._get_events_from_cache( + event_ids, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) + + event_entry_map.update(missing_events) + + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) + except Exception: + logger.exception("Failed to callback") + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(e) + + if event_list: + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + with PreserveLoggingContext(): + self.runWithConnection( + self._do_fetch + ) + + logger.debug("Loading %d events", len(events)) + with PreserveLoggingContext(): + rows = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield make_deferred_yieldable(defer.gatherResults( + [ + preserve_fn(self._get_event_from_row)( + row["internal_metadata"], row["json"], row["redacts"], + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + )) + + defer.returnValue({ + e.event.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i * N:(i + 1) * N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + rejected_reason=None): + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) + + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) + + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) + + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because + + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, + ) + + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) + + defer.returnValue(cache_entry) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8758b1c0c780..04a0b59a3946 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +15,17 @@ # limitations under the License. from ._base import SQLBaseStore +from synapse.storage.appservice import ApplicationServiceWorkerStore +from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.push.baserules import list_with_base_rules from synapse.api.constants import EventTypes from twisted.internet import defer +import abc import logging import simplejson as json @@ -48,7 +55,43 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): +class PushRulesWorkerStore(ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( @@ -89,6 +132,22 @@ def get_push_rules_enabled_for_user(self, user_id): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -228,6 +287,8 @@ def bulk_get_push_rules_enabled(self, user_ids): results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) + +class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, @@ -526,21 +587,8 @@ def get_push_rules_stream_token(self): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 3d8b4d5d5b3f..307660b99ad3 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) -class PusherStore(SQLBaseStore): +class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: dataJson = r['data'] @@ -102,9 +103,6 @@ def get_pushers(txn): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed(([], [])) @@ -198,6 +196,11 @@ def get_if_users_have_pushers(self, user_ids): defer.returnValue(result) + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, @@ -230,14 +233,18 @@ def add_pusher(self, user_id, access_token, kind, app_id, ) if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - self.get_if_user_has_pusher.invalidate(user_id,) + self.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, (user_id,) + ) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) self._simple_delete_one_txn( txn, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 12b3cc7f5fb5..eac8694e0f98 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +15,13 @@ # limitations under the License. from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer +import abc import logging import ujson as json @@ -26,39 +29,36 @@ logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + def __init__(self, db_conn, hs): - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False, - ) - - if res: - if isinstance(res, defer.Deferred) and res.called: - res = res.result - if user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -270,6 +270,59 @@ def f(txn): } defer.returnValue(results) + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + if res: + if isinstance(res, defer.Deferred) and res.called: + res = res.result + if user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -457,25 +510,3 @@ def insert_graph_receipt_txn(self, txn, room_id, receipt_type, "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 95f75d6df1cb..d809b2ba4670 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -19,10 +19,70 @@ from synapse.api.errors import StoreError, Codes from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(background_updates.BackgroundUpdateStore): +class RegistrationWorkerStore(SQLBaseStore): + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=["name", "password_hash", "is_guest"], + allow_none=True, + desc="get_user_by_id", + ) + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. + """ + return self.runInteraction( + "get_user_by_access_token", + self._query_for_auth, + token + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + defer.returnValue(res if res else False) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + +class RegistrationStore(RegistrationWorkerStore, + background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): super(RegistrationStore, self).__init__(db_conn, hs) @@ -187,18 +247,6 @@ def _register( ) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={ - "name": user_id, - }, - retcols=["name", "password_hash", "is_guest"], - allow_none=True, - desc="get_user_by_id", - ) - def get_users_by_id_case_insensitive(self, user_id): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. @@ -304,34 +352,6 @@ def f(txn): return self.runInteraction("delete_access_token", f) - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. - """ - return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - defer.returnValue(res if res else False) - @cachedInlineCallbacks() def is_guest(self, user_id): res = yield self._simple_select_one_onecol( @@ -344,22 +364,6 @@ def is_guest(self, user_id): defer.returnValue(res if res else False) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", { diff --git a/synapse/storage/room.py b/synapse/storage/room.py index fff6652e0534..7f2c08d7a6b8 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore from synapse.storage.search import SearchStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -38,7 +39,126 @@ ) -class RoomStore(SearchStore): +class RoomWorkerStore(SQLBaseStore): + def get_public_room_ids(self): + return self._simple_select_onecol( + table="rooms", + keyvalues={ + "is_public": True, + }, + retcol="room_id", + desc="get_public_room_ids", + ) + + @cached(num_args=2, max_entries=100) + def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): + """Get pulbic rooms for a particular list, or across all lists. + + Args: + stream_id (int) + network_tuple (ThirdPartyInstanceID): The list to use (None, None) + means the main list, None means all lsits. + """ + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, + stream_id, network_tuple=network_tuple + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, + network_tuple): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn( + txn, stream_id, network_tuple=network_tuple + ).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): + if network_tuple: + # We want to get from a particular list. No aggregation required. + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? %s + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + if network_tuple.appservice_id is not None: + txn.execute( + sql % ("AND appservice_id = ? AND network_id = ?",), + (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + ) + else: + txn.execute( + sql % ("AND appservice_id IS NULL",), + (stream_id,) + ) + return dict(txn) + else: + # We want to get from all lists, so we need to aggregate the results + + logger.info("Executing full list") + + sql = (""" + SELECT room_id, visibility + FROM public_room_list_stream + INNER JOIN ( + SELECT + room_id, max(stream_id) AS stream_id, appservice_id, + network_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id, appservice_id, network_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute( + sql, + (stream_id,) + ) + + results = {} + # A room is visible if its visible on any list. + for room_id, visibility in txn: + results[room_id] = bool(visibility) or results.get(room_id, False) + + return results + + def get_public_room_changes(self, prev_stream_id, new_stream_id, + network_tuple): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn( + txn, prev_stream_id, network_tuple + ) + + now_rooms_dict = self.get_published_at_stream_id_txn( + txn, new_stream_id, network_tuple + ) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis + ) + + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms + + return newly_visible, newly_unpublished + + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) + + +class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): @@ -225,16 +345,6 @@ def set_room_is_public_appservice_txn(txn, next_id): ) self.hs.get_notifier().on_new_replication_data() - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={ - "is_public": True, - }, - retcol="room_id", - desc="get_public_room_ids", - ) - def get_room_count(self): """Retrieve a list of all rooms """ @@ -326,113 +436,6 @@ def add_event_report(self, room_id, event_id, user_id, reason, content, def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. - - Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. - """ - return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple - ) - - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } - - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. - - sql = (""" - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """) - - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) - ) - else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) - return dict(txn) - else: - # We want to get from all lists, so we need to aggregate the results - - logger.info("Executing full list") - - sql = (""" - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """) - - txn.execute( - sql, - (stream_id,) - ) - - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn: - results[room_id] = bool(visibility) or results.get(room_id, False) - - return results - - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) - - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) - - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis - ) - - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms - - return newly_visible, newly_unpublished - - return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn - ) - def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(txn): sql = (""" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3e77fd3901df..d79877dac729 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +18,7 @@ from collections import namedtuple -from ._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -48,97 +49,7 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -class RoomMemberStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ] - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. - # The only current event that can also be an outlier is if its an - # invite that has come in across federation. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_invite_from_remote() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - } - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - +class RoomMemberWorkerStore(EventsWorkerStore): @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) def get_hosts_in_room(self, room_id, cache_context): """Returns the set of all hosts currently in the room @@ -295,89 +206,6 @@ def get_users_who_share_room_with_user(self, user_id, cache_context): defer.returnValue(user_who_share_room) - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - txn.call_after(self.was_forgotten_at.invalidate_all) - txn.call_after(self.did_forget.invalidate, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.who_forgot_in_room, (room_id,) - ) - return self.runInteraction("forget_membership", f) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) - defer.returnValue(count == 0) - - @cachedInlineCallbacks(num_args=3) - def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at - event_id. - - event_id must be a membership event.""" - def f(txn): - sql = ( - "SELECT" - " forgotten" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " event_id = ?" - ) - txn.execute(sql, (user_id, room_id, event_id)) - rows = txn.fetchall() - return rows[0][0] - forgot = yield self.runInteraction("did_forget_membership_at", f) - defer.returnValue(forgot == 1) - - @cached() - def who_forgot_in_room(self, room_id): - return self._simple_select_list( - table="room_memberships", - retcols=("user_id", "event_id"), - keyvalues={ - "room_id": room_id, - "forgotten": 1, - }, - desc="who_forgot" - ) - def get_joined_users_from_context(self, event, context): state_group = context.state_group if not state_group: @@ -600,6 +428,185 @@ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry defer.returnValue(joined_hosts) + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cached() + def who_forgot_in_room(self, room_id): + return self._simple_select_list( + table="room_memberships", + retcols=("user_id", "event_id"), + keyvalues={ + "room_id": room_id, + "forgotten": 1, + }, + desc="who_forgot" + ) + + +class RoomMemberStore(RoomMemberWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ] + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, event.internal_metadata.stream_ordering + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + txn.call_after(self.was_forgotten_at.invalidate_all) + txn.call_after(self.did_forget.invalidate, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.who_forgot_in_room, (room_id,) + ) + return self.runInteraction("forget_membership", f) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) + defer.returnValue(count == 0) + + @cachedInlineCallbacks(num_args=3) + def was_forgotten_at(self, user_id, room_id, event_id): + """Returns whether user_id has elected to discard history for room_id at + event_id. + + event_id must be a membership event.""" + def f(txn): + sql = ( + "SELECT" + " forgotten" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " event_id = ?" + ) + txn.execute(sql, (user_id, room_id, event_id)) + rows = txn.fetchall() + return rows[0][0] + forgot = yield self.runInteraction("did_forget_membership_at", f) + defer.returnValue(forgot == 1) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( @@ -675,10 +682,6 @@ def add_membership_profile_txn(txn): defer.returnValue(result) - @cached(max_entries=10000, iterable=True) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - class _JoinedHostsCache(object): """Cache for joined hosts in a room that is optimised to handle updates diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 67d5d9969a24..9e6eaaa532e3 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -22,12 +22,12 @@ from synapse.util.caches.descriptors import cached, cachedList -class SignatureStore(SQLBaseStore): - """Persistence for event signatures and hashes""" - +class SignatureWorkerStore(SQLBaseStore): @cached() def get_event_reference_hash(self, event_id): - return self._get_event_reference_hashes_txn(event_id) + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() @cachedList(cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1) @@ -74,6 +74,10 @@ def _get_event_reference_hashes_txn(self, txn, event_id): txn.execute(query, (event_id, )) return {k: v for k, v in txn} + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" + def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU Args: diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 52bdce5be254..2956c3b3e0df 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,13 +35,16 @@ from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore + from synapse.util.caches.descriptors import cached -from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.storage.engines import PostgresEngine, Sqlite3Engine +import abc import logging @@ -143,81 +146,41 @@ def filter_to_clause(event_filter): return " AND ".join(clauses), args -class StreamStore(SQLBaseStore): - @defer.inlineCallbacks - def get_appservice_room_stream(self, service, from_key, to_key, limit=0): - # NB this lives here instead of appservice.py so we can reuse the - # 'private' StreamToken class in this file. - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - defer.returnValue(([], to_key)) - return - - # select all the events between from/to with a sensible limit - sql = ( - "SELECT e.event_id, e.room_id, e.type, s.state_key, " - "e.stream_ordering FROM events AS e " - "LEFT JOIN state_events as s ON " - "e.event_id = s.event_id " - "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "limit": limit - } - - def f(txn): - # pull out all the events between the tokens - txn.execute(sql, (from_id.stream, to_id.stream,)) - rows = self.cursor_to_dict(txn) - - # Logic: - # - We want ALL events which match the AS room_id regex - # - We want ALL events which match the rooms represented by the AS - # room_alias regex - # - We want ALL events for rooms that AS users have joined. - # This is currently supported via get_app_service_rooms (which is - # used for the Notifier listener rooms). We can't reasonably make a - # SQL query for these room IDs, so we'll pull all the events between - # from/to and filter in python. - rooms_for_as = self._get_app_service_rooms_txn(txn, service) - room_ids_for_as = [r.room_id for r in rooms_for_as] - - def app_service_interested(row): - if row["room_id"] in room_ids_for_as: - return True - - if row["type"] == EventTypes.Member: - if service.is_interested_in_user(row.get("state_key")): - return True - return False +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_room_max_stream_ordering` and `get_room_min_stream_ordering` + which can be called in the initializer. + """ - return [r for r in rows if app_service_interested(r)] + __metaclass__ = abc.ABCMeta - rows = yield self.runInteraction("get_appservice_room_stream", f) + def __init__(self, db_conn, hs): + super(StreamWorkerStore, self).__init__(db_conn, hs) - ret = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._stream_order_on_start = self.get_room_max_stream_ordering() - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + @abc.abstractmethod + def get_room_max_stream_ordering(self): + raise NotImplementedError() - defer.returnValue((ret, key)) + @abc.abstractmethod + def get_room_min_stream_ordering(self): + raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, @@ -380,88 +343,6 @@ def f(txn): defer.returnValue(ret) - @defer.inlineCallbacks - def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, event_filter=None): - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == 'b': - order = "DESC" - bounds = upper_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, lower_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - else: - order = "ASC" - bounds = lower_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, upper_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - if int(limit) > 0: - args.append(int(limit)) - limit_str = " LIMIT ?" - else: - limit_str = "" - - sql = ( - "SELECT * FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s %(limit)s" - ) % { - "bounds": bounds, - "order": order, - "limit": limit_str - } - - def f(txn): - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - if rows: - topo = rows[-1]["topological_ordering"] - toke = rows[-1]["stream_ordering"] - if direction == 'b': - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = str(RoomStreamToken(topo, toke)) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_key if to_key else from_key - - return rows, next_token, - - rows, token = yield self.runInteraction("paginate_room_events", f) - - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - defer.returnValue((events, token)) - @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): rows, token = yield self.get_recent_event_ids_for_room( @@ -534,6 +415,33 @@ def get_recent_events_for_room_txn(txn): "get_recent_events_for_room", get_recent_events_for_room_txn ) + def get_room_event_after_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or after a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering, )) + return txn.fetchone() + + return self.runInteraction( + "get_room_event_after_stream_ordering", _f, + ) + @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. @@ -542,7 +450,7 @@ def get_room_events_max_id(self, room_id=None): `room_id` causes it to return the current room specific topological token. """ - token = yield self._stream_id_gen.get_current_token() + token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: @@ -552,12 +460,6 @@ def get_room_events_max_id(self, room_id=None): ) defer.returnValue("t%d-%d" % (topo, token)) - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() - def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -832,3 +734,93 @@ def update_federation_out_pos(self, typ, stream_id): def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() + + @defer.inlineCallbacks + def paginate_room_events(self, room_id, from_key, to_key=None, + direction='b', limit=-1, event_filter=None): + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == 'b': + order = "DESC" + bounds = upper_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, lower_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + else: + order = "ASC" + bounds = lower_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, upper_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + if int(limit) > 0: + args.append(int(limit)) + limit_str = " LIMIT ?" + else: + limit_str = "" + + sql = ( + "SELECT * FROM events" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s %(limit)s" + ) % { + "bounds": bounds, + "order": order, + "limit": limit_str + } + + def f(txn): + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + + if rows: + topo = rows[-1]["topological_ordering"] + toke = rows[-1]["stream_ordering"] + if direction == 'b': + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = str(RoomStreamToken(topo, toke)) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_key if to_key else from_key + + return rows, next_token, + + rows, token = yield self.runInteraction("paginate_room_events", f) + + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index bff73f3f0414..fc46bf7bb33f 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage.account_data import AccountDataWorkerStore + from synapse.util.caches.descriptors import cached from twisted.internet import defer @@ -23,15 +25,7 @@ logger = logging.getLogger(__name__) -class TagsStore(SQLBaseStore): - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() - +class TagsWorkerStore(AccountDataWorkerStore): @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -170,6 +164,8 @@ def get_tags_for_room(self, user_id, room_id): row["tag"]: json.loads(row["content"]) for row in rows }) + +class TagsStore(TagsWorkerStore): @defer.inlineCallbacks def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 94fa7cac9805..a8dea15c1b4a 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -299,10 +299,6 @@ def preserve_fn(f): Useful for wrapping functions that return a deferred which you don't yield on. """ - def reset_context(result): - LoggingContext.set_current_context(LoggingContext.sentinel) - return result - def g(*args, **kwargs): current = LoggingContext.current_context() res = f(*args, **kwargs) @@ -323,12 +319,11 @@ def g(*args, **kwargs): # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - res.addBoth(reset_context) + res.addBoth(_set_context_cb, LoggingContext.sentinel) return res return g -@defer.inlineCallbacks def make_deferred_yieldable(deferred): """Given a deferred, make it follow the Synapse logcontext rules: @@ -342,9 +337,16 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to preserve_fn.) """ - with PreserveLoggingContext(): - r = yield deferred - defer.returnValue(r) + if isinstance(deferred, defer.Deferred) and not deferred.called: + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred + + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result # modules to ignore in `logcontext_tracer` diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 4780f2ab7202..cb058d314224 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -230,10 +230,12 @@ def persist( state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) - for user_id, actions in push_actions: - yield self.master_store.add_push_actions_to_staging( - event.event_id, user_id, actions, - ) + yield self.master_store.add_push_actions_to_staging( + event.event_id, { + user_id: actions + for user_id, actions in push_actions + }, + ) ordering = None if backfill: diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index a269e6f56e14..e46534cd3506 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -95,7 +95,7 @@ def fetch_room_distributions_into( else: if remotedomains is not None: remotedomains.add(member.domain) - hs.get_handlers().room_member_handler.fetch_room_distributions_into = ( + hs.get_room_member_handler().fetch_room_distributions_into = ( fetch_room_distributions_into ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index d483e7cf9e76..575374c6a6ae 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -71,11 +71,11 @@ def _inject_actions(stream, action): event.depth = stream yield self.store.add_push_actions_to_staging( - event.event_id, user_id, action, + event.event_id, {user_id: action}, ) yield self.store.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, - event, + [(event, None)], [(event, None)], ) def _rotate(stream): @@ -127,3 +127,70 @@ def _mark_read(stream, depth): yield _assert_counts(1, 1) yield _rotate(10) yield _assert_counts(1, 1) + + @tests.unittest.DEBUG + @defer.inlineCallbacks + def test_find_first_stream_ordering_after_ts(self): + def add_event(so, ts): + return self.store._simple_insert("events", { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }) + + # start with the base case where there are no events in the table + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 0) + + # now with one event + yield add_event(2, 10) + r = yield self.store.find_first_stream_ordering_after_ts(9) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(10) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 3) + + # add a bunch of dummy events to the events table + for (stream_ordering, ts) in ( + (3, 110), + (4, 120), + (5, 120), + (10, 130), + (20, 140), + ): + yield add_event(stream_ordering, ts) + + r = yield self.store.find_first_stream_ordering_after_ts(110) + self.assertEqual(r, 3, + "First event after 110ms should be 3, was %i" % r) + + # 4 and 5 are both after 120: we want 4 rather than 5 + r = yield self.store.find_first_stream_ordering_after_ts(120) + self.assertEqual(r, 4, + "First event after 120ms should be 4, was %i" % r) + + r = yield self.store.find_first_stream_ordering_after_ts(129) + self.assertEqual(r, 10, + "First event after 129ms should be 10, was %i" % r) + + # check we can get the last event + r = yield self.store.find_first_stream_ordering_after_ts(140) + self.assertEqual(r, 20, + "First event after 14ms should be 20, was %i" % r) + + # off the end + r = yield self.store.find_first_stream_ordering_after_ts(160) + self.assertEqual(r, 21) + + # check we can find an event at ordering zero + yield add_event(0, 5) + r = yield self.store.find_first_stream_ordering_after_ts(1) + self.assertEqual(r, 0)