diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 628012fb275e..3c86adab5650 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -51,7 +51,7 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator @@ -207,82 +207,6 @@ async def get_received_ts(self, event_id: str) -> Optional[int]: desc="get_received_ts", ) - # Inform mypy that if allow_none is False (the default) then get_event - # always returns an EventBase. - @overload - def get_event_txn( - self, - event_id: str, - allow_rejected: bool = False, - allow_none: Literal[False] = False, - ) -> EventBase: - ... - - @overload - def get_event_txn( - self, - event_id: str, - allow_rejected: bool = False, - allow_none: Literal[True] = False, - ) -> Optional[EventBase]: - ... - - def get_event_txn( - self, - txn: LoggingTransaction, - event_id: str, - allow_rejected: bool = False, - allow_none: bool = False, - ) -> Optional[EventBase]: - """Get an event from the database by event_id. - - Args: - txn: Transaction object - - event_id: The event_id of the event to fetch - - get_prev_content: If True and event is a state event, - include the previous states content in the unsigned field. - - allow_rejected: If True, return rejected events. Otherwise, - behave as per allow_none. - - allow_none: If True, return None if no event found, if - False throw a NotFoundError - - check_room_id: if not None, check the room of the found event. - If there is a mismatch, behave as per allow_none. - - Returns: - The event, or None if the event was not found and allow_none=True - - - Raises: - NotFoundError: if the event_id was not found and allow_none=False - """ - event_map = self._fetch_event_rows(txn, [event_id]) - event_info = event_map[event_id] - if event_info is None and not allow_none: - raise NotFoundError("Could not find event %s" % (event_id,)) - - rejected_reason = event_info["rejected_reason"] - if not allow_rejected and rejected_reason: - return - - d = db_to_json(event_info["json"]) - internal_metadata = db_to_json(event_info["internal_metadata"]) - room_version_id = event_info["room_version_id"] - room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) - - event = make_event_from_dict( - event_dict=d, - room_version=room_version, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - return event - # Inform mypy that if allow_none is False (the default) then get_event # always returns an EventBase. @overload diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 0daea5341dd6..8e22da99ae60 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -58,10 +58,8 @@ def __init__(self, database: DatabasePool, db_conn, hs): async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room - Raises: NotFoundError: if the room is unknown - UnsupportedRoomVersionError: if the room uses an unknown room version. Typically this happens if support for the room's version has been removed from Synapse. @@ -76,14 +74,11 @@ def get_room_version_txn( self, txn: LoggingTransaction, room_id: str ) -> RoomVersion: """Get the room_version of a given room - Args: txn: Transaction object room_id: The room_id of the room you are trying to get the version for - Raises: NotFoundError: if the room is unknown - UnsupportedRoomVersionError: if the room uses an unknown room version. Typically this happens if support for the room's version has been removed from Synapse. @@ -102,7 +97,6 @@ def get_room_version_txn( @cached(max_entries=10000) async def get_room_version_id(self, room_id: str) -> str: """Get the room_version of a given room - Raises: NotFoundError: if the room is unknown """ @@ -114,11 +108,9 @@ async def get_room_version_id(self, room_id: str) -> str: def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str: """Get the room_version of a given room - Args: txn: Transaction object room_id: The room_id of the room you are trying to get the version for - Raises: NotFoundError: if the room is unknown """ @@ -138,12 +130,10 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str: allow_none=True, ) - if room_version is not None: - return room_version + if room_version is None: + raise NotFoundError("Could not room_version for %s" % (room_id,)) - # Retrieve the room's create event - create_event = self.get_create_event_for_room_txn(txn, room_id) - return create_event.content.get("room_version", "1") + return room_version async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. @@ -188,29 +178,7 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase: Raises: NotFoundError if the room is unknown """ - return await self.db_pool.runInteraction( - "get_create_event_for_room_txn", - self.get_create_event_for_room_txn, - room_id, - ) - - def get_create_event_for_room_txn( - self, txn: LoggingTransaction, room_id: str - ) -> EventBase: - """Get the create state event for a room. - - Args: - txn: Transaction object - room_id: The room ID. - - Returns: - The room creation event. - - Raises: - NotFoundError if the room is unknown - """ - - state_ids = self.get_current_state_ids_txn(txn, room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -218,7 +186,7 @@ def get_create_event_for_room_txn( raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = self.get_event_txn(txn, create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) @@ -233,35 +201,20 @@ async def get_current_state_ids(self, room_id: str) -> StateMap[str]: The current state of the room. """ - return await self.db_pool.runInteraction( - "get_current_state_ids_txn", - self.get_current_state_ids_txn, - room_id, - ) - - def get_current_state_ids_txn( - self, txn: LoggingTransaction, room_id: str - ) -> StateMap[str]: - """Get the current state event ids for a room based on the - current_state_events table. + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) - Args: - txn: Transaction object - room_id: The room to get the state IDs of. + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - Returns: - The current state of the room. - """ - - txn.execute( - """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """, - (room_id,), + return await self.db_pool.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn ) - return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - # FIXME: how should this be cached? async def get_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None