Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update EventContext get_current_event_ids and get_prev_event_ids to accept state filters and update calls where possible #12791

Merged
merged 10 commits into from
May 20, 2022
1 change: 1 addition & 0 deletions changelog.d/12791.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible.
19 changes: 15 additions & 4 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter


@attr.s(slots=True, auto_attribs=True)
Expand Down Expand Up @@ -196,14 +197,19 @@ def state_group(self) -> Optional[int]:

return self._state_group

async def get_current_state_ids(self) -> Optional[StateMap[str]]:
async def get_current_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``

It is an error to access this for a rejected event, since rejected state should
not make it into the room state. This method will raise an exception if
``rejected`` is set.

Arg:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules

Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
Expand All @@ -216,20 +222,25 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:

assert self._state_delta_due_to_event is not None

prev_state_ids = await self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids(state_filter)

if self._state_delta_due_to_event:
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update(self._state_delta_due_to_event)

return prev_state_ids

async def get_prev_state_ids(self) -> StateMap[str]:
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
"""
Gets the room state map, excluding this event.

For a non-state event, this will be the same as get_current_state_ids().

Args:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules

Returns:
Returns {} if state_group is None, which happens when the associated
event is an outlier.
Expand All @@ -239,7 +250,7 @@ async def get_prev_state_ids(self) -> StateMap[str]:
"""
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event
self.state_group_before_event, state_filter
)


Expand Down
9 changes: 7 additions & 2 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
Expand Down Expand Up @@ -1259,7 +1260,9 @@ async def add_display_name_to_third_party_invite(
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
)
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = await self.store.get_event(
Expand Down Expand Up @@ -1308,7 +1311,9 @@ async def _check_signature(self, event: EventBase, context: EventContext) -> Non
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]

prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
)
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))

invite_event = None
Expand Down
8 changes: 7 additions & 1 deletion synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from prometheus_client import Counter

from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
Expand Down Expand Up @@ -63,6 +64,7 @@
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
Expand Down Expand Up @@ -1500,7 +1502,11 @@ async def _check_event_auth(
return context

# now check auth against what we think the auth events *should* be.
prev_state_ids = await context.get_prev_state_ids()
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)

auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down
14 changes: 11 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ async def create_event(
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
await self.store.get_event(prev_event_id, allow_none=True)
Expand Down Expand Up @@ -761,7 +763,9 @@ async def deduplicate_state_event(
# This can happen due to out of band memberships
return None

prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(event.type, None)])
)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return None
Expand Down Expand Up @@ -1547,7 +1551,11 @@ async def persist_and_notify_client_event(
"Redacting MSC2716 events is not supported in this room version",
)

prev_state_ids = await context.get_prev_state_ids()
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)

auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ async def _upgrade_room(
context=tombstone_context,
)

old_room_state = await tombstone_context.get_current_state_ids()
state_filter = StateFilter.from_types(
[(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
)
old_room_state = await tombstone_context.get_current_state_ids(state_filter)

# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
Expand Down
9 changes: 7 additions & 2 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
Expand Down Expand Up @@ -362,7 +363,9 @@ async def _local_membership_update(
historical=historical,
)

prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)

prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)

Expand Down Expand Up @@ -1160,7 +1163,9 @@ async def send_membership_event(
else:
requester = types.create_requester(target_user)

prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.GuestAccess, None)])
)
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(prev_state_ids)
Expand Down
9 changes: 7 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from prometheus_client import Counter

from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
Expand All @@ -31,6 +31,7 @@
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import measure_func

from ..storage.state import StateFilter
from .push_rule_evaluator import PushRuleEvaluatorForEvent

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,8 +169,12 @@ def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
event_types = auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
pl_event_id = prev_state_ids.get(POWER_KEY)

if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
Expand Down
7 changes: 5 additions & 2 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,16 +634,19 @@ async def get_state_groups_ids(

return group_to_state

async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group

Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules

Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,))
group_to_state = await self.get_state_for_groups((state_group,), state_filter)

return group_to_state[state_group]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def get_state_groups_ids(self, room_id, event_ids):

return groups

async def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group, state_filter=None):
return self._group_to_state[state_group]

async def store_state_group(
Expand Down