diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d3e84209753d..d9a6be43f793 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -92,6 +92,7 @@ async def on_GET( pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -99,7 +100,6 @@ async def on_GET( direction=direction, from_token=from_token, to_token=to_token, - event=event, ) events = await self.store.get_events_as_list( @@ -288,6 +288,7 @@ async def on_GET( result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -295,7 +296,6 @@ async def on_GET( limit=limit, from_token=from_token, to_token=to_token, - event=event, ) events = await self.store.get_events_as_list( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index cd3a435ae739..be1500092b5b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ def __init__( self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(num_args=9, tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -103,12 +104,12 @@ async def get_relations_for_event( direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - event: Optional[EventBase] = None, ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -118,15 +119,13 @@ async def get_relations_for_event( oldest first (`"f"`). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. - event: The matching EventBase to event_id. This *must* be provided. Returns: List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ - # We don't use `event_id`, its there so that we can cache based on + # We don't use `event_id`, it's there so that we can cache based on # it. The `event_id` must match the `event.event_id`. - assert event is not None assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] @@ -786,7 +785,7 @@ async def _get_bundled_aggregation_for_event( ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f", event=event + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self))