Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Split on_receive_pdu in half (#10640)
Browse files Browse the repository at this point in the history
Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.
  • Loading branch information
richvdh authored Aug 19, 2021
1 parent 50af1ef commit e81d620
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 109 deletions.
1 change: 1 addition & 0 deletions changelog.d/10640.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.
4 changes: 1 addition & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,7 @@ async def _process_incoming_pdus_in_room_inner(
async with lock:
logger.info("handling received PDU: %s", event)
try:
await self.handler.on_receive_pdu(
origin, event, sent_to_us_directly=True
)
await self.handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
Expand Down
236 changes: 138 additions & 98 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,13 @@ def __init__(self, hs: "HomeServer"):

self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages

async def on_receive_pdu(
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
"""Process a PDU received via a federation /send/ transaction
Args:
origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu: received PDU
sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
"""

room_id = pdu.room_id
Expand Down Expand Up @@ -276,92 +271,79 @@ async def on_receive_pdu(
)
return None

state = None

# Check that the event passes auth based on the state at the event. This is
# done for events that are to be added to the timeline (non-outliers).
#
# Get missing pdus if necessary:
# - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier():
if sent_to_us_directly:
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen

if missing_prevs:
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("min_depth: %d", min_depth)

if min_depth is not None and pdu.depth > min_depth:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen

if missing_prevs:
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("min_depth: %d", min_depth)

if min_depth is not None and pdu.depth > min_depth:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquiring room lock to fetch %d missing prev_events: %s",
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),

try:
await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
except Exception as e:
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
) from e

try:
await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
except Exception as e:
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
) from e

# Update the set of things we've seen after trying to
# fetch the missing stuff
seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen

if not missing_prevs:
logger.info("Found all missing prev_events")

if missing_prevs:
# since this event was pushed to us, it is possible for it to
# become the only forward-extremity in the room, and we would then
# trust its state to be the state for the whole room. This is very
# bad. Further, if the event was pushed to us, there is no excuse
# for us not to have all the prev_events. (XXX: apart from
# min_depth?)
#
# We therefore reject any such events.
logger.warning(
"Rejecting: failed to fetch %d prev events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
raise FederationError(
"ERROR",
403,
(
"Your server isn't divulging details about prev_events "
"referenced in this event."
),
affected=pdu.event_id,
)
# Update the set of things we've seen after trying to
# fetch the missing stuff
seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen

else:
state = await self._resolve_state_at_missing_prevs(origin, pdu)
if not missing_prevs:
logger.info("Found all missing prev_events")

if missing_prevs:
# since this event was pushed to us, it is possible for it to
# become the only forward-extremity in the room, and we would then
# trust its state to be the state for the whole room. This is very
# bad. Further, if the event was pushed to us, there is no excuse
# for us not to have all the prev_events. (XXX: apart from
# min_depth?)
#
# We therefore reject any such events.
logger.warning(
"Rejecting: failed to fetch %d prev events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
raise FederationError(
"ERROR",
403,
(
"Your server isn't divulging details about prev_events "
"referenced in this event."
),
affected=pdu.event_id,
)

# A second round of checks for all events. Check that the event passes auth
# based on `auth_events`, this allows us to assert that the event would
# have been allowed at some point. If an event passes this check its OK
# for it to be used as part of a returned `/state` request, as either
# a) we received the event as part of the original join and so trust it, or
# b) we'll do a state resolution with existing state before it becomes
# part of the "current state", which adds more protection.
await self._process_received_pdu(origin, pdu, state=state)
await self._process_received_pdu(origin, pdu, state=None)

async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
Expand Down Expand Up @@ -461,24 +443,7 @@ async def _get_missing_events_for_pdu(
return

logger.info("Got %d prev_events", len(missing_events))

# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)

for ev in missing_events:
logger.info("Handling received prev_event %s", ev)
with nested_logging_context(ev.event_id):
try:
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
"Received prev_event %s failed history check.",
ev.event_id,
)
else:
raise
await self._process_pulled_events(origin, missing_events)

async def _get_state_for_room(
self,
Expand Down Expand Up @@ -1395,6 +1360,81 @@ async def get_event(event_id: str):
event_infos,
)

async def _process_pulled_events(
self, origin: str, events: Iterable[EventBase]
) -> None:
"""Process a batch of events we have pulled from a remote server
Pulls in any events required to auth the events, persists the received events,
and notifies clients, if appropriate.
Assumes the events have already had their signatures and hashes checked.
Params:
origin: The server we received these events from
events: The received events.
"""

# We want to sort these by depth so we process them and
# tell clients about them in order.
sorted_events = sorted(events, key=lambda x: x.depth)

for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev)

async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
"""Process a single event that we have pulled from a remote server
Pulls in any events required to auth the event, persists the received event,
and notifies clients, if appropriate.
Assumes the event has already had its signatures and hashes checked.
This is somewhat equivalent to on_receive_pdu, but applies somewhat different
logic in the case that we are missing prev_events (in particular, it just
requests the state at that point, rather than triggering a get_missing_events) -
so is appropriate when we have pulled the event from a remote server, rather
than having it pushed to us.
Params:
origin: The server we received this event from
events: The received event
"""
logger.info("Processing pulled event %s", event)

# these should not be outliers.
assert not event.internal_metadata.is_outlier()

event_id = event.event_id

existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if existing:
if not existing.internal_metadata.is_outlier():
logger.info(
"Ignoring received event %s which we have already seen",
event_id,
)
return
logger.info("De-outliering event %s", event_id)

try:
self._sanity_check_event(event)
except SynapseError as err:
logger.warning("Event %s failed sanity check: %s", event_id, err)
return

try:
state = await self._resolve_state_at_missing_prevs(origin, event)
await self._process_received_pdu(origin, event, state=state)
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise

async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
Expand Down Expand Up @@ -1780,7 +1820,7 @@ async def _handle_queued_pdus(
p,
)
with nested_logging_context(p.event_id):
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
await self.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
Expand Down
10 changes: 2 additions & 8 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ def setUp(self):

# Send the join, it should return None (which is not an error)
self.assertEqual(
self.get_success(
self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
),
self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
None,
)

Expand Down Expand Up @@ -135,9 +131,7 @@ async def post_json(destination, path, data, headers=None, timeout=0):

with LoggingContext("test-context"):
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
),
self.handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)

Expand Down

0 comments on commit e81d620

Please sign in to comment.