diff --git a/demo/start.sh b/demo/start.sh index 8b0cc84fe6d3..886d21cfa866 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -32,7 +32,7 @@ for port in 8080 8081 8082; do -D --pid-file "$DIR/$port.pid" \ --manhole $((port + 1000)) \ --tls-dh-params-path "demo/demo.tls.dh" \ - $PARAMS + $PARAMS $SYNAPSE_PARAMS python -m synapse.app.homeserver \ --config-path "demo/etc/$port.config" \ diff --git a/docs/server-server/signing.rst b/docs/server-server/signing.rst index dae10f121b5f..60c701ca9100 100644 --- a/docs/server-server/signing.rst +++ b/docs/server-server/signing.rst @@ -1,13 +1,13 @@ Signing JSON ============ -JSON is signed by encoding the JSON object without ``signatures`` or ``meta`` +JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned`` keys using a canonical encoding. The JSON bytes are then signed using the signature algorithm and the signature encoded using base64 with the padding stripped. The resulting base64 signature is added to an object under the *signing key identifier* which is added to the ``signatures`` object under the name of the server signing it which is added back to the original JSON object -along with the ``meta`` object. +along with the ``unsigned`` object. The *signing key identifier* is the concatenation of the *signing algorithm* and a *key version*. The *signing algorithm* identifies the algorithm used to @@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is ``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version* is used to distinguish between different signing keys used by the same entity. -The ``meta`` object and the ``signatures`` object are not covered by the -signature. Therefore intermediate servers can add metadata such as time stamps +The ``unsigned`` object and the ``signatures`` object are not covered by the +signature. Therefore intermediate servers can add unsigneddata such as time stamps and additional signatures. @@ -27,7 +27,7 @@ and additional signatures. "signing_keys": { "ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ" }, - "meta": { + "unsigned": { "retrieved_ts_ms": 922834800000 }, "signatures": { @@ -41,7 +41,7 @@ and additional signatures. def sign_json(json_object, signing_key, signing_name): signatures = json_object.pop("signatures", {}) - meta = json_object.pop("meta", None) + unsigned = json_object.pop("unsigned", None) signed = signing_key.sign(encode_canonical_json(json_object)) signature_base64 = encode_base64(signed.signature) @@ -50,8 +50,8 @@ and additional signatures. signatures.setdefault(sigature_name, {})[key_id] = signature_base64 json_object["signatures"] = signatures - if meta is not None: - json_object["meta"] = meta + if unsigned is not None: + json_object["unsigned"] = unsigned return json_object diff --git a/scripts/check_event_hash.py b/scripts/check_event_hash.py new file mode 100644 index 000000000000..7c32f8102a17 --- /dev/null +++ b/scripts/check_event_hash.py @@ -0,0 +1,47 @@ +from synapse.crypto.event_signing import * +from syutil.base64util import encode_base64 + +import argparse +import hashlib +import sys +import json + + +class dictobj(dict): + def __init__(self, *args, **kargs): + dict.__init__(self, *args, **kargs) + self.__dict__ = self + + def get_dict(self): + return dict(self) + + def get_full_dict(self): + return dict(self) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), + default=sys.stdin) + args = parser.parse_args() + logging.basicConfig() + + event_json = dictobj(json.load(args.input_json)) + + algorithms = { + "sha256": hashlib.sha256, + } + + for alg_name in event_json.hashes: + if check_event_content_hash(event_json, algorithms[alg_name]): + print "PASS content hash %s" % (alg_name,) + else: + print "FAIL content hash %s" % (alg_name,) + + for algorithm in algorithms.values(): + name, h_bytes = compute_event_reference_hash(event_json, algorithm) + print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) + +if __name__=="__main__": + main() + diff --git a/scripts/check_signature.py b/scripts/check_signature.py new file mode 100644 index 000000000000..e146e18e2452 --- /dev/null +++ b/scripts/check_signature.py @@ -0,0 +1,73 @@ + +from syutil.crypto.jsonsign import verify_signed_json +from syutil.crypto.signing_key import ( + decode_verify_key_bytes, write_signing_keys +) +from syutil.base64util import decode_base64 + +import urllib2 +import json +import sys +import dns.resolver +import pprint +import argparse +import logging + +def get_targets(server_name): + if ":" in server_name: + target, port = server_name.split(":") + yield (target, int(port)) + return + try: + answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV") + for srv in answers: + yield (srv.target, srv.port) + except dns.resolver.NXDOMAIN: + yield (server_name, 8480) + +def get_server_keys(server_name, target, port): + url = "https://%s:%i/_matrix/key/v1" % (target, port) + keys = json.load(urllib2.urlopen(url)) + verify_keys = {} + for key_id, key_base64 in keys["verify_keys"].items(): + verify_key = decode_verify_key_bytes(key_id, decode_base64(key_base64)) + verify_signed_json(keys, server_name, verify_key) + verify_keys[key_id] = verify_key + return verify_keys + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("signature_name") + parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), + default=sys.stdin) + + args = parser.parse_args() + logging.basicConfig() + + server_name = args.signature_name + keys = {} + for target, port in get_targets(server_name): + try: + keys = get_server_keys(server_name, target, port) + print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) + write_signing_keys(sys.stdout, keys.values()) + break + except: + logging.exception("Error talking to %s:%s", target, port) + + json_to_check = json.load(args.input_json) + print "Checking JSON:" + for key_id in json_to_check["signatures"][args.signature_name]: + try: + key = keys[key_id] + verify_signed_json(json_to_check, args.signature_name, key) + print "PASS %s" % (key_id,) + except: + logging.exception("Check for key %s failed" % (key_id,)) + print "FAIL %s" % (key_id,) + + +if __name__ == '__main__': + main() + diff --git a/scripts/hash_history.py b/scripts/hash_history.py new file mode 100644 index 000000000000..bdad530af898 --- /dev/null +++ b/scripts/hash_history.py @@ -0,0 +1,69 @@ +from synapse.storage.pdu import PduStore +from synapse.storage.signatures import SignatureStore +from synapse.storage._base import SQLBaseStore +from synapse.federation.units import Pdu +from synapse.crypto.event_signing import ( + add_event_pdu_content_hash, compute_pdu_event_reference_hash +) +from synapse.api.events.utils import prune_pdu +from syutil.base64util import encode_base64, decode_base64 +from syutil.jsonutil import encode_canonical_json +import sqlite3 +import sys + +class Store(object): + _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] + _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] + _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] + _get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] + _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] + _store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] + _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] + _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + + +store = Store() + + +def select_pdus(cursor): + cursor.execute( + "SELECT pdu_id, origin FROM pdus ORDER BY depth ASC" + ) + + ids = cursor.fetchall() + + pdu_tuples = store._get_pdu_tuples(cursor, ids) + + pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples] + + reference_hashes = {} + + for pdu in pdus: + try: + if pdu.prev_pdus: + print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + for pdu_id, origin, hashes in pdu.prev_pdus: + ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] + hashes[ref_alg] = encode_base64(ref_hsh) + store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) + print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + pdu = add_event_pdu_content_hash(pdu) + ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) + reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) + store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) + + for alg, hsh_base64 in pdu.hashes.items(): + print alg, hsh_base64 + store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) + + except: + print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus + +def main(): + conn = sqlite3.connect(sys.argv[1]) + cursor = conn.cursor() + select_pdus(cursor) + conn.commit() + +if __name__=='__main__': + main() diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e1b1823cd7ea..6c2d3db26ea7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,8 +21,10 @@ from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, + RoomJoinRulesEvent, RoomCreateEvent, ) from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 import logging @@ -35,8 +37,7 @@ def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() - @defer.inlineCallbacks - def check(self, event, snapshot, raises=False): + def check(self, event, raises=False): """ Checks if this event is correctly authed. Returns: @@ -47,43 +48,51 @@ def check(self, event, snapshot, raises=False): """ try: if hasattr(event, "room_id"): - is_state = hasattr(event, "state_key") + if event.old_state_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if hasattr(event, "outlier") and event.outlier is True: + # TODO (erikj): Auth for outliers is done differently. + return True + + if event.type == RoomCreateEvent.TYPE: + # FIXME + return True if event.type == RoomMemberEvent.TYPE: - yield self._can_replace_state(event) - allowed = yield self.is_membership_change_allowed(event) - defer.returnValue(allowed) - return - - self._check_joined_room( - member=snapshot.membership_state, - user_id=snapshot.user_id, - room_id=snapshot.room_id, - ) + allowed = self.is_membership_change_allowed(event) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed - if is_state: - # TODO (erikj): This really only should be called for *new* - # state - yield self._can_add_state(event) - yield self._can_replace_state(event) - else: - yield self._can_send_event(event) + self.check_event_sender_in_room(event) + self._can_send_event(event) if event.type == RoomPowerLevelsEvent.TYPE: - yield self._check_power_levels(event) + self._check_power_levels(event) if event.type == RoomRedactionEvent.TYPE: - yield self._check_redaction(event) + self._check_redaction(event) - defer.returnValue(True) + logger.debug("Allowing! %s", event) + return True else: raise AuthError(500, "Unknown event: %s" % event) except AuthError as e: - logger.info("Event auth check failed on event %s with msg: %s", - event, e.msg) + logger.info( + "Event auth check failed on event %s with msg: %s", + event, e.msg + ) + logger.info("Denying! %s", event) if raises: raise e - defer.returnValue(False) + + return False @defer.inlineCallbacks def check_joined_room(self, room_id, user_id): @@ -98,45 +107,80 @@ def check_joined_room(self, room_id, user_id): pass defer.returnValue(None) + @defer.inlineCallbacks + def check_host_in_room(self, room_id, host): + joined_hosts = yield self.store.get_joined_hosts_for_room(room_id) + + defer.returnValue(host in joined_hosts) + + def check_event_sender_in_room(self, event): + key = (RoomMemberEvent.TYPE, event.user_id, ) + member_event = event.state_events.get(key) + + return self._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( user_id, room_id, repr(member) )) - @defer.inlineCallbacks + @log_function def is_membership_change_allowed(self, event): target_user_id = event.state_key - # does this room even exist - room = yield self.store.get_room(event.room_id) - if not room: - raise AuthError(403, "Room does not exist") - # get info about the caller - try: - caller = yield self.store.get_room_member( - user_id=event.user_id, - room_id=event.room_id) - except: - caller = None - caller_in_room = caller and caller.membership == "join" + key = (RoomMemberEvent.TYPE, event.user_id, ) + caller = event.old_state_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE # get info about the target - try: - target = yield self.store.get_room_member( - user_id=target_user_id, - room_id=event.room_id) - except: - target = None - target_in_room = target and target.membership == "join" + key = (RoomMemberEvent.TYPE, target_user_id, ) + target = event.old_state_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN membership = event.content["membership"] - join_rule = yield self.store.get_room_join_rule(event.room_id) - if not join_rule: + key = (RoomJoinRulesEvent.TYPE, "", ) + join_rule_event = event.old_state_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: join_rule = JoinRules.INVITE + user_level = self._get_power_level_from_event_state( + event, + event.user_id, + ) + + ban_level, kick_level, redact_level = ( + self._get_ops_level_from_event_state( + event + ) + ) + + logger.debug( + "is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently # PRIVATE join rules. @@ -153,13 +197,10 @@ def is_membership_change_allowed(self, event): # joined: It's a NOOP if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") - elif join_rule == JoinRules.PUBLIC or room.is_public: + elif join_rule == JoinRules.PUBLIC: pass elif join_rule == JoinRules.INVITE: - if ( - not caller or caller.membership not in - [Membership.INVITE, Membership.JOIN] - ): + if not caller_in_room and not caller_invited: raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list @@ -171,29 +212,16 @@ def is_membership_change_allowed(self, event): if not caller_in_room: # trying to leave a room you aren't joined raise AuthError(403, "You are not in room %s." % event.room_id) elif target_user_id != event.user_id: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - _, kick_level, _ = yield self.store.get_ops_levels(event.room_id) - if kick_level: kick_level = int(kick_level) else: - kick_level = 50 + kick_level = 50 # FIXME (erikj): What should we do here? if user_level < kick_level: raise AuthError( 403, "You cannot kick user %s." % target_user_id ) elif Membership.BAN == membership: - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - - ban_level, _, _ = yield self.store.get_ops_levels(event.room_id) - if ban_level: ban_level = int(ban_level) else: @@ -204,7 +232,30 @@ def is_membership_change_allowed(self, event): else: raise AuthError(500, "Unknown membership %s" % membership) - defer.returnValue(True) + return True + + def _get_power_level_from_event_state(self, event, user_id): + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + return level + + def _get_ops_level_from_event_state(self, event): + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + + if power_level_event: + return ( + power_level_event.content.get("ban", 50), + power_level_event.content.get("kick", 50), + power_level_event.content.get("redact", 50), + ) + return None, None, None, @defer.inlineCallbacks def get_user_by_req(self, request): @@ -229,7 +280,7 @@ def get_user_by_req(self, request): default=[""] )[0] if user and access_token and ip_addr: - self.store.insert_client_ip( + yield self.store.insert_client_ip( user=user, access_token=access_token, device_id=user_info["device_id"], @@ -273,68 +324,81 @@ def is_server_admin(self, user): return self.store.is_server_admin(user) @defer.inlineCallbacks - @log_function - def _can_send_event(self, event): - send_level = yield self.store.get_send_event_level(event.room_id) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - - if user_level: - user_level = int(user_level) - else: - user_level = 0 + def add_auth_events(self, event): + if event.type == RoomCreateEvent.TYPE: + event.auth_events = [] + return - if user_level < send_level: - raise AuthError( - 403, "You don't have permission to post to the room" - ) + auth_events = [] - defer.returnValue(True) + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) - @defer.inlineCallbacks - def _can_add_state(self, event): - add_level = yield self.store.get_add_state_level(event.room_id) + if power_level_event: + auth_events.append(power_level_event.event_id) - if not add_level: - defer.returnValue(True) + key = (RoomJoinRulesEvent.TYPE, "", ) + join_rule_event = event.old_state_events.get(key) - add_level = int(add_level) + key = (RoomMemberEvent.TYPE, event.user_id, ) + member_event = event.old_state_events.get(key) - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, + if join_rule_event: + join_rule = join_rule_event.content.get("join_rule") + is_public = join_rule == JoinRules.PUBLIC if join_rule else False + else: + is_public = False + + if event.type == RoomMemberEvent.TYPE: + e_type = event.content["membership"] + if e_type in [Membership.JOIN, Membership.INVITE]: + if join_rule_event: + auth_events.append(join_rule_event.event_id) + + if member_event and not is_public: + auth_events.append(member_event.event_id) + elif member_event: + if member_event.content["membership"] == Membership.JOIN: + auth_events.append(member_event.event_id) + + hashes = yield self.store.get_event_reference_hashes( + auth_events ) + hashes = [ + { + k: encode_base64(v) for k, v in h.items() + if k == "sha256" + } + for h in hashes + ] + event.auth_events = zip(auth_events, hashes) - user_level = int(user_level) - - if user_level < add_level: - raise AuthError( - 403, "You don't have permission to add state to the room" + @log_function + def _can_send_event(self, event): + key = (RoomPowerLevelsEvent.TYPE, "", ) + send_level_event = event.old_state_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + event.type ) + if not send_level: + if hasattr(event, "state_key"): + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) - defer.returnValue(True) - - @defer.inlineCallbacks - def _can_replace_state(self, event): - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) - - if current_state: - current_state = current_state[0] + if send_level: + send_level = int(send_level) + else: + send_level = 0 - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) @@ -343,35 +407,24 @@ def _can_replace_state(self, event): else: user_level = 0 - logger.debug( - "Checking power level for %s, %s", event.user_id, user_level - ) - if current_state and hasattr(current_state, "required_power_level"): - req = current_state.required_power_level + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) - logger.debug("Checked power level for %s, %s", event.user_id, req) - if user_level < req: - raise AuthError( - 403, - "You don't have permission to change that state" - ) + return True - @defer.inlineCallbacks def _check_redaction(self, event): - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) - if user_level: - user_level = int(user_level) - else: - user_level = 0 - - _, _, redact_level = yield self.store.get_ops_levels(event.room_id) - - if not redact_level: - redact_level = 50 + _, _, redact_level = self._get_ops_level_from_event_state( + event + ) if user_level < redact_level: raise AuthError( @@ -379,16 +432,10 @@ def _check_redaction(self, event): "You don't have permission to redact events" ) - @defer.inlineCallbacks def _check_power_levels(self, event): - for k, v in event.content.items(): - if k == "default": - continue - - # FIXME (erikj): We don't want hsob_Ts in content. - if k == "hsob_ts": - continue - + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): try: self.hs.parse_userid(k) except: @@ -399,80 +446,68 @@ def _check_power_levels(self, event): except: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - current_state = yield self.store.get_current_state( - event.room_id, - event.type, - event.state_key, - ) + key = (event.type, event.state_key, ) + current_state = event.old_state_events.get(key) if not current_state: return - else: - current_state = current_state[0] - user_level = yield self.store.get_power_level( - event.room_id, + user_level = self._get_power_level_from_event_state( + event, event.user_id, ) - if user_level: - user_level = int(user_level) - else: - user_level = 0 + # Check other levels: + levels_to_check = [ + ("users_default", []), + ("events_default", []), + ("ban", []), + ("redact", []), + ("kick", []), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, ["users"]) + ) - old_list = current_state.content + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, ["events"]) + ) - # FIXME (erikj) - old_people = {k: v for k, v in old_list.items() if k.startswith("@")} - new_people = { - k: v for k, v in event.content.items() - if k.startswith("@") - } + old_state = current_state.content + new_state = event.content - removed = set(old_people.keys()) - set(new_people.keys()) - added = set(new_people.keys()) - set(old_people.keys()) - same = set(old_people.keys()) & set(new_people.keys()) + for level_to_check, dir in levels_to_check: + old_loc = old_state + for d in dir: + old_loc = old_loc.get(d, {}) - for r in removed: - if int(old_list[r]) > user_level: - raise AuthError( - 403, - "You don't have permission to remove user: %s" % (r, ) - ) + new_loc = new_state + for d in dir: + new_loc = new_loc.get(d, {}) - for n in added: - if int(event.content[n]) > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None - for s in same: - if int(event.content[s]) != int(old_list[s]): - if int(event.content[s]) > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None - if "default" in old_list: - old_default = int(old_list["default"]) + if new_level is not None and old_level is not None: + if new_level == old_level: + continue - if old_default > user_level: + if old_level > user_level or new_level > user_level: raise AuthError( 403, - "You don't have permission to add ops level greater than " - "your own" + "You don't have permission to add ops level greater " + "than your own" ) - - if "default" in event.content: - new_default = int(event.content["default"]) - - if new_default > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 38ccb4f9d197..33d15072af19 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -158,3 +158,37 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): for key, value in kwargs.iteritems(): err[key] = value return err + + +class FederationError(RuntimeError): + """ This class is used to inform remote home servers about erroneous + PDUs they sent us. + + FATAL: The remote server could not interpret the source event. + (e.g., it was missing a required field) + ERROR: The remote server interpreted the event, but it failed some other + check (e.g. auth) + WARN: The remote server accepted the event, but believes some part of it + is wrong (e.g., it referred to an invalid event) + """ + + def __init__(self, level, code, reason, affected, source=None): + if level not in ["FATAL", "ERROR", "WARN"]: + raise ValueError("Level is not valid: %s" % (level,)) + self.level = level + self.code = code + self.reason = reason + self.affected = affected + self.source = source + + msg = "%s %s: %s" % (level, code, reason,) + super(FederationError, self).__init__(msg) + + def get_dict(self): + return { + "level": self.level, + "code": self.code, + "reason": self.reason, + "affected": self.affected, + "source": self.source if self.source else self.affected, + } diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index f66fea2904c1..1d8bed2906fd 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError, Codes from synapse.util.jsonobject import JsonEncodedObject @@ -56,22 +55,26 @@ class SynapseEvent(JsonEncodedObject): "user_id", # sender/initiator "content", # HTTP body, JSON "state_key", - "required_power_level", "age_ts", "prev_content", - "prev_state", + "replaces_state", "redacted_because", + "origin_server_ts", ] internal_keys = [ "is_state", - "prev_events", "depth", "destinations", "origin", "outlier", - "power_level", "redacted", + "prev_events", + "hashes", + "signatures", + "prev_state", + "auth_events", + "state_hash", ] required_keys = [ @@ -82,8 +85,8 @@ class SynapseEvent(JsonEncodedObject): def __init__(self, raises=True, **kwargs): super(SynapseEvent, self).__init__(**kwargs) - if "content" in kwargs: - self.check_json(self.content, raises=raises) + # if "content" in kwargs: + # self.check_json(self.content, raises=raises) def get_content_template(self): """ Retrieve the JSON template for this event as a dict. @@ -114,66 +117,6 @@ def get_content_template(self): """ raise NotImplementedError("get_content_template not implemented.") - def check_json(self, content, raises=True): - """Checks the given JSON content abides by the rules of the template. - - Args: - content : A JSON object to check. - raises: True to raise a SynapseError if the check fails. - Returns: - True if the content passes the template. Returns False if the check - fails and raises=False. - Raises: - SynapseError if the check fails and raises=True. - """ - # recursively call to inspect each layer - err_msg = self._check_json(content, self.get_content_template()) - if err_msg: - if raises: - raise SynapseError(400, err_msg, Codes.BAD_JSON) - else: - return False - else: - return True - - def _check_json(self, content, template): - """Check content and template matches. - - If the template is a dict, each key in the dict will be validated with - the content, else it will just compare the types of content and - template. This basic type check is required because this function will - be recursively called and could be called with just strs or ints. - - Args: - content: The content to validate. - template: The validation template. - Returns: - str: An error message if the validation fails, else None. - """ - if type(content) != type(template): - return "Mismatched types: %s" % template - - if type(template) == dict: - for key in template: - if key not in content: - return "Missing %s key" % key - - if type(content[key]) != type(template[key]): - return "Key %s is of the wrong type (got %s, want %s)" % ( - key, type(content[key]), type(template[key])) - - if type(content[key]) == dict: - # we must go deeper - msg = self._check_json(content[key], template[key]) - if msg: - return msg - elif type(content[key]) == list: - # make sure each item type in content matches the template - for entry in content[key]: - msg = self._check_json(entry, template[key][0]) - if msg: - return msg - class SynapseStateEvent(SynapseEvent): diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py index 74d0ef77f4b7..a1ec708a8131 100644 --- a/synapse/api/events/factory.py +++ b/synapse/api/events/factory.py @@ -16,11 +16,13 @@ from synapse.api.events.room import ( RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent, InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent, - RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, - RoomCreateEvent, RoomAddStateLevelEvent, RoomSendEventLevelEvent, + RoomPowerLevelsEvent, RoomJoinRulesEvent, + RoomCreateEvent, RoomRedactionEvent, ) +from synapse.types import EventID + from synapse.util.stringutils import random_string @@ -37,9 +39,6 @@ class EventFactory(object): RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomCreateEvent, - RoomAddStateLevelEvent, - RoomSendEventLevelEvent, - RoomOpsPowerLevelsEvent, RoomRedactionEvent, ] @@ -51,12 +50,26 @@ def __init__(self, hs): self.clock = hs.get_clock() self.hs = hs + self.event_id_count = 0 + + def create_event_id(self): + i = str(self.event_id_count) + self.event_id_count += 1 + + local_part = str(int(self.clock.time())) + i + random_string(5) + + e_id = EventID.create_local(local_part, self.hs) + + return e_id.to_string() + def create_event(self, etype=None, **kwargs): kwargs["type"] = etype if "event_id" not in kwargs: - kwargs["event_id"] = "%s@%s" % ( - random_string(10), self.hs.hostname - ) + kwargs["event_id"] = self.create_event_id() + kwargs["origin"] = self.hs.hostname + else: + ev_id = self.hs.parse_eventid(kwargs["event_id"]) + kwargs["origin"] = ev_id.domain if "origin_server_ts" not in kwargs: kwargs["origin_server_ts"] = int(self.clock.time_msec()) diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py index cd936074fc68..8c4ac45d02be 100644 --- a/synapse/api/events/room.py +++ b/synapse/api/events/room.py @@ -154,27 +154,6 @@ def get_content_template(self): return {} -class RoomAddStateLevelEvent(SynapseStateEvent): - TYPE = "m.room.add_state_level" - - def get_content_template(self): - return {} - - -class RoomSendEventLevelEvent(SynapseStateEvent): - TYPE = "m.room.send_event_level" - - def get_content_template(self): - return {} - - -class RoomOpsPowerLevelsEvent(SynapseStateEvent): - TYPE = "m.room.ops_levels" - - def get_content_template(self): - return {} - - class RoomAliasesEvent(SynapseStateEvent): TYPE = "m.room.aliases" diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py index c3a32be8c157..802648f8f7c0 100644 --- a/synapse/api/events/utils.py +++ b/synapse/api/events/utils.py @@ -15,21 +15,34 @@ from .room import ( RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent, - RoomAddStateLevelEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomAliasesEvent, RoomCreateEvent, ) + def prune_event(event): - """ Prunes the given event of all keys we don't know about or think could - potentially be dodgy. + """ Returns a pruned version of the given event, which removes all keys we + don't know about or think could potentially be dodgy. This is used when we "redact" an event. We want to remove all fields that the user has specified, but we do want to keep necessary information like type, state_key etc. """ + event_type = event.type - # Remove all extraneous fields. - event.unrecognized_keys = {} + allowed_keys = [ + "event_id", + "user_id", + "room_id", + "hashes", + "signatures", + "content", + "type", + "state_key", + "depth", + "prev_events", + "prev_state", + "auth_events", + ] new_content = {} @@ -38,27 +51,33 @@ def add_fields(*fields): if field in event.content: new_content[field] = event.content[field] - if event.type == RoomMemberEvent.TYPE: + if event_type == RoomMemberEvent.TYPE: add_fields("membership") - elif event.type == RoomCreateEvent.TYPE: + elif event_type == RoomCreateEvent.TYPE: add_fields("creator") - elif event.type == RoomJoinRulesEvent.TYPE: + elif event_type == RoomJoinRulesEvent.TYPE: add_fields("join_rule") - elif event.type == RoomPowerLevelsEvent.TYPE: - # TODO: Actually check these are valid user_ids etc. - add_fields("default") - for k, v in event.content.items(): - if k.startswith("@") and isinstance(v, (int, long)): - new_content[k] = v - elif event.type == RoomAddStateLevelEvent.TYPE: - add_fields("level") - elif event.type == RoomSendEventLevelEvent.TYPE: - add_fields("level") - elif event.type == RoomOpsPowerLevelsEvent.TYPE: - add_fields("kick_level", "ban_level", "redact_level") - elif event.type == RoomAliasesEvent.TYPE: + elif event_type == RoomPowerLevelsEvent.TYPE: + add_fields( + "users", + "users_default", + "events", + "events_default", + "events_default", + "state_default", + "ban", + "kick", + "redact", + ) + elif event_type == RoomAliasesEvent.TYPE: add_fields("aliases") - event.content = new_content + allowed_fields = { + k: v + for k, v in event.get_full_dict().items() + if k in allowed_keys + } + + allowed_fields["content"] = new_content - return event + return type(event)(**allowed_fields) diff --git a/synapse/api/events/validator.py b/synapse/api/events/validator.py new file mode 100644 index 000000000000..2d4f2a3aa7c2 --- /dev/null +++ b/synapse/api/events/validator.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import SynapseError, Codes + + +class EventValidator(object): + def __init__(self, hs): + pass + + def validate(self, event): + """Checks the given JSON content abides by the rules of the template. + + Args: + content : A JSON object to check. + raises: True to raise a SynapseError if the check fails. + Returns: + True if the content passes the template. Returns False if the check + fails and raises=False. + Raises: + SynapseError if the check fails and raises=True. + """ + # recursively call to inspect each layer + err_msg = self._check_json_template( + event.content, + event.get_content_template() + ) + if err_msg: + raise SynapseError(400, err_msg, Codes.BAD_JSON) + else: + return True + + def _check_json_template(self, content, template): + """Check content and template matches. + + If the template is a dict, each key in the dict will be validated with + the content, else it will just compare the types of content and + template. This basic type check is required because this function will + be recursively called and could be called with just strs or ints. + + Args: + content: The content to validate. + template: The validation template. + Returns: + str: An error message if the validation fails, else None. + """ + if type(content) != type(template): + return "Mismatched types: %s" % template + + if type(template) == dict: + for key in template: + if key not in content: + return "Missing %s key" % key + + if type(content[key]) != type(template[key]): + return "Key %s is of the wrong type (got %s, want %s)" % ( + key, type(content[key]), type(template[key])) + + if type(content[key]) == dict: + # we must go deeper + msg = self._check_json_template( + content[key], + template[key] + ) + if msg: + return msg + elif type(content[key]) == list: + # make sure each item type in content matches the template + for entry in content[key]: + msg = self._check_json_template( + entry, + template[key][0] + ) + if msg: + return msg \ No newline at end of file diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b3dae5da64ba..43164c8d675d 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -236,7 +236,10 @@ def setup(): f.namespace['hs'] = hs reactor.listenTCP(config.manhole, f, interface='127.0.0.1') - hs.start_listening(config.bind_port, config.unsecure_port) + bind_port = config.bind_port + if config.no_tls: + bind_port = None + hs.start_listening(bind_port, config.unsecure_port) if config.daemonize: print config.pid_file diff --git a/synapse/config/server.py b/synapse/config/server.py index 3afda12d5ae3..814a4c349b04 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,6 +30,7 @@ def __init__(self, args): self.pid_file = self.abspath(args.pid_file) self.webclient = True self.manhole = args.manhole + self.no_tls = args.no_tls if not args.content_addr: host = args.server_name @@ -67,6 +68,8 @@ def add_arguments(cls, parser): server_group.add_argument("--content-addr", default=None, help="The host and scheme to use for the " "content repository") + server_group.add_argument("--no-tls", action='store_true', + help="Don't bind to the https port.") def read_signing_key(self, signing_key_path): signing_keys = self.read_file(signing_key_path, "signing_key") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py new file mode 100644 index 000000000000..baa93b0ee4e8 --- /dev/null +++ b/synapse/crypto/event_signing.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.api.events.utils import prune_event +from syutil.jsonutil import encode_canonical_json +from syutil.base64util import encode_base64, decode_base64 +from syutil.crypto.jsonsign import sign_json + +import hashlib +import logging + +logger = logging.getLogger(__name__) + + +def check_event_content_hash(event, hash_algorithm=hashlib.sha256): + """Check whether the hash for this PDU matches the contents""" + computed_hash = _compute_content_hash(event, hash_algorithm) + if computed_hash.name not in event.hashes: + raise Exception("Algorithm %s not in hashes %s" % ( + computed_hash.name, list(event.hashes) + )) + message_hash_base64 = event.hashes[computed_hash.name] + try: + message_hash_bytes = decode_base64(message_hash_base64) + except: + raise Exception("Invalid base64: %s" % (message_hash_base64,)) + return message_hash_bytes == computed_hash.digest() + + +def _compute_content_hash(event, hash_algorithm): + event_json = event.get_full_dict() + # TODO: We need to sign the JSON that is going out via fedaration. + event_json.pop("age_ts", None) + event_json.pop("unsigned", None) + event_json.pop("signatures", None) + event_json.pop("hashes", None) + event_json_bytes = encode_canonical_json(event_json) + return hash_algorithm(event_json_bytes) + + +def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): + tmp_event = prune_event(event) + event_json = tmp_event.get_dict() + event_json.pop("signatures", None) + event_json.pop("age_ts", None) + event_json.pop("unsigned", None) + event_json_bytes = encode_canonical_json(event_json) + hashed = hash_algorithm(event_json_bytes) + return (hashed.name, hashed.digest()) + + +def compute_event_signature(event, signature_name, signing_key): + tmp_event = prune_event(event) + redact_json = tmp_event.get_full_dict() + redact_json.pop("signatures", None) + redact_json.pop("age_ts", None) + redact_json.pop("unsigned", None) + logger.debug("Signing event: %s", redact_json) + redact_json = sign_json(redact_json, signature_name, signing_key) + return redact_json["signatures"] + + +def add_hashes_and_signatures(event, signature_name, signing_key, + hash_algorithm=hashlib.sha256): + if hasattr(event, "old_state_events"): + state_json_bytes = encode_canonical_json( + [e.event_id for e in event.old_state_events.values()] + ) + hashed = hash_algorithm(state_json_bytes) + event.state_hash = { + hashed.name: encode_base64(hashed.digest()) + } + + hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm) + + if not hasattr(event, "hashes"): + event.hashes = {} + event.hashes[hashed.name] = encode_base64(hashed.digest()) + + event.signatures = compute_event_signature( + event, + signature_name=signature_name, + signing_key=signing_key, + ) diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py index e8180d94fd17..52c84efb5bc2 100644 --- a/synapse/federation/pdu_codec.py +++ b/synapse/federation/pdu_codec.py @@ -18,50 +18,25 @@ import copy -def decode_event_id(event_id, server_name): - parts = event_id.split("@") - if len(parts) < 2: - return (event_id, server_name) - else: - return (parts[0], "".join(parts[1:])) - - -def encode_event_id(pdu_id, origin): - return "%s@%s" % (pdu_id, origin) - - class PduCodec(object): def __init__(self, hs): + self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_factory = hs.get_event_factory() self.clock = hs.get_clock() + self.hs = hs def event_from_pdu(self, pdu): kwargs = {} - kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) - kwargs["room_id"] = pdu.context - kwargs["etype"] = pdu.pdu_type - kwargs["prev_events"] = [ - encode_event_id(p[0], p[1]) for p in pdu.prev_pdus - ] - - if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): - kwargs["prev_state"] = encode_event_id( - pdu.prev_state_id, pdu.prev_state_origin - ) + kwargs["etype"] = pdu.type kwargs.update({ k: v for k, v in pdu.get_full_dict().items() if k not in [ - "pdu_id", - "context", - "pdu_type", - "prev_pdus", - "prev_state_id", - "prev_state_origin", + "type", ] }) @@ -70,33 +45,10 @@ def event_from_pdu(self, pdu): def pdu_from_event(self, event): d = event.get_full_dict() - d["pdu_id"], d["origin"] = decode_event_id( - event.event_id, self.server_name - ) - d["context"] = event.room_id - d["pdu_type"] = event.type - - if hasattr(event, "prev_events"): - d["prev_pdus"] = [ - decode_event_id(e, self.server_name) - for e in event.prev_events - ] - - if hasattr(event, "prev_state"): - d["prev_state_id"], d["prev_state_origin"] = ( - decode_event_id(event.prev_state, self.server_name) - ) - - if hasattr(event, "state_key"): - d["is_state"] = True - kwargs = copy.deepcopy(event.unrecognized_keys) kwargs.update({ k: v for k, v in d.items() - if k not in ["event_id", "room_id", "type", "prev_events"] }) - if "origin_server_ts" not in kwargs: - kwargs["origin_server_ts"] = int(self.clock.time_msec()) - - return Pdu(**kwargs) + pdu = Pdu(**kwargs) + return pdu diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 7043fcc504cc..73dc844d59f5 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,8 +21,6 @@ from twisted.internet import defer -from .units import Pdu - from synapse.util.logutils import log_function import json @@ -32,76 +30,6 @@ logger = logging.getLogger(__name__) -class PduActions(object): - """ Defines persistence actions that relate to handling PDUs. - """ - - def __init__(self, datastore): - self.store = datastore - - @log_function - def mark_as_processed(self, pdu): - """ Persist the fact that we have fully processed the given `Pdu` - - Returns: - Deferred - """ - return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin) - - @defer.inlineCallbacks - @log_function - def after_transaction(self, transaction_id, destination, origin): - """ Returns all `Pdu`s that we sent to the given remote home server - after a given transaction id. - - Returns: - Deferred: Results in a list of `Pdu`s - """ - results = yield self.store.get_pdus_after_transaction( - transaction_id, - destination - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def get_all_pdus_from_context(self, context): - results = yield self.store.get_all_pdus_from_context(context) - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @defer.inlineCallbacks - @log_function - def backfill(self, context, pdu_list, limit): - """ For a given list of PDU id and origins return the proceeding - `limit` `Pdu`s in the given `context`. - - Returns: - Deferred: Results in a list of `Pdu`s. - """ - results = yield self.store.get_backfill( - context, pdu_list, limit - ) - - defer.returnValue([Pdu.from_pdu_tuple(p) for p in results]) - - @log_function - def is_new(self, pdu): - """ When we receive a `Pdu` from a remote home server, we want to - figure out whether it is `new`, i.e. it is not some historic PDU that - we haven't seen simply because we haven't backfilled back that far. - - Returns: - Deferred: Results in a `bool` - """ - return self.store.is_pdu_new( - pdu_id=pdu.pdu_id, - origin=pdu.origin, - context=pdu.context, - depth=pdu.depth - ) - - class TransactionActions(object): """ Defines persistence actions that relate to handling Transactions. """ @@ -158,7 +86,6 @@ def prepare_to_send(self, transaction): transaction.transaction_id, transaction.destination, transaction.origin_server_ts, - [(p["pdu_id"], p["origin"]) for p in transaction.pdus] ) @log_function diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 092411eaf94c..5c625ddabf2e 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -21,7 +21,7 @@ from .units import Transaction, Pdu, Edu -from .persistence import PduActions, TransactionActions +from .persistence import TransactionActions from synapse.util.logutils import log_function @@ -57,7 +57,7 @@ def __init__(self, hs, transport_layer): self.transport_layer.register_request_handler(self) self.store = hs.get_datastore() - self.pdu_actions = PduActions(self.store) + # self.pdu_actions = PduActions(self.store) self.transaction_actions = TransactionActions(self.store) self._transaction_queue = _TransactionQueue( @@ -81,7 +81,7 @@ def set_handler(self, handler): def register_edu_handler(self, edu_type, handler): if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type)) + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) self.edu_handlers[edu_type] = handler @@ -102,24 +102,17 @@ def register_query_handler(self, query_type, handler): object to encode as JSON. """ if query_type in self.query_handlers: - raise KeyError("Already have a Query handler for %s" % (query_type)) + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) self.query_handlers[query_type] = handler - @defer.inlineCallbacks @log_function def send_pdu(self, pdu): """Informs the replication layer about a new PDU generated within the home server that should be transmitted to others. - This will fill out various attributes on the PDU object, e.g. the - `prev_pdus` key. - - *Note:* The home server should always call `send_pdu` even if it knows - that it does not need to be replicated to other home servers. This is - in case e.g. someone else joins via a remote home server and then - backfills. - TODO: Figure out when we should actually resolve the deferred. Args: @@ -132,18 +125,15 @@ def send_pdu(self, pdu): order = self._order self._order += 1 - logger.debug("[%s] Persisting PDU", pdu.pdu_id) - - # Save *before* trying to send - yield self.store.persist_event(pdu=pdu) - - logger.debug("[%s] Persisted PDU", pdu.pdu_id) - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id) + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) # TODO, add errback, etc. self._transaction_queue.enqueue_pdu(pdu, order) - logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id) + logger.debug( + "[%s] transaction_layer.enqueue_pdu... done", + pdu.event_id + ) @log_function def send_edu(self, destination, edu_type, content): @@ -158,6 +148,11 @@ def send_edu(self, destination, edu_type, content): self._transaction_queue.enqueue_edu(edu) return defer.succeed(None) + @log_function + def send_failure(self, failure, destination): + self._transaction_queue.enqueue_failure(failure, destination) + return defer.succeed(None) + @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=True): @@ -181,7 +176,7 @@ def make_query(self, destination, query_type, args, @defer.inlineCallbacks @log_function - def backfill(self, dest, context, limit): + def backfill(self, dest, context, limit, extremities): """Requests some more historic PDUs for the given context from the given destination server. @@ -189,12 +184,12 @@ def backfill(self, dest, context, limit): dest (str): The remote home server to ask. context (str): The context to backfill. limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context Returns: Deferred: Results in the received PDUs. """ - extremities = yield self.store.get_oldest_pdus_in_context(context) - logger.debug("backfill extrem=%s", extremities) # If there are no extremeties then we've (probably) reached the start. @@ -210,13 +205,13 @@ def backfill(self, dest, context, limit): pdus = [Pdu(outlier=False, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu, backfilled=True) + yield self._handle_new_pdu(dest, pdu, backfilled=True) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): + def get_pdu(self, destination, event_id, outlier=False): """Requests the PDU with given origin and ID from the remote home server. @@ -225,7 +220,7 @@ def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): Args: destination (str): Which home server to query pdu_origin (str): The home server that originally sent the pdu. - pdu_id (str) + event_id (str) outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` @@ -234,8 +229,9 @@ def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): Deferred: Results in the requested PDU. """ - transaction_data = yield self.transport_layer.get_pdu( - destination, pdu_origin, pdu_id) + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) transaction = Transaction(**transaction_data) @@ -244,13 +240,13 @@ def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): pdu = None if pdu_list: pdu = pdu_list[0] - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdu) @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context): + def get_state_for_context(self, destination, context, event_id=None): """Requests all of the `current` state PDUs for a given context from a remote home server. @@ -263,29 +259,25 @@ def get_state_for_context(self, destination, context): """ transaction_data = yield self.transport_layer.get_context_state( - destination, context) + destination, + context, + event_id=event_id, + ) transaction = Transaction(**transaction_data) pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] for pdu in pdus: - yield self._handle_new_pdu(pdu) + yield self._handle_new_pdu(destination, pdu) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def on_context_pdus_request(self, context): - pdus = yield self.pdu_actions.get_all_pdus_from_context( - context + def on_backfill_request(self, origin, context, versions, limit): + pdus = yield self.handler.on_backfill_request( + origin, context, versions, limit ) - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) - - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, context, versions, limit): - - pdus = yield self.pdu_actions.backfill(context, versions, limit) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @@ -295,6 +287,10 @@ def on_incoming_transaction(self, transaction_data): transaction = Transaction(**transaction_data) for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] if "age" in p: p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) del p["age"] @@ -315,11 +311,15 @@ def on_incoming_transaction(self, transaction_data): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(pdu)) + dl.append(self._handle_new_pdu(transaction.origin, pdu)) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu(transaction.origin, edu.edu_type, edu.content) + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) results = yield defer.DeferredList(dl) @@ -347,20 +347,22 @@ def received_edu(self, origin, edu_type, content): @defer.inlineCallbacks @log_function - def on_context_state_request(self, context): - results = yield self.store.get_current_state_for_context( - context - ) - - logger.debug("Context returning %d results", len(results)) + def on_context_state_request(self, origin, context, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + origin, + context, + event_id, + ) + else: + raise NotImplementedError("Specify an event") - pdus = [Pdu.from_pdu_tuple(p) for p in results] defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function - def on_pdu_request(self, pdu_origin, pdu_id): - pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin) + def on_pdu_request(self, origin, event_id): + pdu = yield self._get_persisted_pdu(origin, event_id) if pdu: defer.returnValue( @@ -372,103 +374,191 @@ def on_pdu_request(self, pdu_origin, pdu_id): @defer.inlineCallbacks @log_function def on_pull_request(self, origin, versions): - transaction_id = max([int(v) for v in versions]) + raise NotImplementedError("Pull transacions not implemented") + + @defer.inlineCallbacks + def on_query_request(self, query_type, args): + if query_type in self.query_handlers: + response = yield self.query_handlers[query_type](args) + defer.returnValue((200, response)) + else: + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type, )) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, context, user_id): + pdu = yield self.handler.on_make_join_request(context, user_id) + defer.returnValue({ + "event": pdu.get_dict(), + }) - response = yield self.pdu_actions.after_transaction( - transaction_id, - origin, - self.server_name + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = Pdu(**content) + ret_pdu = yield self.handler.on_invite_request(origin, pdu) + defer.returnValue( + ( + 200, + { + "event": ret_pdu.get_dict(), + } + ) ) - if not response: - response = [] + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + pdu = Pdu(**content) + res_pdus = yield self.handler.on_send_join_request(origin, pdu) + + defer.returnValue((200, { + "state": [p.get_dict() for p in res_pdus["state"]], + "auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]], + })) + @defer.inlineCallbacks + def on_event_auth(self, origin, context, event_id): + auth_pdus = yield self.handler.on_event_auth(event_id) defer.returnValue( - (200, self._transaction_from_pdus(response).get_dict()) + ( + 200, + { + "auth_chain": [a.get_dict() for a in auth_pdus], + } + ) ) @defer.inlineCallbacks - def on_query_request(self, query_type, args): - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue((404, "No handler for Query type '%s'" - % (query_type) - )) + def make_join(self, destination, context, user_id): + ret = yield self.transport_layer.make_join( + destination=destination, + context=context, + user_id=user_id, + ) + + pdu_dict = ret["event"] + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) @defer.inlineCallbacks + def send_join(self, destination, pdu): + _, content = yield self.transport_layer.send_join( + destination, + pdu.room_id, + pdu.event_id, + pdu.get_dict(), + ) + + logger.debug("Got content: %s", content) + state = [Pdu(outlier=True, **p) for p in content.get("state", [])] + for pdu in state: + yield self._handle_new_pdu(destination, pdu) + + auth_chain = [ + Pdu(outlier=True, **p) for p in content.get("auth_chain", []) + ] + for pdu in auth_chain: + yield self._handle_new_pdu(destination, pdu) + + defer.returnValue(state) + + @defer.inlineCallbacks + def send_invite(self, destination, context, event_id, pdu): + code, content = yield self.transport_layer.send_invite( + destination=destination, + context=context, + event_id=event_id, + content=pdu.get_dict(), + ) + + pdu_dict = content["event"] + + logger.debug("Got response to send_invite: %s", pdu_dict) + + defer.returnValue(Pdu(**pdu_dict)) + @log_function - def _get_persisted_pdu(self, pdu_id, pdu_origin): + def _get_persisted_pdu(self, origin, event_id): """ Get a PDU from the database with given origin and id. Returns: Deferred: Results in a `Pdu`. """ - pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin) - - defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple)) + return self.handler.get_persisted_pdu(origin, event_id) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. """ pdus = [p.get_dict() for p in pdu_list] + time_now = self._clock.time_msec() for p in pdus: - if "age_ts" in pdus: - p["age"] = int(self.clock.time_msec()) - p["age_ts"] - + if "age_ts" in p: + age = time_now - p["age_ts"] + p.setdefault("unsigned", {})["age"] = int(age) + del p["age_ts"] return Transaction( origin=self.server_name, pdus=pdus, - origin_server_ts=int(self._clock.time_msec()), + origin_server_ts=int(time_now), destination=None, ) @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, pdu, backfilled=False): + def _handle_new_pdu(self, origin, pdu, backfilled=False): # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) + existing = yield self._get_persisted_pdu(origin, pdu.event_id) if existing and (not existing.outlier or pdu.outlier): - logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin) + logger.debug("Already seen pdu %s", pdu.event_id) defer.returnValue({}) return + state = None + # Get missing pdus if necessary. - is_new = yield self.pdu_actions.is_new(pdu) - if is_new and not pdu.outlier: + if not pdu.outlier: # We only backfill backwards to the min depth. - min_depth = yield self.store.get_min_depth_for_context(pdu.context) + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) if min_depth and pdu.depth > min_depth: - for pdu_id, origin in pdu.prev_pdus: - exists = yield self._get_persisted_pdu(pdu_id, origin) + for event_id, hashes in pdu.prev_events: + exists = yield self._get_persisted_pdu(origin, event_id) if not exists: - logger.debug("Requesting pdu %s %s", pdu_id, origin) + logger.debug("Requesting pdu %s", event_id) try: yield self.get_pdu( pdu.origin, - pdu_id=pdu_id, - pdu_origin=origin + event_id=event_id, ) - logger.debug("Processed pdu %s %s", pdu_id, origin) + logger.debug("Processed pdu %s", event_id) except: # TODO(erikj): Do some more intelligent retries. logger.exception("Failed to get PDU") - - # Persist the Pdu, but don't mark it as processed yet. - yield self.store.persist_event(pdu=pdu) + else: + # We need to get the state at this event, since we have reached + # a backward extremity edge. + state = yield self.get_state_for_context( + origin, pdu.room_id, pdu.event_id, + ) if not backfilled: - ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) + ret = yield self.handler.on_receive_pdu( + pdu, + backfilled=backfilled, + state=state, + ) else: ret = None - yield self.pdu_actions.mark_as_processed(pdu) + # yield self.pdu_actions.mark_as_processed(pdu) defer.returnValue(ret) @@ -476,14 +566,6 @@ def __str__(self): return "" % self.server_name -class ReplicationHandler(object): - """This defines the methods that the :py:class:`.ReplicationLayer` will - use to communicate with the rest of the home server. - """ - def on_receive_pdu(self, pdu): - raise NotImplementedError("on_receive_pdu") - - class _TransactionQueue(object): """This class makes sure we only have one transaction in flight at a time for a given destination. @@ -509,6 +591,9 @@ def __init__(self, hs, transaction_actions, transport_layer): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = {} + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + # HACK to get unique tx id self._next_txn_id = int(self._clock.time_msec()) @@ -561,6 +646,18 @@ def eb(failure): return deferred + @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): @@ -570,8 +667,9 @@ def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - if not pending_pdus and not pending_edus: + if not pending_pdus and not pending_edus and not pending_failures: return logger.debug("TX [%s] Attempting new transaction", destination) @@ -581,7 +679,11 @@ def _attempt_new_transaction(self, destination): pdus = [x[0] for x in pending_pdus] edus = [x[0] for x in pending_edus] - deferreds = [x[1] for x in pending_pdus + pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] try: self.pending_transactions[destination] = 1 @@ -589,12 +691,13 @@ def _attempt_new_transaction(self, destination): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=self._clock.time_msec(), + origin_server_ts=int(self._clock.time_msec()), transaction_id=str(self._next_txn_id), origin=self.server_name, destination=destination, pdus=pdus, edus=edus, + pdu_failures=failures, ) self._next_txn_id += 1 @@ -614,7 +717,9 @@ def json_data_cb(): if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: - p["age"] = now - int(p["age_ts"]) + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] return data code, response = yield self.transport_layer.send_transaction( diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index e7517cac4da2..95c40c6c1be6 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -72,7 +72,7 @@ def __init__(self, homeserver, server_name, server, client): self.received_handler = None @log_function - def get_context_state(self, destination, context): + def get_context_state(self, destination, context, event_id=None): """ Requests all state for a given context (i.e. room) from the given server. @@ -89,54 +89,62 @@ def get_context_state(self, destination, context): subpath = "/state/%s/" % context - return self._do_request_for_transaction(destination, subpath) + args = {} + if event_id: + args["event_id"] = event_id + + return self._do_request_for_transaction( + destination, subpath, args=args + ) @log_function - def get_pdu(self, destination, pdu_origin, pdu_id): + def get_event(self, destination, event_id): """ Requests the pdu with give id and origin from the given server. Args: destination (str): The host name of the remote home server we want to get the state from. - pdu_origin (str): The home server which created the PDU. - pdu_id (str): The id of the PDU being requested. + event_id (str): The id of the event being requested. Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s", - destination, pdu_origin, pdu_id) + logger.debug("get_pdu dest=%s, event_id=%s", + destination, event_id) - subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id) + subpath = "/event/%s/" % (event_id, ) return self._do_request_for_transaction(destination, subpath) @log_function - def backfill(self, dest, context, pdu_tuples, limit): + def backfill(self, dest, context, event_tuples, limit): """ Requests `limit` previous PDUs in a given context before list of PDUs. Args: dest (str) context (str) - pdu_tuples (list) + event_tuples (list) limt (int) Returns: Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s", - dest, context, repr(pdu_tuples), str(limit) + "backfill dest=%s, context=%s, event_tuples=%s, limit=%s", + dest, context, repr(event_tuples), str(limit) ) - if not pdu_tuples: + if not event_tuples: + # TODO: raise? return - subpath = "/backfill/%s/" % context + subpath = "/backfill/%s/" % (context,) - args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} - args["limit"] = limit + args = { + "v": event_tuples, + "limit": limit, + } return self._do_request_for_transaction( dest, @@ -197,6 +205,72 @@ def make_query(self, destination, query_type, args, retry_on_dns_fail): defer.returnValue(response) + @defer.inlineCallbacks + @log_function + def make_join(self, destination, context, user_id, retry_on_dns_fail=True): + path = PREFIX + "/make_join/%s/%s" % (context, user_id,) + + response = yield self.client.get_json( + destination=destination, + path=path, + retry_on_dns_fail=retry_on_dns_fail, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_join(self, destination, context, event_id, content): + path = PREFIX + "/send_join/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_join", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def send_invite(self, destination, context, event_id, content): + path = PREFIX + "/invite/%s/%s" % ( + context, + event_id, + ) + + code, content = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, context, event_id): + path = PREFIX + "/event_auth/%s/%s" % ( + context, + event_id, + ) + + response = yield self.client.get_json( + destination=destination, + path=path, + ) + + defer.returnValue(response) + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { @@ -210,7 +284,7 @@ def _authenticate_request(self, request): origin = None if request.method == "PUT": - #TODO: Handle other method types? other content types? + # TODO: Handle other method types? other content types? try: content_bytes = request.content.read() content = json.loads(content_bytes) @@ -222,11 +296,13 @@ def parse_auth_header(header_str): try: params = auth.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) + def strip_quotes(value): if value.startswith("\""): return value[1:-1] else: return value + origin = strip_quotes(param_dict["origin"]) key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) @@ -247,7 +323,7 @@ def strip_quotes(value): if auth.startswith("X-Matrix"): (origin, key, sig) = parse_auth_header(auth) json_request["origin"] = origin - json_request["signatures"].setdefault(origin,{})[key] = sig + json_request["signatures"].setdefault(origin, {})[key] = sig if not json_request["signatures"]: raise SynapseError( @@ -313,10 +389,10 @@ def register_request_handler(self, handler): # data_id pair. self.server.register_path( "GET", - re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"), + re.compile("^" + PREFIX + "/event/([^/]*)/$"), self._with_authentication( - lambda origin, content, query, pdu_origin, pdu_id: - handler.on_pdu_request(pdu_origin, pdu_id) + lambda origin, content, query, event_id: + handler.on_pdu_request(origin, event_id) ) ) @@ -326,7 +402,11 @@ def register_request_handler(self, handler): re.compile("^" + PREFIX + "/state/([^/]*)/$"), self._with_authentication( lambda origin, content, query, context: - handler.on_context_state_request(context) + handler.on_context_state_request( + origin, + context, + query.get("event_id", [None])[0], + ) ) ) @@ -336,28 +416,63 @@ def register_request_handler(self, handler): self._with_authentication( lambda origin, content, query, context: self._on_backfill_request( - context, query["v"], query["limit"] + origin, context, query["v"], query["limit"] ) ) ) + # This is when we receive a server-server Query self.server.register_path( "GET", - re.compile("^" + PREFIX + "/context/([^/]*)/$"), + re.compile("^" + PREFIX + "/query/([^/]*)$"), self._with_authentication( - lambda origin, content, query, context: - handler.on_context_pdus_request(context) + lambda origin, content, query, query_type: + handler.on_query_request( + query_type, {k: v[0] for k, v in query.items()} + ) ) ) - # This is when we receive a server-server Query self.server.register_path( "GET", - re.compile("^" + PREFIX + "/query/([^/]*)$"), + re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), self._with_authentication( - lambda origin, content, query, query_type: - handler.on_query_request( - query_type, {k: v[0] for k, v in query.items()} + lambda origin, content, query, context, user_id: + self._on_make_join_request( + origin, content, query, context, user_id + ) + ) + ) + + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + handler.on_event_auth( + origin, context, event_id, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_send_join_request( + origin, content, query, + ) + ) + ) + + self.server.register_path( + "PUT", + re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_invite_request( + origin, content, query, ) ) ) @@ -402,7 +517,8 @@ def _on_send_request(self, origin, content, query, transaction_id): return try: - code, response = yield self.received_handler.on_incoming_transaction( + handler = self.received_handler + code, response = yield handler.on_incoming_transaction( transaction_data ) except: @@ -440,7 +556,7 @@ def _do_request_for_transaction(self, destination, subpath, args={}): defer.returnValue(data) @log_function - def _on_backfill_request(self, context, v_list, limits): + def _on_backfill_request(self, origin, context, v_list, limits): if not limits: return defer.succeed( (400, {"error": "Did not include limit param"}) @@ -448,124 +564,34 @@ def _on_backfill_request(self, context, v_list, limits): limit = int(limits[-1]) - versions = [v.split(",", 1) for v in v_list] + versions = v_list return self.request_handler.on_backfill_request( - context, versions, limit) - - -class TransportReceivedHandler(object): - """ Callbacks used when we receive a transaction - """ - def on_incoming_transaction(self, transaction): - """ Called on PUT /send/, or on response to a request - that we sent (e.g. a backfill request) - - Args: - transaction (synapse.transaction.Transaction): The transaction that - was sent to us. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - -class TransportRequestHandler(object): - """ Handlers used when someone want's data from us - """ - def on_pull_request(self, versions): - """ Called on GET /pull/?v=... - - This is hit when a remote home server wants to get all data - after a given transaction. Mainly used when a home server comes back - online and wants to get everything it has missed. - - Args: - versions (list): A list of transaction_ids that should be used to - determine what PDUs the remote side have not yet seen. - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_pdu_request(self, pdu_origin, pdu_id): - """ Called on GET /pdu/// - - Someone wants a particular PDU. This PDU may or may not have originated - from us. - - Args: - pdu_origin (str) - pdu_id (str) - - Returns: - Deferred: Resultsin a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_context_state_request(self, context): - """ Called on GET /state// - - Gets hit when someone wants all the *current* state for a given - contexts. - - Args: - context (str): The name of the context that we're interested in. - - Returns: - twisted.internet.defer.Deferred: A deferred that gets fired when - the transaction has finished being processed. - - The result should be a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. - - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass - - def on_backfill_request(self, context, versions, limit): - """ Called on GET /backfill//?v=...&limit=... + origin, context, versions, limit + ) - Gets hit when we want to backfill backwards on a given context from - the given point. + @defer.inlineCallbacks + @log_function + def _on_make_join_request(self, origin, content, query, context, user_id): + content = yield self.request_handler.on_make_join_request( + context, user_id, + ) + defer.returnValue((200, content)) - Args: - context (str): The context to backfill - versions (list): A list of 2-tuples representing where to backfill - from, in the form `(pdu_id, origin)` - limit (int): How many pdus to return. + @defer.inlineCallbacks + @log_function + def _on_send_join_request(self, origin, content, query): + content = yield self.request_handler.on_send_join_request( + origin, content, + ) - Returns: - Deferred: Results in a tuple in the form of - `(response_code, respond_body)`, where `response_body` is a python - dict that will get serialized to JSON. + defer.returnValue((200, content)) - On errors, the dict should have an `error` key with a brief message - of what went wrong. - """ - pass + @defer.inlineCallbacks + @log_function + def _on_invite_request(self, origin, content, query): + content = yield self.request_handler.on_invite_request( + origin, content, + ) - def on_query_request(self): - """ Called on a GET /query/ request. """ + defer.returnValue((200, content)) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b2fb9641805f..f4e7b62bd9dc 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -20,8 +20,6 @@ from synapse.util.jsonobject import JsonEncodedObject import logging -import json -import copy logger = logging.getLogger(__name__) @@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject): A Pdu can be classified as "state". For a given context, we can efficiently retrieve all state pdu's that haven't been clobbered. Clobbering is done - via a unique constraint on the tuple (context, pdu_type, state_key). A pdu + via a unique constraint on the tuple (context, type, state_key). A pdu is a state pdu if `is_state` is True. Example pdu:: { - "pdu_id": "78c", + "event_id": "$78c:example.com", "origin_server_ts": 1404835423000, "origin": "bar", "prev_ids": [ @@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject): """ valid_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "destinations", - "transaction_id", - "prev_pdus", + "prev_events", "depth", "content", - "outlier", - "is_state", # Below this are keys valid only for State Pdus. - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - "required_power_level", + "hashes", "user_id", + "auth_events", + "signatures", # Below this are keys valid only for State Pdus. + "state_key", + "prev_state", ] internal_keys = [ @@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject): ] required_keys = [ - "pdu_id", - "context", + "event_id", + "room_id", "origin", "origin_server_ts", - "pdu_type", + "type", "content", ] # TODO: We need to make this properly load content rather than # just leaving it as a dict. (OR DO WE?!) - def __init__(self, destinations=[], is_state=False, prev_pdus=[], - outlier=False, **kwargs): - if is_state: - for required_key in ["state_key"]: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - + def __init__(self, destinations=[], prev_events=[], + outlier=False, hashes={}, signatures={}, **kwargs): super(Pdu, self).__init__( destinations=destinations, - is_state=is_state, - prev_pdus=prev_pdus, + prev_events=prev_events, outlier=outlier, + hashes=hashes, + signatures=signatures, **kwargs ) - @classmethod - def from_pdu_tuple(cls, pdu_tuple): - """ Converts a PduTuple to a Pdu - - Args: - pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to - convert - - Returns: - Pdu - """ - if pdu_tuple: - d = copy.copy(pdu_tuple.pdu_entry._asdict()) - d["origin_server_ts"] = d.pop("ts") - - d["content"] = json.loads(d["content_json"]) - del d["content_json"] - - args = {f: d[f] for f in cls.valid_keys if f in d} - if "unrecognized_keys" in d and d["unrecognized_keys"]: - args.update(json.loads(d["unrecognized_keys"])) - - return Pdu( - prev_pdus=pdu_tuple.prev_pdu_list, - **args - ) - else: - return None - def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) @@ -193,6 +155,7 @@ class Transaction(JsonEncodedObject): "edus", "transaction_id", "destination", + "pdu_failures", ] internal_keys = [ @@ -229,7 +192,9 @@ def create_new(pdus, **kwargs): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") + raise KeyError( + "Require 'origin_server_ts' to construct a Transaction" + ) if "transaction_id" not in kwargs: raise KeyError( "Require 'transaction_id' to construct a Transaction" @@ -241,6 +206,3 @@ def create_new(pdus, **kwargs): kwargs["pdus"] = [p.get_dict() for p in pdus] return Transaction(**kwargs) - - - diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index de4d23bbb3c9..07a8464107ec 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -14,7 +14,18 @@ # limitations under the License. from twisted.internet import defer + from synapse.api.errors import LimitExceededError +from synapse.util.async import run_on_reactor +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.api.events.room import RoomMemberEvent +from synapse.api.constants import Membership + +import logging + + +logger = logging.getLogger(__name__) + class BaseHandler(object): @@ -30,6 +41,9 @@ def __init__(self, hs): self.clock = hs.get_clock() self.hs = hs + self.signing_key = hs.config.signing_key[0] + self.server_name = hs.hostname + def ratelimit(self, user_id): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( @@ -44,16 +58,58 @@ def ratelimit(self, user_id): @defer.inlineCallbacks def _on_new_room_event(self, event, snapshot, extra_destinations=[], - extra_users=[]): + extra_users=[], suppress_auth=False, + do_invite_host=None): + yield run_on_reactor() + snapshot.fill_out_prev_events(event) + yield self.state_handler.annotate_state_groups(event) + + yield self.auth.add_auth_events(event) + + logger.debug("Signing event...") + + add_hashes_and_signatures( + event, self.server_name, self.signing_key + ) + + logger.debug("Signed event.") + + if not suppress_auth: + logger.debug("Authing...") + self.auth.check(event, raises=True) + logger.debug("Authed") + else: + logger.debug("Suppressed auth.") + + if do_invite_host: + federation_handler = self.hs.get_handlers().federation_handler + invite_event = yield federation_handler.send_invite( + do_invite_host, + event + ) + + # FIXME: We need to check if the remote changed anything else + event.signatures = invite_event.signatures + yield self.store.persist_event(event) destinations = set(extra_destinations) # Send a PDU to all hosts who have joined the room. - destinations.update((yield self.store.get_joined_hosts_for_room( - event.room_id - ))) + + for k, s in event.state_events.items(): + try: + if k[0] == RoomMemberEvent.TYPE: + if s.content["membership"] == Membership.JOIN: + destinations.add( + self.hs.parse_userid(s.state_key).domain + ) + except: + logger.warn( + "Failed to get destination from event %s", s.event_id + ) + event.destinations = list(destinations) self.notifier.on_new_room_event(event, extra_users=extra_users) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a56830d52094..164363cdc523 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -147,10 +147,8 @@ def _update_room_alias_events(self, user_id, room_id): content={"aliases": aliases}, ) - snapshot = yield self.store.snapshot_room( - room_id=room_id, - user_id=user_id, - ) + snapshot = yield self.store.snapshot_room(event) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user_id], suppress_auth=True + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f52591d2a3c8..c2cd91bb39ee 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,13 +17,15 @@ from ._base import BaseHandler -from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent +from synapse.api.errors import AuthError, FederationError +from synapse.api.events.room import RoomMemberEvent from synapse.api.constants import Membership from synapse.util.logutils import log_function from synapse.federation.pdu_codec import PduCodec -from synapse.api.errors import SynapseError +from synapse.util.async import run_on_reactor +from synapse.crypto.event_signing import compute_event_signature -from twisted.internet import defer, reactor +from twisted.internet import defer import logging @@ -62,6 +64,9 @@ def __init__(self, hs): self.pdu_codec = PduCodec(hs) + # When joining a room we need to queue any events for that room up + self.room_queues = {} + @log_function @defer.inlineCallbacks def handle_new_event(self, event, snapshot): @@ -78,6 +83,8 @@ def handle_new_event(self, event, snapshot): processing. """ + yield run_on_reactor() + pdu = self.pdu_codec.pdu_from_event(event) if not hasattr(pdu, "destinations") or not pdu.destinations: @@ -87,97 +94,88 @@ def handle_new_event(self, event, snapshot): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, pdu, backfilled): + def on_receive_pdu(self, pdu, backfilled, state=None): """ Called by the ReplicationLayer when we have a new pdu. We need to - do auth checks and put it throught the StateHandler. + do auth checks and put it through the StateHandler. """ event = self.pdu_codec.event_from_pdu(pdu) logger.debug("Got event: %s", event.event_id) - with (yield self.lock_manager.lock(pdu.context)): - if event.is_state and not backfilled: - is_new_state = yield self.state_handler.handle_new_state( - pdu - ) - else: - is_new_state = False - # TODO: Implement something in federation that allows us to - # respond to PDU. + if event.room_id in self.room_queues: + self.room_queues[event.room_id].append(pdu) + return - target_is_mine = False - if hasattr(event, "target_host"): - target_is_mine = event.target_host == self.hs.hostname - - if event.type == InviteJoinEvent.TYPE: - if not target_is_mine: - logger.debug("Ignoring invite/join event %s", event) - return - - # If we receive an invite/join event then we need to join the - # sender to the given room. - # TODO: We should probably auth this or some such - content = event.content - content.update({"membership": Membership.JOIN}) - new_event = self.event_factory.create_event( - etype=RoomMemberEvent.TYPE, - state_key=event.user_id, - room_id=event.room_id, - user_id=event.user_id, - membership=Membership.JOIN, - content=content - ) + logger.debug("Processing event: %s", event.event_id) + + if state: + state = [self.pdu_codec.event_from_pdu(p) for p in state] + + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) - yield self.hs.get_handlers().room_member_handler.change_membership( - new_event, - do_auth=False, + logger.debug("Event: %s", event) + + try: + self.auth.check(event, raises=True) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, ) - else: - with (yield self.room_lock.lock(event.room_id)): - yield self.store.persist_event( - event, - backfilled, - is_new_state=is_new_state - ) + is_new_state = is_new_state and not backfilled - room = yield self.store.get_room(event.room_id) + # TODO: Implement something in federation that allows us to + # respond to PDU. - if not room: - # Huh, let's try and get the current state - try: - yield self.replication_layer.get_state_for_context( - event.origin, event.room_id - ) + yield self.store.persist_event( + event, + backfilled, + is_new_state=is_new_state + ) - hosts = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - if self.hs.hostname in hosts: - try: - yield self.store.store_room( - room_id=event.room_id, - room_creator_user_id="", - is_public=False, - ) - except: - pass - except: - logger.exception( - "Failed to get current state for room %s", - event.room_id - ) + room = yield self.store.get_room(event.room_id) - if not backfilled: - extra_users = [] - if event.type == RoomMemberEvent.TYPE: - target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) - extra_users.append(target_user) + if not room: + # Huh, let's try and get the current state + try: + yield self.replication_layer.get_state_for_context( + event.origin, event.room_id, event.event_id, + ) - yield self.notifier.on_new_room_event( - event, extra_users=extra_users + hosts = yield self.store.get_joined_hosts_for_room( + event.room_id ) + if self.hs.hostname in hosts: + try: + yield self.store.store_room( + room_id=event.room_id, + room_creator_user_id="", + is_public=False, + ) + except: + pass + except: + logger.exception( + "Failed to get current state for room %s", + event.room_id + ) + + if not backfilled: + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) if event.type == RoomMemberEvent.TYPE: if event.membership == Membership.JOIN: @@ -189,79 +187,344 @@ def on_receive_pdu(self, pdu, backfilled): @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit): - pdus = yield self.replication_layer.backfill(dest, room_id, limit) + extremities = yield self.store.get_oldest_events_in_room(room_id) + + pdus = yield self.replication_layer.backfill( + dest, + room_id, + limit, + extremities=extremities, + ) events = [] for pdu in pdus: event = self.pdu_codec.event_from_pdu(pdu) + + # FIXME (erikj): Not sure this actually works :/ + yield self.state_handler.annotate_state_groups(event) + events.append(event) + yield self.store.persist_event(event, backfilled=True) defer.returnValue(events) + @defer.inlineCallbacks + def send_invite(self, target_host, event): + pdu = yield self.replication_layer.send_invite( + destination=target_host, + context=event.room_id, + event_id=event.event_id, + pdu=self.pdu_codec.pdu_from_event(event) + ) + + defer.returnValue(self.pdu_codec.event_from_pdu(pdu)) + + @defer.inlineCallbacks + def on_event_auth(self, event_id): + auth = yield self.store.get_auth_chain(event_id) + defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth]) + @log_function @defer.inlineCallbacks def do_invite_join(self, target_host, room_id, joinee, content, snapshot): - hosts = yield self.store.get_joined_hosts_for_room(room_id) if self.hs.hostname in hosts: # We are already in the room. logger.debug("We're already in the room apparently") defer.returnValue(False) - # First get current state to see if we are already joined. + pdu = yield self.replication_layer.make_join( + target_host, + room_id, + joinee + ) + + logger.debug("Got response to make_join: %s", pdu) + + event = self.pdu_codec.event_from_pdu(pdu) + + # We should assert some things. + assert(event.type == RoomMemberEvent.TYPE) + assert(event.user_id == joinee) + assert(event.state_key == joinee) + assert(event.room_id == room_id) + + event.outlier = False + + self.room_queues[room_id] = [] + try: - yield self.replication_layer.get_state_for_context( - target_host, room_id + event.event_id = self.event_factory.create_event_id() + event.content = content + + state = yield self.replication_layer.send_join( + target_host, + self.pdu_codec.pdu_from_event(event) ) - hosts = yield self.store.get_joined_hosts_for_room(room_id) - if self.hs.hostname in hosts: - # Oh, we were actually in the room already. - logger.debug("We're already in the room apparently") - defer.returnValue(False) - except Exception: - logger.exception("Failed to get current state") - - new_event = self.event_factory.create_event( - etype=InviteJoinEvent.TYPE, - target_host=target_host, - room_id=room_id, - user_id=joinee, - content=content - ) + state = [self.pdu_codec.event_from_pdu(p) for p in state] - new_event.destinations = [target_host] + logger.debug("do_invite_join state: %s", state) - snapshot.fill_out_prev_events(new_event) - yield self.handle_new_event(new_event, snapshot) + is_new_state = yield self.state_handler.annotate_state_groups( + event, + old_state=state + ) - # TODO (erikj): Time out here. - d = defer.Deferred() - self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d) - reactor.callLater(10, d.cancel) + logger.debug("do_invite_join event: %s", event) - try: - yield d - except defer.CancelledError: - raise SynapseError(500, "Unable to join remote room") + try: + yield self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=False + ) + except: + # FIXME + pass - try: - yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + for e in state: + # FIXME: Auth these. + e.outlier = True + + yield self.state_handler.annotate_state_groups( + e, + ) + + yield self.store.persist_event( + e, + backfilled=False, + is_new_state=False + ) + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state ) - except: - pass + finally: + room_queue = self.room_queues[room_id] + del self.room_queues[room_id] + for p in room_queue: + try: + yield self.on_receive_pdu(p, backfilled=False) + except: + pass defer.returnValue(True) + @defer.inlineCallbacks + @log_function + def on_make_join_request(self, context, user_id): + event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + content={"membership": Membership.JOIN}, + room_id=context, + user_id=user_id, + state_key=user_id, + ) + + snapshot = yield self.store.snapshot_room(event) + snapshot.fill_out_prev_events(event) + + yield self.state_handler.annotate_state_groups(event) + yield self.auth.add_auth_events(event) + self.auth.check(event, raises=True) + + pdu = self.pdu_codec.pdu_from_event(event) + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def on_send_join_request(self, origin, pdu): + event = self.pdu_codec.event_from_pdu(pdu) + + event.outlier = False + + is_new_state = yield self.state_handler.annotate_state_groups(event) + self.auth.check(event, raises=True) + + # FIXME (erikj): All this is duplicated above :( + + yield self.store.persist_event( + event, + backfilled=False, + is_new_state=is_new_state + ) + + extra_users = [] + if event.type == RoomMemberEvent.TYPE: + target_user_id = event.state_key + target_user = self.hs.parse_userid(target_user_id) + extra_users.append(target_user) + + yield self.notifier.on_new_room_event( + event, extra_users=extra_users + ) + + if event.type == RoomMemberEvent.TYPE: + if event.membership == Membership.JOIN: + user = self.hs.parse_userid(event.state_key) + self.distributor.fire( + "user_joined_room", user=user, room_id=event.room_id + ) + + new_pdu = self.pdu_codec.pdu_from_event(event) + + destinations = set() + + for k, s in event.state_events.items(): + try: + if k[0] == RoomMemberEvent.TYPE: + if s.content["membership"] == Membership.JOIN: + destinations.add( + self.hs.parse_userid(s.state_key).domain + ) + except: + logger.warn( + "Failed to get destination from event %s", s.event_id + ) + + new_pdu.destinations = list(destinations) + + yield self.replication_layer.send_pdu(new_pdu) + + auth_chain = yield self.store.get_auth_chain(event.event_id) + pdu_auth_chain = [ + self.pdu_codec.pdu_from_event(e) + for e in auth_chain + ] + + defer.returnValue({ + "state": [ + self.pdu_codec.pdu_from_event(e) + for e in event.state_events.values() + ], + "auth_chain": pdu_auth_chain, + }) + + @defer.inlineCallbacks + def on_invite_request(self, origin, pdu): + event = self.pdu_codec.event_from_pdu(pdu) + + event.outlier = True + + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) + + yield self.state_handler.annotate_state_groups(event) + + yield self.store.persist_event( + event, + backfilled=False, + ) + + target_user = self.hs.parse_userid(event.state_key) + yield self.notifier.on_new_room_event( + event, extra_users=[target_user], + ) + + defer.returnValue(self.pdu_codec.pdu_from_event(event)) + + @defer.inlineCallbacks + def get_state_for_pdu(self, origin, room_id, event_id): + yield run_on_reactor() + + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + state_groups = yield self.store.get_state_groups( + [event_id] + ) + + if state_groups: + _, state = state_groups.items().pop() + results = { + (e.type, e.state_key): e for e in state + } + + event = yield self.store.get_event(event_id) + if hasattr(event, "state_key"): + # Get previous state + if hasattr(event, "replaces_state") and event.replaces_state: + prev_event = yield self.store.get_event( + event.replaces_state + ) + results[(event.type, event.state_key)] = prev_event + else: + del results[(event.type, event.state_key)] + + defer.returnValue( + [ + self.pdu_codec.pdu_from_event(s) + for s in results.values() + ] + ) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, origin, context, pdu_list, limit): + in_room = yield self.auth.check_host_in_room(context, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + events = yield self.store.get_backfill_events( + context, + pdu_list, + limit + ) + + defer.returnValue([ + self.pdu_codec.pdu_from_event(e) + for e in events + ]) + + @defer.inlineCallbacks + @log_function + def get_persisted_pdu(self, origin, event_id): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + event = yield self.store.get_event( + event_id, + allow_none=True, + ) + + if event: + in_room = yield self.auth.check_host_in_room( + event.room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + + defer.returnValue(self.pdu_codec.pdu_from_event(event)) + else: + defer.returnValue(None) + + @log_function + def get_min_depth_for_context(self, context): + return self.store.get_min_depth(context) @log_function def _on_user_joined(self, user, room_id): - waiters = self.waiting_for_join_list.get((user.to_string(), room_id), []) + waiters = self.waiting_for_join_list.get( + (user.to_string(), room_id), + [] + ) while waiters: waiters.pop().callback(None) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72894869ea37..8394013df33a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,12 +81,11 @@ def send_message(self, event=None, suppress_auth=False): user = self.hs.parse_userid(event.user_id) assert user.is_mine, "User must be our own: %s" % (user,) - snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) + snapshot = yield self.store.snapshot_room(event) - if not suppress_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self._on_new_room_event(event, snapshot) + yield self._on_new_room_event( + event, snapshot, suppress_auth=suppress_auth + ) self.hs.get_handlers().presence_handler.bump_presence_active_time( user @@ -142,16 +141,7 @@ def store_room_data(self, event=None): SynapseError if something went wrong. """ - snapshot = yield self.store.snapshot_room( - event.room_id, - event.user_id, - state_type=event.type, - state_key=event.state_key, - ) - - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) + snapshot = yield self.store.snapshot_room(event) yield self._on_new_room_event(event, snapshot) @@ -201,7 +191,7 @@ def get_room_data(self, user_id=None, room_id=None, raise RoomError( 403, "Member does not meet private room rules.") - data = yield self.store.get_current_state( + data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) defer.returnValue(data) @@ -219,9 +209,7 @@ def get_feedback(self, event_id): @defer.inlineCallbacks def send_feedback(self, event): - snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - - yield self.auth.check(event, snapshot, raises=True) + snapshot = yield self.store.snapshot_room(event) # store message in db yield self._on_new_room_event(event, snapshot) @@ -239,7 +227,7 @@ def get_state_events(self, user_id, room_id): yield self.auth.check_joined_room(room_id, user_id) # TODO: This is duplicating logic from snapshot_all_rooms - current_state = yield self.store.get_current_state(room_id) + current_state = yield self.state_handler.get_current_state(room_id) defer.returnValue([self.hs.serialize_event(c) for c in current_state]) @defer.inlineCallbacks @@ -316,7 +304,7 @@ def snapshot_all_rooms(self, user_id=None, pagin_config=None, "end": end_token.to_string(), } - current_state = yield self.store.get_current_state( + current_state = yield self.state_handler.get_current_state( event.room_id ) d["state"] = [self.hs.serialize_event(c) for c in current_state] diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dab9b03f045b..834b37f5f3dc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import Membership -from synapse.api.events.room import RoomMemberEvent from ._base import BaseHandler @@ -153,10 +152,13 @@ def collect_presencelike_data(self, user, state): if not user.is_mine: defer.returnValue(None) - (displayname, avatar_url) = yield defer.gatherResults([ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ]) + (displayname, avatar_url) = yield defer.gatherResults( + [ + self.store.get_profile_displayname(user.localpart), + self.store.get_profile_avatar_url(user.localpart), + ], + consumeErrors=True + ) state["displayname"] = displayname state["avatar_url"] = avatar_url @@ -196,10 +198,7 @@ def _update_join_states(self, user): ) for j in joins: - snapshot = yield self.store.snapshot_room( - j.room_id, j.state_key, RoomMemberEvent.TYPE, - j.state_key - ) + snapshot = yield self.store.snapshot_room(j) content = { "membership": j.content["membership"], @@ -218,5 +217,6 @@ def _update_join_states(self, user): user_id=j.state_key, ) - yield self.state_handler.handle_new_event(new_event, snapshot) - yield self._on_new_room_event(new_event, snapshot) + yield self._on_new_room_event( + new_event, snapshot, suppress_auth=True + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 81ce1a5907a7..3642fcfc6db6 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,8 +21,7 @@ from synapse.api.errors import StoreError, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent, - RoomJoinRulesEvent, RoomAddStateLevelEvent, RoomTopicEvent, - RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent, + RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent, ) from synapse.util import stringutils from ._base import BaseHandler @@ -122,15 +121,13 @@ def create_room(self, user_id, room_id, config): @defer.inlineCallbacks def handle_event(event): - snapshot = yield self.store.snapshot_room( - room_id=room_id, - user_id=user_id, - ) + snapshot = yield self.store.snapshot_room(event) logger.debug("Event: %s", event) - yield self.state_handler.handle_new_event(event, snapshot) - yield self._on_new_room_event(event, snapshot, extra_users=[user]) + yield self._on_new_room_event( + event, snapshot, extra_users=[user], suppress_auth=True + ) for event in creation_events: yield handle_event(event) @@ -141,7 +138,6 @@ def handle_event(event): etype=RoomNameEvent.TYPE, room_id=room_id, user_id=user_id, - required_power_level=50, content={"name": name}, ) @@ -153,7 +149,6 @@ def handle_event(event): etype=RoomTopicEvent.TYPE, room_id=room_id, user_id=user_id, - required_power_level=50, content={"topic": topic}, ) @@ -198,7 +193,6 @@ def _create_events_for_new_room(self, creator, room_id, is_public=False): event_keys = { "room_id": room_id, "user_id": creator.to_string(), - "required_power_level": 100, } def create(etype, **content): @@ -215,7 +209,21 @@ def create(etype, **content): power_levels_event = self.event_factory.create_event( etype=RoomPowerLevelsEvent.TYPE, - content={creator.to_string(): 100, "default": 0}, + content={ + "users": { + creator.to_string(): 100, + }, + "users_default": 0, + "events": { + RoomNameEvent.TYPE: 100, + RoomPowerLevelsEvent.TYPE: 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50 + }, **event_keys ) @@ -225,30 +233,10 @@ def create(etype, **content): join_rule=join_rule, ) - add_state_event = create( - etype=RoomAddStateLevelEvent.TYPE, - level=100, - ) - - send_event = create( - etype=RoomSendEventLevelEvent.TYPE, - level=0, - ) - - ops = create( - etype=RoomOpsPowerLevelsEvent.TYPE, - ban_level=50, - kick_level=50, - redact_level=50, - ) - return [ creation_event, power_levels_event, join_rules_event, - add_state_event, - send_event, - ops, ] @@ -363,10 +351,8 @@ def change_membership(self, event=None, do_auth=True): """ target_user_id = event.state_key - snapshot = yield self.store.snapshot_room( - event.room_id, event.user_id, - RoomMemberEvent.TYPE, target_user_id - ) + snapshot = yield self.store.snapshot_room(event) + ## TODO(markjh): get prev state from snapshot. prev_state = yield self.store.get_room_member( target_user_id, event.room_id @@ -375,13 +361,6 @@ def change_membership(self, event=None, do_auth=True): if prev_state: event.content["prev"] = prev_state.membership -# if prev_state and prev_state.membership == event.membership: -# # treat this event as a NOOP. -# if do_auth: # This is mainly to fix a unit test. -# yield self.auth.check(event, raises=True) -# defer.returnValue({}) -# return - room_id = event.room_id # If we're trying to join a room then we have to do this differently @@ -391,29 +370,17 @@ def change_membership(self, event=None, do_auth=True): yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. - if do_auth: - yield self.auth.check(event, snapshot, raises=True) - - # If we're banning someone, set a req power level - if event.membership == Membership.BAN: - if not hasattr(event, "required_power_level") or event.required_power_level is None: - # Add some default required_power_level - user_level = yield self.store.get_power_level( - event.room_id, - event.user_id, - ) - event.required_power_level = user_level if prev_state and prev_state.membership == event.membership: # double same action, treat this event as a NOOP. defer.returnValue({}) return - yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) defer.returnValue({"room_id": room_id}) @@ -443,10 +410,7 @@ def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): content=content, ) - snapshot = yield self.store.snapshot_room( - room_id, joinee.to_string(), RoomMemberEvent.TYPE, - joinee.to_string() - ) + snapshot = yield self.store.snapshot_room(new_event) yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) @@ -502,14 +466,11 @@ def _do_join(self, event, snapshot, room_host=None, do_auth=True): if not have_joined: logger.debug("Doing normal join") - if do_auth: - yield self.auth.check(event, snapshot, raises=True) - - yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], snapshot=snapshot, + do_auth=do_auth, ) user = self.hs.parse_userid(event.user_id) @@ -553,26 +514,27 @@ def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): defer.returnValue([r.room_id for r in rooms]) - def _do_local_membership_update(self, event, membership, snapshot): - destinations = [] - + @defer.inlineCallbacks + def _do_local_membership_update(self, event, membership, snapshot, + do_auth): # If we're inviting someone, then we should also send it to that # HS. target_user_id = event.state_key target_user = self.hs.parse_userid(target_user_id) - if membership == Membership.INVITE: - host = target_user.domain - destinations.append(host) - - # Always include target domain - host = target_user.domain - destinations.append(host) - - return self._on_new_room_event( - event, snapshot, extra_destinations=destinations, - extra_users=[target_user] + if membership == Membership.INVITE and not target_user.is_mine: + do_invite_host = target_user.domain + else: + do_invite_host = None + + yield self._on_new_room_event( + event, + snapshot, + extra_users=[target_user], + suppress_auth=(not do_auth), + do_invite_host=do_invite_host, ) + class RoomListHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/rest/base.py b/synapse/rest/base.py index 2e8e3fa7d4a2..79fc4dfb84a2 100644 --- a/synapse/rest/base.py +++ b/synapse/rest/base.py @@ -18,6 +18,11 @@ from synapse.rest.transactions import HttpTransactionStore import re +import logging + + +logger = logging.getLogger(__name__) + def client_path_pattern(path_regex): """Creates a regex compiled client path with the correct client path @@ -62,6 +67,8 @@ def __init__(self, hs): self.auth = hs.get_auth() self.txns = HttpTransactionStore() + self.validator = hs.get_event_validator() + def register(self, http_server): """ Register this servlet with the given HTTP server. """ if hasattr(self, "PATTERN"): diff --git a/synapse/rest/events.py b/synapse/rest/events.py index 097195d7cc37..92ff5e5ca7d2 100644 --- a/synapse/rest/events.py +++ b/synapse/rest/events.py @@ -20,6 +20,12 @@ from synapse.streams.config import PaginationConfig from synapse.rest.base import RestServlet, client_path_pattern +import logging + + +logger = logging.getLogger(__name__) + + class EventStreamRestServlet(RestServlet): PATTERN = client_path_pattern("/events$") @@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): auth_user = yield self.auth.get_user_by_req(request) - - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - chunk = yield handler.get_stream(auth_user.to_string(), pagin_config, - timeout=timeout) + try: + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + chunk = yield handler.get_stream( + auth_user.to_string(), pagin_config, timeout=timeout + ) + except: + logger.exception("Event stream failed") + raise defer.returnValue((200, chunk)) diff --git a/synapse/rest/room.py b/synapse/rest/room.py index 77249670618d..05da0be09014 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -138,7 +138,7 @@ def on_GET(self, request, room_id, event_type, state_key): raise SynapseError( 404, "Event not found.", errcode=Codes.NOT_FOUND ) - defer.returnValue((200, data[0].get_dict()["content"])) + defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key): @@ -154,6 +154,9 @@ def on_PUT(self, request, room_id, event_type, state_key): user_id=user.to_string(), state_key=urllib.unquote(state_key) ) + + self.validator.validate(event) + if event_type == RoomMemberEvent.TYPE: # membership events are special handler = self.handlers.room_member_handler @@ -188,6 +191,8 @@ def on_POST(self, request, room_id, event_type): content=content ) + self.validator.validate(event) + msg_handler = self.handlers.message_handler yield msg_handler.send_message(event) @@ -253,6 +258,9 @@ def on_POST(self, request, room_identifier): user_id=user.to_string(), state_key=user.to_string() ) + + self.validator.validate(event) + handler = self.handlers.room_member_handler yield handler.change_membership(event) defer.returnValue((200, {})) @@ -424,6 +432,9 @@ def on_POST(self, request, room_id, membership_action): user_id=user.to_string(), state_key=state_key ) + + self.validator.validate(event) + handler = self.handlers.room_member_handler yield handler.change_membership(event) defer.returnValue((200, {})) @@ -461,6 +472,8 @@ def on_POST(self, request, room_id, event_id): redacts=urllib.unquote(event_id), ) + self.validator.validate(event) + msg_handler = self.handlers.message_handler yield msg_handler.send_message(event) diff --git a/synapse/server.py b/synapse/server.py index a4d2d4aba505..da0a44433a28 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -22,13 +22,14 @@ from synapse.federation import initialize_http_replication from synapse.api.events import serialize_event from synapse.api.events.factory import EventFactory +from synapse.api.events.validator import EventValidator from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers from synapse.rest import RestServletFactory from synapse.state import StateHandler from synapse.storage import DataStore -from synapse.types import UserID, RoomAlias, RoomID +from synapse.types import UserID, RoomAlias, RoomID, EventID from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.lockutils import LockManager @@ -80,6 +81,7 @@ def build_DEPENDENCY(self) 'event_sources', 'ratelimiter', 'keyring', + 'event_validator', ] def __init__(self, hostname, **kwargs): @@ -143,6 +145,11 @@ def parse_roomid(self, s): object.""" return RoomID.from_string(s, hs=self) + def parse_eventid(self, s): + """Parse the string given by 's' as a Event ID and return a EventID + object.""" + return EventID.from_string(s, hs=self) + def serialize_event(self, e): return serialize_event(self, e) @@ -218,6 +225,9 @@ def build_ratelimiter(self): def build_keyring(self): return Keyring(self) + def build_event_validator(self): + return EventValidator(self) + def register_servlets(self): """ Register all servlets associated with this HomeServer. """ diff --git a/synapse/state.py b/synapse/state.py index 9db84c9b5cb7..11c54fd38c14 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,11 +16,13 @@ from twisted.internet import defer -from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function +from synapse.util.async import run_on_reactor +from synapse.api.events.room import RoomPowerLevelsEvent from collections import namedtuple +import copy import logging import hashlib @@ -35,230 +37,169 @@ def _get_state_key_from_event(event): class StateHandler(object): - """ Repsonsible for doing state conflict resolution. + """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() - self._replication = hs.get_replication_layer() - self.server_name = hs.hostname @defer.inlineCallbacks @log_function - def handle_new_event(self, event, snapshot): - """ Given an event this works out if a) we have sufficient power level - to update the state and b) works out what the prev_state should be. + def annotate_state_groups(self, event, old_state=None): + yield run_on_reactor() - Returns: - Deferred: Resolved with a boolean indicating if we succesfully - updated the state. + if old_state: + event.state_group = None + event.old_state_events = { + (s.type, s.state_key): s for s in old_state + } + event.state_events = event.old_state_events - Raised: - AuthError - """ - # This needs to be done in a transaction. + if hasattr(event, "state_key"): + event.state_events[(event.type, event.state_key)] = event - if not hasattr(event, "state_key"): + defer.returnValue(False) return - key = KeyStateTuple( - event.room_id, - event.type, - _get_state_key_from_event(event) - ) - - # Now I need to fill out the prev state and work out if it has auth - # (w.r.t. to power levels) - - snapshot.fill_out_prev_events(event) - - event.prev_events = [ - e for e in event.prev_events if e != event.event_id - ] + if hasattr(event, "outlier") and event.outlier: + event.state_group = None + event.old_state_events = None + event.state_events = {} + defer.returnValue(False) + return - current_state = snapshot.prev_state_pdu + ids = [e for e, _ in event.prev_events] - if current_state: - event.prev_state = encode_event_id( - current_state.pdu_id, current_state.origin - ) + ret = yield self.resolve_state_groups(ids) + state_group, new_state = ret - # TODO check current_state to see if the min power level is less - # than the power level of the user - # power_level = self._get_power_level_for_event(event) + event.old_state_events = copy.deepcopy(new_state) - pdu_id, origin = decode_event_id(event.event_id, self.server_name) + if hasattr(event, "state_key"): + key = (event.type, event.state_key) + if key in new_state: + event.replaces_state = new_state[key].event_id + new_state[key] = event + elif state_group: + event.state_group = state_group + event.state_events = new_state + defer.returnValue(False) - yield self.store.update_current_state( - pdu_id=pdu_id, - origin=origin, - context=key.context, - pdu_type=key.type, - state_key=key.state_key - ) + event.state_group = None + event.state_events = new_state - defer.returnValue(True) + defer.returnValue(hasattr(event, "state_key")) @defer.inlineCallbacks - @log_function - def handle_new_state(self, new_pdu): - """ Apply conflict resolution to `new_pdu`. - - This should be called on every new state pdu, regardless of whether or - not there is a conflict. - - This function is safe against the race of it getting called with two - `PDU`s trying to update the same state. - """ - - # This needs to be done in a transaction. - - is_new = yield self._handle_new_state(new_pdu) + def get_current_state(self, room_id, event_type=None, state_key=""): + events = yield self.store.get_latest_events_in_room(room_id) - logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + event_ids = [ + e_id + for e_id, _, _ in events + ] - if is_new: - yield self.store.update_current_state( - pdu_id=new_pdu.pdu_id, - origin=new_pdu.origin, - context=new_pdu.context, - pdu_type=new_pdu.pdu_type, - state_key=new_pdu.state_key - ) + res = yield self.resolve_state_groups(event_ids) - defer.returnValue(is_new) + if event_type: + defer.returnValue(res[1].get((event_type, state_key))) + return - def _get_power_level_for_event(self, event): - # return self._persistence.get_power_level_for_user(event.room_id, - # event.sender) - return event.power_level + defer.returnValue(res[1].values()) @defer.inlineCallbacks @log_function - def _handle_new_state(self, new_pdu): - tree, missing_branch = yield self.store.get_unresolved_state_tree( - new_pdu - ) - new_branch, current_branch = tree - - logger.debug( - "_handle_new_state new=%s, current=%s", - new_branch, current_branch + def resolve_state_groups(self, event_ids): + state_groups = yield self.store.get_state_groups( + event_ids ) - if missing_branch is not None: - # We're missing some PDUs. Fetch them. - # TODO (erikj): Limit this. - missing_prev = tree[missing_branch][-1] - - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin - - is_missing = yield self.store.get_pdu(pdu_id, origin) is None - if not is_missing: - raise Exception("Conflict resolution failed") - - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) - - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) - - if not current_branch: - # There is no current state - defer.returnValue(True) - return - - n = new_branch[-1] - c = current_branch[-1] - - common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin - - if common_ancestor: - # We found a common ancestor! - - if len(current_branch) == 1: - # This is a direct clobber so we can just... - defer.returnValue(True) + group_names = set(state_groups.keys()) + if len(group_names) == 1: + name, state_list = state_groups.items().pop() + state = { + (e.type, e.state_key): e + for e in state_list + } + defer.returnValue((name, state)) + + state = {} + for group, g_state in state_groups.items(): + for s in g_state: + state.setdefault( + (s.type, s.state_key), + {} + )[s.event_id] = s + + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } + + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } + + try: + new_state = {} + new_state.update(unconflicted_state) + for key, events in conflicted_state.items(): + new_state[key] = self._resolve_state_events(events) + except: + logger.exception("Failed to resolve state") + raise + + defer.returnValue((None, new_state)) + + def _get_power_level_from_event_state(self, event, user_id): + if hasattr(event, "old_state_events") and event.old_state_events: + key = (RoomPowerLevelsEvent.TYPE, "", ) + power_level_event = event.old_state_events.get(key) + level = None + if power_level_event: + level = power_level_event.content.get("users", {}).get( + user_id + ) + if not level: + level = power_level_event.content.get("users_default", 0) + return level else: - # We didn't find a common ancestor. This is probably fine. - pass + return 0 - result = yield self._do_conflict_res( - new_branch, current_branch, common_ancestor - ) - defer.returnValue(result) + @log_function + def _resolve_state_events(self, events): + curr_events = events - @defer.inlineCallbacks - def _do_conflict_res(self, new_branch, current_branch, common_ancestor): - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, + new_powers = [ + self._get_power_level_from_event_state(e, e.user_id) + for e in curr_events ] - for algo in conflict_res: - new_res, curr_res = yield defer.maybeDeferred( - algo, - new_branch, current_branch, common_ancestor - ) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - - @defer.inlineCallbacks - def _do_power_level_conflict_res(self, new_branch, current_branch, - common_ancestor): - new_powers_deferreds = [] - for e in new_branch[:-1] if common_ancestor else new_branch: - if hasattr(e, "user_id"): - new_powers_deferreds.append( - self.store.get_power_level(e.context, e.user_id) - ) - - current_powers_deferreds = [] - for e in current_branch[:-1] if common_ancestor else current_branch: - if hasattr(e, "user_id"): - current_powers_deferreds.append( - self.store.get_power_level(e.context, e.user_id) - ) - - new_powers = yield defer.gatherResults( - new_powers_deferreds, - consumeErrors=True - ) - - current_powers = yield defer.gatherResults( - current_powers_deferreds, - consumeErrors=True - ) + new_powers = [ + int(p) if p else 0 for p in new_powers + ] - max_power_new = max(new_powers) - max_power_current = max(current_powers) + max_power = max(new_powers) - defer.returnValue( - (max_power_new, max_power_current) - ) - - def _do_chain_length_conflict_res(self, new_branch, current_branch, - common_ancestor): - return (len(new_branch), len(current_branch)) + curr_events = [ + z[0] for z in zip(curr_events, new_powers) + if z[1] == max_power + ] - def _do_hash_conflict_res(self, new_branch, current_branch, - common_ancestor): - new_str = "".join([p.pdu_id + p.origin for p in new_branch]) - c_str = "".join([p.pdu_id + p.origin for p in current_branch]) + if not curr_events: + raise RuntimeError("Max didn't get a max?") + elif len(curr_events) == 1: + return curr_events[0] + # TODO: For now, just choose the one with the largest event_id. return ( - hashlib.sha1(new_str).hexdigest(), - hashlib.sha1(c_str).hexdigest() + sorted( + curr_events, + key=lambda e: hashlib.sha1( + e.event_id + e.user_id + e.room_id + e.type + ).hexdigest() + )[0] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4e9291fdff2a..4034437f6b4a 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -16,14 +16,7 @@ from twisted.internet import defer from synapse.api.events.room import ( - RoomMemberEvent, RoomTopicEvent, FeedbackEvent, -# RoomConfigEvent, - RoomNameEvent, - RoomJoinRulesEvent, - RoomPowerLevelsEvent, - RoomAddStateLevelEvent, - RoomSendEventLevelEvent, - RoomOpsPowerLevelsEvent, + RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent, RoomRedactionEvent, ) @@ -37,9 +30,17 @@ from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore +from .event_federation import EventFederationStore + +from .state import StateStore +from .signatures import SignatureStore + +from syutil.base64util import decode_base64 + +from synapse.crypto.event_signing import compute_event_reference_hash + import json import logging @@ -51,7 +52,6 @@ SCHEMAS = [ "transactions", - "pdu", "users", "profiles", "presence", @@ -59,6 +59,9 @@ "room_aliases", "keys", "redactions", + "state", + "event_edges", + "event_signatures", ] @@ -73,10 +76,12 @@ class _RollbackButIsFineException(Exception): """ pass + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, - PresenceStore, PduStore, StatePduStore, TransactionStore, - DirectoryStore, KeyStore): + PresenceStore, TransactionStore, + DirectoryStore, KeyStore, StateStore, SignatureStore, + EventFederationStore, ): def __init__(self, hs): super(DataStore, self).__init__(hs) @@ -88,8 +93,7 @@ def __init__(self, hs): @defer.inlineCallbacks @log_function - def persist_event(self, event=None, backfilled=False, pdu=None, - is_new_state=True): + def persist_event(self, event, backfilled=False, is_new_state=True): stream_ordering = None if backfilled: if not self.min_token_deferred.called: @@ -99,8 +103,8 @@ def persist_event(self, event=None, backfilled=False, pdu=None, try: yield self.runInteraction( - self._persist_pdu_event_txn, - pdu=pdu, + "persist_event", + self._persist_event_txn, event=event, backfilled=backfilled, stream_ordering=stream_ordering, @@ -119,7 +123,8 @@ def get_event(self, event_id, allow_none=False): "type", "room_id", "content", - "unrecognized_keys" + "unrecognized_keys", + "depth", ], allow_none=allow_none, ) @@ -130,42 +135,6 @@ def get_event(self, event_id, allow_none=False): event = self._parse_event_from_row(events_dict) defer.returnValue(event) - def _persist_pdu_event_txn(self, txn, pdu=None, event=None, - backfilled=False, stream_ordering=None, - is_new_state=True): - if pdu is not None: - self._persist_event_pdu_txn(txn, pdu) - if event is not None: - return self._persist_event_txn( - txn, event, backfilled, stream_ordering, - is_new_state=is_new_state, - ) - - def _persist_event_pdu_txn(self, txn, pdu): - cols = dict(pdu.__dict__) - unrec_keys = dict(pdu.unrecognized_keys) - del cols["content"] - del cols["prev_pdus"] - cols["content_json"] = json.dumps(pdu.content) - - unrec_keys.update({ - k: v for k, v in cols.items() - if k not in PdusTable.fields - }) - - cols["unrecognized_keys"] = json.dumps(unrec_keys) - - cols["ts"] = cols.pop("origin_server_ts") - - logger.debug("Persisting: %s", repr(cols)) - - if pdu.is_state: - self._persist_state_txn(txn, pdu.prev_pdus, cols) - else: - self._persist_pdu_txn(txn, pdu.prev_pdus, cols) - - self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) - @log_function def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, is_new_state=True): @@ -177,19 +146,13 @@ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, self._store_room_name_txn(txn, event) elif event.type == RoomTopicEvent.TYPE: self._store_room_topic_txn(txn, event) - elif event.type == RoomJoinRulesEvent.TYPE: - self._store_join_rule(txn, event) - elif event.type == RoomPowerLevelsEvent.TYPE: - self._store_power_levels(txn, event) - elif event.type == RoomAddStateLevelEvent.TYPE: - self._store_add_state_level(txn, event) - elif event.type == RoomSendEventLevelEvent.TYPE: - self._store_send_event_level(txn, event) - elif event.type == RoomOpsPowerLevelsEvent.TYPE: - self._store_ops_level(txn, event) elif event.type == RoomRedactionEvent.TYPE: self._store_redaction(txn, event) + outlier = False + if hasattr(event, "outlier"): + outlier = event.outlier + vals = { "topological_ordering": event.depth, "event_id": event.event_id, @@ -197,25 +160,33 @@ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, "room_id": event.room_id, "content": json.dumps(event.content), "processed": True, + "outlier": outlier, + "depth": event.depth, } if stream_ordering is not None: vals["stream_ordering"] = stream_ordering - if hasattr(event, "outlier"): - vals["outlier"] = event.outlier - else: - vals["outlier"] = False - unrec = { k: v for k, v in event.get_full_dict().items() - if k not in vals.keys() and k not in ["redacted", "redacted_because"] + if k not in vals.keys() and k not in [ + "redacted", + "redacted_because", + "signatures", + "hashes", + "prev_events", + ] } vals["unrecognized_keys"] = json.dumps(unrec) try: - self._simple_insert_txn(txn, "events", vals) + self._simple_insert_txn( + txn, + "events", + vals, + or_replace=(not outlier), + ) except: logger.warn( "Failed to persist, probably duplicate: %s", @@ -224,6 +195,16 @@ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, ) raise _RollbackButIsFineException("_persist_event") + self._handle_prev_events( + txn, + outlier=outlier, + event_id=event.event_id, + prev_events=event.prev_events, + room_id=event.room_id, + ) + + self._store_state_groups_txn(txn, event) + is_state = hasattr(event, "state_key") and event.state_key is not None if is_new_state and is_state: vals = { @@ -233,8 +214,8 @@ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, "state_key": event.state_key, } - if hasattr(event, "prev_state"): - vals["prev_state"] = event.prev_state + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state self._simple_insert_txn(txn, "state_events", vals) @@ -249,6 +230,81 @@ def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, } ) + for e_id, h in event.prev_state: + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event.event_id, + "prev_event_id": e_id, + "room_id": event.room_id, + "is_state": 1, + }, + or_ignore=True, + ) + + if not backfilled: + self._simple_insert_txn( + txn, + table="state_forward_extremities", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + ) + + for prev_state_id, _ in event.prev_state: + self._simple_delete_txn( + txn, + table="state_forward_extremities", + keyvalues={ + "event_id": prev_state_id, + } + ) + + for hash_alg, hash_base64 in event.hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_event_content_hash_txn( + txn, event.event_id, hash_alg, hash_bytes, + ) + + if hasattr(event, "signatures"): + signatures = event.signatures.get(event.origin, {}) + + for key_id, signature_base64 in signatures.items(): + signature_bytes = decode_base64(signature_base64) + self._store_event_origin_signature_txn( + txn, event.event_id, event.origin, key_id, signature_bytes, + ) + + for prev_event_id, prev_hashes in event.prev_events: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_event_hash_txn( + txn, event.event_id, prev_event_id, alg, hash_bytes + ) + + for auth_id, _ in event.auth_events: + self._simple_insert_txn( + txn, + table="event_auth", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + }, + or_ignore=True, + ) + + (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) + self._store_event_reference_hash_txn( + txn, event.event_id, ref_alg, ref_hash_bytes + ) + + self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) + def _store_redaction(self, txn, event): txn.execute( "INSERT OR IGNORE INTO redactions " @@ -319,7 +375,7 @@ def get_user_ip_and_agents(self, user): ], ) - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + def snapshot_room(self, event): """Snapshot the room for an update by a user Args: room_id (synapse.types.RoomId): The room to snapshot. @@ -330,29 +386,33 @@ def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): synapse.storage.Snapshot: A snapshot of the state of the room. """ def _snapshot(txn): - membership_state = self._get_room_member(txn, user_id, room_id) - prev_pdus = self._get_latest_pdus_in_context( - txn, room_id + prev_events = self._get_latest_events_in_room( + txn, + event.room_id ) - if state_type is not None and state_key is not None: - prev_state_pdu = self._get_current_state_pdu( - txn, room_id, state_type, state_key + + prev_state = None + state_key = None + if hasattr(event, "state_key"): + state_key = event.state_key + prev_state = self._get_latest_state_in_room( + txn, + event.room_id, + type=event.type, + state_key=state_key, ) - else: - prev_state_pdu = None return Snapshot( store=self, - room_id=room_id, - user_id=user_id, - prev_pdus=prev_pdus, - membership_state=membership_state, - state_type=state_type, + room_id=event.room_id, + user_id=event.user_id, + prev_events=prev_events, + prev_state=prev_state, + state_type=event.type, state_key=state_key, - prev_state_pdu=prev_state_pdu, ) - return self.runInteraction(_snapshot) + return self.runInteraction("snapshot_room", _snapshot) class Snapshot(object): @@ -361,7 +421,7 @@ class Snapshot(object): store (DataStore): The datastore. room_id (RoomId): The room of the snapshot. user_id (UserId): The user this snapshot is for. - prev_pdus (list): The list of PDU ids this snapshot is after. + prev_events (list): The list of event ids this snapshot is after. membership_state (RoomMemberEvent): The current state of the user in the room. state_type (str, optional): State type captured by the snapshot @@ -370,32 +430,30 @@ class Snapshot(object): the previous value of the state type and key in the room. """ - def __init__(self, store, room_id, user_id, prev_pdus, - membership_state, state_type=None, state_key=None, - prev_state_pdu=None): + def __init__(self, store, room_id, user_id, prev_events, + prev_state, state_type=None, state_key=None): self.store = store self.room_id = room_id self.user_id = user_id - self.prev_pdus = prev_pdus - self.membership_state = membership_state + self.prev_events = prev_events + self.prev_state = prev_state self.state_type = state_type self.state_key = state_key - self.prev_state_pdu = prev_state_pdu def fill_out_prev_events(self, event): - if hasattr(event, "prev_events"): - return - - es = [ - "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus - ] - - event.prev_events = [e for e in es if e != event.event_id] + if not hasattr(event, "prev_events"): + event.prev_events = [ + (event_id, hashes) + for event_id, hashes, _ in self.prev_events + ] + + if self.prev_events: + event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 + else: + event.depth = 0 - if self.prev_pdus: - event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 - else: - event.depth = 0 + if not hasattr(event, "prev_state") and self.prev_state is not None: + event.prev_state = self.prev_state def schema_path(schema): @@ -436,11 +494,13 @@ def prepare_database(db_conn): user_version = row[0] if user_version > SCHEMA_VERSION: - raise ValueError("Cannot use this database as it is too " + + raise ValueError( + "Cannot use this database as it is too " + "new for the server to understand" ) elif user_version < SCHEMA_VERSION: - logging.info("Upgrading database from version %d", + logging.info( + "Upgrading database from version %d", user_version ) @@ -452,13 +512,13 @@ def prepare_database(db_conn): db_conn.commit() else: - sql_script = "BEGIN TRANSACTION;" + sql_script = "BEGIN TRANSACTION;\n" for sql_loc in SCHEMAS: sql_script += read_schema(sql_loc) + sql_script += "\n" sql_script += "COMMIT TRANSACTION;" c.executescript(sql_script) db_conn.commit() c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) c.close() - diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 65a86e905660..a1ee0318f60e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,59 +14,69 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.api.events.utils import prune_event from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 import collections import copy import json +import sys +import time logger = logging.getLogger(__name__) sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging to the .execute() method.""" - __slots__ = ["txn"] + __slots__ = ["txn", "name"] - def __init__(self, txn): + def __init__(self, txn, name): object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) - def __getattribute__(self, name): - if name == "execute": - return object.__getattribute__(self, "execute") - - return getattr(object.__getattribute__(self, "txn"), name) + def __getattr__(self, name): + return getattr(self.txn, name) def __setattr__(self, name, value): - setattr(object.__getattribute__(self, "txn"), name, value) + setattr(self.txn, name, value) def execute(self, sql, *args, **kwargs): # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] %s", sql) + sql_logger.debug("[SQL] {%s} %s", self.name, sql) try: if args and args[0]: values = args[0] - sql_logger.debug("[SQL values] " + - ", ".join(("<%s>",) * len(values)), *values) + sql_logger.debug( + "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), + self.name, + *values + ) except: # Don't let logging failures stop SQL from working pass - # TODO(paul): Here would be an excellent place to put some timing - # measurements, and log (warning?) slow queries. - return object.__getattribute__(self, "txn").execute( - sql, *args, **kwargs - ) + start = time.clock() * 1000 + try: + return self.txn.execute( + sql, *args, **kwargs + ) + except: + logger.exception("[SQL FAIL] {%s}", self.name) + raise + finally: + end = time.clock() * 1000 + sql_logger.debug("[SQL time] {%s} %f", self.name, end - start) class SQLBaseStore(object): + _TXN_ID = 0 def __init__(self, hs): self.hs = hs @@ -74,10 +84,30 @@ def __init__(self, hs): self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() - def runInteraction(self, func, *args, **kwargs): + def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" def inner_func(txn, *args, **kwargs): - return func(LoggingTransaction(txn), *args, **kwargs) + start = time.clock() * 1000 + txn_id = SQLBaseStore._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + try: + return func(LoggingTransaction(txn, name), *args, **kwargs) + except: + logger.exception("[TXN FAIL] {%s}", name) + raise + finally: + end = time.clock() * 1000 + transaction_logger.debug( + "[TXN END] {%s} %f", + name, end - start + ) return self._db_pool.runInteraction(inner_func, *args, **kwargs) @@ -113,7 +143,7 @@ def interaction(txn): else: return cursor.fetchall() - return self.runInteraction(interaction) + return self.runInteraction("_execute", interaction) def _execute_and_decode(self, query, *args): return self._execute(self.cursor_to_dict, query, *args) @@ -130,6 +160,7 @@ def _simple_insert(self, table, values, or_replace=False, or_ignore=False): or_replace : bool; if True performs an INSERT OR REPLACE """ return self.runInteraction( + "_simple_insert", self._simple_insert_txn, table, values, or_replace=or_replace, or_ignore=or_ignore, ) @@ -170,7 +201,6 @@ def _simple_select_one(self, table, keyvalues, retcols, table, keyvalues, retcols=retcols, allow_none=allow_none ) - @defer.inlineCallbacks def _simple_select_one_onecol(self, table, keyvalues, retcol, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -181,19 +211,40 @@ def _simple_select_one_onecol(self, table, keyvalues, retcol, keyvalues : dict of column names and values to select the row with retcol : string giving the name of the column to return """ - ret = yield self._simple_select_one( + return self.runInteraction( + "_simple_select_one_onecol", + self._simple_select_one_onecol_txn, + table, keyvalues, retcol, allow_none=allow_none, + ) + + def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, + allow_none=False): + ret = self._simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, - retcols=[retcol], - allow_none=allow_none + retcol=retcol, ) if ret: - defer.returnValue(ret[retcol]) + return ret[0] else: - defer.returnValue(None) + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + txn.execute(sql, keyvalues.values()) + + return [r[0] for r in txn.fetchall()] - @defer.inlineCallbacks def _simple_select_onecol(self, table, keyvalues, retcol): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -206,25 +257,33 @@ def _simple_select_onecol(self, table, keyvalues, retcol): Returns: Deferred: Results in a list """ - sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { - "retcol": retcol, - "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), - } - - def func(txn): - txn.execute(sql, keyvalues.values()) - return txn.fetchall() + return self.runInteraction( + "_simple_select_onecol", + self._simple_select_onecol_txn, + table, keyvalues, retcol + ) - res = yield self.runInteraction(func) + def _simple_select_list(self, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. - defer.returnValue([r[0] for r in res]) + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + return self.runInteraction( + "_simple_select_list", + self._simple_select_list_txn, + table, keyvalues, retcols + ) - def _simple_select_list(self, table, keyvalues, retcols): + def _simple_select_list_txn(self, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: + txn : Transaction object table : string giving the table name keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return @@ -232,14 +291,11 @@ def _simple_select_list(self, table, keyvalues, retcols): sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - def func(txn): - txn.execute(sql, keyvalues.values()) - return self.cursor_to_dict(txn) - - return self.runInteraction(func) + txn.execute(sql, keyvalues.values()) + return self.cursor_to_dict(txn) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -307,7 +363,7 @@ def func(txn): raise StoreError(500, "More than one row matched") return ret - return self.runInteraction(func) + return self.runInteraction("_simple_selectupdate_one", func) def _simple_delete_one(self, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a @@ -319,7 +375,7 @@ def _simple_delete_one(self, table, keyvalues): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) def func(txn): @@ -328,7 +384,25 @@ def func(txn): raise StoreError(404, "No row found") if txn.rowcount > 1: raise StoreError(500, "more than one row matched") - return self.runInteraction(func) + return self.runInteraction("_simple_delete_one", func) + + def _simple_delete(self, table, keyvalues): + """Executes a DELETE query on the named table. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + + return self.runInteraction("_simple_delete", self._simple_delete_txn) + + def _simple_delete_txn(self, txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k, ) for k in keyvalues) + ) + + return txn.execute(sql, keyvalues.values()) def _simple_max_id(self, table): """Executes a SELECT query on the named table, expecting to return the @@ -346,7 +420,7 @@ def func(txn): return 0 return max_id - return self.runInteraction(func) + return self.runInteraction("_simple_max_id", func) def _parse_event_from_row(self, row_dict): d = copy.deepcopy({k: v for k, v in row_dict.items()}) @@ -355,6 +429,10 @@ def _parse_event_from_row(self, row_dict): d.pop("topological_ordering", None) d.pop("processed", None) d["origin_server_ts"] = d.pop("ts", 0) + replaces_state = d.pop("prev_state", None) + + if replaces_state: + d["replaces_state"] = replaces_state d.update(json.loads(row_dict["unrecognized_keys"])) d["content"] = json.loads(d["content"]) @@ -369,23 +447,65 @@ def _parse_event_from_row(self, row_dict): **d ) + def _get_events_txn(self, txn, event_ids): + # FIXME (erikj): This should be batched? + + sql = "SELECT * FROM events WHERE event_id = ?" + + event_rows = [] + for e_id in event_ids: + c = txn.execute(sql, (e_id,)) + event_rows.extend(self.cursor_to_dict(c)) + + return self._parse_events_txn(txn, event_rows) + def _parse_events(self, rows): - return self.runInteraction(self._parse_events_txn, rows) + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) def _parse_events_txn(self, txn, rows): events = [self._parse_event_from_row(r) for r in rows] - sql = "SELECT * FROM events WHERE event_id = ?" + select_event_sql = "SELECT * FROM events WHERE event_id = ?" + + for i, ev in enumerate(events): + signatures = self._get_event_origin_signatures_txn( + txn, ev.event_id, + ) - for ev in events: - if hasattr(ev, "prev_state"): - # Load previous state_content. - # TODO: Should we be pulling this out above? - cursor = txn.execute(sql, (ev.prev_state,)) - prevs = self.cursor_to_dict(cursor) - if prevs: - prev = self._parse_event_from_row(prevs[0]) - ev.prev_content = prev.content + ev.signatures = { + k: encode_base64(v) for k, v in signatures.items() + } + + prevs = self._get_prev_events_and_state(txn, ev.event_id) + + ev.prev_events = [ + (e_id, h) + for e_id, h, is_state in prevs + if is_state == 0 + ] + + ev.auth_events = self._get_auth_events(txn, ev.event_id) + + if hasattr(ev, "state_key"): + ev.prev_state = [ + (e_id, h) + for e_id, h, is_state in prevs + if is_state == 1 + ] + + if hasattr(ev, "replaces_state"): + # Load previous state_content. + # FIXME (erikj): Handle multiple prev_states. + cursor = txn.execute( + select_event_sql, + (ev.replaces_state,) + ) + prevs = self.cursor_to_dict(cursor) + if prevs: + prev = self._parse_event_from_row(prevs[0]) + ev.prev_content = prev.content if not hasattr(ev, "redacted"): logger.debug("Doesn't have redacted key: %s", ev) @@ -393,15 +513,16 @@ def _parse_events_txn(self, txn, rows): if ev.redacted: # Get the redaction event. - sql = "SELECT * FROM events WHERE event_id = ?" - txn.execute(sql, (ev.redacted,)) + select_event_sql = "SELECT * FROM events WHERE event_id = ?" + txn.execute(select_event_sql, (ev.redacted,)) del_evs = self._parse_events_txn( txn, self.cursor_to_dict(txn) ) if del_evs: - prune_event(ev) + ev = prune_event(ev) + events[i] = ev ev.redacted_because = del_evs[0] return events diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 52373a28a672..d6a7113b9c08 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -95,6 +95,7 @@ def create_room_alias_association(self, room_alias, room_id, servers): def delete_room_alias(self, room_alias): return self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias, ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py new file mode 100644 index 000000000000..a027db386800 --- /dev/null +++ b/synapse/storage/event_federation.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from syutil.base64util import encode_base64 + +import logging + + +logger = logging.getLogger(__name__) + + +class EventFederationStore(SQLBaseStore): + + def get_auth_chain(self, event_id): + return self.runInteraction( + "get_auth_chain", + self._get_auth_chain_txn, + event_id + ) + + def _get_auth_chain_txn(self, txn, event_id): + results = self._get_auth_chain_ids_txn(txn, event_id) + + sql = "SELECT * FROM events WHERE event_id = ?" + rows = [] + for ev_id in results: + c = txn.execute(sql, (ev_id,)) + rows.extend(self.cursor_to_dict(c)) + + return self._parse_events_txn(txn, rows) + + def get_auth_chain_ids(self, event_id): + return self.runInteraction( + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_id + ) + + def _get_auth_chain_ids_txn(self, txn, event_id): + results = set() + + base_sql = ( + "SELECT auth_id FROM event_auth WHERE %s" + ) + + front = set([event_id]) + while front: + sql = base_sql % ( + " OR ".join(["event_id=?"] * len(front)), + ) + + txn.execute(sql, list(front)) + front = [r[0] for r in txn.fetchall()] + results.update(front) + + return list(results) + + def get_oldest_events_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_in_room", + self._get_oldest_events_in_room_txn, + room_id, + ) + + def _get_oldest_events_in_room_txn(self, txn, room_id): + return self._simple_select_onecol_txn( + txn, + table="event_backward_extremities", + keyvalues={ + "room_id": room_id, + }, + retcol="event_id", + ) + + def get_latest_events_in_room(self, room_id): + return self.runInteraction( + "get_latest_events_in_room", + self._get_latest_events_in_room, + room_id, + ) + + def _get_latest_events_in_room(self, txn, room_id): + sql = ( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "WHERE f.room_id = ?" + ) + + txn.execute(sql, (room_id, )) + + results = [] + for event_id, depth in txn.fetchall(): + hashes = self._get_event_reference_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((event_id, prev_hashes, depth)) + + return results + + def _get_latest_state_in_room(self, txn, room_id, type, state_key): + event_ids = self._simple_select_onecol_txn( + txn, + table="state_forward_extremities", + keyvalues={ + "room_id": room_id, + "type": type, + "state_key": state_key, + }, + retcol="event_id", + ) + + results = [] + for event_id in event_ids: + hashes = self._get_event_reference_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((event_id, prev_hashes)) + + return results + + def _get_prev_events(self, txn, event_id): + results = self._get_prev_events_and_state( + txn, + event_id, + is_state=0, + ) + + return [(e_id, h, ) for e_id, h, _ in results] + + def _get_prev_state(self, txn, event_id): + results = self._get_prev_events_and_state( + txn, + event_id, + is_state=1, + ) + + return [(e_id, h, ) for e_id, h, _ in results] + + def _get_prev_events_and_state(self, txn, event_id, is_state=None): + keyvalues = { + "event_id": event_id, + } + + if is_state is not None: + keyvalues["is_state"] = is_state + + res = self._simple_select_list_txn( + txn, + table="event_edges", + keyvalues=keyvalues, + retcols=["prev_event_id", "is_state"], + ) + + results = [] + for d in res: + hashes = self._get_event_reference_hashes_txn( + txn, + d["prev_event_id"] + ) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((d["prev_event_id"], prev_hashes, d["is_state"])) + + return results + + def _get_auth_events(self, txn, event_id): + auth_ids = self._simple_select_onecol_txn( + txn, + table="event_auth", + keyvalues={ + "event_id": event_id, + }, + retcol="auth_id", + ) + + results = [] + for auth_id in auth_ids: + hashes = self._get_event_reference_hashes_txn(txn, auth_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((auth_id, prev_hashes)) + + return results + + def get_min_depth(self, room_id): + return self.runInteraction( + "get_min_depth", + self._get_min_depth_interaction, + room_id, + ) + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self._simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + self._simple_insert_txn( + txn, + table="room_depth", + values={ + "room_id": room_id, + "min_depth": depth, + }, + or_replace=True, + ) + + def _handle_prev_events(self, txn, outlier, event_id, prev_events, + room_id): + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": e_id, + "room_id": room_id, + "is_state": 0, + }, + or_ignore=True, + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={ + "event_id": e_id, + "room_id": room_id, + } + ) + + # We only insert as a forward extremity the new pdu if there are + # no other pdus that reference it as a prev pdu + query = ( + "INSERT OR IGNORE INTO %(table)s (event_id, room_id) " + "SELECT ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(event_edges)s WHERE " + "prev_event_id = ? " + ")" + ) % { + "table": "event_forward_extremities", + "event_edges": "event_edges", + } + + logger.debug("query: %s", query) + + txn.execute(query, (event_id, room_id, event_id)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + for e_id, _ in prev_events: + # TODO (erikj): This could be done as a bulk insert + self._simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": e_id, + "room_id": room_id, + }, + or_ignore=True, + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM event_backward_extremities WHERE EXISTS (" + "SELECT 1 FROM events " + "WHERE " + "event_backward_extremities.event_id = events.event_id " + "AND not events.outlier " + ")" + ) + txn.execute(query) + + def get_backfill_events(self, room_id, event_list, limit): + """Get a list of Events for a given topic that occured before (and + including) the pdus in pdu_list. Return a list of max size `limit`. + + Args: + txn + room_id (str) + event_list (list) + limit (int) + + Return: + list: A list of PduTuples + """ + return self.runInteraction( + "get_backfill_events", + self._get_backfill_events, room_id, event_list, limit + ) + + def _get_backfill_events(self, txn, room_id, event_list, limit): + logger.debug( + "_get_backfill_events: %s, %s, %s", + room_id, repr(event_list), limit + ) + + # We seed the pdu_results with the things from the pdu_list. + event_results = event_list + + front = event_list + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? " + "LIMIT ?" + ) + + # We iterate through all event_ids in `front` to select their previous + # events. These are dumped in `new_front`. + # We continue until we reach the limit *or* new_front is empty (i.e., + # we've run out of things to select + while front and len(event_results) < limit: + + new_front = [] + for event_id in front: + logger.debug( + "_backfill_interaction: id=%s", + event_id + ) + + txn.execute( + query, + (room_id, event_id, limit - len(event_results)) + ) + + for row in txn.fetchall(): + logger.debug( + "_backfill_interaction: got id=%s", + *row + ) + new_front.append(row[0]) + + front = new_front + event_results += new_front + + # We also want to update the `prev_pdus` attributes before returning. + return self._get_events_txn(txn, event_results) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py deleted file mode 100644 index d70467dcd670..000000000000 --- a/synapse/storage/pdu.py +++ /dev/null @@ -1,915 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from ._base import SQLBaseStore, Table, JoinHelper - -from synapse.federation.units import Pdu -from synapse.util.logutils import log_function - -from collections import namedtuple - -import logging - -logger = logging.getLogger(__name__) - - -class PduStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def get_pdu(self, pdu_id, origin): - """Given a pdu_id and origin, get a PDU. - - Args: - txn - pdu_id (str) - origin (str) - - Returns: - PduTuple: If the pdu does not exist in the database, returns None - """ - - return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin - ) - - def _get_pdu_tuple(self, txn, pdu_id, origin): - res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) - return res[0] if res else None - - def _get_pdu_tuples(self, txn, pdu_id_tuples): - results = [] - for pdu_id, origin in pdu_id_tuples: - txn.execute( - PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), - (pdu_id, origin) - ) - - edges = [ - (r.prev_pdu_id, r.prev_origin) - for r in PduEdgesTable.decode_results(txn.fetchall()) - ] - - query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - row = txn.fetchone() - if row: - results.append(PduTuple(PduEntry(*row), edges)) - - return results - - def get_current_state_for_context(self, context): - """Get a list of PDUs that represent the current state for a given - context - - Args: - context (str) - - Returns: - list: A list of PduTuples - """ - - return self.runInteraction( - self._get_current_state_for_context, - context - ) - - def _get_current_state_for_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s WHERE context = ?" - % CurrentStateTable.table_name - ) - - logger.debug("get_current_state %s, Args=%s", query, context) - txn.execute(query, (context,)) - - res = txn.fetchall() - - logger.debug("get_current_state %d results", len(res)) - - return self._get_pdu_tuples(txn, res) - - def _persist_pdu_txn(self, txn, prev_pdus, cols): - """Inserts a (non-state) PDU into the database. - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable. - """ - entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - - txn.execute(PdusTable.insert_statement(), entry) - - self._handle_prev_pdus( - txn, entry.outlier, entry.pdu_id, entry.origin, - prev_pdus, entry.context - ) - - def mark_pdu_as_processed(self, pdu_id, pdu_origin): - """Mark a received PDU as processed. - - Args: - txn - pdu_id (str) - pdu_origin (str) - """ - - return self.runInteraction( - self._mark_as_processed, pdu_id, pdu_origin - ) - - def _mark_as_processed(self, txn, pdu_id, pdu_origin): - txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) - - def get_all_pdus_from_context(self, context): - """Get a list of all PDUs for a given context.""" - return self.runInteraction( - self._get_all_pdus_from_context, context, - ) - - def _get_all_pdus_from_context(self, txn, context): - query = ( - "SELECT pdu_id, origin FROM %s " - "WHERE context = ?" - ) % PdusTable.table_name - - txn.execute(query, (context,)) - - return self._get_pdu_tuples(txn, txn.fetchall()) - - def get_backfill(self, context, pdu_list, limit): - """Get a list of Pdus for a given topic that occured before (and - including) the pdus in pdu_list. Return a list of max size `limit`. - - Args: - txn - context (str) - pdu_list (list) - limit (int) - - Return: - list: A list of PduTuples - """ - return self.runInteraction( - self._get_backfill, context, pdu_list, limit - ) - - def _get_backfill(self, txn, context, pdu_list, limit): - logger.debug( - "backfill: %s, %s, %s", - context, repr(pdu_list), limit - ) - - # We seed the pdu_results with the things from the pdu_list. - pdu_results = pdu_list - - front = pdu_list - - query = ( - "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " - "WHERE context = ? AND pdu_id = ? AND origin = ? " - "LIMIT ?" - ) % { - "edges_table": PduEdgesTable.table_name, - } - - # We iterate through all pdu_ids in `front` to select their previous - # pdus. These are dumped in `new_front`. We continue until we reach the - # limit *or* new_front is empty (i.e., we've run out of things to - # select - while front and len(pdu_results) < limit: - - new_front = [] - for pdu_id, origin in front: - logger.debug( - "_backfill_interaction: i=%s, o=%s", - pdu_id, origin - ) - - txn.execute( - query, - (context, pdu_id, origin, limit - len(pdu_results)) - ) - - for row in txn.fetchall(): - logger.debug( - "_backfill_interaction: got i=%s, o=%s", - *row - ) - new_front.append(row) - - front = new_front - pdu_results += new_front - - # We also want to update the `prev_pdus` attributes before returning. - return self._get_pdu_tuples(txn, pdu_results) - - def get_min_depth_for_context(self, context): - """Get the current minimum depth for a context - - Args: - txn - context (str) - """ - return self.runInteraction( - self._get_min_depth_for_context, context - ) - - def _get_min_depth_for_context(self, txn, context): - return self._get_min_depth_interaction(txn, context) - - def _get_min_depth_interaction(self, txn, context): - txn.execute( - "SELECT min_depth FROM %s WHERE context = ?" - % ContextDepthTable.table_name, - (context,) - ) - - row = txn.fetchone() - - return row[0] if row else None - - def _update_min_depth_for_context_txn(self, txn, context, depth): - """Update the minimum `depth` of the given context, which is the line - on which we stop backfilling backwards. - - Args: - context (str) - depth (int) - """ - min_depth = self._get_min_depth_interaction(txn, context) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - txn.execute( - "INSERT OR REPLACE INTO %s (context, min_depth) " - "VALUES (?,?)" % ContextDepthTable.table_name, - (context, depth) - ) - - def _get_latest_pdus_in_context(self, txn, context): - """Get's a list of the most current pdus for a given context. This is - used when we are sending a Pdu and need to fill out the `prev_pdus` - key - - Args: - txn - context - """ - query = ( - "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " - "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " - "AND f.origin = p.origin " - "WHERE f.context = ?" - ) % { - "pdus": PdusTable.table_name, - "forward": PduForwardExtremitiesTable.table_name, - } - - logger.debug("get_prev query: %s", query) - - txn.execute( - query, - (context, ) - ) - - results = txn.fetchall() - - return [(row[0], row[1], row[2]) for row in results] - - @defer.inlineCallbacks - def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and havent - seen). This list is used when we want to backfill backwards and is the - list we send to the remote server. - - Args: - txn - context (str) - - Returns: - list: A list of PduIdTuple. - """ - results = yield self._execute( - None, - "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" - % {"back": PduBackwardExtremitiesTable.table_name, }, - context - ) - - defer.returnValue([PduIdTuple(i, o) for i, o in results]) - - def is_pdu_new(self, pdu_id, origin, context, depth): - """For a given Pdu, try and figure out if it's 'new', i.e., if it's - not something we got randomly from the past, for example when we - request the current state of the room that will probably return a bunch - of pdus from before we joined. - - Args: - txn - pdu_id (str) - origin (str) - context (str) - depth (int) - - Returns: - bool - """ - - return self.runInteraction( - self._is_pdu_new, - pdu_id=pdu_id, - origin=origin, - context=context, - depth=depth - ) - - def _is_pdu_new(self, txn, pdu_id, origin, context, depth): - # If depth > min depth in back table, then we classify it as new. - # OR if there is nothing in the back table, then it kinda needs to - # be a new thing. - query = ( - "SELECT min(p.depth) FROM %(edges)s as e " - "INNER JOIN %(back)s as b " - "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " - "INNER JOIN %(pdus)s as p " - "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " - "WHERE p.context = ?" - ) % { - "pdus": PdusTable.table_name, - "edges": PduEdgesTable.table_name, - "back": PduBackwardExtremitiesTable.table_name, - } - - txn.execute(query, (context,)) - - min_depth, = txn.fetchone() - - if not min_depth or depth > int(min_depth): - logger.debug( - "is_new true: id=%s, o=%s, d=%s min_depth=%s", - pdu_id, origin, depth, min_depth - ) - return True - - # If this pdu is in the forwards table, then it also is a new one - query = ( - "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" - ) % { - "forward": PduForwardExtremitiesTable.table_name, - } - - txn.execute(query, (pdu_id, origin)) - - # Did we get anything? - if txn.fetchall(): - logger.debug( - "is_new true: id=%s, o=%s, d=%s was forward", - pdu_id, origin, depth - ) - return True - - logger.debug( - "is_new false: id=%s, o=%s, d=%s", - pdu_id, origin, depth - ) - - # FINE THEN. It's probably old. - return False - - @staticmethod - @log_function - def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, - context): - txn.executemany( - PduEdgesTable.insert_statement(), - [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] - ) - - # Update the extremities table if this is not an outlier. - if not outlier: - - # First, we delete the new one from the forwards extremities table. - query = ( - "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" - % PduForwardExtremitiesTable.table_name - ) - txn.executemany(query, prev_pdus) - - # We only insert as a forward extremety the new pdu if there are no - # other pdus that reference it as a prev pdu - query = ( - "INSERT INTO %(table)s (pdu_id, origin, context) " - "SELECT ?, ?, ? WHERE NOT EXISTS (" - "SELECT 1 FROM %(pdu_edges)s WHERE " - "prev_pdu_id = ? AND prev_origin = ?" - ")" - ) % { - "table": PduForwardExtremitiesTable.table_name, - "pdu_edges": PduEdgesTable.table_name - } - - logger.debug("query: %s", query) - - txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) - - # Insert all the prev_pdus as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - txn.executemany( - PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o in prev_pdus] - ) - - # Also delete from the backwards extremities table all ones that - # reference pdus that we have already seen - query = ( - "DELETE FROM %(pdu_back)s WHERE EXISTS (" - "SELECT 1 FROM %(pdus)s AS pdus " - "WHERE " - "%(pdu_back)s.pdu_id = pdus.pdu_id " - "AND %(pdu_back)s.origin = pdus.origin " - "AND not pdus.outlier " - ")" - ) % { - "pdu_back": PduBackwardExtremitiesTable.table_name, - "pdus": PdusTable.table_name, - } - txn.execute(query) - - -class StatePduStore(SQLBaseStore): - """A collection of queries for handling state PDUs. - """ - - def _persist_state_txn(self, txn, prev_pdus, cols): - """Inserts a state PDU into the database - - Args: - txn, - prev_pdus (list) - **cols: The columns to insert into the PdusTable and StatePdusTable - """ - pdu_entry = PdusTable.EntryType( - **{k: cols.get(k, None) for k in PdusTable.fields} - ) - state_entry = StatePdusTable.EntryType( - **{k: cols.get(k, None) for k in StatePdusTable.fields} - ) - - logger.debug("Inserting pdu: %s", repr(pdu_entry)) - logger.debug("Inserting state: %s", repr(state_entry)) - - txn.execute(PdusTable.insert_statement(), pdu_entry) - txn.execute(StatePdusTable.insert_statement(), state_entry) - - self._handle_prev_pdus( - txn, - pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, - pdu_entry.context - ) - - def get_unresolved_state_tree(self, new_state_pdu): - return self.runInteraction( - self._get_unresolved_state_tree, new_state_pdu - ) - - @log_function - def _get_unresolved_state_tree(self, txn, new_pdu): - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - ReturnType = namedtuple( - "StateReturnType", ["new_branch", "current_branch"] - ) - return_value = ReturnType([new_pdu], []) - - if not current: - logger.debug("get_unresolved_state_tree No current state.") - return (return_value, None) - - return_value.current_branch.append(current) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - - missing_branch = None - for branch, prev_state, state in enum_branches: - if state: - return_value[branch].append(state) - else: - # We don't have prev_state :( - missing_branch = branch - break - - return (return_value, missing_branch) - - def update_current_state(self, pdu_id, origin, context, pdu_type, - state_key): - return self.runInteraction( - self._update_current_state, - pdu_id, origin, context, pdu_type, state_key - ) - - def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, - state_key): - query = ( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - ) % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - } - - query_args = CurrentStateTable.EntryType( - pdu_id=pdu_id, - origin=origin, - context=context, - pdu_type=pdu_type, - state_key=state_key - ) - - txn.execute(query, query_args) - - def get_current_state_pdu(self, context, pdu_type, state_key): - """For a given context, pdu_type, state_key 3-tuple, return what is - currently considered the current state. - - Args: - txn - context (str) - pdu_type (str) - state_key (str) - - Returns: - PduEntry - """ - - return self.runInteraction( - self._get_current_state_pdu, context, pdu_type, state_key - ) - - def _get_current_state_pdu(self, txn, context, pdu_type, state_key): - return self._get_current_interaction(txn, context, pdu_type, state_key) - - def _get_current_interaction(self, txn, context, pdu_type, state_key): - logger.debug( - "_get_current_interaction %s %s %s", - context, pdu_type, state_key - ) - - fields = _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s") - - current_query = ( - "SELECT %(fields)s FROM %(state)s as s " - "INNER JOIN %(pdus)s as p " - "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " - "INNER JOIN %(curr)s as c " - "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " - "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " - ) % { - "fields": fields, - "curr": CurrentStateTable.table_name, - "state": StatePdusTable.table_name, - "pdus": PdusTable.table_name, - } - - txn.execute( - current_query, - (context, pdu_type, state_key) - ) - - row = txn.fetchone() - - result = PduEntry(*row) if row else None - - if not result: - logger.debug("_get_current_interaction not found") - else: - logger.debug( - "_get_current_interaction found %s %s", - result.pdu_id, result.origin - ) - - return result - - def handle_new_state(self, new_pdu): - """Actually perform conflict resolution on the new_pdu on the - assumption we have all the pdus required to perform it. - - Args: - new_pdu - - Returns: - bool: True if the new_pdu clobbered the current state, False if not - """ - return self.runInteraction( - self._handle_new_state, new_pdu - ) - - def _handle_new_state(self, txn, new_pdu): - logger.debug( - "handle_new_state %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - is_current = False - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - # Oh, we don't have any state for this yet. - is_current = True - elif (current.pdu_id == new_pdu.prev_state_id - and current.origin == new_pdu.prev_state_origin): - # Oh! A direct clobber. Just do it. - is_current = True - else: - ## - # Ok, now loop through until we get to a common ancestor. - max_new = int(new_pdu.power_level) - max_current = int(current.power_level) - - enum_branches = self._enumerate_state_branches( - txn, new_pdu, current - ) - for branch, prev_state, state in enum_branches: - if not state: - raise RuntimeError( - "Could not find state_pdu %s %s" % - ( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - ) - - if branch == 0: - max_new = max(int(state.depth), max_new) - else: - max_current = max(int(state.depth), max_current) - - is_current = max_new > max_current - - if is_current: - logger.debug("handle_new_state make current") - - # Right, this is a new thing, so woo, just insert it. - txn.execute( - "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" - % { - "curr": CurrentStateTable.table_name, - "fields": CurrentStateTable.get_fields_string(), - "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) - }, - CurrentStateTable.EntryType( - *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) - ) - ) - else: - logger.debug("handle_new_state not current") - - logger.debug("handle_new_state done") - - return is_current - - @log_function - def _enumerate_state_branches(self, txn, pdu_a, pdu_b): - branch_a = pdu_a - branch_b = pdu_b - - while True: - if (branch_a.pdu_id == branch_b.pdu_id - and branch_a.origin == branch_b.origin): - # Woo! We found a common ancestor - logger.debug("_enumerate_state_branches Found common ancestor") - break - - do_branch_a = ( - hasattr(branch_a, "prev_state_id") and - branch_a.prev_state_id - ) - - do_branch_b = ( - hasattr(branch_b, "prev_state_id") and - branch_b.prev_state_id - ) - - logger.debug( - "do_branch_a=%s, do_branch_b=%s", - do_branch_a, do_branch_b - ) - - if do_branch_a and do_branch_b: - do_branch_a = int(branch_a.depth) > int(branch_b.depth) - - if do_branch_a: - pdu_tuple = PduIdTuple( - branch_a.prev_state_id, - branch_a.prev_state_origin - ) - - prev_branch = branch_a - - logger.debug("getting branch_a prev %s", pdu_tuple) - branch_a = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_a: - branch_a = Pdu.from_pdu_tuple(branch_a) - - logger.debug("branch_a=%s", branch_a) - - yield (0, prev_branch, branch_a) - - if not branch_a: - break - elif do_branch_b: - pdu_tuple = PduIdTuple( - branch_b.prev_state_id, - branch_b.prev_state_origin - ) - - prev_branch = branch_b - - logger.debug("getting branch_b prev %s", pdu_tuple) - branch_b = self._get_pdu_tuple(txn, *pdu_tuple) - if branch_b: - branch_b = Pdu.from_pdu_tuple(branch_b) - - logger.debug("branch_b=%s", branch_b) - - yield (1, prev_branch, branch_b) - - if not branch_b: - break - else: - break - - -class PdusTable(Table): - table_name = "pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "ts", - "depth", - "is_state", - "content_json", - "unrecognized_keys", - "outlier", - "have_processed", - ] - - EntryType = namedtuple("PdusEntry", fields) - - -class PduDestinationsTable(Table): - table_name = "pdu_destinations" - - fields = [ - "pdu_id", - "origin", - "destination", - "delivered_ts", - ] - - EntryType = namedtuple("PduDestinationsEntry", fields) - - -class PduEdgesTable(Table): - table_name = "pdu_edges" - - fields = [ - "pdu_id", - "origin", - "prev_pdu_id", - "prev_origin", - "context" - ] - - EntryType = namedtuple("PduEdgesEntry", fields) - - -class PduForwardExtremitiesTable(Table): - table_name = "pdu_forward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduForwardExtremitiesEntry", fields) - - -class PduBackwardExtremitiesTable(Table): - table_name = "pdu_backward_extremities" - - fields = [ - "pdu_id", - "origin", - "context", - ] - - EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) - - -class ContextDepthTable(Table): - table_name = "context_depth" - - fields = [ - "context", - "min_depth", - ] - - EntryType = namedtuple("ContextDepthEntry", fields) - - -class StatePdusTable(Table): - table_name = "state_pdus" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - "power_level", - "prev_state_id", - "prev_state_origin", - ] - - EntryType = namedtuple("StatePdusEntry", fields) - - -class CurrentStateTable(Table): - table_name = "current_state" - - fields = [ - "pdu_id", - "origin", - "context", - "pdu_type", - "state_key", - ] - - EntryType = namedtuple("CurrentStateEntry", fields) - -_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) - - -# TODO: These should probably be put somewhere more sensible -PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) - -PduEntry = _pdu_state_joiner.EntryType -""" We are always interested in the join of the PdusTable and StatePdusTable, -rather than just the PdusTable. - -This does not include a prev_pdus key. -""" - -PduTuple = namedtuple( - "PduTuple", - ("pdu_entry", "prev_pdu_list") -) -""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent -the `prev_pdus` key of a PDU. -""" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 719806f82b5c..1f89d7734460 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -62,8 +62,10 @@ def register(self, user_id, token, password_hash): Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction(self._register, user_id, token, - password_hash) + yield self.runInteraction( + "register", + self._register, user_id, token, password_hash + ) def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) @@ -100,17 +102,22 @@ def get_user_by_token(self, token): StoreError if no user was found. """ return self.runInteraction( + "get_user_by_token", self._query_for_auth, token ) + @defer.inlineCallbacks def is_server_admin(self, user): - return self._simple_select_one_onecol( + res = yield self._simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", + allow_none=True, ) + defer.returnValue(res if res else False) + def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.admin, access_tokens.device_id " diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8cd46334cf36..cc0513b8d23a 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -132,209 +132,29 @@ def get_rooms(self, is_public): defer.returnValue(ret) - @defer.inlineCallbacks - def get_room_join_rule(self, room_id): - sql = ( - "SELECT join_rule FROM room_join_rules as r " - "INNER JOIN current_state_events as c " - "ON r.event_id = c.event_id " - "WHERE c.room_id = ? " - ) - - rows = yield self._execute(None, sql, room_id) - - if len(rows) == 1: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(None) - - def get_power_level(self, room_id, user_id): - return self.runInteraction( - self._get_power_level, - room_id, user_id, - ) - - def _get_power_level(self, txn, room_id, user_id): - sql = ( - "SELECT level FROM room_power_levels as r " - "INNER JOIN current_state_events as c " - "ON r.event_id = c.event_id " - "WHERE c.room_id = ? AND r.user_id = ? " - ) - - rows = txn.execute(sql, (room_id, user_id,)).fetchall() - - if len(rows) == 1: - return rows[0][0] - - sql = ( - "SELECT level FROM room_default_levels as r " - "INNER JOIN current_state_events as c " - "ON r.event_id = c.event_id " - "WHERE c.room_id = ? " - ) - - rows = txn.execute(sql, (room_id,)).fetchall() - - if len(rows) == 1: - return rows[0][0] - else: - return None - - def get_ops_levels(self, room_id): - return self.runInteraction( - self._get_ops_levels, - room_id, - ) - - def _get_ops_levels(self, txn, room_id): - sql = ( - "SELECT ban_level, kick_level, redact_level " - "FROM room_ops_levels as r " - "INNER JOIN current_state_events as c " - "ON r.event_id = c.event_id " - "WHERE c.room_id = ? " - ) - - rows = txn.execute(sql, (room_id,)).fetchall() - - if len(rows) == 1: - return OpsLevel(rows[0][0], rows[0][1], rows[0][2]) - else: - return OpsLevel(None, None) - - def get_add_state_level(self, room_id): - return self._get_level_from_table("room_add_state_levels", room_id) - - def get_send_event_level(self, room_id): - return self._get_level_from_table("room_send_event_levels", room_id) - - @defer.inlineCallbacks - def _get_level_from_table(self, table, room_id): - sql = ( - "SELECT level FROM %(table)s as r " - "INNER JOIN current_state_events as c " - "ON r.event_id = c.event_id " - "WHERE c.room_id = ? " - ) % {"table": table} - - rows = yield self._execute(None, sql, room_id) - - if len(rows) == 1: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(None) - def _store_room_topic_txn(self, txn, event): - self._simple_insert_txn( - txn, - "topics", - { - "event_id": event.event_id, - "room_id": event.room_id, - "topic": event.topic, - } - ) + if hasattr(event, "topic"): + self._simple_insert_txn( + txn, + "topics", + { + "event_id": event.event_id, + "room_id": event.room_id, + "topic": event.topic, + } + ) def _store_room_name_txn(self, txn, event): - self._simple_insert_txn( - txn, - "room_names", - { - "event_id": event.event_id, - "room_id": event.room_id, - "name": event.name, - } - ) - - def _store_join_rule(self, txn, event): - self._simple_insert_txn( - txn, - "room_join_rules", - { - "event_id": event.event_id, - "room_id": event.room_id, - "join_rule": event.content["join_rule"], - }, - ) - - def _store_power_levels(self, txn, event): - for user_id, level in event.content.items(): - if user_id == "default": - self._simple_insert_txn( - txn, - "room_default_levels", - { - "event_id": event.event_id, - "room_id": event.room_id, - "level": level, - }, - ) - else: - self._simple_insert_txn( - txn, - "room_power_levels", - { - "event_id": event.event_id, - "room_id": event.room_id, - "user_id": user_id, - "level": level - }, - ) - - def _store_default_level(self, txn, event): - self._simple_insert_txn( - txn, - "room_default_levels", - { - "event_id": event.event_id, - "room_id": event.room_id, - "level": event.content["default_level"], - }, - ) - - def _store_add_state_level(self, txn, event): - self._simple_insert_txn( - txn, - "room_add_state_levels", - { - "event_id": event.event_id, - "room_id": event.room_id, - "level": event.content["level"], - }, - ) - - def _store_send_event_level(self, txn, event): - self._simple_insert_txn( - txn, - "room_send_event_levels", - { - "event_id": event.event_id, - "room_id": event.room_id, - "level": event.content["level"], - }, - ) - - def _store_ops_level(self, txn, event): - content = { - "event_id": event.event_id, - "room_id": event.room_id, - } - - if "kick_level" in event.content: - content["kick_level"] = event.content["kick_level"] - - if "ban_level" in event.content: - content["ban_level"] = event.content["ban_level"] - - if "redact_level" in event.content: - content["redact_level"] = event.content["redact_level"] - - self._simple_insert_txn( - txn, - "room_ops_levels", - content, - ) + if hasattr(event, "name"): + self._simple_insert_txn( + txn, + "room_names", + { + "event_id": event.event_id, + "room_id": event.room_id, + "name": event.name, + } + ) class RoomsTable(Table): diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql deleted file mode 100644 index 8a008680659d..000000000000 --- a/synapse/storage/schema/edge_pdus.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS context_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE TABLE IF NOT EXISTS origin_edge_pdus( - id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this - pdu_id TEXT, - origin TEXT, - CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin) -); - -CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); -CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin); diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql new file mode 100644 index 000000000000..be1c72a77571 --- /dev/null +++ b/synapse/storage/schema/event_edges.sql @@ -0,0 +1,75 @@ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_state INTEGER NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state) +); + +CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); +CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT NOT NULL, + min_depth INTEGER NOT NULL, + CONSTRAINT uniqueness UNIQUE (room_id) +); + +CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); + + +create TABLE IF NOT EXISTS event_destinations( + event_id TEXT NOT NULL, + destination TEXT NOT NULL, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); + + +CREATE TABLE IF NOT EXISTS state_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities( + room_id, type, state_key +); +CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_auth( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id) +); + +CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id); +CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id); \ No newline at end of file diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql new file mode 100644 index 000000000000..5491c7ecec21 --- /dev/null +++ b/synapse/storage/schema/event_signatures.sql @@ -0,0 +1,65 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_content_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_origin_signatures ( + event_id TEXT, + origin TEXT, + key_id TEXT, + signature BLOB, + CONSTRAINT uniqueness UNIQUE (event_id, key_id) +); + +CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures ( + event_id +); + + +CREATE TABLE IF NOT EXISTS event_edge_hashes( + event_id TEXT, + prev_event_id TEXT, + algorithm TEXT, + hash BLOB, + CONSTRAINT uniqueness UNIQUE ( + event_id, prev_event_id, algorithm + ) +); + +CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( + event_id +); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 3aa83f5c8cbd..8ba732a23bc8 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events( unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, + depth INTEGER DEFAULT 0 NOT NULL, CONSTRAINT ev_uniq UNIQUE (event_id) ); @@ -84,80 +85,24 @@ CREATE TABLE IF NOT EXISTS topics( topic TEXT NOT NULL ); +CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id); +CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id); + CREATE TABLE IF NOT EXISTS room_names( event_id TEXT NOT NULL, room_id TEXT NOT NULL, name TEXT NOT NULL ); +CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id); +CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id); + CREATE TABLE IF NOT EXISTS rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public INTEGER, creator TEXT ); -CREATE TABLE IF NOT EXISTS room_join_rules( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - join_rule TEXT NOT NULL -); -CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id); -CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id); - - -CREATE TABLE IF NOT EXISTS room_power_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - user_id TEXT NOT NULL, - level INTEGER NOT NULL -); -CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id); -CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id); -CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id); - - -CREATE TABLE IF NOT EXISTS room_default_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id); -CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_add_state_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id); -CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_send_event_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - level INTEGER NOT NULL -); - -CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id); -CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id); - - -CREATE TABLE IF NOT EXISTS room_ops_levels( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - ban_level INTEGER, - kick_level INTEGER, - redact_level INTEGER -); - -CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id); -CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id); - - CREATE TABLE IF NOT EXISTS room_hosts( room_id TEXT NOT NULL, host TEXT NOT NULL, diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql deleted file mode 100644 index 16e111a56c60..000000000000 --- a/synapse/storage/schema/pdu.sql +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2014 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores pdus and their content -CREATE TABLE IF NOT EXISTS pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - ts INTEGER, - depth INTEGER DEFAULT 0 NOT NULL, - is_state BOOL, - content_json TEXT, - unrecognized_keys TEXT, - outlier BOOL NOT NULL, - have_processed BOOL, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) -); - --- Stores what the current state pdu is for a given (context, pdu_type, key) tuple -CREATE TABLE IF NOT EXISTS state_pdus( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - power_level TEXT, - prev_state_id TEXT, - prev_state_origin TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin) -); - -CREATE TABLE IF NOT EXISTS current_state( - pdu_id TEXT, - origin TEXT, - context TEXT, - pdu_type TEXT, - state_key TEXT, - CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) - CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE -); - --- Stores where each pdu we want to send should be sent and the delivery status. -create TABLE IF NOT EXISTS pdu_destinations( - pdu_id TEXT, - origin TEXT, - destination TEXT, - delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_forward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_backward_extremities( - pdu_id TEXT, - origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE -); - -CREATE TABLE IF NOT EXISTS pdu_edges( - pdu_id TEXT, - origin TEXT, - prev_pdu_id TEXT, - prev_origin TEXT, - context TEXT, - CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context) -); - -CREATE TABLE IF NOT EXISTS context_depth( - context TEXT, - min_depth INTEGER, - CONSTRAINT uniqueness UNIQUE (context) -); - -CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context); - - -CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin); --- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination); - -CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context); -CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin); - -CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context); diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql new file mode 100644 index 000000000000..b44c56b51967 --- /dev/null +++ b/synapse/storage/schema/state.sql @@ -0,0 +1,33 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group INTEGER NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group INTEGER NOT NULL +); \ No newline at end of file diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py new file mode 100644 index 000000000000..84a49088a2a7 --- /dev/null +++ b/synapse/storage/signatures.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class SignatureStore(SQLBaseStore): + """Persistence for event signatures and hashes""" + + def _get_event_content_hashes_txn(self, txn, event_id): + """Get all the hashes for a given Event. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_content_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_content_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a Event + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_content_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + def get_event_reference_hashes(self, event_ids): + def f(txn): + return [ + self._get_event_reference_hashes_txn(txn, ev) + for ev in event_ids + ] + + return self.runInteraction( + "get_event_reference_hashes", + f + ) + + def _get_event_reference_hashes_txn(self, txn, event_id): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_reference_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_reference_hash_txn(self, txn, event_id, algorithm, + hash_bytes): + """Store a hash for a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + algorithm (str): Hashing algorithm. + hash_bytes (bytes): Hash function output bytes. + """ + self._simple_insert_txn( + txn, + "event_reference_hashes", + { + "event_id": event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) + + + def _get_event_origin_signatures_txn(self, txn, event_id): + """Get all the signatures for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict of key_id -> signature_bytes. + """ + query = ( + "SELECT key_id, signature" + " FROM event_origin_signatures" + " WHERE event_id = ? " + ) + txn.execute(query, (event_id, )) + return dict(txn.fetchall()) + + def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id, + signature_bytes): + """Store a signature from the origin server for a PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + origin (str): origin of the Event. + key_id (str): Id for the signing key. + signature (bytes): The signature. + """ + self._simple_insert_txn( + txn, + "event_origin_signatures", + { + "event_id": event_id, + "origin": origin, + "key_id": key_id, + "signature": buffer(signature_bytes), + }, + or_ignore=True, + ) + + def _get_prev_event_hashes_txn(self, txn, event_id): + """Get all the hashes for previous PDUs of a PDU + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. + """ + query = ( + "SELECT prev_event_id, algorithm, hash" + " FROM event_edge_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id, )) + results = {} + for prev_event_id, algorithm, hash_bytes in txn.fetchall(): + hashes = results.setdefault(prev_event_id, {}) + hashes[algorithm] = hash_bytes + return results + + def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, + algorithm, hash_bytes): + self._simple_insert_txn( + txn, + "event_edge_hashes", + { + "event_id": event_id, + "prev_event_id": prev_event_id, + "algorithm": algorithm, + "hash": buffer(hash_bytes), + }, + or_ignore=True, + ) \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py new file mode 100644 index 000000000000..68975969f5c7 --- /dev/null +++ b/synapse/storage/state.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + + +class StateStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_state_groups(self, event_ids): + groups = set() + for event_id in event_ids: + group = yield self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + groups.add(group) + + res = {} + for group in groups: + state_ids = yield self._simple_select_onecol( + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = yield self.get_event( + state_id, + allow_none=True, + ) + if s: + state.append(s) + + res[group] = state + + defer.returnValue(res) + + def store_state_groups(self, event): + return self.runInteraction( + "store_state_groups", + self._store_state_groups_txn, event + ) + + def _store_state_groups_txn(self, txn, event): + if not event.state_events: + return + + state_group = event.state_group + if not state_group: + state_group = self._simple_insert_txn( + txn, + table="state_groups", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + } + ) + + for state in event.state_events.values(): + self._simple_insert_txn( + txn, + table="state_groups_state", + values={ + "state_group": state_group, + "room_id": state.room_id, + "type": state.type, + "state_key": state.state_key, + "event_id": state.event_id, + } + ) + + self._simple_insert_txn( + txn, + table="event_to_state_groups", + values={ + "state_group": state_group, + "event_id": event.event_id, + } + ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d61f90993920..475e7f20a160 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -177,10 +177,9 @@ def get_room_events_stream(self, user_id, from_key, to_key, room_id, sql = ( "SELECT *, (%(redacted)s) AS redacted FROM events AS e WHERE " - "((room_id IN (%(current)s)) OR " + "(e.outlier = 0 AND (room_id IN (%(current)s)) OR " "(event_id IN (%(invites)s))) " "AND e.stream_ordering > ? AND e.stream_ordering <= ? " - "AND e.outlier = 0 " "ORDER BY stream_ordering ASC LIMIT %(limit)d " ) % { "redacted": del_sql, @@ -309,7 +308,10 @@ def get_recent_events_for_room(self, room_id, limit, end_token, defer.returnValue(ret) def get_room_events_max_id(self): - return self.runInteraction(self._get_room_events_max_id_txn) + return self.runInteraction( + "get_room_events_max_id", + self._get_room_events_max_id_txn + ) def _get_room_events_max_id_txn(self, txn): txn.execute( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 2ba8e30efe1f..00d0f4808271 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -14,7 +14,6 @@ # limitations under the License. from ._base import SQLBaseStore, Table -from .pdu import PdusTable from collections import namedtuple @@ -42,6 +41,7 @@ def get_received_txn_response(self, transaction_id, origin): """ return self.runInteraction( + "get_received_txn_response", self._get_received_txn_response, transaction_id, origin ) @@ -73,6 +73,7 @@ def set_received_txn_response(self, transaction_id, origin, code, """ return self.runInteraction( + "set_received_txn_response", self._set_received_txn_response, transaction_id, origin, code, response_dict ) @@ -88,7 +89,7 @@ def _set_received_txn_response(self, txn, transaction_id, origin, code, txn.execute(query, (code, response_json, transaction_id, origin)) def prep_send_transaction(self, transaction_id, destination, - origin_server_ts, pdu_list): + origin_server_ts): """Persists an outgoing transaction and calculates the values for the previous transaction id list. @@ -99,19 +100,19 @@ def prep_send_transaction(self, transaction_id, destination, transaction_id (str) destination (str) origin_server_ts (int) - pdu_list (list) Returns: list: A list of previous transaction ids. """ return self.runInteraction( + "prep_send_transaction", self._prep_send_transaction, - transaction_id, destination, origin_server_ts, pdu_list + transaction_id, destination, origin_server_ts ) def _prep_send_transaction(self, txn, transaction_id, destination, - origin_server_ts, pdu_list): + origin_server_ts): # First we find out what the prev_txs should be. # Since we know that we are only sending one transaction at a time, @@ -139,15 +140,15 @@ def _prep_send_transaction(self, txn, transaction_id, destination, # Update the tx id -> pdu id mapping - values = [ - (transaction_id, destination, pdu[0], pdu[1]) - for pdu in pdu_list - ] - - logger.debug("Inserting: %s", repr(values)) - - query = TransactionsToPduTable.insert_statement() - txn.executemany(query, values) + # values = [ + # (transaction_id, destination, pdu[0], pdu[1]) + # for pdu in pdu_list + # ] + # + # logger.debug("Inserting: %s", repr(values)) + # + # query = TransactionsToPduTable.insert_statement() + # txn.executemany(query, values) return prev_txns @@ -161,6 +162,7 @@ def delivered_txn(self, transaction_id, destination, code, response_dict): response_json (str) """ return self.runInteraction( + "delivered_txn", self._delivered_txn, transaction_id, destination, code, response_dict ) @@ -186,6 +188,7 @@ def get_transactions_after(self, transaction_id, destination): list: A list of `ReceivedTransactionsTable.EntryType` """ return self.runInteraction( + "get_transactions_after", self._get_transactions_after, transaction_id, destination ) @@ -202,49 +205,6 @@ def _get_transactions_after(cls, txn, transaction_id, destination): return ReceivedTransactionsTable.decode_results(txn.fetchall()) - def get_pdus_after_transaction(self, transaction_id, destination): - """For a given local transaction_id that we sent to a given destination - home server, return a list of PDUs that were sent to that destination - after it. - - Args: - txn - transaction_id (str) - destination (str) - - Returns - list: A list of PduTuple - """ - return self.runInteraction( - self._get_pdus_after_transaction, - transaction_id, destination - ) - - def _get_pdus_after_transaction(self, txn, transaction_id, destination): - - # Query that first get's all transaction_ids with an id greater than - # the one given from the `sent_transactions` table. Then JOIN on this - # from the `tx->pdu` table to get a list of (pdu_id, origin) that - # specify the pdus that were sent in those transactions. - query = ( - "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp " - "INNER JOIN %(sent_tx)s as st " - "ON tp.transaction_id = st.transaction_id " - "AND tp.destination = st.destination " - "WHERE st.id > (" - "SELECT id FROM %(sent_tx)s " - "WHERE transaction_id = ? AND destination = ?" - ) % { - "tx_pdu": TransactionsToPduTable.table_name, - "sent_tx": SentTransactions.table_name, - } - - txn.execute(query, (transaction_id, destination)) - - pdus = PdusTable.decode_results(txn.fetchall()) - - return self._get_pdu_tuples(txn, pdus) - class ReceivedTransactionsTable(Table): table_name = "received_transactions" diff --git a/synapse/types.py b/synapse/types.py index c51bc8e4f2c3..649ff2f7d7d2 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -78,6 +78,11 @@ def create_local(cls, localpart, hs): """Create a structure on the local domain""" return cls(localpart=localpart, domain=hs.hostname, is_mine=True) + @classmethod + def create(cls, localpart, domain, hs): + is_mine = domain == hs.hostname + return cls(localpart=localpart, domain=domain, is_mine=is_mine) + class UserID(DomainSpecificString): """Structure representing a user ID.""" @@ -94,6 +99,11 @@ class RoomID(DomainSpecificString): SIGIL = "!" +class EventID(DomainSpecificString): + """Structure representing an event id. """ + SIGIL = "$" + + class StreamToken( namedtuple( "Token", diff --git a/synapse/util/async.py b/synapse/util/async.py index 647ea6142c56..bf578f8bfbb2 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -21,3 +21,10 @@ def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) return d + + +def run_on_reactor(): + """ This will cause the rest of the function to be invoked upon the next + iteration of the main loop + """ + return sleep(0) \ No newline at end of file diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index c91eb897a81b..e79b68f66198 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -80,7 +80,7 @@ def get_dict(self): def get_full_dict(self): d = { - k: v for (k, v) in self.__dict__.items() + k: _encode(v) for (k, v) in self.__dict__.items() if k in self.valid_keys or k in self.internal_keys } d.update(self.unrecognized_keys) diff --git a/tests/events/test_events.py b/tests/events/test_events.py index a4b6cb3afd2b..91d1d44fee93 100644 --- a/tests/events/test_events.py +++ b/tests/events/test_events.py @@ -14,6 +14,8 @@ # limitations under the License. from synapse.api.events import SynapseEvent +from synapse.api.events.validator import EventValidator +from synapse.api.errors import SynapseError from tests import unittest @@ -21,7 +23,7 @@ class SynapseTemplateCheckTestCase(unittest.TestCase): def setUp(self): - pass + self.validator = EventValidator(None) def tearDown(self): pass @@ -38,22 +40,28 @@ def test_top_level_keys(self): } event = MockSynapseEvent(template) - self.assertTrue(event.check_json(content, raises=False)) + event.content = content + self.assertTrue(self.validator.validate(event)) content = { "person": {"name": "bob"}, "friends": ["jill"], "enemies": ["mike"] } - event = MockSynapseEvent(template) - self.assertTrue(event.check_json(content, raises=False)) + event.content = content + self.assertTrue(self.validator.validate(event)) content = { "person": {"name": "bob"}, # missing friends "enemies": ["mike", "jill"] } - self.assertFalse(event.check_json(content, raises=False)) + event.content = content + self.assertRaises( + SynapseError, + self.validator.validate, + event + ) def test_lists(self): template = { @@ -67,13 +75,19 @@ def test_lists(self): } event = MockSynapseEvent(template) - self.assertFalse(event.check_json(content, raises=False)) + event.content = content + self.assertRaises( + SynapseError, + self.validator.validate, + event + ) content = { "person": {"name": "bob"}, "friends": [{"name": "jill"}, {"name": "mike"}] } - self.assertTrue(event.check_json(content, raises=False)) + event.content = content + self.assertTrue(self.validator.validate(event)) def test_nested_lists(self): template = { @@ -103,7 +117,12 @@ def test_nested_lists(self): } event = MockSynapseEvent(template) - self.assertFalse(event.check_json(content, raises=False)) + event.content = content + self.assertRaises( + SynapseError, + self.validator.validate, + event + ) content = { "results": { @@ -117,7 +136,8 @@ def test_nested_lists(self): ] } } - self.assertTrue(event.check_json(content, raises=False)) + event.content = content + self.assertTrue(self.validator.validate(event)) def test_nested_keys(self): template = { @@ -145,7 +165,8 @@ def test_nested_keys(self): } } - self.assertTrue(event.check_json(content, raises=False)) + event.content = content + self.assertTrue(self.validator.validate(event)) content = { "person": { @@ -159,7 +180,12 @@ def test_nested_keys(self): } } - self.assertFalse(event.check_json(content, raises=False)) + event.content = content + self.assertRaises( + SynapseError, + self.validator.validate, + event + ) content = { "person": { @@ -173,7 +199,12 @@ def test_nested_keys(self): } } - self.assertFalse(event.check_json(content, raises=False)) + event.content = content + self.assertRaises( + SynapseError, + self.validator.validate, + event + ) class MockSynapseEvent(SynapseEvent): diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py index 933aa61c776e..eb329eec507a 100644 --- a/tests/federation/test_federation.py +++ b/tests/federation/test_federation.py @@ -24,7 +24,6 @@ from synapse.server import HomeServer from synapse.federation import initialize_http_replication from synapse.federation.units import Pdu -from synapse.storage.pdu import PduTuple, PduEntry def make_pdu(prev_pdus=[], **kwargs): @@ -41,7 +40,7 @@ def make_pdu(prev_pdus=[], **kwargs): } pdu_fields.update(kwargs) - return PduTuple(PduEntry(**pdu_fields), prev_pdus) + return Pdu(prev_pdus=prev_pdus, **pdu_fields) class FederationTestCase(unittest.TestCase): @@ -52,177 +51,185 @@ def setUp(self): "put_json", ]) self.mock_persistence = Mock(spec=[ - "get_current_state_for_context", - "get_pdu", - "persist_event", - "update_min_depth_for_context", "prep_send_transaction", "delivered_txn", "get_received_txn_response", "set_received_txn_response", ]) self.mock_persistence.get_received_txn_response.return_value = ( - defer.succeed(None) + defer.succeed(None) ) self.mock_config = Mock() self.mock_config.signing_key = [MockKey()] self.clock = MockClock() - hs = HomeServer("test", - resource_for_federation=self.mock_resource, - http_client=self.mock_http_client, - db_pool=None, - datastore=self.mock_persistence, - clock=self.clock, - config=self.mock_config, - keyring=Mock(), + hs = HomeServer( + "test", + resource_for_federation=self.mock_resource, + http_client=self.mock_http_client, + db_pool=None, + datastore=self.mock_persistence, + clock=self.clock, + config=self.mock_config, + keyring=Mock(), ) self.federation = initialize_http_replication(hs) self.distributor = hs.get_distributor() @defer.inlineCallbacks def test_get_state(self): - self.mock_persistence.get_current_state_for_context.return_value = ( - defer.succeed([]) - ) + mock_handler = Mock(spec=[ + "get_state_for_pdu", + ]) + + self.federation.set_handler(mock_handler) + + mock_handler.get_state_for_pdu.return_value = defer.succeed([]) # Empty context initially - (code, response) = yield self.mock_resource.trigger("GET", - "/_matrix/federation/v1/state/my-context/", None) + (code, response) = yield self.mock_resource.trigger( + "GET", + "/_matrix/federation/v1/state/my-context/", + None + ) self.assertEquals(200, code) self.assertFalse(response["pdus"]) # Now lets give the context some state - self.mock_persistence.get_current_state_for_context.return_value = ( + mock_handler.get_state_for_pdu.return_value = ( defer.succeed([ make_pdu( - pdu_id="the-pdu-id", + event_id="the-pdu-id", origin="red", - context="my-context", - pdu_type="m.topic", - ts=123456789000, + room_id="my-context", + type="m.topic", + origin_server_ts=123456789000, depth=1, - is_state=True, - content_json='{"topic":"The topic"}', + content={"topic": "The topic"}, state_key="", power_level=1000, - prev_state_id="last-pdu-id", - prev_state_origin="blue", + prev_state="last-pdu-id", ), ]) ) - (code, response) = yield self.mock_resource.trigger("GET", - "/_matrix/federation/v1/state/my-context/", None) + (code, response) = yield self.mock_resource.trigger( + "GET", + "/_matrix/federation/v1/state/my-context/", + None + ) self.assertEquals(200, code) self.assertEquals(1, len(response["pdus"])) @defer.inlineCallbacks def test_get_pdu(self): - self.mock_persistence.get_pdu.return_value = ( + mock_handler = Mock(spec=[ + "get_persisted_pdu", + ]) + + self.federation.set_handler(mock_handler) + + mock_handler.get_persisted_pdu.return_value = ( defer.succeed(None) ) - (code, response) = yield self.mock_resource.trigger("GET", - "/_matrix/federation/v1/pdu/red/abc123def456/", None) + (code, response) = yield self.mock_resource.trigger( + "GET", + "/_matrix/federation/v1/event/abc123def456/", + None + ) self.assertEquals(404, code) # Now insert such a PDU - self.mock_persistence.get_pdu.return_value = ( + mock_handler.get_persisted_pdu.return_value = ( defer.succeed( make_pdu( - pdu_id="abc123def456", + event_id="abc123def456", origin="red", - context="my-context", - pdu_type="m.text", - ts=123456789001, + room_id="my-context", + type="m.text", + origin_server_ts=123456789001, depth=1, - content_json='{"text":"Here is the message"}', + content={"text": "Here is the message"}, ) ) ) - (code, response) = yield self.mock_resource.trigger("GET", - "/_matrix/federation/v1/pdu/red/abc123def456/", None) + (code, response) = yield self.mock_resource.trigger( + "GET", + "/_matrix/federation/v1/event/abc123def456/", + None + ) self.assertEquals(200, code) self.assertEquals(1, len(response["pdus"])) - self.assertEquals("m.text", response["pdus"][0]["pdu_type"]) + self.assertEquals("m.text", response["pdus"][0]["type"]) @defer.inlineCallbacks def test_send_pdu(self): self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") + (200, "OK") ) pdu = Pdu( - pdu_id="abc123def456", - origin="red", - destinations=["remote"], - context="my-context", - origin_server_ts=123456789002, - pdu_type="m.test", - content={"testing": "content here"}, - depth=1, + event_id="abc123def456", + origin="red", + room_id="my-context", + type="m.text", + origin_server_ts=123456789001, + depth=1, + content={"text": "Here is the message"}, + destinations=["remote"], ) yield self.federation.send_pdu(pdu) self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin_server_ts": 1000000, - "origin": "test", - "pdus": [ - { - "origin": "red", - "pdu_id": "abc123def456", - "prev_pdus": [], - "origin_server_ts": 123456789002, - "context": "my-context", - "pdu_type": "m.test", - "is_state": False, - "content": {"testing": "content here"}, - "depth": 1, - }, - ] - }, - json_data_callback=ANY, + "remote", + path="/_matrix/federation/v1/send/1000000/", + data={ + "origin_server_ts": 1000000, + "origin": "test", + "pdus": [ + pdu.get_dict(), + ], + 'pdu_failures': [], + }, + json_data_callback=ANY, ) @defer.inlineCallbacks def test_send_edu(self): self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") + (200, "OK") ) yield self.federation.send_edu( - destination="remote", - edu_type="m.test", - content={"testing": "content here"}, + destination="remote", + edu_type="m.test", + content={"testing": "content here"}, ) # MockClock ensures we can guess these timestamps self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin": "test", - "origin_server_ts": 1000000, - "pdus": [], - "edus": [ - { - # TODO: SYN-103: Remove "origin" and "destination" - "origin": "test", - "destination": "remote", - "edu_type": "m.test", - "content": {"testing": "content here"}, - } - ], - }, - json_data_callback=ANY, + "remote", + path="/_matrix/federation/v1/send/1000000/", + data={ + "origin": "test", + "origin_server_ts": 1000000, + "pdus": [], + "edus": [ + { + # TODO: SYN-103: Remove "origin" and "destination" + "origin": "test", + "destination": "remote", + "edu_type": "m.test", + "content": {"testing": "content here"}, + } + ], + 'pdu_failures': [], + }, + json_data_callback=ANY, ) - @defer.inlineCallbacks def test_recv_edu(self): recv_observer = Mock() @@ -230,24 +237,26 @@ def test_recv_edu(self): self.federation.register_edu_handler("m.test", recv_observer) - yield self.mock_resource.trigger("PUT", - "/_matrix/federation/v1/send/1001000/", - """{ - "origin": "remote", - "origin_server_ts": 1001000, - "pdus": [], - "edus": [ - { - "origin": "remote", - "destination": "test", - "edu_type": "m.test", - "content": {"testing": "reply here"} - } - ] - }""") + yield self.mock_resource.trigger( + "PUT", + "/_matrix/federation/v1/send/1001000/", + """{ + "origin": "remote", + "origin_server_ts": 1001000, + "pdus": [], + "edus": [ + { + "origin": "remote", + "destination": "test", + "edu_type": "m.test", + "content": {"testing": "reply here"} + } + ] + }""" + ) recv_observer.assert_called_with( - "remote", {"testing": "reply here"} + "remote", {"testing": "reply here"} ) @defer.inlineCallbacks @@ -278,8 +287,11 @@ def test_recv_query(self): self.federation.register_query_handler("a-question", recv_handler) - code, response = yield self.mock_resource.trigger("GET", - "/_matrix/federation/v1/query/a-question?three=3&four=4", None) + code, response = yield self.mock_resource.trigger( + "GET", + "/_matrix/federation/v1/query/a-question?three=3&four=4", + None + ) self.assertEquals(200, code) self.assertEquals({"another": "response"}, response) diff --git a/tests/federation/test_pdu_codec.py b/tests/federation/test_pdu_codec.py deleted file mode 100644 index 0754ef92e822..000000000000 --- a/tests/federation/test_pdu_codec.py +++ /dev/null @@ -1,160 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tests import unittest - -from synapse.federation.pdu_codec import ( - PduCodec, encode_event_id, decode_event_id -) -from synapse.federation.units import Pdu -#from synapse.api.events.room import MessageEvent - -from synapse.server import HomeServer - -from mock import Mock - - -class PduCodecTestCase(unittest.TestCase): - def setUp(self): - self.hs = HomeServer("blargle.net") - self.event_factory = self.hs.get_event_factory() - - self.codec = PduCodec(self.hs) - - def test_decode_event_id(self): - self.assertEquals( - ("foo", "bar.com"), - decode_event_id("foo@bar.com", "A") - ) - - self.assertEquals( - ("foo", "bar.com"), - decode_event_id("foo", "bar.com") - ) - - def test_encode_event_id(self): - self.assertEquals("A@B", encode_event_id("A", "B")) - - def test_codec_event_id(self): - event_id = "aa@bb.com" - - self.assertEquals( - event_id, - encode_event_id(*decode_event_id(event_id, None)) - ) - - pdu_id = ("aa", "bb.com") - - self.assertEquals( - pdu_id, - decode_event_id(encode_event_id(*pdu_id), None) - ) - - def test_event_from_pdu(self): - pdu = Pdu( - pdu_id="foo", - context="rooooom", - pdu_type="m.room.message", - origin="bar.com", - origin_server_ts=12345, - depth=5, - prev_pdus=[("alice", "bob.com")], - is_state=False, - content={"msgtype": u"test"}, - ) - - event = self.codec.event_from_pdu(pdu) - - self.assertEquals("foo@bar.com", event.event_id) - self.assertEquals(pdu.context, event.room_id) - self.assertEquals(pdu.is_state, event.is_state) - self.assertEquals(pdu.depth, event.depth) - self.assertEquals(["alice@bob.com"], event.prev_events) - self.assertEquals(pdu.content, event.content) - - def test_pdu_from_event(self): - event = self.event_factory.create_event( - etype="m.room.message", - event_id="gargh_id", - room_id="rooom", - user_id="sender", - content={"msgtype": u"test"}, - ) - - pdu = self.codec.pdu_from_event(event) - - self.assertEquals(event.event_id, pdu.pdu_id) - self.assertEquals(self.hs.hostname, pdu.origin) - self.assertEquals(event.room_id, pdu.context) - self.assertEquals(event.content, pdu.content) - self.assertEquals(event.type, pdu.pdu_type) - - event = self.event_factory.create_event( - etype="m.room.message", - event_id="gargh_id@bob.com", - room_id="rooom", - user_id="sender", - content={"msgtype": u"test"}, - ) - - pdu = self.codec.pdu_from_event(event) - - self.assertEquals("gargh_id", pdu.pdu_id) - self.assertEquals("bob.com", pdu.origin) - self.assertEquals(event.room_id, pdu.context) - self.assertEquals(event.content, pdu.content) - self.assertEquals(event.type, pdu.pdu_type) - - def test_event_from_state_pdu(self): - pdu = Pdu( - pdu_id="foo", - context="rooooom", - pdu_type="m.room.topic", - origin="bar.com", - origin_server_ts=12345, - depth=5, - prev_pdus=[("alice", "bob.com")], - is_state=True, - content={"topic": u"test"}, - state_key="", - ) - - event = self.codec.event_from_pdu(pdu) - - self.assertEquals("foo@bar.com", event.event_id) - self.assertEquals(pdu.context, event.room_id) - self.assertEquals(pdu.is_state, event.is_state) - self.assertEquals(pdu.depth, event.depth) - self.assertEquals(["alice@bob.com"], event.prev_events) - self.assertEquals(pdu.content, event.content) - self.assertEquals(pdu.state_key, event.state_key) - - def test_pdu_from_state_event(self): - event = self.event_factory.create_event( - etype="m.room.topic", - event_id="gargh_id", - room_id="rooom", - user_id="sender", - content={"topic": u"test"}, - ) - - pdu = self.codec.pdu_from_event(event) - - self.assertEquals(event.event_id, pdu.pdu_id) - self.assertEquals(self.hs.hostname, pdu.origin) - self.assertEquals(event.room_id, pdu.context) - self.assertEquals(event.content, pdu.content) - self.assertEquals(event.type, pdu.pdu_type) - self.assertEquals(event.state_key, pdu.state_key) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index e10a49a8acbc..8e164e4be031 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -21,9 +21,8 @@ from synapse.server import HomeServer from synapse.handlers.directory import DirectoryHandler -from synapse.storage.directory import RoomAliasMapping -from tests.utils import SQLiteMemoryDbPool +from tests.utils import SQLiteMemoryDbPool, MockKey class DirectoryHandlers(object): @@ -41,6 +40,7 @@ def setUp(self): ]) self.query_handlers = {} + def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler self.mock_federation.register_query_handler = register_query_handler @@ -48,11 +48,16 @@ def register_query_handler(query_type, handler): db_pool = SQLiteMemoryDbPool() yield db_pool.prepare() - hs = HomeServer("test", + self.mock_config = Mock() + self.mock_config.signing_key = [MockKey()] + + hs = HomeServer( + "test", db_pool=db_pool, http_client=None, resource_for_federation=Mock(), replication_layer=self.mock_federation, + config=self.mock_config, ) hs.handlers = DirectoryHandlers(hs) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 219b2c4c5ee1..a9d6b2bb17b8 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -17,16 +17,15 @@ from tests import unittest from synapse.api.events.room import ( - InviteJoinEvent, MessageEvent, RoomMemberEvent + MessageEvent, ) -from synapse.api.constants import Membership from synapse.handlers.federation import FederationHandler from synapse.server import HomeServer from synapse.federation.units import Pdu from mock import NonCallableMock, ANY -from ..utils import get_mock_call_args, MockKey +from ..utils import MockKey class FederationTestCase(unittest.TestCase): @@ -36,6 +35,14 @@ def setUp(self): self.mock_config = NonCallableMock() self.mock_config.signing_key = [MockKey()] + self.state_handler = NonCallableMock(spec_set=[ + "annotate_state_groups", + ]) + + self.auth = NonCallableMock(spec_set=[ + "check", + ]) + self.hostname = "test" hs = HomeServer( self.hostname, @@ -53,6 +60,8 @@ def setUp(self): "federation_handler", ]), config=self.mock_config, + auth=self.auth, + state_handler=self.state_handler, ) self.datastore = hs.get_datastore() @@ -65,74 +74,35 @@ def setUp(self): @defer.inlineCallbacks def test_msg(self): pdu = Pdu( - pdu_type=MessageEvent.TYPE, - context="foo", + type=MessageEvent.TYPE, + room_id="foo", content={"msgtype": u"fooo"}, origin_server_ts=0, - pdu_id="a", + event_id="$a:b", origin="b", ) - store_id = "ASD" - self.datastore.persist_event.return_value = defer.succeed(store_id) + self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.get_room.return_value = defer.succeed(True) + self.state_handler.annotate_state_groups.return_value = ( + defer.succeed(False) + ) + yield self.handlers.federation_handler.on_receive_pdu(pdu, False) self.datastore.persist_event.assert_called_once_with( ANY, False, is_new_state=False ) - self.notifier.on_new_room_event.assert_called_once_with(ANY, extra_users=[]) - - @defer.inlineCallbacks - def test_invite_join_target_this(self): - room_id = "foo" - user_id = "@bob:red" - pdu = Pdu( - pdu_type=InviteJoinEvent.TYPE, - user_id=user_id, - target_host=self.hostname, - context=room_id, - content={}, - origin_server_ts=0, - pdu_id="a", - origin="b", + self.state_handler.annotate_state_groups.assert_called_once_with( + ANY, + old_state=None, ) - yield self.handlers.federation_handler.on_receive_pdu(pdu, False) + self.auth.check.assert_called_once_with(ANY, raises=True) - mem_handler = self.handlers.room_member_handler - self.assertEquals(1, mem_handler.change_membership.call_count) - call_args = get_mock_call_args( - lambda event, do_auth: None, - mem_handler.change_membership + self.notifier.on_new_room_event.assert_called_once_with( + ANY, + extra_users=[] ) - self.assertEquals(False, call_args["do_auth"]) - - new_event = call_args["event"] - self.assertEquals(RoomMemberEvent.TYPE, new_event.type) - self.assertEquals(room_id, new_event.room_id) - self.assertEquals(user_id, new_event.state_key) - self.assertEquals(Membership.JOIN, new_event.membership) - - @defer.inlineCallbacks - def test_invite_join_target_other(self): - room_id = "foo" - user_id = "@bob:red" - - pdu = Pdu( - pdu_type=InviteJoinEvent.TYPE, - user_id=user_id, - state_key="@red:not%s" % self.hostname, - context=room_id, - content={}, - origin_server_ts=0, - pdu_id="a", - origin="b", - ) - - yield self.handlers.federation_handler.on_receive_pdu(pdu, False) - - mem_handler = self.handlers.room_member_handler - self.assertEquals(0, mem_handler.change_membership.call_count) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index fdc2e8de4ac6..a6af648defaa 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -51,6 +51,7 @@ def _expect_edu(destination, edu_type, content, origin="test"): "content": content, } ], + "pdu_failures": [], } def _make_edu_json(origin, edu_type, content): diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py index 047752ad6815..532ecf0f2cec 100644 --- a/tests/handlers/test_presencelike.py +++ b/tests/handlers/test_presencelike.py @@ -21,7 +21,7 @@ from mock import Mock, call, ANY -from ..utils import MockClock +from ..utils import MockClock, MockKey from synapse.server import HomeServer from synapse.api.constants import PresenceState @@ -57,6 +57,9 @@ def __init__(self, hs): class PresenceProfilelikeDataTestCase(unittest.TestCase): def setUp(self): + self.mock_config = Mock() + self.mock_config.signing_key = [MockKey()] + hs = HomeServer("test", clock=MockClock(), db_pool=None, @@ -72,6 +75,7 @@ def setUp(self): resource_for_federation=Mock(), http_client=None, replication_layer=MockReplication(), + config=self.mock_config, ) hs.handlers = PresenceAndProfileHandlers(hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5dc9b456e1b9..1660e7e9285b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,7 +24,7 @@ from synapse.handlers.profile import ProfileHandler from synapse.api.constants import Membership -from tests.utils import SQLiteMemoryDbPool +from tests.utils import SQLiteMemoryDbPool, MockKey class ProfileHandlers(object): @@ -49,12 +49,16 @@ def register_query_handler(query_type, handler): db_pool = SQLiteMemoryDbPool() yield db_pool.prepare() + self.mock_config = Mock() + self.mock_config.signing_key = [MockKey()] + hs = HomeServer("test", db_pool=db_pool, http_client=None, handlers=None, resource_for_federation=Mock(), replication_layer=self.mock_federation, + config=self.mock_config, ) hs.handlers = ProfileHandlers(hs) diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index c88d1c884036..55c9f6e14258 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -18,7 +18,7 @@ from tests import unittest from synapse.api.events.room import ( - InviteJoinEvent, RoomMemberEvent, RoomConfigEvent + RoomMemberEvent, ) from synapse.api.constants import Membership from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler @@ -34,6 +34,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): def setUp(self): self.mock_config = NonCallableMock() self.mock_config.signing_key = [MockKey()] + self.hostname = "red" hs = HomeServer( self.hostname, @@ -57,13 +58,16 @@ def setUp(self): "profile_handler", "federation_handler", ]), - auth=NonCallableMock(spec_set=["check"]), - state_handler=NonCallableMock(spec_set=["handle_new_event"]), + auth=NonCallableMock(spec_set=["check", "add_auth_events"]), + state_handler=NonCallableMock(spec_set=[ + "annotate_state_groups", + ]), config=self.mock_config, ) self.federation = NonCallableMock(spec_set=[ "handle_new_event", + "send_invite", "get_state_for_room", ]) @@ -106,7 +110,6 @@ def test_invite(self): joined = ["red", "green"] - self.state_handler.handle_new_event.return_value = defer.succeed(True) self.datastore.get_joined_hosts_for_room.return_value = ( defer.succeed(joined) ) @@ -114,18 +117,29 @@ def test_invite(self): store_id = "store_id_fooo" self.datastore.persist_event.return_value = defer.succeed(store_id) + self.datastore.get_room_member.return_value = defer.succeed(None) + + event.state_events = { + (RoomMemberEvent.TYPE, "@alice:green"): self._create_member( + user_id="@alice:green", + room_id=room_id, + ), + (RoomMemberEvent.TYPE, "@bob:red"): self._create_member( + user_id="@bob:red", + room_id=room_id, + ), + (RoomMemberEvent.TYPE, target_user_id): event, + } + # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with( - event, self.snapshot, - ) self.federation.handle_new_event.assert_called_once_with( event, self.snapshot, ) self.assertEquals( - set(["blue", "red", "green"]), + set(["red", "green"]), set(event.destinations) ) @@ -144,28 +158,19 @@ def test_simple_join(self): room_id = "!foo:red" user_id = "@bob:red" user = self.hs.parse_userid(user_id) - target_user_id = "@bob:red" - content = {"membership": Membership.JOIN} - event = self.hs.get_event_factory().create_event( - etype=RoomMemberEvent.TYPE, + event = self._create_member( user_id=user_id, - state_key=target_user_id, room_id=room_id, - membership=Membership.JOIN, - content=content, ) joined = ["red", "green"] - self.state_handler.handle_new_event.return_value = defer.succeed(True) - def get_joined(*args): return defer.succeed(joined) self.datastore.get_joined_hosts_for_room.side_effect = get_joined - store_id = "store_id_fooo" self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.get_room.return_value = defer.succeed(1) # Not None. @@ -178,12 +183,17 @@ def get_joined(*args): join_signal_observer = Mock() self.distributor.observe("user_joined_room", join_signal_observer) + event.state_events = { + (RoomMemberEvent.TYPE, "@alice:green"): self._create_member( + user_id="@alice:green", + room_id=room_id, + ), + (RoomMemberEvent.TYPE, user_id): event, + } + # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with( - event, self.snapshot - ) self.federation.handle_new_event.assert_called_once_with( event, self.snapshot ) @@ -197,138 +207,32 @@ def get_joined(*args): event ) self.notifier.on_new_room_event.assert_called_once_with( - event, extra_users=[user]) - - join_signal_observer.assert_called_with( - user=user, room_id=room_id) - - @defer.inlineCallbacks - def STALE_test_invite_join(self): - room_id = "foo" - user_id = "@bob:red" - target_user_id = "@bob:red" - content = {"membership": Membership.JOIN} - - event = self.hs.get_event_factory().create_event( - etype=RoomMemberEvent.TYPE, - user_id=user_id, - target_user_id=target_user_id, - room_id=room_id, - membership=Membership.JOIN, - content=content, - ) - - joined = ["red", "blue", "green"] - - self.state_handler.handle_new_event.return_value = defer.succeed(True) - self.datastore.get_joined_hosts_for_room.return_value = ( - defer.succeed(joined) - ) - - store_id = "store_id_fooo" - self.datastore.store_room_member.return_value = defer.succeed(store_id) - self.datastore.get_room.return_value = defer.succeed(None) - - prev_state = NonCallableMock(name="prev_state") - prev_state.membership = Membership.INVITE - prev_state.sender = "@foo:blue" - self.datastore.get_room_member.return_value = defer.succeed(prev_state) - - # Actual invocation - yield self.room_member_handler.change_membership(event) - - self.datastore.get_room_member.assert_called_once_with( - target_user_id, room_id + event, extra_users=[user] ) - self.assertTrue(self.federation.handle_new_event.called) - args = self.federation.handle_new_event.call_args[0] - invite_join_event = args[0] - - self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE) - self.assertTrue("blue", invite_join_event.target_host) - self.assertTrue(room_id, invite_join_event.room_id) - self.assertTrue(user_id, invite_join_event.user_id) - self.assertFalse(hasattr(invite_join_event, "state_key")) - - self.assertEquals( - set(["blue"]), - set(invite_join_event.destinations) - ) - - self.federation.get_state_for_room.assert_called_once_with( - "blue", room_id + join_signal_observer.assert_called_with( + user=user, room_id=room_id ) - self.assertFalse(self.datastore.store_room_member.called) - - self.assertFalse(self.notifier.on_new_room_event.called) - self.assertFalse(self.state_handler.handle_new_event.called) - - @defer.inlineCallbacks - def STALE_test_invite_join_public(self): - room_id = "#foo:blue" - user_id = "@bob:red" - target_user_id = "@bob:red" - content = {"membership": Membership.JOIN} - - event = self.hs.get_event_factory().create_event( + def _create_member(self, user_id, room_id): + return self.hs.get_event_factory().create_event( etype=RoomMemberEvent.TYPE, user_id=user_id, - target_user_id=target_user_id, + state_key=user_id, room_id=room_id, membership=Membership.JOIN, - content=content, - ) - - joined = ["red", "blue", "green"] - - self.state_handler.handle_new_event.return_value = defer.succeed(True) - self.datastore.get_joined_hosts_for_room.return_value = ( - defer.succeed(joined) - ) - - store_id = "store_id_fooo" - self.datastore.store_room_member.return_value = defer.succeed(store_id) - self.datastore.get_room.return_value = defer.succeed(None) - - prev_state = NonCallableMock(name="prev_state") - prev_state.membership = Membership.INVITE - prev_state.sender = "@foo:blue" - self.datastore.get_room_member.return_value = defer.succeed(prev_state) - - # Actual invocation - yield self.room_member_handler.change_membership(event) - - self.assertTrue(self.federation.handle_new_event.called) - args = self.federation.handle_new_event.call_args[0] - invite_join_event = args[0] - - self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE) - self.assertTrue("blue", invite_join_event.target_host) - self.assertTrue("foo", invite_join_event.room_id) - self.assertTrue(user_id, invite_join_event.user_id) - self.assertFalse(hasattr(invite_join_event, "state_key")) - - self.assertEquals( - set(["blue"]), - set(invite_join_event.destinations) + content={"membership": Membership.JOIN}, ) - self.federation.get_state_for_room.assert_called_once_with( - "blue", "foo" - ) - - self.assertFalse(self.datastore.store_room_member.called) - - self.assertFalse(self.notifier.on_new_room_event.called) - self.assertFalse(self.state_handler.handle_new_event.called) - class RoomCreationTest(unittest.TestCase): def setUp(self): self.hostname = "red" + + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] + hs = HomeServer( self.hostname, db_pool=None, @@ -345,12 +249,14 @@ def setUp(self): "room_member_handler", "federation_handler", ]), - auth=NonCallableMock(spec_set=["check"]), - state_handler=NonCallableMock(spec_set=["handle_new_event"]), + auth=NonCallableMock(spec_set=["check", "add_auth_events"]), + state_handler=NonCallableMock(spec_set=[ + "annotate_state_groups", + ]), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.federation = NonCallableMock(spec_set=[ @@ -373,6 +279,11 @@ def setUp(self): ]) self.room_member_handler = self.handlers.room_member_handler + def annotate(event): + event.state_events = {} + return defer.succeed(None) + self.state_handler.annotate_state_groups.side_effect = annotate + def hosts(room): return defer.succeed([]) self.datastore.get_joined_hosts_for_room.side_effect = hosts @@ -400,6 +311,6 @@ def test_room_creation(self): self.assertEquals(user_id, join_event.user_id) self.assertEquals(user_id, join_event.state_key) - self.assertTrue(self.state_handler.handle_new_event.called) + self.assertTrue(self.state_handler.annotate_state_groups.called) self.assertTrue(self.federation.handle_new_event.called) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f1d3b27f741b..07acda5eeec3 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -40,6 +40,7 @@ def _expect_edu(destination, edu_type, content, origin="test"): "content": content, } ], + "pdu_failures": [], } diff --git a/tests/rest/test_events.py b/tests/rest/test_events.py index 79b371c04dff..4a3234c3320b 100644 --- a/tests/rest/test_events.py +++ b/tests/rest/test_events.py @@ -25,10 +25,7 @@ from synapse.server import HomeServer -# python imports -import json - -from ..utils import MockHttpResource, MemoryDataStore +from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey from .utils import RestTestCase from mock import Mock, NonCallableMock @@ -49,7 +46,7 @@ def setUp(self): def tearDown(self): pass - def test_long_poll(self): + def TODO_test_long_poll(self): # stream from 'end' key, send (self+other) message, expect message. # stream from 'END', send (self+other) message, expect message. @@ -64,7 +61,7 @@ def test_long_poll(self): pass - def test_stream_forward(self): + def TODO_test_stream_forward(self): # stream from START, expect injected items # stream from 'start' key, expect same content @@ -80,14 +77,14 @@ def test_stream_forward(self): # returned as end key pass - def test_limits(self): + def TODO_test_limits(self): # stream from a key, expect limit_num items # stream from START, expect limit_num items pass - def test_range(self): + def TODO_test_range(self): # stream from key to key, expect X items # stream from key to END, expect X items @@ -97,7 +94,7 @@ def test_range(self): # stream from START to END, expect all items pass - def test_direction(self): + def TODO_test_direction(self): # stream from END to START and fwds, expect newest first # stream from END to START and bwds, expect oldest first @@ -116,19 +113,20 @@ class EventStreamPermissionsTestCase(RestTestCase): def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service.get_latest_pdus_in_context.return_value = [] + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] + + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() + hs = HomeServer( "test", - db_pool=None, + db_pool=db_pool, http_client=None, replication_layer=Mock(), - state_handler=state_handler, - datastore=MemoryDataStore(), persistence_service=persistence_service, clock=Mock(spec=[ "call_later", @@ -139,7 +137,7 @@ def setUp(self): ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -148,6 +146,7 @@ def setUp(self): hs.get_handlers().federation_handler = Mock() hs.get_clock().time_msec.return_value = 1000000 + hs.get_clock().time.return_value = 1000 synapse.rest.register.register_servlets(hs, self.mock_resource) synapse.rest.events.register_servlets(hs, self.mock_resource) @@ -172,12 +171,14 @@ def tearDown(self): def test_stream_basic_permissions(self): # invalid token, expect 403 (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s" % ("invalid" + self.token)) + "/events?access_token=%s" % ("invalid" + self.token, ) + ) self.assertEquals(403, code, msg=str(response)) # valid token, expect content (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s&timeout=0" % (self.token)) + "/events?access_token=%s&timeout=0" % (self.token,) + ) self.assertEquals(200, code, msg=str(response)) self.assertTrue("chunk" in response) self.assertTrue("start" in response) @@ -185,15 +186,23 @@ def test_stream_basic_permissions(self): @defer.inlineCallbacks def test_stream_room_permissions(self): - room_id = yield self.create_room_as(self.other_user, - tok=self.other_token) + room_id = yield self.create_room_as( + self.other_user, + tok=self.other_token + ) yield self.send(room_id, tok=self.other_token) # invited to room (expect no content for room) - yield self.invite(room_id, src=self.other_user, targ=self.user_id, - tok=self.other_token) + yield self.invite( + room_id, + src=self.other_user, + targ=self.user_id, + tok=self.other_token + ) + (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s&timeout=0" % (self.token)) + "/events?access_token=%s&timeout=0" % (self.token,) + ) self.assertEquals(200, code, msg=str(response)) self.assertEquals(0, len(response["chunk"])) @@ -203,7 +212,7 @@ def test_stream_room_permissions(self): # left to room (expect no content for room) - def test_stream_items(self): + def TODO_test_stream_items(self): # new user, no content # join room, expect 1 item (join) diff --git a/tests/rest/test_profile.py b/tests/rest/test_profile.py index b0f48e7fd8bf..3a0d1e700a9d 100644 --- a/tests/rest/test_profile.py +++ b/tests/rest/test_profile.py @@ -18,9 +18,9 @@ from tests import unittest from twisted.internet import defer -from mock import Mock +from mock import Mock, NonCallableMock -from ..utils import MockHttpResource +from ..utils import MockHttpResource, MockKey from synapse.api.errors import SynapseError, AuthError from synapse.server import HomeServer @@ -41,6 +41,9 @@ def setUp(self): "set_avatar_url", ]) + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] + hs = HomeServer("test", db_pool=None, http_client=None, @@ -48,6 +51,7 @@ def setUp(self): federation=Mock(), replication_layer=Mock(), datastore=None, + config=self.mock_config, ) def _get_user_by_req(request=None): diff --git a/tests/rest/test_rooms.py b/tests/rest/test_rooms.py index 1ce9b8a83dfd..61b01d369dcf 100644 --- a/tests/rest/test_rooms.py +++ b/tests/rest/test_rooms.py @@ -23,11 +23,14 @@ from synapse.server import HomeServer +from tests import unittest + # python imports import json import urllib +import types -from ..utils import MockHttpResource, MemoryDataStore +from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey from .utils import RestTestCase from mock import Mock, NonCallableMock @@ -44,24 +47,21 @@ class RoomPermissionsTestCase(RestTestCase): def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -76,6 +76,10 @@ def _get_user_by_token(token=None): } hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + self.auth_user_id = self.rmcreator_id synapse.rest.room.register_servlets(hs, self.mock_resource) @@ -147,38 +151,55 @@ def tearDown(self): @defer.inlineCallbacks def test_send_message(self): msg_content = '{"msgtype":"m.text","body":"hello"}' - send_msg_path = ("/rooms/%s/send/m.room.message/mid1" % - (self.created_rmid)) + send_msg_path = ( + "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,) + ) # send message in uncreated room, expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % - (self.uncreated_rmid), msg_content) + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content + ) self.assertEquals(403, code, msg=str(response)) # send message in created room not joined (no state), expect 403 (code, response) = yield self.mock_resource.trigger( - "PUT", send_msg_path, msg_content) + "PUT", + send_msg_path, + msg_content + ) self.assertEquals(403, code, msg=str(response)) # send message in created room and invited, expect 403 - yield self.invite(room=self.created_rmid, src=self.rmcreator_id, - targ=self.user_id) + yield self.invite( + room=self.created_rmid, + src=self.rmcreator_id, + targ=self.user_id + ) (code, response) = yield self.mock_resource.trigger( - "PUT", send_msg_path, msg_content) + "PUT", + send_msg_path, + msg_content + ) self.assertEquals(403, code, msg=str(response)) # send message in created room and joined, expect 200 yield self.join(room=self.created_rmid, user=self.user_id) (code, response) = yield self.mock_resource.trigger( - "PUT", send_msg_path, msg_content) + "PUT", + send_msg_path, + msg_content + ) self.assertEquals(200, code, msg=str(response)) # send message in created room and left, expect 403 yield self.leave(room=self.created_rmid, user=self.user_id) (code, response) = yield self.mock_resource.trigger( - "PUT", send_msg_path, msg_content) + "PUT", + send_msg_path, + msg_content + ) self.assertEquals(403, code, msg=str(response)) @defer.inlineCallbacks @@ -215,9 +236,14 @@ def test_topic_perms(self): # set/get topic in created PRIVATE room and joined, expect 200 yield self.join(room=self.created_rmid, user=self.user_id) + + # Only room ops can set topic by default + self.auth_user_id = self.rmcreator_id (code, response) = yield self.mock_resource.trigger( "PUT", topic_path, topic_content) self.assertEquals(200, code, msg=str(response)) + self.auth_user_id = self.user_id + (code, response) = yield self.mock_resource.trigger_get(topic_path) self.assertEquals(200, code, msg=str(response)) self.assert_dict(json.loads(topic_content), response) @@ -381,45 +407,55 @@ def test_leave_permissions(self): # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 403s for usr in [self.user_id, self.rmcreator_id]: - yield self.change_membership(room=room, src=self.user_id, - targ=usr, - membership=Membership.INVITE, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=usr, - membership=Membership.JOIN, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=usr, - membership=Membership.LEAVE, - expect_code=403) + yield self.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.INVITE, + expect_code=403 + ) + + yield self.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.JOIN, + expect_code=403 + ) + + # It is always valid to LEAVE if you've already left (currently.) + yield self.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403 + ) class RoomsMemberListTestCase(RestTestCase): """ Tests /rooms/$room_id/members/list REST events.""" user_id = "@sid1:red" + @defer.inlineCallbacks def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -436,6 +472,10 @@ def _get_user_by_token(token=None): } hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + synapse.rest.room.register_servlets(hs, self.mock_resource) def tearDown(self): @@ -487,28 +527,26 @@ class RoomsCreateTestCase(RestTestCase): """ Tests /rooms and /rooms/$room_id REST events. """ user_id = "@sid1:red" + @defer.inlineCallbacks def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.auth_user_id = self.user_id - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -523,6 +561,10 @@ def _get_user_by_token(token=None): } hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + synapse.rest.room.register_servlets(hs, self.mock_resource) def tearDown(self): @@ -592,24 +634,21 @@ def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.auth_user_id = self.user_id - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -622,13 +661,18 @@ def _get_user_by_token(token=None): "admin": False, "device_id": None, } + hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + synapse.rest.room.register_servlets(hs, self.mock_resource) # create the room self.room_id = yield self.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % self.room_id + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) def tearDown(self): pass @@ -706,24 +750,21 @@ def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.auth_user_id = self.user_id - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -736,13 +777,12 @@ def _get_user_by_token(token=None): "admin": False, "device_id": None, } - return { - "user": hs.parse_userid(self.auth_user_id), - "admin": False, - "device_id": None, - } hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + synapse.rest.room.register_servlets(hs, self.mock_resource) self.room_id = yield self.create_room_as(self.user_id) @@ -847,24 +887,21 @@ def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.auth_user_id = self.user_id - state_handler = Mock(spec=["handle_new_event"]) - state_handler.handle_new_event.return_value = True + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] - persistence_service = Mock(spec=["get_latest_pdus_in_context"]) - persistence_service.get_latest_pdus_in_context.return_value = [] + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() hs = HomeServer( "red", - db_pool=None, + db_pool=db_pool, http_client=None, - datastore=MemoryDataStore(), replication_layer=Mock(), - state_handler=state_handler, - persistence_service=persistence_service, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), - config=NonCallableMock(), + config=self.mock_config, ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -879,6 +916,10 @@ def _get_user_by_token(token=None): } hs.get_auth().get_user_by_token = _get_user_by_token + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + synapse.rest.room.register_servlets(hs, self.mock_resource) self.room_id = yield self.create_room_as(self.user_id) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3ad9a4b0c015..fabd364be9c0 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -74,7 +74,7 @@ def test_insert_3cols(self): @defer.inlineCallbacks def test_select_one_1col(self): self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Value",) + self.mock_txn.fetchall.return_value = [("Value",)] value = yield self.datastore._simple_select_one_onecol( table="tablename", diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index dae1641ea173..adfe64a98098 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -61,6 +61,7 @@ def inject_room_member(self, room, user, membership, prev_state=None, membership=membership, content={"membership": membership}, depth=self.depth, + prev_events=[], ) event.content.update(extra_content) @@ -68,6 +69,11 @@ def inject_room_member(self, room, user, membership, prev_state=None, if prev_state: event.prev_state = prev_state + event.state_events = None + event.hashes = {} + event.prev_state = [] + event.auth_events = [] + # Have to create a join event using the eventfactory yield self.store.persist_event( event @@ -85,8 +91,13 @@ def inject_message(self, room, user, body): room_id=room.to_string(), content={"body": body, "msgtype": u"message"}, depth=self.depth, + prev_events=[], ) + event.state_events = None + event.hashes = {} + event.auth_events = [] + yield self.store.persist_event( event ) @@ -102,8 +113,13 @@ def inject_redaction(self, room, event_id, user, reason): content={"reason": reason}, depth=self.depth, redacts=event_id, + prev_events=[], ) + event.state_events = None + event.hashes = {} + event.auth_events = [] + yield self.store.persist_event( event ) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 369a73d91776..4ff02c306bd7 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -127,7 +127,7 @@ def inject_room_event(self, **kwargs): ) @defer.inlineCallbacks - def test_room_name(self): + def STALE_test_room_name(self): name = u"A-Room-Name" yield self.inject_room_event( @@ -150,7 +150,7 @@ def test_room_name(self): ) @defer.inlineCallbacks - def test_room_name(self): + def STALE_test_room_topic(self): topic = u"A place for things" yield self.inject_room_event( diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index eae278ee8d37..8614e5ca9d8b 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -51,16 +51,24 @@ def setUp(self): @defer.inlineCallbacks def inject_room_member(self, room, user, membership): # Have to create a join event using the eventfactory + event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + user_id=user.to_string(), + state_key=user.to_string(), + room_id=room.to_string(), + membership=membership, + content={"membership": membership}, + depth=1, + prev_events=[], + ) + + event.state_events = None + event.hashes = {} + event.prev_state = {} + event.auth_events = {} + yield self.store.persist_event( - self.event_factory.create_event( - etype=RoomMemberEvent.TYPE, - user_id=user.to_string(), - state_key=user.to_string(), - room_id=room.to_string(), - membership=membership, - content={"membership": membership}, - depth=1, - ) + event ) @defer.inlineCallbacks diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ab30e6ea2547..5038546aeeaa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -48,7 +48,7 @@ def setUp(self): self.depth = 1 @defer.inlineCallbacks - def inject_room_member(self, room, user, membership, prev_state=None): + def inject_room_member(self, room, user, membership, replaces_state=None): self.depth += 1 event = self.event_factory.create_event( @@ -59,10 +59,17 @@ def inject_room_member(self, room, user, membership, prev_state=None): membership=membership, content={"membership": membership}, depth=self.depth, + prev_events=[], ) - if prev_state: - event.prev_state = prev_state + event.state_events = None + event.hashes = {} + event.prev_state = [] + event.auth_events = [] + + if replaces_state: + event.prev_state = [(replaces_state, "hash")] + event.replaces_state = replaces_state # Have to create a join event using the eventfactory yield self.store.persist_event( @@ -75,15 +82,22 @@ def inject_room_member(self, room, user, membership, prev_state=None): def inject_message(self, room, user, body): self.depth += 1 + event = self.event_factory.create_event( + etype=MessageEvent.TYPE, + user_id=user.to_string(), + room_id=room.to_string(), + content={"body": body, "msgtype": u"message"}, + depth=self.depth, + prev_events=[], + ) + + event.state_events = None + event.hashes = {} + event.auth_events = [] + # Have to create a join event using the eventfactory yield self.store.persist_event( - self.event_factory.create_event( - etype=MessageEvent.TYPE, - user_id=user.to_string(), - room_id=room.to_string(), - content={"body": body, "msgtype": u"message"}, - depth=self.depth, - ) + event ) @defer.inlineCallbacks @@ -206,7 +220,7 @@ def test_event_stream_prev_content(self): event2 = yield self.inject_room_member( self.room1, self.u_alice, Membership.JOIN, - prev_state=event1.event_id, + replaces_state=event1.event_id, ) end = yield self.store.get_room_events_max_id() @@ -223,4 +237,7 @@ def test_event_stream_prev_content(self): event = results[0] - self.assertTrue(hasattr(event, "prev_content"), msg="No prev_content key") + self.assertTrue( + hasattr(event, "prev_content"), + msg="No prev_content key" + ) diff --git a/tests/test_state.py b/tests/test_state.py index 4b1feaf4106e..3cc358be329a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -15,599 +15,258 @@ from tests import unittest from twisted.internet import defer -from twisted.python.log import PythonLoggingObserver from synapse.state import StateHandler -from synapse.storage.pdu import PduEntry -from synapse.federation.pdu_codec import encode_event_id -from synapse.federation.units import Pdu - -from collections import namedtuple from mock import Mock -import mock - - -ReturnType = namedtuple( - "StateReturnType", ["new_branch", "current_branch"] -) - - -def _gen_get_power_level(power_level_list): - def get_power_level(room_id, user_id): - return defer.succeed(power_level_list.get(user_id, None)) - return get_power_level class StateTestCase(unittest.TestCase): def setUp(self): - self.persistence = Mock(spec=[ - "get_unresolved_state_tree", - "update_current_state", - "get_latest_pdus_in_context", - "get_current_state_pdu", - "get_pdu", - "get_power_level", - ]) - self.replication = Mock(spec=["get_pdu"]) - - hs = Mock(spec=["get_datastore", "get_replication_layer"]) - hs.get_datastore.return_value = self.persistence - hs.get_replication_layer.return_value = self.replication - hs.hostname = "bob.com" - - self.state = StateHandler(hs) - - @defer.inlineCallbacks - def test_new_state_key(self): - # We've never seen anything for this state before - new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({}) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu], []), None) - ) - - is_new = yield self.state.handle_new_state(new_pdu) - - self.assertTrue(is_new) - - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu - ) - - self.assertEqual(1, self.persistence.update_current_state.call_count) - - self.assertFalse(self.replication.get_pdu.called) - - @defer.inlineCallbacks - def test_direct_overwrite(self): - # We do a direct overwriting of the old state, i.e., the new state - # points to the old state. - - old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") - new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 5, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu, old_pdu], [old_pdu]), None) - ) - - is_new = yield self.state.handle_new_state(new_pdu) - - self.assertTrue(is_new) - - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu + self.store = Mock( + spec_set=[ + "get_state_groups", + ] ) + hs = Mock(spec=["get_datastore"]) + hs.get_datastore.return_value = self.store - self.assertEqual(1, self.persistence.update_current_state.call_count) - - self.assertFalse(self.replication.get_pdu.called) + self.state = StateHandler(hs) + self.event_id = 0 @defer.inlineCallbacks - def test_overwrite(self): - old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") - old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2") - new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 5, - "u3": 0, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None) - ) + def test_annotate_with_old_message(self): + event = self.create_event(type="test_message", name="event") - is_new = yield self.state.handle_new_state(new_pdu) + old_state = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), + ] - self.assertTrue(is_new) + yield self.state.annotate_state_groups(event, old_state=old_state) - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu - ) + for k, v in event.old_state_events.items(): + type, state_key = k + self.assertEqual(type, v.type) + self.assertEqual(state_key, v.state_key) - self.assertEqual(1, self.persistence.update_current_state.call_count) + self.assertEqual(set(old_state), set(event.old_state_events.values())) + self.assertDictEqual(event.old_state_events, event.state_events) - self.assertFalse(self.replication.get_pdu.called) + self.assertIsNone(event.state_group) @defer.inlineCallbacks - def test_power_level_fail(self): - # We try to update the state based on an outdated state, and have a - # too low power level. - - old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") - old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") - new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 5, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) - ) - - is_new = yield self.state.handle_new_state(new_pdu) - - self.assertFalse(is_new) - - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu - ) - - self.assertEqual(0, self.persistence.update_current_state.call_count) + def test_annotate_with_old_state(self): + event = self.create_event(type="state", state_key="", name="event") - self.assertFalse(self.replication.get_pdu.called) - - @defer.inlineCallbacks - def test_power_level_succeed(self): - # We try to update the state based on an outdated state, but have - # sufficient power level to force the update. - - old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") - old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") - new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 15, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) - ) + old_state = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), + ] - is_new = yield self.state.handle_new_state(new_pdu) + yield self.state.annotate_state_groups(event, old_state=old_state) - self.assertTrue(is_new) + for k, v in event.old_state_events.items(): + type, state_key = k + self.assertEqual(type, v.type) + self.assertEqual(state_key, v.state_key) - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu + self.assertEqual( + set(old_state + [event]), + set(event.old_state_events.values()) ) - self.assertEqual(1, self.persistence.update_current_state.call_count) + self.assertDictEqual(event.old_state_events, event.state_events) - self.assertFalse(self.replication.get_pdu.called) + self.assertIsNone(event.state_group) @defer.inlineCallbacks - def test_power_level_equal_same_len(self): - # We try to update the state based on an outdated state, the power - # levels are the same and so are the branch lengths - - old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") - old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") - new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 10, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) - ) - - is_new = yield self.state.handle_new_state(new_pdu) + def test_trivial_annotate_message(self): + event = self.create_event(type="test_message", name="event") + event.prev_events = [] + + old_state = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), + ] - self.assertTrue(is_new) + group_name = "group_name_1" - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu - ) + self.store.get_state_groups.return_value = { + group_name: old_state, + } - self.assertEqual(1, self.persistence.update_current_state.call_count) + yield self.state.annotate_state_groups(event) - self.assertFalse(self.replication.get_pdu.called) + for k, v in event.old_state_events.items(): + type, state_key = k + self.assertEqual(type, v.type) + self.assertEqual(state_key, v.state_key) - @defer.inlineCallbacks - def test_power_level_equal_diff_len(self): - # We try to update the state based on an outdated state, the power - # levels are the same but the branch length of the new one is longer. - - old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") - old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") - old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3") - new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4") - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 10, - "u4": 10, - }) - - self.persistence.get_unresolved_state_tree.return_value = ( - ( - ReturnType( - [new_pdu, old_pdu_3, old_pdu_1], - [old_pdu_2, old_pdu_1] - ), - None - ) + self.assertEqual( + set([e.event_id for e in old_state]), + set([e.event_id for e in event.old_state_events.values()]) ) - is_new = yield self.state.handle_new_state(new_pdu) - - self.assertTrue(is_new) - - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu + self.assertDictEqual( + { + k: v.event_id + for k, v in event.old_state_events.items() + }, + { + k: v.event_id + for k, v in event.state_events.items() + } ) - self.assertEqual(1, self.persistence.update_current_state.call_count) - - self.assertFalse(self.replication.get_pdu.called) + self.assertEqual(group_name, event.state_group) @defer.inlineCallbacks - def test_missing_pdu(self): - # We try to update state against a PDU we haven't yet seen, - # triggering a get_pdu request - - # The pdu we haven't seen - old_pdu_1 = new_fake_pdu( - "A", "test", "mem", "x", None, "u1", depth=0 - ) - - old_pdu_2 = new_fake_pdu( - "B", "test", "mem", "x", "A", "u2", depth=1 - ) - new_pdu = new_fake_pdu( - "C", "test", "mem", "x", "A", "u3", depth=2 - ) - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 20, - }) - - # The return_value of `get_unresolved_state_tree`, which changes after - # the call to get_pdu - tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] - - def return_tree(p): - return tree_to_return[0] - - def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): - tree_to_return[0] = ( - ReturnType( - [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] - ), - None - ) - return defer.succeed(None) - - self.persistence.get_unresolved_state_tree.side_effect = return_tree + def test_trivial_annotate_state(self): + event = self.create_event(type="state", state_key="", name="event") + event.prev_events = [] + + old_state = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), + ] - self.replication.get_pdu.side_effect = set_return_tree + group_name = "group_name_1" - self.persistence.get_pdu.return_value = None + self.store.get_state_groups.return_value = { + group_name: old_state, + } - is_new = yield self.state.handle_new_state(new_pdu) + yield self.state.annotate_state_groups(event) - self.assertTrue(is_new) + for k, v in event.old_state_events.items(): + type, state_key = k + self.assertEqual(type, v.type) + self.assertEqual(state_key, v.state_key) - self.replication.get_pdu.assert_called_with( - destination=new_pdu.origin, - pdu_origin=old_pdu_1.origin, - pdu_id=old_pdu_1.pdu_id, - outlier=True + self.assertEqual( + set([e.event_id for e in old_state]), + set([e.event_id for e in event.old_state_events.values()]) ) - self.persistence.get_unresolved_state_tree.assert_called_with( - new_pdu + self.assertEqual( + set([e.event_id for e in old_state] + [event.event_id]), + set([e.event_id for e in event.state_events.values()]) ) - self.assertEquals( - 2, self.persistence.get_unresolved_state_tree.call_count + new_state = { + k: v.event_id + for k, v in event.state_events.items() + } + old_state = { + k: v.event_id + for k, v in event.old_state_events.items() + } + old_state[(event.type, event.state_key)] = event.event_id + self.assertDictEqual( + old_state, + new_state ) - self.assertEqual(1, self.persistence.update_current_state.call_count) + self.assertIsNone(event.state_group) @defer.inlineCallbacks - def test_missing_pdu_depth_1(self): - # We try to update state against a PDU we haven't yet seen, - # triggering a get_pdu request - - # The pdu we haven't seen - old_pdu_1 = new_fake_pdu( - "A", "test", "mem", "x", None, "u1", depth=0 - ) - - old_pdu_2 = new_fake_pdu( - "B", "test", "mem", "x", "A", "u2", depth=2 - ) - old_pdu_3 = new_fake_pdu( - "C", "test", "mem", "x", "B", "u3", depth=3 - ) - new_pdu = new_fake_pdu( - "D", "test", "mem", "x", "A", "u4", depth=4 - ) - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 10, - "u4": 20, - }) - - # The return_value of `get_unresolved_state_tree`, which changes after - # the call to get_pdu - tree_to_return = [ - ( - ReturnType([new_pdu], [old_pdu_3]), - 0 - ), - ( - ReturnType( - [new_pdu, old_pdu_1], [old_pdu_3] - ), - 1 - ), - ( - ReturnType( - [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] - ), - None - ), + def test_resolve_message_conflict(self): + event = self.create_event(type="test_message", name="event") + event.prev_events = [] + + old_state_1 = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), ] - to_return = [0] - - def return_tree(p): - return tree_to_return[to_return[0]] - - def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): - to_return[0] += 1 - return defer.succeed(None) - - self.persistence.get_unresolved_state_tree.side_effect = return_tree - - self.replication.get_pdu.side_effect = set_return_tree - - self.persistence.get_pdu.return_value = None - - is_new = yield self.state.handle_new_state(new_pdu) + old_state_2 = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test3", state_key="2"), + self.create_event(type="test4", state_key=""), + ] - self.assertTrue(is_new) + group_name_1 = "group_name_1" + group_name_2 = "group_name_2" - self.assertEqual(2, self.replication.get_pdu.call_count) + self.store.get_state_groups.return_value = { + group_name_1: old_state_1, + group_name_2: old_state_2, + } - self.replication.get_pdu.assert_has_calls( - [ - mock.call( - destination=new_pdu.origin, - pdu_origin=old_pdu_1.origin, - pdu_id=old_pdu_1.pdu_id, - outlier=True - ), - mock.call( - destination=old_pdu_3.origin, - pdu_origin=old_pdu_2.origin, - pdu_id=old_pdu_2.pdu_id, - outlier=True - ), - ] - ) + yield self.state.annotate_state_groups(event) - self.persistence.get_unresolved_state_tree.assert_called_with( - new_pdu - ) + self.assertEqual(len(event.old_state_events), 5) - self.assertEquals( - 3, self.persistence.get_unresolved_state_tree.call_count + self.assertEqual( + set([e.event_id for e in event.state_events.values()]), + set([e.event_id for e in event.old_state_events.values()]) ) - self.assertEqual(1, self.persistence.update_current_state.call_count) + self.assertIsNone(event.state_group) @defer.inlineCallbacks - def test_missing_pdu_depth_2(self): - # We try to update state against a PDU we haven't yet seen, - # triggering a get_pdu request - - # The pdu we haven't seen - old_pdu_1 = new_fake_pdu( - "A", "test", "mem", "x", None, "u1", depth=0 - ) - - old_pdu_2 = new_fake_pdu( - "B", "test", "mem", "x", "A", "u2", depth=2 - ) - old_pdu_3 = new_fake_pdu( - "C", "test", "mem", "x", "B", "u3", depth=3 - ) - new_pdu = new_fake_pdu( - "D", "test", "mem", "x", "A", "u4", depth=1 - ) - - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 10, - "u2": 10, - "u3": 10, - "u4": 20, - }) - - # The return_value of `get_unresolved_state_tree`, which changes after - # the call to get_pdu - tree_to_return = [ - ( - ReturnType([new_pdu], [old_pdu_3]), - 1, - ), - ( - ReturnType( - [new_pdu], [old_pdu_3, old_pdu_2] - ), - 0, - ), - ( - ReturnType( - [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] - ), - None - ), + def test_resolve_state_conflict(self): + event = self.create_event(type="test4", state_key="", name="event") + event.prev_events = [] + + old_state_1 = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test1", state_key="2"), + self.create_event(type="test2", state_key=""), ] - to_return = [0] - - def return_tree(p): - return tree_to_return[to_return[0]] - - def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): - to_return[0] += 1 - return defer.succeed(None) - - self.persistence.get_unresolved_state_tree.side_effect = return_tree - - self.replication.get_pdu.side_effect = set_return_tree - - self.persistence.get_pdu.return_value = None - - is_new = yield self.state.handle_new_state(new_pdu) - - self.assertTrue(is_new) - - self.assertEqual(2, self.replication.get_pdu.call_count) - - self.replication.get_pdu.assert_has_calls( - [ - mock.call( - destination=old_pdu_3.origin, - pdu_origin=old_pdu_2.origin, - pdu_id=old_pdu_2.pdu_id, - outlier=True - ), - mock.call( - destination=new_pdu.origin, - pdu_origin=old_pdu_1.origin, - pdu_id=old_pdu_1.pdu_id, - outlier=True - ), - ] - ) - - self.persistence.get_unresolved_state_tree.assert_called_with( - new_pdu - ) - - self.assertEquals( - 3, self.persistence.get_unresolved_state_tree.call_count - ) - - self.assertEqual(1, self.persistence.update_current_state.call_count) - - @defer.inlineCallbacks - def test_no_common_ancestor(self): - # We do a direct overwriting of the old state, i.e., the new state - # points to the old state. + old_state_2 = [ + self.create_event(type="test1", state_key="1"), + self.create_event(type="test3", state_key="2"), + self.create_event(type="test4", state_key=""), + ] - old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") - new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2") + group_name_1 = "group_name_1" + group_name_2 = "group_name_2" - self.persistence.get_power_level.side_effect = _gen_get_power_level({ - "u1": 5, - "u2": 10, - }) + self.store.get_state_groups.return_value = { + group_name_1: old_state_1, + group_name_2: old_state_2, + } - self.persistence.get_unresolved_state_tree.return_value = ( - (ReturnType([new_pdu], [old_pdu]), None) - ) + yield self.state.annotate_state_groups(event) - is_new = yield self.state.handle_new_state(new_pdu) + self.assertEqual(len(event.old_state_events), 5) - self.assertTrue(is_new) + expected_new = event.old_state_events + expected_new[(event.type, event.state_key)] = event - self.persistence.get_unresolved_state_tree.assert_called_once_with( - new_pdu + self.assertEqual( + set([e.event_id for e in expected_new.values()]), + set([e.event_id for e in event.state_events.values()]), ) - self.assertEqual(1, self.persistence.update_current_state.call_count) - - self.assertFalse(self.replication.get_pdu.called) - - @defer.inlineCallbacks - def test_new_event(self): - event = Mock() - event.event_id = "12123123@test" + self.assertIsNone(event.state_group) - state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20) + def create_event(self, name=None, type=None, state_key=None): + self.event_id += 1 + event_id = str(self.event_id) - snapshot = Mock() - snapshot.prev_state_pdu = state_pdu - event_id = "pdu_id@origin.com" + if not name: + if state_key is not None: + name = "<%s-%s>" % (type, state_key) + else: + name = "<%s>" % (type, ) - def fill_out_prev_events(event): - event.prev_events = [event_id] - event.depth = 6 - snapshot.fill_out_prev_events = fill_out_prev_events + event = Mock(name=name, spec=[]) + event.type = type - yield self.state.handle_new_event(event, snapshot) - - self.assertLess(5, event.depth) - - self.assertEquals(1, len(event.prev_events)) - - prev_id = event.prev_events[0] - - self.assertEqual(event_id, prev_id) - - self.assertEqual( - encode_event_id(state_pdu.pdu_id, state_pdu.origin), - event.prev_state - ) + if state_key is not None: + event.state_key = state_key + event.event_id = event_id + event.user_id = "@user_id:example.com" + event.room_id = "!room_id:example.com" -def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id, - user_id, depth=0): - new_pdu = Pdu( - pdu_id=pdu_id, - pdu_type=pdu_type, - state_key=state_key, - user_id=user_id, - prev_state_id=prev_state_id, - origin="example.com", - context="context", - origin_server_ts=1405353060021, - depth=depth, - content_json="{}", - unrecognized_keys="{}", - outlier=True, - is_state=True, - prev_state_origin="example.com", - have_processed=True, - content={}, - ) - - return new_pdu + return event diff --git a/tests/utils.py b/tests/utils.py index 60fd6085ac6f..d8be73dba89d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -118,13 +118,14 @@ def register_path(self, method, path_pattern, callback): class MockKey(object): alg = "mock_alg" version = "mock_version" + signature = b"\x9a\x87$" @property def verify_key(self): return self def sign(self, message): - return b"\x9a\x87$" + return self def verify(self, message, sig): assert sig == b"\x9a\x87$"