Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert Transation and Edu object to attrs #10542

Merged
merged 8 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10542.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert `Transaction` and `Edu` objects to attrs.
50 changes: 30 additions & 20 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,17 @@ async def on_backfill_request(
origin, room_id, versions, limit
)

res = self._transaction_from_pdus(pdus).get_dict()
res = self._transaction_dict_from_pdus(pdus)

return 200, res

async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict
) -> Tuple[int, Dict[str, Any]]:
self,
origin: str,
transaction_id: str,
destination: str,
transaction_data: JsonDict,
) -> Tuple[int, JsonDict]:
# If we receive a transaction we should make sure that kick off handling
# any old events in the staging area.
if not self._started_handling_of_staged_events:
Expand All @@ -212,18 +216,22 @@ async def on_incoming_transaction(
# accurate as possible.
request_time = self._clock.time_msec()

transaction = Transaction(**transaction_data)
transaction_id = transaction.transaction_id # type: ignore
transaction = Transaction(
transaction_id=transaction_id,
destination=destination,
origin=origin,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we pulled origin from transaction_data, but the comment in synapse.federation.transport.server makes me think we should be using the origin we've calculated instead.

origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore
pdus=transaction_data.get("pdus"), # type: ignore
edus=transaction_data.get("edus"),
)

if not transaction_id:
raise Exception("Transaction missing transaction_id")

logger.debug("[%s] Got transaction", transaction_id)

# Reject malformed transactions early: reject if too many PDUs/EDUs
if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
logger.info("Transaction PDU or EDU count too large. Returning 400")
return 400, {}

Expand Down Expand Up @@ -263,7 +271,7 @@ async def _on_incoming_transaction_inner(
# CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions.
assert origin not in self._active_transactions
self._active_transactions[origin] = transaction.transaction_id # type: ignore
self._active_transactions[origin] = transaction.transaction_id

try:
result = await self._handle_incoming_transaction(
Expand Down Expand Up @@ -291,11 +299,11 @@ async def _handle_incoming_transaction(
if response:
logger.debug(
"[%s] We've already responded to this request",
transaction.transaction_id, # type: ignore
transaction.transaction_id,
)
return response

logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
logger.debug("[%s] Transaction is new", transaction.transaction_id)

# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
Expand Down Expand Up @@ -334,15 +342,15 @@ async def _handle_pdus_in_txn(
report back to the sending server.
"""

received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
received_pdus_counter.inc(len(transaction.pdus))

origin_host, _ = parse_server_name(origin)

pdus_by_room: Dict[str, List[EventBase]] = {}

newest_pdu_ts = 0

for p in transaction.pdus: # type: ignore
for p in transaction.pdus:
# FIXME (richardv): I don't think this works:
# https://github.com/matrix-org/synapse/issues/8429
if "unsigned" in p:
Expand Down Expand Up @@ -436,10 +444,10 @@ async def process_pdu(pdu: EventBase) -> JsonDict:

return pdu_results

async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
"""Process the EDUs in a received transaction."""

async def _process_edu(edu_dict):
async def _process_edu(edu_dict: JsonDict) -> None:
received_edus_counter.inc()

edu = Edu(
Expand All @@ -452,7 +460,7 @@ async def _process_edu(edu_dict):

await concurrently_execute(
_process_edu,
getattr(transaction, "edus", []),
transaction.edus,
TRANSACTION_CONCURRENCY_LIMIT,
)

Expand Down Expand Up @@ -538,7 +546,7 @@ async def on_pdu_request(
pdu = await self.handler.get_persisted_pdu(origin, event_id)

if pdu:
return 200, self._transaction_from_pdus([pdu]).get_dict()
return 200, self._transaction_dict_from_pdus([pdu])
else:
return 404, ""

Expand Down Expand Up @@ -879,18 +887,20 @@ async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)

def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
# Just need a dummy transaction ID and destination since it won't be used.
transaction_id="",
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
destination="",
).get_dict()

async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
"""Process a PDU received in a federation /send/ transaction.
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def have_responded(
`None` if we have not previously responded to this transaction or a
2-tuple of `(int, dict)` representing the response code and response body.
"""
transaction_id = transaction.transaction_id # type: ignore
transaction_id = transaction.transaction_id
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")

Expand All @@ -56,7 +56,7 @@ async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
"""Persist how we responded to a transaction."""
transaction_id = transaction.transaction_id # type: ignore
transaction_id = transaction.transaction_id
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")

Expand Down
9 changes: 5 additions & 4 deletions synapse/federation/sender/transaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
tags,
whitelisted_homeserver,
)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.metrics import measure_func

Expand Down Expand Up @@ -104,13 +105,13 @@ async def send_new_transaction(
len(edus),
)

transaction = Transaction.create_new(
transaction = Transaction(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self._server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdus=[p.get_pdu_json() for p in pdus],
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
edus=[edu.get_dict() for edu in edus],
)

self._next_txn_id += 1
Expand All @@ -131,7 +132,7 @@ async def send_new_transaction(
# FIXME (richardv): I also believe it no longer works. We (now?) store
# "age_ts" in "unsigned" rather than at the top level. See
# https://github.com/matrix-org/synapse/issues/8429.
def json_data_cb():
def json_data_cb() -> JsonDict:
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def send_transaction(
"""Sends the given Transaction to its destination

Args:
transaction (Transaction)
transaction

Returns:
Succeeds when we get a 2xx HTTP response. The result
Expand Down
11 changes: 1 addition & 10 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,21 +450,12 @@ async def on_PUT(
len(transaction_data.get("edus", [])),
)

# We should ideally be getting this from the security layer.
# origin = body["origin"]

# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
transaction_id=transaction_id, destination=self.server_name
)

except Exception as e:
logger.exception(e)
return 400, {"error": "Invalid transaction"}

code, response = await self.handler.on_incoming_transaction(
origin, transaction_data
origin, transaction_id, self.server_name, transaction_data
)

return code, response
Expand Down
90 changes: 35 additions & 55 deletions synapse/federation/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,28 @@
"""

import logging
from typing import Optional
from typing import List, Optional

import attr

from synapse.types import JsonDict
from synapse.util.jsonobject import JsonEncodedObject

logger = logging.getLogger(__name__)


@attr.s(slots=True)
class Edu(JsonEncodedObject):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Edu:
"""An Edu represents a piece of data sent from one homeserver to another.

In comparison to Pdus, Edus are not persisted for a long time on disk, are
not meaningful beyond a given pair of homeservers, and don't have an
internal ID or previous references graph.
"""

edu_type = attr.ib(type=str)
content = attr.ib(type=dict)
origin = attr.ib(type=str)
destination = attr.ib(type=str)
edu_type: str
content: dict
origin: str
destination: str

def get_dict(self) -> JsonDict:
return {
Expand All @@ -55,14 +54,21 @@ def get_internal_dict(self) -> JsonDict:
"destination": self.destination,
}

def get_context(self):
def get_context(self) -> str:
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")

def strip_context(self):
def strip_context(self) -> None:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"


class Transaction(JsonEncodedObject):
def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
if edus is None:
return []
return edus


@attr.s(slots=True, frozen=True, auto_attribs=True)
class Transaction:
"""A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.

Expand All @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject):

"""

valid_keys = [
"transaction_id",
"origin",
"destination",
"origin_server_ts",
"previous_ids",
"pdus",
"edus",
]

internal_keys = ["transaction_id", "destination"]

required_keys = [
"transaction_id",
"origin",
"destination",
"origin_server_ts",
"pdus",
]

def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs):
"""If we include a list of pdus then we decode then as PDU's
automatically.
"""

# If there's no EDUs then remove the arg
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]

super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs)

@staticmethod
def create_new(pdus, **kwargs):
"""Used to create a new transaction. Will auto fill out
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
if "transaction_id" not in kwargs:
raise KeyError("Require 'transaction_id' to construct a Transaction")

kwargs["pdus"] = [p.get_pdu_json() for p in pdus]

return Transaction(**kwargs)
# Required keys.
transaction_id: str
origin: str
destination: str
origin_server_ts: int
pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)

def get_dict(self) -> JsonDict:
"""A JSON-ready dictionary of valid keys which aren't internal."""
result = {
"origin": self.origin,
"origin_server_ts": self.origin_server_ts,
"pdus": self.pdus,
}
if self.edus:
result["edus"] = self.edus
return result
Loading