diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.yml b/.github/ISSUE_TEMPLATE/BUG_REPORT.yml index abe0f656a28b..79578eeaaaff 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.yml +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.yml @@ -129,7 +129,7 @@ body: attributes: label: Relevant log output description: | - Please copy and paste any relevant log output, ideally at INFO or DEBUG log level. + Please copy and paste any relevant log output as text (not images), ideally at INFO or DEBUG log level. This will be automatically formatted into code, so there is no need for backticks (`\``). Please be careful to remove any personal or private data. diff --git a/.gitignore b/.gitignore index 9d037f28e758..8cf504324b9e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,9 +15,10 @@ _trial_temp*/ .DS_Store __pycache__/ -# We do want the poetry and cargo lockfile. +# We do want poetry, cargo and flake lockfiles. !poetry.lock !Cargo.lock +!flake.lock # stuff that is likely to exist when you run a server locally /*.db @@ -38,6 +39,9 @@ __pycache__/ /.envrc .direnv/ +# For nix/devenv users +.devenv/ + # IDEs /.idea/ /.ropeproject/ diff --git a/CHANGES.md b/CHANGES.md index b2cc138ee302..9c200bfb7be7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,62 @@ +Synapse 1.83.0 (2023-05-09) +=========================== + +No significant changes since 1.83.0rc1. + + +Synapse 1.83.0rc1 (2023-05-02) +============================== + +Features +-------- + +- Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981). ([\#15315](https://github.com/matrix-org/synapse/issues/15315)) +- Experimental support for [MSC3970](https://github.com/matrix-org/matrix-spec-proposals/pull/3970): Scope transaction IDs to devices. ([\#15318](https://github.com/matrix-org/synapse/issues/15318)) +- Add an [admin API endpoint](https://matrix-org.github.io/synapse/v1.83/admin_api/experimental_features.html) to support per-user feature flags. ([\#15344](https://github.com/matrix-org/synapse/issues/15344)) +- Add a module API to send an HTTP push notification. ([\#15387](https://github.com/matrix-org/synapse/issues/15387)) +- Add an [admin API endpoint](https://matrix-org.github.io/synapse/v1.83/admin_api/statistics.html#get-largest-rooms-by-size-in-database) to query the largest rooms by disk space used in the database. ([\#15482](https://github.com/matrix-org/synapse/issues/15482)) + + +Bugfixes +-------- + +- Disable push rule evaluation for rooms excluded from sync. ([\#15361](https://github.com/matrix-org/synapse/issues/15361)) +- Fix a long-standing bug where cached server key results which were directly fetched would not be properly re-used. ([\#15417](https://github.com/matrix-org/synapse/issues/15417)) +- Fix a bug introduced in Synapse 1.73.0 where some experimental push rules were returned by default. ([\#15494](https://github.com/matrix-org/synapse/issues/15494)) + + +Improved Documentation +---------------------- + +- Add Nginx loadbalancing example with sticky mxid for workers. ([\#15411](https://github.com/matrix-org/synapse/issues/15411)) +- Update outdated development docs that mention restrictions in versions of SQLite that we no longer support. ([\#15498](https://github.com/matrix-org/synapse/issues/15498)) + + +Internal Changes +---------------- + +- Speedup tests by caching HomeServerConfig instances. ([\#15284](https://github.com/matrix-org/synapse/issues/15284)) +- Add denormalised event stream ordering column to membership state tables for future use. Contributed by Nick @ Beeper (@fizzadar). ([\#15356](https://github.com/matrix-org/synapse/issues/15356)) +- Always use multi-user device resync replication endpoints. ([\#15418](https://github.com/matrix-org/synapse/issues/15418)) +- Add column `full_user_id` to tables `profiles` and `user_filters`. ([\#15458](https://github.com/matrix-org/synapse/issues/15458)) +- Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request. ([\#15462](https://github.com/matrix-org/synapse/issues/15462)) +- Improve type hints. ([\#15465](https://github.com/matrix-org/synapse/issues/15465), [\#15496](https://github.com/matrix-org/synapse/issues/15496), [\#15497](https://github.com/matrix-org/synapse/issues/15497)) +- Support claiming more than one OTK at a time. ([\#15468](https://github.com/matrix-org/synapse/issues/15468)) +- Bump types-pyyaml from 6.0.12.8 to 6.0.12.9. ([\#15471](https://github.com/matrix-org/synapse/issues/15471)) +- Bump pyasn1-modules from 0.2.8 to 0.3.0. ([\#15473](https://github.com/matrix-org/synapse/issues/15473)) +- Bump cryptography from 40.0.1 to 40.0.2. ([\#15474](https://github.com/matrix-org/synapse/issues/15474)) +- Bump types-netaddr from 0.8.0.7 to 0.8.0.8. ([\#15475](https://github.com/matrix-org/synapse/issues/15475)) +- Bump types-jsonschema from 4.17.0.6 to 4.17.0.7. ([\#15476](https://github.com/matrix-org/synapse/issues/15476)) +- Ask bug reporters to provide logs as text. ([\#15479](https://github.com/matrix-org/synapse/issues/15479)) +- Add a Nix flake for use as a development environment. ([\#15495](https://github.com/matrix-org/synapse/issues/15495)) +- Bump anyhow from 1.0.70 to 1.0.71. ([\#15507](https://github.com/matrix-org/synapse/issues/15507)) +- Bump types-pillow from 9.4.0.19 to 9.5.0.2. ([\#15508](https://github.com/matrix-org/synapse/issues/15508)) +- Bump packaging from 23.0 to 23.1. ([\#15510](https://github.com/matrix-org/synapse/issues/15510)) +- Bump types-requests from 2.28.11.16 to 2.29.0.0. ([\#15511](https://github.com/matrix-org/synapse/issues/15511)) +- Bump setuptools-rust from 1.5.2 to 1.6.0. ([\#15512](https://github.com/matrix-org/synapse/issues/15512)) +- Update the check_schema_delta script to account for when the schema version has been bumped locally. ([\#15466](https://github.com/matrix-org/synapse/issues/15466)) + + Synapse 1.82.0 (2023-04-25) =========================== diff --git a/Cargo.lock b/Cargo.lock index f661eb532cdc..1085673c7263 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" [[package]] name = "arc-swap" diff --git a/debian/changelog b/debian/changelog index f6e8720e5894..15ff7e82c31b 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,15 @@ +matrix-synapse-py3 (1.83.0) stable; urgency=medium + + * New Synapse release 1.83.0. + + -- Synapse Packaging team Tue, 09 May 2023 18:13:37 +0200 + +matrix-synapse-py3 (1.83.0~rc1) stable; urgency=medium + + * New Synapse release 1.83.0rc1. + + -- Synapse Packaging team Tue, 02 May 2023 15:56:38 +0100 + matrix-synapse-py3 (1.82.0) stable; urgency=medium * New Synapse release 1.82.0. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index ade77d49261c..a8e5ddad9d48 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -57,6 +57,7 @@ - [Account Validity](admin_api/account_validity.md) - [Background Updates](usage/administration/admin_api/background_updates.md) - [Event Reports](admin_api/event_reports.md) + - [Experimental Features](admin_api/experimental_features.md) - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - [Register Users](admin_api/register_api.md) diff --git a/docs/admin_api/experimental_features.md b/docs/admin_api/experimental_features.md new file mode 100644 index 000000000000..c1aebe4b01a8 --- /dev/null +++ b/docs/admin_api/experimental_features.md @@ -0,0 +1,54 @@ +# Experimental Features API + +This API allows a server administrator to enable or disable some experimental features on a per-user +basis. Currently supported features are [msc3026](https://github.com/matrix-org/matrix-spec-proposals/pull/3026): busy +presence state enabled, [msc2654](https://github.com/matrix-org/matrix-spec-proposals/pull/2654): enable unread counts, +[msc3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications +for another client, and [msc3967](https://github.com/matrix-org/matrix-spec-proposals/pull/3967): do not require +UIA when first uploading cross-signing keys. + + +To use it, you will need to authenticate by providing an `access_token` +for a server admin: see [Admin API](../usage/administration/admin_api/). + +## Enabling/Disabling Features + +This API allows a server administrator to enable experimental features for a given user. The request must +provide a body containing the user id and listing the features to enable/disable in the following format: +```json +{ + "features": { + "msc3026":true, + "msc2654":true + } +} +``` +where true is used to enable the feature, and false is used to disable the feature. + + +The API is: + +``` +PUT /_synapse/admin/v1/experimental_features/ +``` + +## Listing Enabled Features + +To list which features are enabled/disabled for a given user send a request to the following API: + +``` +GET /_synapse/admin/v1/experimental_features/ +``` + +It will return a list of possible features and indicate whether they are enabled or disabled for the +user like so: +```json +{ + "features": { + "msc3026": true, + "msc2654": true, + "msc3881": false, + "msc3967": false + } +} +``` \ No newline at end of file diff --git a/docs/admin_api/statistics.md b/docs/admin_api/statistics.md index 03b3621e5595..2bd417e90031 100644 --- a/docs/admin_api/statistics.md +++ b/docs/admin_api/statistics.md @@ -81,3 +81,52 @@ The following fields are returned in the JSON response body: - `user_id` - string - Fully-qualified user ID (ex. `@user:server.com`). * `next_token` - integer - Opaque value used for pagination. See above. * `total` - integer - Total number of users after filtering. + + +# Get largest rooms by size in database + +Returns the 10 largest rooms and an estimate of how much space in the database +they are taking. + +This does not include the size of any associated media associated with the room. + +Returns an error on SQLite. + +*Note:* This uses the planner statistics from PostgreSQL to do the estimates, +which means that the returned information can vary widely from reality. However, +it should be enough to get a rough idea of where database disk space is going. + + +The API is: + +``` +GET /_synapse/admin/v1/statistics/statistics/database/rooms +``` + +A response body like the following is returned: + +```json +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "estimated_size": 47325417353 + } + ], +} +``` + + + +**Response** + +The following fields are returned in the JSON response body: + +* `rooms` - An array of objects, sorted by largest room first. Objects contain + the following fields: + - `room_id` - string - The room ID. + - `estimated_size` - integer - Estimated disk space used in bytes by the room + in the database. + + +*Added in Synapse 1.83.0* diff --git a/docs/development/database_schema.md b/docs/development/database_schema.md index 29945c264ee8..e231be21ddd2 100644 --- a/docs/development/database_schema.md +++ b/docs/development/database_schema.md @@ -155,43 +155,11 @@ def run_upgrade( Boolean columns require special treatment, since SQLite treats booleans the same as integers. -There are three separate aspects to this: - - * Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in +Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in `synapse/_scripts/synapse_port_db.py`. This tells the port script to cast the integer value from SQLite to a boolean before writing the value to the postgres database. - * Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by - SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not - supported. This makes it necessary to avoid using `TRUE` and `FALSE` - constants in SQL commands. - - For example, to insert a `TRUE` value into the database, write: - - ```python - txn.execute("INSERT INTO tbl(col) VALUES (?)", (True, )) - ``` - - * Default values for new boolean columns present a particular - difficulty. Generally it is best to create separate schema files for - Postgres and SQLite. For example: - - ```sql - # in 00delta.sql.postgres: - ALTER TABLE tbl ADD COLUMN col BOOLEAN DEFAULT FALSE; - ``` - - ```sql - # in 00delta.sql.sqlite: - ALTER TABLE tbl ADD COLUMN col BOOLEAN DEFAULT 0; - ``` - - Note that there is a particularly insidious failure mode here: the Postgres - flavour will be accepted by SQLite 3.22, but will give a column whose - default value is the **string** `"FALSE"` - which, when cast back to a boolean - in Python, evaluates to `True`. - ## `event_id` global uniqueness diff --git a/docs/workers.md b/docs/workers.md index 6192a46e0950..765f03c2635b 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -325,8 +325,7 @@ load balancing can be done in different ways. For `/sync` and `/initialSync` requests it will be more efficient if all requests from a particular user are routed to a single instance. This can -be done e.g. in nginx via IP `hash $http_x_forwarded_for;` or via -`hash $http_authorization consistent;` which contains the users access token. +be done in reverse proxy by extracting username part from the users access token. Admins may additionally wish to separate out `/sync` requests that have a `since` query parameter from those that don't (and @@ -335,6 +334,69 @@ when a user logs in on a new device and can be *very* resource intensive, so isolating these requests will stop them from interfering with other users ongoing syncs. +Example `nginx` configuration snippet that handles the cases above. This is just an +example and probably requires some changes according to your particular setup: + +```nginx +# Choose sync worker based on the existence of "since" query parameter +map $arg_since $sync { + default synapse_sync; + '' synapse_initial_sync; +} + +# Extract username from access token passed as URL parameter +map $arg_access_token $accesstoken_from_urlparam { + # Defaults to just passing back the whole accesstoken + default $arg_access_token; + # Try to extract username part from accesstoken URL parameter + "~syt_(?.*?)_.*" $username; +} + +# Extract username from access token passed as authorization header +map $http_authorization $mxid_localpart { + # Defaults to just passing back the whole accesstoken + default $http_authorization; + # Try to extract username part from accesstoken header + "~Bearer syt_(?.*?)_.*" $username; + # if no authorization-header exist, try mapper for URL parameter "access_token" + "" $accesstoken_from_urlparam; +} + +upstream synapse_initial_sync { + # Use the username mapper result for hash key + hash $mxid_localpart consistent; + server 127.0.0.1:8016; + server 127.0.0.1:8036; +} + +upstream synapse_sync { + # Use the username mapper result for hash key + hash $mxid_localpart consistent; + server 127.0.0.1:8013; + server 127.0.0.1:8037; + server 127.0.0.1:8038; + server 127.0.0.1:8039; +} + +# Sync initial/normal +location ~ ^/_matrix/client/(r0|v3)/sync$ { + proxy_pass http://$sync; +} + +# Normal sync +location ~ ^/_matrix/client/(api/v1|r0|v3)/events$ { + proxy_pass http://synapse_sync; +} + +# Initial_sync +location ~ ^/_matrix/client/(api/v1|r0|v3)/initialSync$ { + proxy_pass http://synapse_initial_sync; +} +location ~ ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ { + proxy_pass http://synapse_initial_sync; +} +``` + Federation and client requests can be balanced via simple round robin. The inbound federation transaction request `^/_matrix/federation/v1/send/` diff --git a/flake.lock b/flake.lock new file mode 100644 index 000000000000..85886b730f54 --- /dev/null +++ b/flake.lock @@ -0,0 +1,274 @@ +{ + "nodes": { + "devenv": { + "inputs": { + "flake-compat": "flake-compat", + "nix": "nix", + "nixpkgs": "nixpkgs", + "pre-commit-hooks": "pre-commit-hooks" + }, + "locked": { + "lastModified": 1682534083, + "narHash": "sha256-lBgFaLNHRQtD3InZbBXzIS8HgZUgcPJ6jiqGa4FJPrk=", + "owner": "anoadragon453", + "repo": "devenv", + "rev": "9694bd0a845dd184d4468cc3d3461089aace787a", + "type": "github" + }, + "original": { + "owner": "anoadragon453", + "ref": "anoa/fix_languages_python", + "repo": "devenv", + "type": "github" + } + }, + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1682490133, + "narHash": "sha256-tR2Qx0uuk97WySpSSk4rGS/oH7xb5LykbjATcw1vw1I=", + "owner": "nix-community", + "repo": "fenix", + "rev": "4e9412753ab75ef0e038a5fe54a062fb44c27c6a", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1667395993, + "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1660459072, + "narHash": "sha256-8DFJjXG8zqoONA1vXtgeKXy68KdJL5UaXR8NtVMUbx8=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "a20de23b925fd8264fd7fad6454652e142fd7f73", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "lowdown-src": { + "flake": false, + "locked": { + "lastModified": 1633514407, + "narHash": "sha256-Dw32tiMjdK9t3ETl5fzGrutQTzh2rufgZV4A/BbxuD4=", + "owner": "kristapsdz", + "repo": "lowdown", + "rev": "d2c2b44ff6c27b936ec27358a2653caaef8f73b8", + "type": "github" + }, + "original": { + "owner": "kristapsdz", + "repo": "lowdown", + "type": "github" + } + }, + "nix": { + "inputs": { + "lowdown-src": "lowdown-src", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-regression": "nixpkgs-regression" + }, + "locked": { + "lastModified": 1676545802, + "narHash": "sha256-EK4rZ+Hd5hsvXnzSzk2ikhStJnD63odF7SzsQ8CuSPU=", + "owner": "domenkozar", + "repo": "nix", + "rev": "7c91803598ffbcfe4a55c44ac6d49b2cf07a527f", + "type": "github" + }, + "original": { + "owner": "domenkozar", + "ref": "relaxed-flakes", + "repo": "nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1678875422, + "narHash": "sha256-T3o6NcQPwXjxJMn2shz86Chch4ljXgZn746c2caGxd8=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "126f49a01de5b7e35a43fd43f891ecf6d3a51459", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-regression": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, + "nixpkgs-stable": { + "locked": { + "lastModified": 1673800717, + "narHash": "sha256-SFHraUqLSu5cC6IxTprex/nTsI81ZQAtDvlBvGDWfnA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2f9fd351ec37f5d479556cd48be4ca340da59b8f", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-22.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1682519441, + "narHash": "sha256-Vsq/8NOtvW1AoC6shCBxRxZyMQ+LhvPuJT6ltbzuv+Y=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "7a32a141db568abde9bc389845949dc2a454dfd3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "master", + "repo": "nixpkgs", + "type": "github" + } + }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-utils": "flake-utils", + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1678376203, + "narHash": "sha256-3tyYGyC8h7fBwncLZy5nCUjTJPrHbmNwp47LlNLOHSM=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "1a20b9708962096ec2481eeb2ddca29ed747770a", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "fenix": "fenix", + "nixpkgs": "nixpkgs_2", + "systems": "systems" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1682426789, + "narHash": "sha256-UqnLmJESRZE0tTEaGbRAw05Hm19TWIPA+R3meqi5I4w=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "943d2a8a1ca15e8b28a1f51f5a5c135e3728da04", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000000..91916d9abb51 --- /dev/null +++ b/flake.nix @@ -0,0 +1,204 @@ +# A nix flake that sets up a complete Synapse development environment. Dependencies +# for the SyTest (https://github.com/matrix-org/sytest) and Complement +# (https://github.com/matrix-org/complement) Matrix homeserver test suites are also +# installed automatically. +# +# You must have already installed nix (https://nixos.org) on your system to use this. +# nix can be installed on Linux or MacOS; NixOS is not required. Windows is not +# directly supported, but nix can be installed inside of WSL2 or even Docker +# containers. Please refer to https://nixos.org/download for details. +# +# You must also enable support for flakes in Nix. See the following for how to +# do so permanently: https://nixos.wiki/wiki/Flakes#Enable_flakes +# +# Usage: +# +# With nix installed, navigate to the directory containing this flake and run +# `nix develop --impure`. The `--impure` is necessary in order to store state +# locally from "services", such as PostgreSQL and Redis. +# +# You should now be dropped into a new shell with all programs and dependencies +# availabile to you! +# +# You can start up pre-configured, local PostgreSQL and Redis instances by +# running: `devenv up`. To stop them, use Ctrl-C. +# +# A PostgreSQL database called 'synapse' will be set up for you, along with +# a PostgreSQL user named 'synapse_user'. +# The 'host' can be found by running `echo $PGHOST` with the development +# shell activated. Use these values to configure your Synapse to connect +# to the local PostgreSQL database. You do not need to specify a password. +# https://matrix-org.github.io/synapse/latest/postgres +# +# All state (the venv, postgres and redis data and config) are stored in +# .devenv/state. Deleting a file from here and then re-entering the shell +# will recreate these files from scratch. +# +# You can exit the development shell by typing `exit`, or using Ctrl-D. +# +# If you would like this development environment to activate automatically +# upon entering this directory in your terminal, first install `direnv` +# (https://direnv.net/). Then run `echo 'use flake . --impure' >> .envrc` at +# the root of the Synapse repo. Finally, run `direnv allow .` to allow the +# contents of '.envrc' to run every time you enter this directory. VoilĂ ! + +{ + inputs = { + # Use the master/unstable branch of nixpkgs. The latest stable, 22.11, + # does not contain 'perl536Packages.NetAsyncHTTP', needed by Sytest. + nixpkgs.url = "github:NixOS/nixpkgs/master"; + # Output a development shell for x86_64/aarch64 Linux/Darwin (MacOS). + systems.url = "github:nix-systems/default"; + # A development environment manager built on Nix. See https://devenv.sh. + # This is temporarily overridden to a fork that fixes a quirk between + # devenv's service and python language features. This can be removed + # when https://github.com/cachix/devenv/pull/559 is merged upstream. + devenv.url = "github:anoadragon453/devenv/anoa/fix_languages_python"; + #devenv.url = "github:cachix/devenv/main"; + # Rust toolchains and rust-analyzer nightly. + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = { self, nixpkgs, devenv, systems, ... } @ inputs: + let + forEachSystem = nixpkgs.lib.genAttrs (import systems); + in { + devShells = forEachSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in { + # Everything is configured via devenv - a nix module for creating declarative + # developer environments. See https://devenv.sh/reference/options/ for a list + # of all possible options. + default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + # Make use of the Starship command prompt when this development environment + # is manually activated (via `nix develop --impure`). + # See https://starship.rs/ for details on the prompt itself. + starship.enable = true; + + # Configure packages to install. + # Search for package names at https://search.nixos.org/packages?channel=unstable + packages = with pkgs; [ + # Native dependencies for running Synapse. + icu + libffi + libjpeg + libpqxx + libwebp + libxml2 + libxslt + sqlite + + # Native dependencies for unit tests (SyTest also requires OpenSSL). + openssl + + # Native dependencies for running Complement. + olm + ]; + + # Install Python and manage a virtualenv with Poetry. + languages.python.enable = true; + languages.python.poetry.enable = true; + # Automatically activate the poetry virtualenv upon entering the shell. + languages.python.poetry.activate.enable = true; + # Install all extra Python dependencies; this is needed to run the unit + # tests and utilitise all Synapse features. + languages.python.poetry.install.arguments = ["--extras all"]; + # Install the 'matrix-synapse' package from the local checkout. + languages.python.poetry.install.installRootPackage = true; + + # This is a work-around for NixOS systems. NixOS is special in + # that you can have multiple versions of packages installed at + # once, including your libc linker! + # + # Some binaries built for Linux expect those to be in a certain + # filepath, but that is not the case on NixOS. In that case, we + # force compiling those binaries locally instead. + env.POETRY_INSTALLER_NO_BINARY = "ruff"; + + # Install dependencies for the additional programming languages + # involved with Synapse development. + # + # * Rust is used for developing and running Synapse. + # * Golang is needed to run the Complement test suite. + # * Perl is needed to run the SyTest test suite. + languages.go.enable = true; + languages.rust.enable = true; + languages.rust.version = "stable"; + languages.perl.enable = true; + + # Postgres is needed to run Synapse with postgres support and + # to run certain unit tests that require postgres. + services.postgres.enable = true; + + # On the first invocation of `devenv up`, create a database for + # Synapse to store data in. + services.postgres.initdbArgs = ["--locale=C" "--encoding=UTF8"]; + services.postgres.initialDatabases = [ + { name = "synapse"; } + ]; + # Create a postgres user called 'synapse_user' which has ownership + # over the 'synapse' database. + services.postgres.initialScript = '' + CREATE USER synapse_user; + ALTER DATABASE synapse OWNER TO synapse_user; + ''; + + # Redis is needed in order to run Synapse in worker mode. + services.redis.enable = true; + + # Define the perl modules we require to run SyTest. + # + # This list was compiled by cross-referencing https://metacpan.org/ + # with the modules defined in './cpanfile' and then finding the + # corresponding nix packages on https://search.nixos.org/packages. + # + # This was done until `./install-deps.pl --dryrun` produced no output. + env.PERL5LIB = "${with pkgs.perl536Packages; makePerlPath [ + DBI + ClassMethodModifiers + CryptEd25519 + DataDump + DBDPg + DigestHMAC + DigestSHA1 + EmailAddressXS + EmailMIME + EmailSimple # required by Email::Mime + EmailMessageID # required by Email::Mime + EmailMIMEContentType # required by Email::Mime + TextUnidecode # required by Email::Mime + ModuleRuntime # required by Email::Mime + EmailMIMEEncodings # required by Email::Mime + FilePath + FileSlurper + Future + GetoptLong + HTTPMessage + IOAsync + IOAsyncSSL + IOSocketSSL + NetSSLeay + JSON + ListUtilsBy + ScalarListUtils + ModulePluggable + NetAsyncHTTP + MetricsAny # required by Net::Async::HTTP + NetAsyncHTTPServer + StructDumb + URI + YAMLLibYAML + ]}"; + } + ]; + }; + }); + }; +} diff --git a/mypy.ini b/mypy.ini index 945f7925cb2c..5e7057cfb7b1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,26 +21,7 @@ files = tests/, build_rust.py -# Note: Better exclusion syntax coming in mypy > 0.910 -# https://github.com/python/mypy/pull/11329 -# -# For now, set the (?x) flag enable "verbose" regexes -# https://docs.python.org/3/library/re.html#re.X -exclude = (?x) - ^( - |synapse/storage/databases/__init__.py - |synapse/storage/databases/main/cache.py - |synapse/storage/schema/ - )$ - -[mypy-synapse.federation.transport.client] -disallow_untyped_defs = False - -[mypy-synapse.http.matrixfederationclient] -disallow_untyped_defs = False - [mypy-synapse.metrics._reactor_metrics] -disallow_untyped_defs = False # This module imports select.epoll. That exists on Linux, but doesn't on macOS. # See https://github.com/matrix-org/synapse/pull/11771. warn_unused_ignores = False diff --git a/poetry.lock b/poetry.lock index 19bf47b934b3..daafa355cdde 100644 --- a/poetry.lock +++ b/poetry.lock @@ -481,31 +481,31 @@ files = [ [[package]] name = "cryptography" -version = "40.0.1" +version = "40.0.2" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." category = "main" optional = false python-versions = ">=3.6" files = [ - {file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:918cb89086c7d98b1b86b9fdb70c712e5a9325ba6f7d7cfb509e784e0cfc6917"}, - {file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9618a87212cb5200500e304e43691111570e1f10ec3f35569fdfcd17e28fd797"}, - {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4805a4ca729d65570a1b7cac84eac1e431085d40387b7d3bbaa47e39890b88"}, - {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63dac2d25c47f12a7b8aa60e528bfb3c51c5a6c5a9f7c86987909c6c79765554"}, - {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:0a4e3406cfed6b1f6d6e87ed243363652b2586b2d917b0609ca4f97072994405"}, - {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1e0af458515d5e4028aad75f3bb3fe7a31e46ad920648cd59b64d3da842e4356"}, - {file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d8aa3609d337ad85e4eb9bb0f8bcf6e4409bfb86e706efa9a027912169e89122"}, - {file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cf91e428c51ef692b82ce786583e214f58392399cf65c341bc7301d096fa3ba2"}, - {file = "cryptography-40.0.1-cp36-abi3-win32.whl", hash = "sha256:650883cc064297ef3676b1db1b7b1df6081794c4ada96fa457253c4cc40f97db"}, - {file = "cryptography-40.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:a805a7bce4a77d51696410005b3e85ae2839bad9aa38894afc0aa99d8e0c3160"}, - {file = "cryptography-40.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cd033d74067d8928ef00a6b1327c8ea0452523967ca4463666eeba65ca350d4c"}, - {file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d36bbeb99704aabefdca5aee4eba04455d7a27ceabd16f3b3ba9bdcc31da86c4"}, - {file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:32057d3d0ab7d4453778367ca43e99ddb711770477c4f072a51b3ca69602780a"}, - {file = "cryptography-40.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:f5d7b79fa56bc29580faafc2ff736ce05ba31feaa9d4735048b0de7d9ceb2b94"}, - {file = "cryptography-40.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7c872413353c70e0263a9368c4993710070e70ab3e5318d85510cc91cce77e7c"}, - {file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:28d63d75bf7ae4045b10de5413fb1d6338616e79015999ad9cf6fc538f772d41"}, - {file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6f2bbd72f717ce33100e6467572abaedc61f1acb87b8d546001328d7f466b778"}, - {file = "cryptography-40.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cc3a621076d824d75ab1e1e530e66e7e8564e357dd723f2533225d40fe35c60c"}, - {file = "cryptography-40.0.1.tar.gz", hash = "sha256:2803f2f8b1e95f614419926c7e6f55d828afc614ca5ed61543877ae668cc3472"}, + {file = "cryptography-40.0.2-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:8f79b5ff5ad9d3218afb1e7e20ea74da5f76943ee5edb7f76e56ec5161ec782b"}, + {file = "cryptography-40.0.2-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:05dc219433b14046c476f6f09d7636b92a1c3e5808b9a6536adf4932b3b2c440"}, + {file = "cryptography-40.0.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4df2af28d7bedc84fe45bd49bc35d710aede676e2a4cb7fc6d103a2adc8afe4d"}, + {file = "cryptography-40.0.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dcca15d3a19a66e63662dc8d30f8036b07be851a8680eda92d079868f106288"}, + {file = "cryptography-40.0.2-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:a04386fb7bc85fab9cd51b6308633a3c271e3d0d3eae917eebab2fac6219b6d2"}, + {file = "cryptography-40.0.2-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:adc0d980fd2760c9e5de537c28935cc32b9353baaf28e0814df417619c6c8c3b"}, + {file = "cryptography-40.0.2-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d5a1bd0e9e2031465761dfa920c16b0065ad77321d8a8c1f5ee331021fda65e9"}, + {file = "cryptography-40.0.2-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:a95f4802d49faa6a674242e25bfeea6fc2acd915b5e5e29ac90a32b1139cae1c"}, + {file = "cryptography-40.0.2-cp36-abi3-win32.whl", hash = "sha256:aecbb1592b0188e030cb01f82d12556cf72e218280f621deed7d806afd2113f9"}, + {file = "cryptography-40.0.2-cp36-abi3-win_amd64.whl", hash = "sha256:b12794f01d4cacfbd3177b9042198f3af1c856eedd0a98f10f141385c809a14b"}, + {file = "cryptography-40.0.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:142bae539ef28a1c76794cca7f49729e7c54423f615cfd9b0b1fa90ebe53244b"}, + {file = "cryptography-40.0.2-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:956ba8701b4ffe91ba59665ed170a2ebbdc6fc0e40de5f6059195d9f2b33ca0e"}, + {file = "cryptography-40.0.2-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4f01c9863da784558165f5d4d916093737a75203a5c5286fde60e503e4276c7a"}, + {file = "cryptography-40.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3daf9b114213f8ba460b829a02896789751626a2a4e7a43a28ee77c04b5e4958"}, + {file = "cryptography-40.0.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48f388d0d153350f378c7f7b41497a54ff1513c816bcbbcafe5b829e59b9ce5b"}, + {file = "cryptography-40.0.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c0764e72b36a3dc065c155e5b22f93df465da9c39af65516fe04ed3c68c92636"}, + {file = "cryptography-40.0.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cbaba590180cba88cb99a5f76f90808a624f18b169b90a4abb40c1fd8c19420e"}, + {file = "cryptography-40.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7a38250f433cd41df7fcb763caa3ee9362777fdb4dc642b9a349721d2bf47404"}, + {file = "cryptography-40.0.2.tar.gz", hash = "sha256:c33c0d32b8594fa647d2e01dbccc303478e16fdd7cf98652d5b3ed11aa5e5c99"}, ] [package.dependencies] @@ -1660,14 +1660,14 @@ tests = ["Sphinx", "doubles", "flake8", "flake8-quotes", "gevent", "mock", "pyte [[package]] name = "packaging" -version = "23.0" +version = "23.1" description = "Core utilities for Python packages" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"}, - {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] [[package]] @@ -1916,18 +1916,18 @@ files = [ [[package]] name = "pyasn1-modules" -version = "0.2.8" -description = "A collection of ASN.1-based protocols modules." +version = "0.3.0" +description = "A collection of ASN.1-based protocols modules" category = "main" optional = false -python-versions = "*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"}, - {file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"}, + {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, + {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, ] [package.dependencies] -pyasn1 = ">=0.4.6,<0.5.0" +pyasn1 = ">=0.4.6,<0.6.0" [[package]] name = "pycparser" @@ -2511,14 +2511,14 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "setuptools-rust" -version = "1.5.2" +version = "1.6.0" description = "Setuptools Rust extension plugin" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-rust-1.5.2.tar.gz", hash = "sha256:d8daccb14dc0eae1b6b6eb3ecef79675bd37b4065369f79c35393dd5c55652c7"}, - {file = "setuptools_rust-1.5.2-py3-none-any.whl", hash = "sha256:8eb45851e34288f2296cd5ab9e924535ac1757318b730a13fe6836867843f206"}, + {file = "setuptools-rust-1.6.0.tar.gz", hash = "sha256:c86e734deac330597998bfbc08da45187e6b27837e23bd91eadb320732392262"}, + {file = "setuptools_rust-1.6.0-py3-none-any.whl", hash = "sha256:e28ae09fb7167c44ab34434eb49279307d611547cb56cb9789955cdb54a1aed9"}, ] [package.dependencies] @@ -3083,14 +3083,14 @@ files = [ [[package]] name = "types-netaddr" -version = "0.8.0.7" +version = "0.8.0.8" description = "Typing stubs for netaddr" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-netaddr-0.8.0.7.tar.gz", hash = "sha256:3362864fa0258782d449b91707f37e55f62290b4f438974a08758b498169e109"}, - {file = "types_netaddr-0.8.0.7-py3-none-any.whl", hash = "sha256:a540cdfb2f858a0509ce5a4e4fcc80ef11b19f10a2473e48d32217af517818c0"}, + {file = "types-netaddr-0.8.0.8.tar.gz", hash = "sha256:db7e8cd16b1244e7c4541edd0df99d1039fc05fd5387c21840f0b958fc52aabc"}, + {file = "types_netaddr-0.8.0.8-py3-none-any.whl", hash = "sha256:6741b3824e2ec3f7a74842b394439b71107c7675f8ae42bb2b5e7a8ebfe8cf18"}, ] [[package]] @@ -3107,14 +3107,14 @@ files = [ [[package]] name = "types-pillow" -version = "9.4.0.19" +version = "9.5.0.2" description = "Typing stubs for Pillow" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-Pillow-9.4.0.19.tar.gz", hash = "sha256:a04401181979049977e318dae4523ab5ae8246314fc68fcf50b043ac885a5468"}, - {file = "types_Pillow-9.4.0.19-py3-none-any.whl", hash = "sha256:b55f2508be21e68a39f0a41830f1f1725aba0888e727e2eccd253c78cd5357a5"}, + {file = "types-Pillow-9.5.0.2.tar.gz", hash = "sha256:b3f9f621f259566c19c1deca21901017c8b1e3e200ed2e49e0a2d83c0a5175db"}, + {file = "types_Pillow-9.5.0.2-py3-none-any.whl", hash = "sha256:58fdebd0ffa2353ecccdd622adde23bce89da5c0c8b96c34f2d1eca7b7e42d0e"}, ] [[package]] @@ -3158,14 +3158,14 @@ files = [ [[package]] name = "types-requests" -version = "2.28.11.17" +version = "2.29.0.0" description = "Typing stubs for requests" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-requests-2.28.11.17.tar.gz", hash = "sha256:0d580652ce903f643f8c3b494dd01d29367ea57cea0c7ad7f65cf3169092edb0"}, - {file = "types_requests-2.28.11.17-py3-none-any.whl", hash = "sha256:cc1aba862575019306b2ed134eb1ea994cab1c887a22e18d3383e6dd42e9789b"}, + {file = "types-requests-2.29.0.0.tar.gz", hash = "sha256:c86f4a955d943d2457120dbe719df24ef0924e11177164d10a0373cf311d7b4d"}, + {file = "types_requests-2.29.0.0-py3-none-any.whl", hash = "sha256:4cf6e323e856c779fbe8815bb977a5bf5d6c5034713e4c17ff2a9a20610f5b27"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 111e27fda8f7..099c2f95be00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml" [tool.poetry] name = "matrix-synapse" -version = "1.82.0" +version = "1.83.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "Apache-2.0" diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 7648938c9dac..66b2f2ed4338 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -574,7 +574,10 @@ impl FilteredPushRules { .filter(|rule| { // Ignore disabled experimental push rules - if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + if !self.msc1767_enabled + && (rule.rule_id.contains("org.matrix.msc1767") + || rule.rule_id.contains("org.matrix.msc3933")) + { return false; } diff --git a/scripts-dev/check_schema_delta.py b/scripts-dev/check_schema_delta.py index 32fe7f50deea..fee4a8bd3d5b 100755 --- a/scripts-dev/check_schema_delta.py +++ b/scripts-dev/check_schema_delta.py @@ -40,10 +40,32 @@ def main(force_colors: bool) -> None: exec(r, locals) current_schema_version = locals["SCHEMA_VERSION"] - click.secho(f"Current schema version: {current_schema_version}") - diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) + # Get the schema version of the local file to check against current schema on develop + with open("synapse/storage/schema/__init__.py", "r") as file: + local_schema = file.read() + new_locals: Dict[str, Any] = {} + exec(local_schema, new_locals) + local_schema_version = new_locals["SCHEMA_VERSION"] + + if local_schema_version != current_schema_version: + # local schema version must be +/-1 the current schema version on develop + if abs(local_schema_version - current_schema_version) != 1: + click.secho( + "The proposed schema version has diverged more than one version from develop, please fix!", + fg="red", + bold=True, + color=force_colors, + ) + click.get_current_context().exit(1) + + # right, we've changed the schema version within the allowable tolerance so + # let's now use the local version as the canonical version + current_schema_version = local_schema_version + + click.secho(f"Current schema version: {current_schema_version}") + seen_deltas = False bad_files = [] for diff in diffs: diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index a58ae2a308c6..27fee3d9a934 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -54,7 +54,7 @@ ) from synapse.notifier import ReplicationNotifier from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn -from synapse.storage.databases.main import PushRuleStore +from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore @@ -69,6 +69,7 @@ MediaRepositoryBackgroundUpdateStore, ) from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore +from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.pusher import ( PusherBackgroundUpdatesStore, PusherWorkerStore, @@ -124,6 +125,7 @@ "users": ["shadow_banned", "approved"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], + "per_user_experimental_features": ["enabled"], } @@ -229,6 +231,8 @@ class Store( EndToEndRoomKeyBackgroundStore, StatsStore, AccountDataWorkerStore, + FilteringWorkerStore, + ProfileWorkerStore, PushRuleStore, PusherWorkerStore, PusherBackgroundUpdatesStore, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index b9f432cc234a..de7c56bc0fa3 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -170,11 +170,9 @@ async def get_user_filter( result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(self._hs, result) - def add_user_filter( - self, user_localpart: str, user_filter: JsonDict - ) -> Awaitable[int]: + def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]: self.check_valid_filter(user_filter) - return self.store.add_user_filter(user_localpart, user_filter) + return self.store.add_user_filter(user_id, user_filter) # TODO(paul): surely we should probably add a delete_user_filter or # replace_user_filter at some point? There's no REST API specified for diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 86ddb1bb289e..024098e9cbb0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -442,8 +442,10 @@ async def push_bulk( return False async def claim_client_keys( - self, service: "ApplicationService", query: List[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, service: "ApplicationService", query: List[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Claim one time keys from an application service. Note that any error (including a timeout) is treated as the application @@ -469,8 +471,10 @@ async def claim_client_keys( # Create the expected payload shape. body: Dict[str, Dict[str, List[str]]] = {} - for user_id, device, algorithm in query: - body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + for user_id, device, algorithm, count in query: + body.setdefault(user_id, {}).setdefault(device, []).extend( + [algorithm] * count + ) uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" try: @@ -493,11 +497,20 @@ async def claim_client_keys( # or if some are still missing. # # TODO This places a lot of faith in the response shape being correct. - missing = [ - (user_id, device, algorithm) - for user_id, device, algorithm in query - if algorithm not in response.get(user_id, {}).get(device, []) - ] + missing = [] + for user_id, device, algorithm, count in query: + # Count the number of keys in the response for this algorithm by + # checking which key IDs start with the algorithm. This uses that + # True == 1 in Python to generate a count. + response_count = sum( + key_id.startswith(f"{algorithm}:") + for key_id in response.get(user_id, {}).get(device, {}) + ) + count -= response_count + # If the appservice responds with fewer keys than requested, then + # consider the request unfulfilled. + if count > 0: + missing.append((user_id, device, algorithm, count)) return response, missing diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d616b3dcfd5c..5bf75e737df9 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -203,3 +203,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC2659: Application service ping endpoint self.msc2659_enabled = experimental.get("msc2659_enabled", False) + + # MSC3981: Recurse relations + self.msc3981_recurse_relations = experimental.get( + "msc3981_recurse_relations", False + ) + + # MSC3970: Scope transaction IDs to devices + self.msc3970_enabled = experimental.get("msc3970_enabled", False) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d2f99dc2acd1..afdf6863d6d1 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -150,18 +150,19 @@ class Keyring: def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None ): - self.clock = hs.get_clock() - if key_fetchers is None: - key_fetchers = ( - # Fetch keys from the database. - StoreKeyFetcher(hs), - # Fetch keys from a configured Perspectives server. - PerspectivesKeyFetcher(hs), - # Fetch keys from the origin server directly. - ServerKeyFetcher(hs), - ) - self._key_fetchers = key_fetchers + # Always fetch keys from the database. + mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)] + # Fetch keys from configured trusted key servers, if any exist. + key_servers = hs.config.key.key_servers + if key_servers: + mutable_key_fetchers.append(PerspectivesKeyFetcher(hs)) + # Finally, fetch keys from the origin server directly. + mutable_key_fetchers.append(ServerKeyFetcher(hs)) + + self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers) + else: + self._key_fetchers = key_fetchers self._fetch_keys_queue: BatchingQueue[ _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] @@ -510,7 +511,7 @@ async def _fetch_keys( for key_id in queue_value.key_ids ) - res = await self.store.get_server_verify_keys(key_ids_to_fetch) + res = await self.store.get_server_keys_json(key_ids_to_fetch) keys: Dict[str, Dict[str, FetchKeyResult]] = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key @@ -522,7 +523,6 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastores().main - self.config = hs.config async def process_v2_response( self, from_server: str, response_json: JsonDict, time_added_ms: int @@ -626,7 +626,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - self.key_servers = self.config.key.key_servers + self.key_servers = hs.config.key.key_servers async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -775,7 +775,7 @@ async def get_server_verify_key_v2_indirect( keys.setdefault(server_name, {}).update(processed_response) - await self.store.store_server_verify_keys( + await self.store.store_server_signature_keys( perspective_name, time_now_ms, added_keys ) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 4501518cf0e5..de7e5be42bec 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -198,9 +198,16 @@ def __init__(self, internal_metadata_dict: JsonDict): soft_failed: DictProperty[bool] = DictProperty("soft_failed") proactively_send: DictProperty[bool] = DictProperty("proactively_send") redacted: DictProperty[bool] = DictProperty("redacted") + historical: DictProperty[bool] = DictProperty("historical") + txn_id: DictProperty[str] = DictProperty("txn_id") + """The transaction ID, if it was set when the event was created.""" + token_id: DictProperty[int] = DictProperty("token_id") - historical: DictProperty[bool] = DictProperty("historical") + """The access token ID of the user who sent this event, if any.""" + + device_id: DictProperty[str] = DictProperty("device_id") + """The device ID of the user who sent this event, if any.""" # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 6413389a0569..b2c5f9882cd5 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -339,6 +339,7 @@ def serialize_event( time_now_ms: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + msc3970_enabled: bool = False, ) -> JsonDict: """Serialize event for clients @@ -346,6 +347,8 @@ def serialize_event( e time_now_ms config: Event serialization config + msc3970_enabled: Whether MSC3970 is enabled. It changes whether we should + include the `transaction_id` in the event's `unsigned` section. Returns: The serialized event dictionary. @@ -368,27 +371,43 @@ def serialize_event( if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms, config=config + e.unsigned["redacted_because"], + time_now_ms, + config=config, + msc3970_enabled=msc3970_enabled, ) # If we have a txn_id saved in the internal_metadata, we should include it in the # unsigned section of the event if it was sent by the same session as the one # requesting the event. - # There is a special case for guests, because they only have one access token - # without associated access_token_id, so we always include the txn_id for events - # they sent. - txn_id = getattr(e.internal_metadata, "txn_id", None) + txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None) if txn_id is not None and config.requester is not None: - event_token_id = getattr(e.internal_metadata, "token_id", None) - if config.requester.user.to_string() == e.sender and ( - ( - event_token_id is not None - and config.requester.access_token_id is not None - and event_token_id == config.requester.access_token_id + # For the MSC3970 rules to be applied, we *need* to have the device ID in the + # event internal metadata. Since we were not recording them before, if it hasn't + # been recorded, we fallback to the old behaviour. + event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None) + if msc3970_enabled and event_device_id is not None: + if event_device_id == config.requester.device_id: + d["unsigned"]["transaction_id"] = txn_id + + else: + # The pre-MSC3970 behaviour is to only include the transaction ID if the + # event was sent from the same access token. For regular users, we can use + # the access token ID to determine this. For guests, we can't, but since + # each guest only has one access token, we can just check that the event was + # sent by the same user as the one requesting the event. + event_token_id: Optional[int] = getattr( + e.internal_metadata, "token_id", None ) - or config.requester.is_guest - ): - d["unsigned"]["transaction_id"] = txn_id + if config.requester.user.to_string() == e.sender and ( + ( + event_token_id is not None + and config.requester.access_token_id is not None + and event_token_id == config.requester.access_token_id + ) + or config.requester.is_guest + ): + d["unsigned"]["transaction_id"] = txn_id # Beeper: include internal stream ordering as HS order unsigned hint stream_ordering = getattr(e.internal_metadata, "stream_ordering", None) @@ -424,13 +443,8 @@ class EventClientSerializer: clients. """ - def __init__(self, inhibit_replacement_via_edits: bool = False): - """ - Args: - inhibit_replacement_via_edits: If this is set to True, then events are - never replaced by their edits. - """ - self._inhibit_replacement_via_edits = inhibit_replacement_via_edits + def __init__(self, *, msc3970_enabled: bool = False): + self._msc3970_enabled = msc3970_enabled def serialize_event( self, @@ -456,7 +470,9 @@ def serialize_event( if not isinstance(event, EventBase): return event - serialized_event = serialize_event(event, time_now, config=config) + serialized_event = serialize_event( + event, time_now, config=config, msc3970_enabled=self._msc3970_enabled + ) # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: @@ -514,7 +530,9 @@ def _inject_bundled_aggregations( # `sender` of the edit; however MSC3925 proposes extending it to the whole # of the edit, which is what we do here. serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event( - event_aggregations.replace, time_now, config=config + event_aggregations.replace, + time_now, + config=config, ) # Include any threaded replies to this event. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4cf4957a42ca..0b2d1a78f7b5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -235,7 +235,10 @@ async def query_user_devices( ) async def claim_client_keys( - self, destination: str, content: JsonDict, timeout: Optional[int] + self, + destination: str, + query: Dict[str, Dict[str, Dict[str, int]]], + timeout: Optional[int], ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -247,6 +250,50 @@ async def claim_client_keys( The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() + + # Convert the query with counts into a stable and unstable query and check + # if attempting to claim more than 1 OTK. + content: Dict[str, Dict[str, str]] = {} + unstable_content: Dict[str, Dict[str, List[str]]] = {} + use_unstable = False + for user_id, one_time_keys in query.items(): + for device_id, algorithms in one_time_keys.items(): + if any(count > 1 for count in algorithms.values()): + use_unstable = True + if algorithms: + # For the stable query, choose only the first algorithm. + content.setdefault(user_id, {})[device_id] = next(iter(algorithms)) + # For the unstable query, repeat each algorithm by count, then + # splat those into chain to get a flattened list of all algorithms. + # + # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"]. + unstable_content.setdefault(user_id, {})[device_id] = list( + itertools.chain( + *( + itertools.repeat(algorithm, count) + for algorithm, count in algorithms.items() + ) + ) + ) + + if use_unstable: + try: + return await self.transport_layer.claim_client_keys_unstable( + destination, unstable_content, timeout + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error + # and raise. + if not is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't claim client keys with the unstable API, falling back to the v1 API" + ) + else: + logger.debug("Skipping unstable claim client keys API") + return await self.transport_layer.claim_client_keys( destination, content, timeout ) @@ -280,15 +327,11 @@ async def backfill( logger.debug("backfill transaction_data=%r", transaction_data) if not isinstance(transaction_data, dict): - # TODO we probably want an exception type specific to federation - # client validation. - raise TypeError("Backfill transaction_data is not a dict.") + raise InvalidResponseError("Backfill transaction_data is not a dict.") transaction_data_pdus = transaction_data.get("pdus") if not isinstance(transaction_data_pdus, list): - # TODO we probably want an exception type specific to federation - # client validation. - raise TypeError("transaction_data.pdus is not a list.") + raise InvalidResponseError("transaction_data.pdus is not a list.") room_version = await self.store.get_room_version(room_id) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d7740eb3b448..ca43c7bfc0d1 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,15 +1005,12 @@ async def on_query_user_devices( @trace async def on_claim_client_keys( - self, origin: str, content: JsonDict + self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool ) -> Dict[str, Any]: - query = [] - for user_id, device_keys in content.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) - log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self._e2e_keys_handler.claim_local_one_time_keys(query) + results = await self._e2e_keys_handler.claim_local_one_time_keys( + query, always_include_fallback_keys=always_include_fallback_keys + ) json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for result in results: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index c05d598b70cf..bc70b94f6820 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -16,6 +16,7 @@ import logging import urllib from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -42,18 +43,21 @@ ) from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction -from synapse.http.matrixfederationclient import ByteParser +from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser from synapse.http.types import QueryParams from synapse.types import JsonDict from synapse.util import ExceptionBundle +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class TransportLayerClient: """Sends federation HTTP requests to other servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.client = hs.get_federation_http_client() self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled @@ -133,7 +137,7 @@ async def get_event( async def backfill( self, destination: str, room_id: str, event_tuples: Collection[str], limit: int - ) -> Optional[JsonDict]: + ) -> Optional[Union[JsonDict, list]]: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -388,6 +392,7 @@ async def send_leave_v1( # server was just having a momentary blip, the room will be out of # sync. ignore_backoff=True, + parser=LegacyJsonSendParser(), ) async def send_leave_v2( @@ -445,7 +450,11 @@ async def send_invite_v1( path = _create_v1_path("/invite/%s/%s", room_id, event_id) return await self.client.put_json( - destination=destination, path=path, data=content, ignore_backoff=True + destination=destination, + path=path, + data=content, + ignore_backoff=True, + parser=LegacyJsonSendParser(), ) async def send_invite_v2( @@ -641,10 +650,10 @@ async def claim_client_keys( Response: { - "device_keys": { + "one_time_keys": { "": { "": { - ":": "" + ":": } } } @@ -660,7 +669,50 @@ async def claim_client_keys( path = _create_v1_path("/user/keys/claim") return await self.client.post_json( - destination=destination, path=path, data=query_content, timeout=timeout + destination=destination, + path=path, + data={"one_time_keys": query_content}, + timeout=timeout, + ) + + async def claim_client_keys_unstable( + self, destination: str, query_content: JsonDict, timeout: Optional[int] + ) -> JsonDict: + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "": { + "": {"": } + } + } + } + + Response: + { + "one_time_keys": { + "": { + "": { + ":": + } + } + } + } + + Args: + destination: The server to query. + query_content: The user ids to query. + Returns: + A dict containing the one-time keys. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim") + + return await self.client.post_json( + destination=destination, + path=path, + data={"one_time_keys": query_content}, + timeout=timeout, ) async def get_missing_events( diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 753372fc5476..55d2cd0a9aa2 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,6 +25,7 @@ from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, + FederationUnstableClientKeysClaimServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -298,6 +299,11 @@ def register_servlets( and not hs.config.experimental.msc3720_enabled ): continue + if ( + servletclass == FederationUnstableClientKeysClaimServlet + and not hs.config.experimental.msc3983_appservice_otk_claims + ): + continue servletclass( hs=hs, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index ec5b5eeafa29..36b0362504f5 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import Counter from typing import ( TYPE_CHECKING, Dict, @@ -577,7 +578,43 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - response = await self.handler.on_claim_client_keys(origin, content) + # Generate a count for each algorithm, which is hard-coded to 1. + key_query: List[Tuple[str, str, str, int]] = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + key_query.append((user_id, device_id, algorithm, 1)) + + response = await self.handler.on_claim_client_keys( + key_query, always_include_fallback_keys=False + ) + return 200, response + + +class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): + """ + Identical to the stable endpoint (FederationClientKeysClaimServlet) except + it allows for querying for multiple OTKs at once and always includes fallback + keys in the response. + """ + + PREFIX = FEDERATION_UNSTABLE_PREFIX + PATH = "/user/keys/claim" + CATEGORY = "Federation requests" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + # Generate a count for each algorithm. + key_query: List[Tuple[str, str, str, int]] = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithms in device_keys.items(): + counts = Counter(algorithms) + for algorithm, count in counts.items(): + key_query.append((user_id, device_id, algorithm, count)) + + response = await self.handler.on_claim_client_keys( + key_query, always_include_fallback_keys=True + ) return 200, response @@ -784,6 +821,7 @@ async def on_POST( FederationClientKeysQueryServlet, FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, + FederationUnstableClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, FederationVersionServlet, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index da887647d4d8..6429545c98d5 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -841,9 +841,9 @@ async def _check_user_exists(self, user_id: str) -> bool: return True async def claim_e2e_one_time_keys( - self, query: Iterable[Tuple[str, str, str]] + self, query: Iterable[Tuple[str, str, str, int]] ) -> Tuple[ - Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] ]: """Claim one time keys from application services. @@ -856,7 +856,7 @@ async def claim_e2e_one_time_keys( Returns: A tuple of: - An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + A map of user ID -> a map device ID -> a map of key ID -> JSON. A copy of the input which has not been fulfilled (either because they are not appservice users or the appservice does not support @@ -865,18 +865,18 @@ async def claim_e2e_one_time_keys( services = self.store.get_app_services() # Partition the users by appservice. - query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {} missing = [] - for user_id, device, algorithm in query: + for user_id, device, algorithm, count in query: if not self.store.get_if_app_services_interested_in_user(user_id): - missing.append((user_id, device, algorithm)) + missing.append((user_id, device, algorithm, count)) continue # Find the associated appservice. for service in services: if service.is_exclusive_user(user_id): query_by_appservice.setdefault(service.id, []).append( - (user_id, device, algorithm) + (user_id, device, algorithm, count) ) continue @@ -897,12 +897,11 @@ async def claim_e2e_one_time_keys( ) # Patch together the results -- they are all independent (since they - # require exclusive control over the users). They get returned as a list - # and the caller combines them. - claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + # require exclusive control over the users, which is the outermost key). + claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for success, result in results: if success: - claimed_keys.append(result[0]) + claimed_keys.update(result[0]) missing.extend(result[1]) return claimed_keys, missing diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ae1d9337adc5..b9d3b7fbc67b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -921,12 +920,8 @@ class DeviceListWorkerUpdater: def __init__(self, hs: "HomeServer"): from synapse.replication.http.devices import ( ReplicationMultiUserDevicesResyncRestServlet, - ReplicationUserDevicesResyncRestServlet, ) - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) self._multi_user_device_resync_client = ( ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) ) @@ -948,37 +943,7 @@ async def multi_user_device_resync( # Shortcut empty requests return {} - try: - return await self._multi_user_device_resync_client(user_ids=user_ids) - except SynapseError as err: - if not ( - err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED - ): - raise - - # Fall back to single requests - result: Dict[str, Optional[JsonDict]] = {} - for user_id in user_ids: - result[user_id] = await self._user_device_resync_client(user_id=user_id) - return result - - async def user_device_resync( - self, user_id: str, mark_failed_as_stale: bool = True - ) -> Optional[JsonDict]: - """Fetches all devices for a user and updates the device cache with them. - - Args: - user_id: The user's id whose device_list will be updated. - mark_failed_as_stale: Whether to mark the user's device list as stale - if the attempt to resync failed. - Returns: - A dict with device info as under the "devices" in the result of this - request: - https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid - None when we weren't able to fetch the device info for some reason, - e.g. due to a connection problem. - """ - return (await self.multi_user_device_resync([user_id]))[user_id] + return await self._multi_user_device_resync_client(user_ids=user_ids) class DeviceListUpdater(DeviceListWorkerUpdater): @@ -1131,7 +1096,7 @@ async def _handle_device_updates(self, user_id: str) -> None: ) if resync: - await self.user_device_resync(user_id) + await self.multi_user_device_resync([user_id]) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) @@ -1198,10 +1163,9 @@ async def _maybe_retry_device_resync(self) -> None: for user_id in need_resync: try: # Try to resync the current user's devices list. - result = await self.user_device_resync( - user_id=user_id, - mark_failed_as_stale=False, - ) + result = (await self.multi_user_device_resync([user_id], False))[ + user_id + ] # user_device_resync only returns a result if it managed to # successfully resync and update the database. Updating the table @@ -1260,18 +1224,6 @@ async def multi_user_device_resync( return result - async def user_device_resync( - self, user_id: str, mark_failed_as_stale: bool = True - ) -> Optional[JsonDict]: - result, failed = await self._user_device_resync_returning_failed(user_id) - - if failed and mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_users_device_caches_as_stale((user_id,)) - - return result - async def _user_device_resync_returning_failed( self, user_id: str ) -> Tuple[Optional[JsonDict], bool]: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 00c403db4925..3caf9b31cc8b 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -25,7 +25,9 @@ log_kv, set_tag, ) -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.devices import ( + ReplicationMultiUserDevicesResyncRestServlet, +) from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -71,12 +73,12 @@ def __init__(self, hs: "HomeServer"): # sync. We do all device list resyncing on the master instance, so if # we're on a worker we hit the device resync replication API. if hs.config.worker.worker_app is None: - self._user_device_resync = ( - hs.get_device_handler().device_list_updater.user_device_resync + self._multi_user_device_resync = ( + hs.get_device_handler().device_list_updater.multi_user_device_resync ) else: - self._user_device_resync = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) + self._multi_user_device_resync = ( + ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) ) # a rate limiter for room key requests. The keys are @@ -198,7 +200,7 @@ async def _check_for_unknown_devices( await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,)) # Immediately attempt a resync in the background - run_in_background(self._user_device_resync, user_id=sender_user_id) + run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id]) async def send_device_message( self, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 007366747014..24741b667bb9 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -563,7 +563,9 @@ async def on_federation_query_client_keys( return ret async def claim_local_one_time_keys( - self, local_query: List[Tuple[str, str, str]] + self, + local_query: List[Tuple[str, str, str, int]], + always_include_fallback_keys: bool, ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: """Claim one time keys for local users. @@ -573,43 +575,104 @@ async def claim_local_one_time_keys( Args: local_query: An iterable of tuples of (user ID, device ID, algorithm). + always_include_fallback_keys: True to always include fallback keys. Returns: An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. """ + # Cap the number of OTKs that can be claimed at once to avoid abuse. + local_query = [ + (user_id, device_id, algorithm, min(count, 5)) + for user_id, device_id, algorithm, count in local_query + ] + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) # If the application services have not provided any keys via the C-S # API, query it directly for one-time keys. if self._query_appservices_for_otks: + # TODO Should this query for fallback keys of uploaded OTKs if + # always_include_fallback_keys is True? The MSC is ambiguous. ( appservice_results, not_found, ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) else: - appservice_results = [] + appservice_results = {} + + # Calculate which user ID / device ID / algorithm tuples to get fallback + # keys for. This can be either only missing results *or* all results + # (which don't already have a fallback key). + if always_include_fallback_keys: + # Build the fallback query as any part of the original query where + # the appservice didn't respond with a fallback key. + fallback_query = [] + + # Iterate each item in the original query and search the results + # from the appservice for that user ID / device ID. If it is found, + # check if any of the keys match the requested algorithm & are a + # fallback key. + for user_id, device_id, algorithm, _count in local_query: + # Check if the appservice responded for this query. + as_result = appservice_results.get(user_id, {}).get(device_id, {}) + found_otk = False + for key_id, key_json in as_result.items(): + if key_id.startswith(f"{algorithm}:"): + # A OTK or fallback key was found for this query. + found_otk = True + # A fallback key was found for this query, no need to + # query further. + if key_json.get("fallback", False): + break + + else: + # No fallback key was found from appservices, query for it. + # Only mark the fallback key as used if no OTK was found + # (from either the database or appservices). + mark_as_used = not found_otk and not any( + key_id.startswith(f"{algorithm}:") + for key_id in otk_results.get(user_id, {}) + .get(device_id, {}) + .keys() + ) + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). + fallback_query.append((user_id, device_id, algorithm, mark_as_used)) + + else: + # All fallback keys get marked as used. + fallback_query = [ + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). + (user_id, device_id, algorithm, True) + for user_id, device_id, algorithm, count in not_found + ] # For each user that does not have a one-time keys available, see if # there is a fallback key. - fallback_results = await self.store.claim_e2e_fallback_keys(not_found) + fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) # Return the results in order, each item from the input query should # only appear once in the combined list. - return (otk_results, *appservice_results, fallback_results) + return (otk_results, appservice_results, fallback_results) @trace async def claim_one_time_keys( - self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] + self, + query: Dict[str, Dict[str, Dict[str, int]]], + timeout: Optional[int], + always_include_fallback_keys: bool, ) -> JsonDict: - local_query: List[Tuple[str, str, str]] = [] - remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} + local_query: List[Tuple[str, str, str, int]] = [] + remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} - for user_id, one_time_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in one_time_keys.items(): - local_query.append((user_id, device_id, algorithm)) + for device_id, algorithms in one_time_keys.items(): + for algorithm, count in algorithms.items(): + local_query.append((user_id, device_id, algorithm, count)) else: domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = one_time_keys @@ -617,7 +680,9 @@ async def claim_one_time_keys( set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.claim_local_one_time_keys(local_query) + results = await self.claim_local_one_time_keys( + local_query, always_include_fallback_keys + ) # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} @@ -625,7 +690,9 @@ async def claim_one_time_keys( for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): for key_id, key in keys.items(): - json_result.setdefault(user_id, {})[device_id] = {key_id: key} + json_result.setdefault(user_id, {}).setdefault( + device_id, {} + ).update({key_id: key}) # Remote failures. failures: Dict[str, JsonDict] = {} @@ -636,7 +703,7 @@ async def claim_client_keys(destination: str) -> None: device_keys = remote_queries[destination] try: remote_result = await self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys}, timeout=timeout + destination, device_keys, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 8d5be81a9207..06609fab93af 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -70,7 +70,9 @@ trace, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.devices import ( + ReplicationMultiUserDevicesResyncRestServlet, +) from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) @@ -167,8 +169,8 @@ def __init__(self, hs: "HomeServer"): self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) if hs.config.worker.worker_app: - self._user_device_resync = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) + self._multi_user_device_resync = ( + ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) ) else: self._device_list_updater = hs.get_device_handler().device_list_updater @@ -1487,9 +1489,11 @@ async def _resync_device(self, sender: str) -> None: # Immediately attempt a resync in the background if self._config.worker.worker_app: - await self._user_device_resync(user_id=sender) + await self._multi_user_device_resync(user_ids=[sender]) else: - await self._device_list_updater.user_device_resync(sender) + await self._device_list_updater.multi_user_device_resync( + user_ids=[sender] + ) except Exception: logger.exception("Failed to resync device for %s", sender) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 521ad9df5def..9c7a431a818d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -561,6 +561,8 @@ def __init__(self, hs: "HomeServer"): expiry_ms=30 * 60 * 1000, ) + self._msc3970_enabled = hs.config.experimental.msc3970_enabled + async def create_event( self, requester: Requester, @@ -701,9 +703,16 @@ async def create_event( if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) + # Save the access token ID, the device ID and the transaction ID in the event + # internal metadata. This is useful to determine if we should echo the + # transaction_id in events. + # See `synapse.events.utils.EventClientSerializer.serialize_event` if requester.access_token_id is not None: builder.internal_metadata.token_id = requester.access_token_id + if requester.device_id is not None: + builder.internal_metadata.device_id = requester.device_id + if txn_id is not None: builder.internal_metadata.txn_id = txn_id @@ -897,12 +906,31 @@ async def get_event_from_transaction( Returns: An event if one could be found, None otherwise. """ + + if self._msc3970_enabled and requester.device_id: + # When MSC3970 is enabled, we lookup for events sent by the same device first, + # and fallback to the old behaviour if none were found. + existing_event_id = ( + await self.store.get_event_id_from_transaction_id_and_device_id( + room_id, + requester.user.to_string(), + requester.device_id, + txn_id, + ) + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + # Pre-MSC3970, we looked up for events that were sent by the same session by + # using the access token ID. if requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, + existing_event_id = ( + await self.store.get_event_id_from_transaction_id_and_token_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) ) if existing_event_id: return await self.store.get_event(existing_event_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 9a81a77cbd3b..440d3f4acd64 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,9 +178,7 @@ async def set_displayname( authenticated_entity=requester.authenticated_entity, ) - await self.store.set_profile_displayname( - target_user.localpart, displayname_to_set - ) + await self.store.set_profile_displayname(target_user, displayname_to_set) profile = await self.store.get_profileinfo(target_user.localpart) await self.user_directory_handler.handle_local_profile_change( @@ -272,9 +270,7 @@ async def set_avatar_url( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url( - target_user.localpart, avatar_url_to_set - ) + await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) profile = await self.store.get_profileinfo(target_user.localpart) await self.user_directory_handler.handle_local_profile_change( diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1d09fdf13519..48246351625a 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -85,6 +85,7 @@ async def get_relations( event_id: str, room_id: str, pagin_config: PaginationConfig, + recurse: bool, include_original_event: bool, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -98,6 +99,7 @@ async def get_relations( event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. pagin_config: The pagination config rules to apply, if any. + recurse: Whether to recursively find relations. include_original_event: Whether to include the parent event. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -132,6 +134,7 @@ async def get_relations( direction=pagin_config.direction, from_token=pagin_config.from_token, to_token=pagin_config.to_token, + recurse=recurse, ) events = await self._main_store.get_events_as_list( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ec317e60239a..ed805d6ec87e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -169,6 +169,8 @@ def __init__(self, hs: "HomeServer"): self.request_ratelimiter = hs.get_request_ratelimiter() hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) + self._msc3970_enabled = hs.config.experimental.msc3970_enabled + def _on_user_joined_room(self, event_id: str, room_id: str) -> None: """Notify the rate limiter that a room join has occurred. @@ -399,13 +401,30 @@ async def _local_membership_update( # Check if we already have an event with a matching transaction ID. (We # do this check just before we persist an event as well, but may as well # do it up front for efficiency.) - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) + if txn_id: + existing_event_id = None + if self._msc3970_enabled and requester.device_id: + # When MSC3970 is enabled, we lookup for events sent by the same device + # first, and fallback to the old behaviour if none were found. + existing_event_id = ( + await self.store.get_event_id_from_transaction_id_and_device_id( + room_id, + requester.user.to_string(), + requester.device_id, + txn_id, + ) + ) + + if requester.access_token_id and not existing_event_id: + existing_event_id = ( + await self.store.get_event_id_from_transaction_id_and_token_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + ) + if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) return existing_event_id, event_pos.stream diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3302d4e48a0f..634882487c06 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,7 +17,6 @@ import logging import random import sys -import typing import urllib.parse from http import HTTPStatus from io import BytesIO, StringIO @@ -30,9 +29,11 @@ Generic, List, Optional, + TextIO, Tuple, TypeVar, Union, + cast, overload, ) @@ -183,20 +184,61 @@ def get_json(self) -> Optional[JsonDict]: return self.json -class JsonParser(ByteParser[Union[JsonDict, list]]): +class _BaseJsonParser(ByteParser[T]): """A parser that buffers the response and tries to parse it as JSON.""" CONTENT_TYPE = "application/json" - def __init__(self) -> None: + def __init__( + self, validator: Optional[Callable[[Optional[object]], bool]] = None + ) -> None: + """ + Args: + validator: A callable which takes the parsed JSON value and returns + true if the value is valid. + """ self._buffer = StringIO() self._binary_wrapper = BinaryIOWrapper(self._buffer) + self._validator = validator def write(self, data: bytes) -> int: return self._binary_wrapper.write(data) - def finish(self) -> Union[JsonDict, list]: - return json_decoder.decode(self._buffer.getvalue()) + def finish(self) -> T: + result = json_decoder.decode(self._buffer.getvalue()) + if self._validator is not None and not self._validator(result): + raise ValueError( + f"Received incorrect JSON value: {result.__class__.__name__}" + ) + return result + + +class JsonParser(_BaseJsonParser[JsonDict]): + """A parser that buffers the response and tries to parse it as a JSON object.""" + + def __init__(self) -> None: + super().__init__(self._validate) + + @staticmethod + def _validate(v: Any) -> bool: + return isinstance(v, dict) + + +class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): + """Ensure the legacy responses of /send_join & /send_leave are correct.""" + + def __init__(self) -> None: + super().__init__(self._validate) + + @staticmethod + def _validate(v: Any) -> bool: + # Match [integer, JSON dict] + return ( + isinstance(v, list) + and len(v) == 2 + and type(v[0]) == int + and isinstance(v[1], dict) + ) async def _handle_response( @@ -313,9 +355,7 @@ async def _handle_response( class BinaryIOWrapper: """A wrapper for a TextIO which converts from bytes on the fly.""" - def __init__( - self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" - ): + def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"): self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.file = file @@ -793,7 +833,7 @@ async def put_json( backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: ... @overload @@ -825,8 +865,8 @@ async def put_json( ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser] = None, - ): + parser: Optional[ByteParser[T]] = None, + ) -> Union[JsonDict, T]: """Sends the specified json data using PUT Args: @@ -902,7 +942,7 @@ async def put_json( _sec_timeout = self.default_timeout if parser is None: - parser = JsonParser() + parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( self.reactor, @@ -924,7 +964,7 @@ async def post_json( timeout: Optional[int] = None, ignore_backoff: bool = False, args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: """Sends the specified json data using POST Args: @@ -998,7 +1038,7 @@ async def get_json( ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: ... @overload @@ -1024,8 +1064,8 @@ async def get_json( timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser] = None, - ): + parser: Optional[ByteParser[T]] = None, + ) -> Union[JsonDict, T]: """GETs some json from the given host homeserver and path Args: @@ -1091,7 +1131,7 @@ async def get_json( _sec_timeout = self.default_timeout if parser is None: - parser = JsonParser() + parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( self.reactor, @@ -1112,7 +1152,7 @@ async def delete_json( timeout: Optional[int] = None, ignore_backoff: bool = False, args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: """Send a DELETE request to the remote expecting some json response Args: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index eeafea74d15c..90eff030b573 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -105,6 +105,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK, SpamCheckerModuleApiCallbacks, ) +from synapse.push.httppusher import HttpPusher from synapse.rest.client.login import LoginResponse from synapse.storage import DataStore from synapse.storage.background_updates import ( @@ -248,6 +249,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._pusherpool = hs.get_pusherpool() self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory self._callbacks = hs.get_module_api_callbacks() @@ -1225,6 +1227,50 @@ async def sleep(self, seconds: float) -> None: await self._clock.sleep(seconds) + async def send_http_push_notification( + self, + user_id: str, + device_id: Optional[str], + content: JsonDict, + tweaks: Optional[JsonMapping] = None, + default_payload: Optional[JsonMapping] = None, + ) -> Dict[str, bool]: + """Send an HTTP push notification that is forwarded to the registered push gateway + for the specified user/device. + + Added in Synapse v1.82.0. + + Args: + user_id: The user ID to send the push notification to. + device_id: The device ID of the device where to send the push notification. If `None`, + the notification will be sent to all registered HTTP pushers of the user. + content: A dict of values that will be put in the `notification` field of the push + (cf Push Gateway spec). `devices` field will be overrided if included. + tweaks: A dict of `tweaks` that will be inserted in the `devices` section, cf spec. + default_payload: default payload to add in `devices[0].data.default_payload`. + This will be merged (and override if some matching values already exist there) + with existing `default_payload`. + + Returns: + a dict reprensenting the status of the push per device ID + """ + status = {} + if user_id in self._pusherpool.pushers: + for p in self._pusherpool.pushers[user_id].values(): + if isinstance(p, HttpPusher) and ( + not device_id or p.device_id == device_id + ): + res = await p.dispatch_push(content, tweaks, default_payload) + # Check if the push was successful and no pushers were rejected. + sent = res is not False and not res + + # This is mainly to accomodate mypy + # device_id should never be empty after the `set_device_id_for_pushers` + # background job has been properly run. + if p.device_id: + status[p.device_id] = sent + return status + async def send_mail( self, recipient: str, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7021c2c8aa30..b5a2d9994c5f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -327,6 +327,7 @@ async def _action_for_event_by_user( not event.internal_metadata.is_notifiable() or event.internal_metadata.is_historical() or event.content.get(EventContentFields.MSC2716_HISTORICAL) + or event.room_id in self.hs.config.server.rooms_to_exclude_from_sync ): # Push rules for events that aren't notifiable can't be processed by this and # we want to skip push notification actions for historical messages diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d95fe06d4128..25914340515f 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -27,6 +27,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction +from synapse.types import JsonDict, JsonMapping from . import push_tools @@ -56,7 +57,7 @@ ) -def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> JsonMapping: """ Converts a list of actions into a `tweaks` dict (which can then be passed to the push gateway). @@ -101,6 +102,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name + self.device_id = pusher_config.device_id self.pushkey_ts = pusher_config.ts self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC @@ -334,7 +336,7 @@ async def _process_one(self, push_action: HttpPushAction) -> bool: event = await self.store.get_event(push_action.event_id, allow_none=True) if event is None: return True # It's been redacted - rejected = await self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push_event(event, tweaks, badge) if rejected is False: return False @@ -352,10 +354,84 @@ async def _process_one(self, push_action: HttpPushAction) -> bool: await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id) return True - async def _build_notification_dict( - self, event: EventBase, tweaks: Dict[str, bool], badge: int - ) -> Dict[str, Any]: - priority = "high" # Beeper: always use high priority notifications + async def dispatch_push( + self, + content: JsonDict, + tweaks: Optional[JsonMapping] = None, + default_payload: Optional[JsonMapping] = None, + ) -> Union[bool, List[str]]: + """Send a notification to the registered push gateway, with `content` being + the content of the `notification` top property specified in the spec. + Note that the `devices` property will be added with device-specific + information for this pusher. + + Args: + content: the content + tweaks: tweaks to add into the `devices` section + default_payload: default payload to add in `devices[0].data.default_payload`. + This will be merged (and override if some matching values already exist there) + with existing `default_payload`. + + Returns: + False if an error occured when calling the push gateway, or an array of + rejected push keys otherwise. If this array is empty, the push fully + succeeded. + """ + content = content.copy() + + data = self.data_minus_url.copy() + if default_payload: + data.setdefault("default_payload", {}).update(default_payload) + + device = { + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": int(self.pushkey_ts / 1000), + "data": data, + } + if tweaks: + device["tweaks"] = tweaks + + content["devices"] = [device] + + try: + resp = await self.http_client.post_json_get_json( + self.url, {"notification": content} + ) + except Exception as e: + logger.warning( + "Failed to push data to %s: %s %s", + self.name, + type(e), + e, + ) + return False + rejected = [] + if "rejected" in resp: + rejected = resp["rejected"] + return rejected + + async def dispatch_push_event( + self, + event: EventBase, + tweaks: JsonMapping, + badge: int, + ) -> Union[bool, List[str]]: + """Send a notification to the registered push gateway by building it + from an event. + + Args: + event: the event + tweaks: tweaks to add into the `devices` section, used to decide the + push priority + badge: unread count to send with the push notification + + Returns: + False if an error occured when calling the push gateway, or an array of + rejected push keys otherwise. If this array is empty, the push fully + succeeded. + """ + priority = "low" if ( event.type == EventTypes.Encrypted or tweaks.get("highlight") @@ -368,34 +444,24 @@ async def _build_notification_dict( # This was checked in the __init__, but mypy doesn't seem to know that. assert self.data is not None if self.data.get("format") == "event_id_only": - d: Dict[str, Any] = { - "notification": { - "event_id": event.event_id, - "room_id": event.room_id, - "counts": { - "unread": badge, - "com.beeper.server_type": "synapse", - }, - "com.beeper.user_id": self.user_id, - "prio": priority, - "devices": [ - { - "app_id": self.app_id, - "pushkey": self.pushkey, - "pushkey_ts": int(self.pushkey_ts / 1000), - "data": self.data_minus_url, - } - ], - } + content: JsonDict = { + "event_id": event.event_id, + "room_id": event.room_id, + "counts": { + "unread": badge, + "com.beeper.server_type": "synapse", + }, + "com.beeper.user_id": self.user_id, + "prio": priority, } - return d - - ctx = await push_tools.get_context_for_event( - self._storage_controllers, event, self.user_id - ) + # event_id_only doesn't include the tweaks, so override them. + tweaks = {} + else: + ctx = await push_tools.get_context_for_event( + self._storage_controllers, event, self.user_id + ) - d = { - "notification": { + content = { "id": event.event_id, # deprecated: remove soon "event_id": event.event_id, "room_id": event.room_id, @@ -408,57 +474,27 @@ async def _build_notification_dict( # 'missed_calls': 2 }, "com.beeper.user_id": self.user_id, - "devices": [ - { - "app_id": self.app_id, - "pushkey": self.pushkey, - "pushkey_ts": int(self.pushkey_ts / 1000), - "data": self.data_minus_url, - "tweaks": tweaks, - } - ], } - } - if event.type == "m.room.member" and event.is_state(): - d["notification"]["membership"] = event.content["membership"] - d["notification"]["user_is_target"] = event.state_key == self.user_id - if self.hs.config.push.push_include_content and event.content: - d["notification"]["content"] = event.content - - # We no longer send aliases separately, instead, we send the human - # readable name of the room, which may be an alias. - if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0: - d["notification"]["sender_display_name"] = ctx["sender_display_name"] - if "name" in ctx and len(ctx["name"]) > 0: - d["notification"]["room_name"] = ctx["name"] - - return d - - async def dispatch_push( - self, event: EventBase, tweaks: Dict[str, bool], badge: int - ) -> Union[bool, Iterable[str]]: - notification_dict = await self._build_notification_dict(event, tweaks, badge) - if not notification_dict: - return [] - try: - resp = await self.http_client.post_json_get_json( - self.url, notification_dict - ) - except Exception as e: - logger.warning( - "Failed to push event %s to %s: %s %s", - event.event_id, - self.name, - type(e), - e, - ) - return False - rejected = [] - if "rejected" in resp: - rejected = resp["rejected"] - if not rejected: + if event.type == "m.room.member" and event.is_state(): + content["membership"] = event.content["membership"] + content["user_is_target"] = event.state_key == self.user_id + if self.hs.config.push.push_include_content and event.content: + content["content"] = event.content + + # We no longer send aliases separately, instead, we send the human + # readable name of the room, which may be an alias. + if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0: + content["sender_display_name"] = ctx["sender_display_name"] + if "name" in ctx and len(ctx["name"]) > 0: + content["room_name"] = ctx["name"] + + res = await self.dispatch_push(content, tweaks) + + # If the push is successful and none are rejected, update the badge count. + if res is not False and not res: self.badge_count_last_call = badge - return rejected + + return res async def _send_badge(self, badge: int) -> None: """ diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index cc3929dcf565..f874f072f901 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -28,62 +28,6 @@ logger = logging.getLogger(__name__) -class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): - """Ask master to resync the device list for a user by contacting their - server. - - This must happen on master so that the results can be correctly cached in - the database and streamed to workers. - - Request format: - - POST /_synapse/replication/user_device_resync/:user_id - - {} - - Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id` - response, e.g.: - - { - "user_id": "@alice:example.org", - "devices": [ - { - "device_id": "JLAFKJWSCS", - "keys": { ... }, - "device_display_name": "Alice's Mobile Phone" - } - ] - } - """ - - NAME = "user_device_resync" - PATH_ARGS = ("user_id",) - CACHE = False - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - from synapse.handlers.device import DeviceHandler - - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self.device_list_updater = handler.device_list_updater - - self.store = hs.get_datastores().main - self.clock = hs.get_clock() - - @staticmethod - async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] - return {} - - async def _handle_request( # type: ignore[override] - self, request: Request, content: JsonDict, user_id: str - ) -> Tuple[int, Optional[JsonDict]]: - user_devices = await self.device_list_updater.user_device_resync(user_id) - - return 200, user_devices - - class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): """Ask master to resync the device list for multiple users from the same remote server by contacting their server. @@ -216,6 +160,5 @@ async def _handle_request( # type: ignore[override] def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - ReplicationUserDevicesResyncRestServlet(hs).register(http_server) ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 79f22a59f195..c729364839c0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -39,6 +39,7 @@ EventReportDetailRestServlet, EventReportsRestServlet, ) +from synapse.rest.admin.experimental_features import ExperimentalFeaturesRestServlet from synapse.rest.admin.federation import ( DestinationMembershipRestServlet, DestinationResetConnectionRestServlet, @@ -68,7 +69,10 @@ RoomTimestampToEventRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet -from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet +from synapse.rest.admin.statistics import ( + LargestRoomsStatistics, + UserMediaStatisticsRestServlet, +) from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountDataRestServlet, @@ -259,6 +263,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) + LargestRoomsStatistics(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) AccountDataRestServlet(hs).register(http_server) @@ -288,6 +293,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: BackgroundUpdateEnabledRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server) + ExperimentalFeaturesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py new file mode 100644 index 000000000000..1d409ac2b7b0 --- /dev/null +++ b/synapse/rest/admin/experimental_features.py @@ -0,0 +1,119 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum +from http import HTTPStatus +from typing import TYPE_CHECKING, Dict, Tuple + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.admin import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class ExperimentalFeature(str, Enum): + """ + Currently supported per-user features + """ + + MSC3026 = "msc3026" + MSC2654 = "msc2654" + MSC3881 = "msc3881" + MSC3967 = "msc3967" + + +class ExperimentalFeaturesRestServlet(RestServlet): + """ + Enable or disable experimental features for a user or determine which features are enabled + for a given user + """ + + PATTERNS = admin_patterns("/experimental_features/(?P[^/]*)") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.is_mine = hs.is_mine + + async def on_GET( + self, + request: SynapseRequest, + user_id: str, + ) -> Tuple[int, JsonDict]: + """ + List which features are enabled for a given user + """ + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "User must be local to check what experimental features are enabled.", + ) + + enabled_features = await self.store.list_enabled_features(user_id) + + user_features = {} + for feature in ExperimentalFeature: + if feature in enabled_features: + user_features[feature] = True + else: + user_features[feature] = False + return HTTPStatus.OK, {"features": user_features} + + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[HTTPStatus, Dict]: + """ + Enable or disable the provided features for the requester + """ + await assert_requester_is_admin(self.auth, request) + + body = parse_json_object_from_request(request) + + target_user = UserID.from_string(user_id) + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "User must be local to enable experimental features.", + ) + + features = body.get("features") + if not features: + raise SynapseError( + HTTPStatus.BAD_REQUEST, "You must provide features to set." + ) + + # validate the provided features + validated_features = {} + for feature, enabled in features.items(): + try: + validated_feature = ExperimentalFeature(feature) + validated_features[validated_feature] = enabled + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"{feature!r} is not recognised as a valid experimental feature.", + ) + + await self.store.set_features_for_user(user_id, validated_features) + + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 9c45f4650dc3..19780e4b4ca6 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -113,3 +113,28 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ret["next_token"] = start + len(users_media) return HTTPStatus.OK, ret + + +class LargestRoomsStatistics(RestServlet): + """Get the largest rooms by database size. + + Only works when using PostgreSQL. + """ + + PATTERNS = admin_patterns("/statistics/database/rooms$") + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.stats_controller = hs.get_storage_controllers().stats + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + room_sizes = await self.stats_controller.get_room_db_size_estimate() + + return HTTPStatus.OK, { + "rooms": [ + {"room_id": room_id, "estimated_size": size} + for room_id, size in room_sizes + ] + } diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index ab7d8c94191b..04561f36d7a1 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -94,7 +94,7 @@ async def on_POST( set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) filter_id = await self.filtering.add_user_filter( - user_localpart=target_user.localpart, user_filter=content + user_id=target_user, user_filter=content ) return 200, {"filter_id": str(filter_id)} diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 6209b79b019e..9bbab5e6241e 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,7 +15,9 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Optional, Tuple +import re +from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.http.server import HttpServer @@ -288,7 +290,64 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + + # Generate a count for each algorithm, which is hard-coded to 1. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for user_id, one_time_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithm in one_time_keys.items(): + query.setdefault(user_id, {})[device_id] = {algorithm: 1} + + result = await self.e2e_keys_handler.claim_one_time_keys( + query, timeout, always_include_fallback_keys=False + ) + return 200, result + + +class UnstableOneTimeKeyServlet(RestServlet): + """ + Identical to the stable endpoint (OneTimeKeyServlet) except it allows for + querying for multiple OTKs at once and always includes fallback keys in the + response. + + POST /keys/claim HTTP/1.1 + { + "one_time_keys": { + "": { + "": ["", ...] + } } } + + HTTP/1.1 200 OK + { + "one_time_keys": { + "": { + "": { + ":": "" + } } } } + + """ + + PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] + CATEGORY = "Encryption requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + + # Generate a count for each algorithm. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for user_id, one_time_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithms in one_time_keys.items(): + query.setdefault(user_id, {})[device_id] = Counter(algorithms) + + result = await self.e2e_keys_handler.claim_one_time_keys( + query, timeout, always_include_fallback_keys=True + ) return 200, result @@ -394,6 +453,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + if hs.config.experimental.msc3983_appservice_otk_claims: + UnstableOneTimeKeyServlet(hs).register(http_server) if hs.config.worker.worker_app is None: SigningKeyUploadServlet(hs).register(http_server) SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b8b296bc0cb8..785dfa08d845 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,7 +19,7 @@ from synapse.api.constants import Direction from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.relations import ThreadsNextBatch @@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self._store = hs.get_datastores().main self._relations_handler = hs.get_relations_handler() + self._support_recurse = hs.config.experimental.msc3981_recurse_relations async def on_GET( self, @@ -63,6 +64,12 @@ async def on_GET( pagination_config = await PaginationConfig.from_request( self._store, request, default_limit=5, default_dir=Direction.BACKWARDS ) + if self._support_recurse: + recurse = parse_boolean( + request, "org.matrix.msc3981.recurse", default=False + ) + else: + recurse = False # The unstable version of this API returns an extra field for client # compatibility, see https://github.com/matrix-org/synapse/issues/12930. @@ -75,6 +82,7 @@ async def on_GET( event_id=parent_id, room_id=room_id, pagin_config=pagination_config, + recurse=recurse, include_original_event=include_original_event, relation_type=relation_type, event_type=event_type, diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index f2aaab622740..0d8a63d8beda 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -50,6 +50,8 @@ def __init__(self, hs: "HomeServer"): # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + self._msc3970_enabled = hs.config.experimental.msc3970_enabled + def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. @@ -58,6 +60,7 @@ def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hasha requests to the same endpoint. The key is formed from the HTTP request path and attributes from the requester: the access_token_id for regular users, the user ID for guest users, and the appservice ID for appservice users. + With MSC3970, for regular users, the key is based on the user ID and device ID. Args: request: The incoming request. @@ -67,11 +70,21 @@ def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hasha """ assert request.path is not None path: str = request.path.decode("utf8") + if requester.is_guest: assert requester.user is not None, "Guest requester must have a user ID set" return (path, "guest", requester.user) + elif requester.app_service is not None: return (path, "appservice", requester.app_service.id) + + # With MSC3970, we use the user ID and device ID as the transaction key + elif self._msc3970_enabled: + assert requester.user, "Requester must have a user" + assert requester.device_id, "Requester must have a device_id" + return (path, "user", requester.user, requester.device_id) + + # Otherwise, the pre-MSC3970 behaviour is to use the access token ID else: assert ( requester.access_token_id is not None diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3bdb6ec9098d..ff0454ca5706 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -155,7 +155,7 @@ async def query_keys( for key_id in key_ids: store_queries.append((server_name, key_id, None)) - cached = await self.store.get_server_keys_json(store_queries) + cached = await self.store.get_server_keys_json_for_remote(store_queries) json_results: Set[bytes] = set() diff --git a/synapse/server.py b/synapse/server.py index a79dc66a3026..a35c31a4bd44 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -763,7 +763,9 @@ def get_oidc_handler(self) -> "OidcHandler": @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit) + return EventClientSerializer( + msc3970_enabled=self.config.experimental.msc3970_enabled + ) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 45101cda7adf..0ef860263104 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -19,6 +19,7 @@ ) from synapse.storage.controllers.purge_events import PurgeEventsStorageController from synapse.storage.controllers.state import StateStorageController +from synapse.storage.controllers.stats import StatsController from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -40,6 +41,7 @@ def __init__(self, hs: "HomeServer", stores: Databases): self.purge_events = PurgeEventsStorageController(hs, stores) self.state = StateStorageController(hs, stores) + self.stats = StatsController(hs, stores) self.persistence = None if stores.persist_events: diff --git a/synapse/storage/controllers/stats.py b/synapse/storage/controllers/stats.py new file mode 100644 index 000000000000..988e44c6af4a --- /dev/null +++ b/synapse/storage/controllers/stats.py @@ -0,0 +1,113 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import Counter +from typing import TYPE_CHECKING, Collection, List, Tuple + +from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases import Databases +from synapse.storage.engines import PostgresEngine + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class StatsController: + """High level interface for getting statistics.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + self.stores = stores + + async def get_room_db_size_estimate(self) -> List[Tuple[str, int]]: + """Get an estimate of the largest rooms and how much database space they + use, in bytes. + + Only works against PostgreSQL. + + Note: this uses the postgres statistics so is a very rough estimate. + """ + + # Note: We look at both tables on the main and state databases. + if not isinstance(self.stores.main.database_engine, PostgresEngine): + raise SynapseError(400, "Endpoint requires using PostgreSQL") + + if not isinstance(self.stores.state.database_engine, PostgresEngine): + raise SynapseError(400, "Endpoint requires using PostgreSQL") + + # For each "large" table, we go through and get the largest rooms + # and an estimate of how much space they take. We can then sum the + # results and return the top 10. + # + # This isn't the most accurate, but given all of these are estimates + # anyway its good enough. + room_estimates: Counter[str] = Counter() + + # Return size of the table on disk, including indexes and TOAST. + table_sql = """ + SELECT pg_total_relation_size(?) + """ + + # Get an estimate for the largest rooms and their frequency. + # + # Note: the cast here is a hack to cast from `anyarray` to an actual + # type. This ensures that psycopg2 passes us a back a a Python list. + column_sql = """ + SELECT + most_common_vals::TEXT::TEXT[], most_common_freqs::TEXT::NUMERIC[] + FROM pg_stats + WHERE tablename = ? and attname = 'room_id' + """ + + def get_room_db_size_estimate_txn( + txn: LoggingTransaction, + tables: Collection[str], + ) -> None: + for table in tables: + txn.execute(table_sql, (table,)) + row = txn.fetchone() + assert row is not None + (table_size,) = row + + txn.execute(column_sql, (table,)) + row = txn.fetchone() + assert row is not None + vals, freqs = row + + for room_id, freq in zip(vals, freqs): + room_estimates[room_id] += int(freq * table_size) + + await self.stores.main.db_pool.runInteraction( + "get_room_db_size_estimate_main", + get_room_db_size_estimate_txn, + ( + "event_json", + "events", + "event_search", + "event_edges", + "event_push_actions", + "stream_ordering_to_exterm", + ), + ) + + await self.stores.state.db_pool.runInteraction( + "get_room_db_size_estimate_state", + get_room_db_size_estimate_txn, + ("state_groups_state",), + ) + + return room_estimates.most_common(10) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index ce3d1d4e942e..7aa24ccf2121 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -95,7 +95,7 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` if hs.get_instance_name() in hs.config.worker.writers.events: - persist_events = PersistEventsStore(hs, database, main, db_conn) + persist_events = PersistEventsStore(hs, database, main, db_conn) # type: ignore[arg-type] if "state" in database_config.databases: logger.info( @@ -133,6 +133,6 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # We use local variables here to ensure that the databases do not have # optional types. - self.main = main + self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 216d25da8d78..824698372fcf 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -44,6 +44,7 @@ from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .events_forward_extremities import EventForwardExtremitiesStore +from .experimental_features import ExperimentalFeaturesStore from .filtering import FilteringWorkerStore from .keys import KeyStore from .lock import LockStore @@ -83,6 +84,7 @@ class DataStore( EventsBackgroundUpdatesStore, + ExperimentalFeaturesStore, DeviceStore, RoomMemberStore, RoomStore, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 7cb14f271fe5..e76628d48379 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -210,13 +210,13 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: ) elif row.type == EventsStreamCurrentStateRow.TypeId: assert isinstance(data, EventsStreamCurrentStateRow) - self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( + self.get_rooms_for_user_with_stream_ordering.invalidate( # type: ignore[attr-defined] (data.state_key,) ) - self.get_rooms_for_user.invalidate((data.state_key,)) + self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined] else: raise Exception("Unknown events stream row type %s" % (row.type,)) @@ -234,7 +234,7 @@ def _invalidate_caches_for_event( # This invalidates any local in-memory cached event objects, the original # process triggering the invalidation is responsible for clearing any external # cached objects. - self._invalidate_local_get_event_cache(event_id) + self._invalidate_local_get_event_cache(event_id) # type: ignore[attr-defined] self._attempt_to_invalidate_cache("have_seen_event", (room_id, event_id)) self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) @@ -247,10 +247,10 @@ def _invalidate_caches_for_event( self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,)) if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) # type: ignore[attr-defined] if redacts: - self._invalidate_local_get_event_cache(redacts) + self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined] # Caches which might leak edits must be invalidated for the event being # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) @@ -259,7 +259,7 @@ def _invalidate_caches_for_event( self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) # type: ignore[attr-defined] self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) @@ -386,6 +386,8 @@ def _send_invalidation_to_replication( ) if isinstance(self.database_engine, PostgresEngine): + assert self._cache_id_gen is not None + # get_next() returns a context manager which is designed to wrap # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index dc7768c50cab..4bc391f21316 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1027,8 +1027,10 @@ def get_device_stream_token(self) -> int: ... async def claim_e2e_one_time_keys( - self, query_list: Iterable[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, query_list: Iterable[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Take a list of one time keys out of the database. Args: @@ -1043,8 +1045,12 @@ async def claim_e2e_one_time_keys( @trace def _claim_e2e_one_time_key_simple( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. Returns: @@ -1055,36 +1061,41 @@ def _claim_e2e_one_time_key_simple( sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? """ - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is None: - return None - - key_id, key_json = otk_row + txn.execute(sql, (user_id, device_id, algorithm, count)) + otk_rows = list(txn) + if not otk_rows: + return [] - self.db_pool.simple_delete_one_txn( + self.db_pool.simple_delete_many_txn( txn, table="e2e_one_time_keys_json", + column="key_id", + values=[otk_row[0] for otk_row in otk_rows], keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, - "key_id": key_id, }, ) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] @trace def _claim_e2e_one_time_key_returning( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. Returns: @@ -1099,28 +1110,30 @@ def _claim_e2e_one_time_key_returning( AND key_id IN ( SELECT key_id FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? ) RETURNING key_id, key_json """ txn.execute( - sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) + sql, + (user_id, device_id, algorithm, user_id, device_id, algorithm, count), ) - otk_row = txn.fetchone() - if otk_row is None: - return None + otk_rows = list(txn) + if not otk_rows: + return [] self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - key_id, key_json = otk_row - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} - missing: List[Tuple[str, str, str]] = [] - for user_id, device_id, algorithm in query_list: + missing: List[Tuple[str, str, str, int]] = [] + for user_id, device_id, algorithm, count in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that # allows us to use autocommit mode. @@ -1130,37 +1143,42 @@ def _claim_e2e_one_time_key_returning( _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - claim_row = await self.db_pool.runInteraction( + claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, device_id, algorithm, + count, db_autocommit=db_autocommit, ) - if claim_row: + if claim_rows: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) - else: - missing.append((user_id, device_id, algorithm)) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) return results, missing async def claim_e2e_fallback_keys( - self, query_list: Iterable[Tuple[str, str, str]] + self, query_list: Iterable[Tuple[str, str, str, bool]] ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: """Take a list of fallback keys out of the database. Args: - query_list: An iterable of tuples of (user ID, device ID, algorithm). + query_list: An iterable of tuples of + (user ID, device ID, algorithm, whether the key should be marked as used). Returns: A map of user ID -> a map device ID -> a map of key ID -> JSON. """ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} - for user_id, device_id, algorithm in query_list: + for user_id, device_id, algorithm, mark_as_used in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", keyvalues={ @@ -1180,7 +1198,7 @@ async def claim_e2e_fallback_keys( used = row["used"] # Mark fallback key as used if not already. - if not used: + if not used and mark_as_used: await self.db_pool.simple_update_one( table="e2e_fallback_keys_json", keyvalues={ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index f652f54cb6d9..e71d23ad64f5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -127,6 +127,8 @@ def __init__( self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + self._msc3970_enabled = hs.config.experimental.msc3970_enabled + @trace async def _persist_events_and_state_updates( self, @@ -977,23 +979,43 @@ def _persist_transaction_ids_txn( ) -> None: """Persist the mapping from transaction IDs to event IDs (if defined).""" - to_insert = [] + inserted_ts = self._clock.time_msec() + to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = [] + to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = [] for event, _ in events_and_contexts: - token_id = getattr(event.internal_metadata, "token_id", None) txn_id = getattr(event.internal_metadata, "txn_id", None) - if token_id and txn_id: - to_insert.append( - ( - event.event_id, - event.room_id, - event.sender, - token_id, - txn_id, - self._clock.time_msec(), + token_id = getattr(event.internal_metadata, "token_id", None) + device_id = getattr(event.internal_metadata, "device_id", None) + + if txn_id is not None: + if token_id is not None: + to_insert_token_id.append( + ( + event.event_id, + event.room_id, + event.sender, + token_id, + txn_id, + inserted_ts, + ) + ) + + if device_id is not None: + to_insert_device_id.append( + ( + event.event_id, + event.room_id, + event.sender, + device_id, + txn_id, + inserted_ts, + ) ) - ) - if to_insert: + # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events. + # Since this is an experimental flag, we still store the mapping even if the + # flag is disabled. + if to_insert_token_id: self.db_pool.simple_insert_many_txn( txn, table="event_txn_id", @@ -1005,7 +1027,25 @@ def _persist_transaction_ids_txn( "txn_id", "inserted_ts", ), - values=to_insert, + values=to_insert_token_id, + ) + + # With MSC3970, we rely on the device_id instead to scope the txn_id for events. + # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970 + # behaviour would allow for a UNIQUE constraint violation on this table + if to_insert_device_id and self._msc3970_enabled: + self.db_pool.simple_insert_many_txn( + txn, + table="event_txn_id_device_id", + keys=( + "event_id", + "room_id", + "user_id", + "device_id", + "txn_id", + "inserted_ts", + ), + values=to_insert_device_id, ) async def update_current_state( @@ -1127,11 +1167,15 @@ def _update_current_state_txn( # been inserted into room_memberships. txn.execute_batch( """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, type, state_key, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[0], key[1], ev_id, ev_id) + (room_id, key[0], key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() ], ) @@ -1158,11 +1202,15 @@ def _update_current_state_txn( if to_insert: txn.execute_batch( """INSERT INTO local_current_membership - (room_id, user_id, event_id, membership) - VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, user_id, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[1], ev_id, ev_id) + (room_id, key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() if key[0] == EventTypes.Member and self.is_mine_id(key[1]) ], @@ -1768,6 +1816,7 @@ def _store_room_members_txn( table="room_memberships", keys=( "event_id", + "event_stream_ordering", "user_id", "sender", "room_id", @@ -1778,6 +1827,7 @@ def _store_room_members_txn( values=[ ( event.event_id, + event.internal_metadata.stream_ordering, event.state_key, event.user_id, event.room_id, @@ -1810,6 +1860,7 @@ def _store_room_members_txn( keyvalues={"room_id": event.room_id, "user_id": event.state_key}, values={ "event_id": event.event_id, + "event_stream_ordering": event.internal_metadata.stream_ordering, "membership": event.membership, }, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8e5ebd53efcc..176ef3bcdb25 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -2041,7 +2041,7 @@ def get_next_event_to_expire_txn( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - async def get_event_id_from_transaction_id( + async def get_event_id_from_transaction_id_and_token_id( self, room_id: str, user_id: str, token_id: int, txn_id: str ) -> Optional[str]: """Look up if we have already persisted an event for the transaction ID, @@ -2057,7 +2057,26 @@ async def get_event_id_from_transaction_id( }, retcol="event_id", allow_none=True, - desc="get_event_id_from_transaction_id", + desc="get_event_id_from_transaction_id_and_token_id", + ) + + async def get_event_id_from_transaction_id_and_device_id( + self, room_id: str, user_id: str, device_id: str, txn_id: str + ) -> Optional[str]: + """Look up if we have already persisted an event for the transaction ID, + returning the event ID if so. + """ + return await self.db_pool.simple_select_one_onecol( + table="event_txn_id_device_id", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "device_id": device_id, + "txn_id": txn_id, + }, + retcol="event_id", + allow_none=True, + desc="get_event_id_from_transaction_id_and_device_id", ) async def get_already_persisted_events( @@ -2087,7 +2106,7 @@ async def get_already_persisted_events( # Check if this is a duplicate of an event we've already # persisted. - existing = await self.get_event_id_from_transaction_id( + existing = await self.get_event_id_from_transaction_id_and_token_id( event.room_id, event.sender, token_id, txn_id ) if existing: @@ -2103,11 +2122,17 @@ async def _cleanup_old_transaction_ids(self) -> None: """Cleans out transaction id mappings older than 24hrs.""" def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? """ - one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn.execute(sql, (one_day_ago,)) + + sql = """ + DELETE FROM event_txn_id_device_id + WHERE inserted_ts < ? + """ txn.execute(sql, (one_day_ago,)) return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py new file mode 100644 index 000000000000..cf3226ae5a70 --- /dev/null +++ b/synapse/storage/databases/main/experimental_features.py @@ -0,0 +1,75 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.databases.main import CacheInvalidationWorkerStore +from synapse.types import StrCollection +from synapse.util.caches.descriptors import cached + +if TYPE_CHECKING: + from synapse.rest.admin.experimental_features import ExperimentalFeature + from synapse.server import HomeServer + + +class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ) -> None: + super().__init__(database, db_conn, hs) + + @cached() + async def list_enabled_features(self, user_id: str) -> StrCollection: + """ + Checks to see what features are enabled for a given user + Args: + user: + the user to be queried on + Returns: + the features currently enabled for the user + """ + enabled = await self.db_pool.simple_select_list( + "per_user_experimental_features", + {"user_id": user_id, "enabled": True}, + ["feature"], + ) + + return [feature["feature"] for feature in enabled] + + async def set_features_for_user( + self, + user: str, + features: Dict["ExperimentalFeature", bool], + ) -> None: + """ + Enables or disables features for a given user + Args: + user: + the user for whom to enable/disable the features + features: + pairs of features and True/False for whether the feature should be enabled + """ + for feature, enabled in features.items(): + await self.db_pool.simple_upsert( + table="per_user_experimental_features", + keyvalues={"feature": feature, "user_id": user}, + values={"enabled": enabled}, + insertion_values={"user_id": user, "feature": feature}, + ) + + await self.invalidate_cache_and_stream("list_enabled_features", (user,)) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 8e57c8e5a07a..50516402f96f 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -16,15 +16,38 @@ from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json +from typing_extensions import TYPE_CHECKING from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction -from synapse.types import JsonDict +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict, UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + class FilteringWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + self.db_pool.updates.register_background_index_update( + "full_users_filters_unique_idx", + index_name="full_users_unique_idx", + table="user_filters", + columns=["full_user_id, filter_id"], + unique=True, + ) + @cached(num_args=2) async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] @@ -46,7 +69,7 @@ async def get_user_filter( return db_to_json(def_json) - async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: + async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -56,13 +79,13 @@ def _do_txn(txn: LoggingTransaction) -> int: "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" ) - txn.execute(sql, (user_localpart, bytearray(def_json))) + txn.execute(sql, (user_id.localpart, bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" - txn.execute(sql, (user_localpart,)) + txn.execute(sql, (user_id.localpart,)) max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 @@ -70,10 +93,18 @@ def _do_txn(txn: LoggingTransaction) -> int: filter_id = max_id + 1 sql = ( - "INSERT INTO user_filters (user_id, filter_id, filter_json)" - "VALUES(?, ?, ?)" + "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)" + "VALUES(?, ?, ?, ?)" + ) + txn.execute( + sql, + ( + user_id.to_string(), + user_id.localpart, + filter_id, + bytearray(def_json), + ), ) - txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) return filter_id diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 89c37a4eb560..1666e3c43b44 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -14,10 +14,12 @@ # limitations under the License. import itertools +import json import logging from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from signedjson.key import decode_verify_key_bytes +from unpaddedbase64 import decode_base64 from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction @@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore): """Persistence for signature verification keys""" @cached() - def _get_server_verify_key( + def _get_server_signature_key( self, server_name_and_key_id: Tuple[str, str] ) -> FetchKeyResult: raise NotImplementedError() @cachedList( - cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + cached_method_name="_get_server_signature_key", + list_name="server_name_and_key_ids", ) - async def get_server_verify_keys( + async def get_server_signature_keys( self, server_name_and_key_ids: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], FetchKeyResult]: """ @@ -62,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) - sql = ( - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " - "FROM server_signature_keys WHERE 1=0" - ) + " OR (server_name=? AND key_id=?)" * len(batch) + sql = """ + SELECT server_name, key_id, verify_key, ts_valid_until_ms + FROM server_signature_keys WHERE 1=0 + """ + " OR (server_name=? AND key_id=?)" * len( + batch + ) txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) @@ -89,9 +94,9 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: _get_keys(txn, batch) return keys - return await self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - async def store_server_verify_keys( + async def store_server_signature_keys( self, from_server: str, ts_added_ms: int, @@ -119,7 +124,7 @@ async def store_server_verify_keys( ) ) # invalidate takes a tuple corresponding to the params of - # _get_server_verify_key. _get_server_verify_key only takes one + # _get_server_signature_key. _get_server_signature_key only takes one # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) @@ -134,10 +139,10 @@ async def store_server_verify_keys( "verify_key", ), value_values=value_values, - desc="store_server_verify_keys", + desc="store_server_signature_keys", ) - invalidate = self._get_server_verify_key.invalidate + invalidate = self._get_server_signature_key.invalidate for i in invalidations: invalidate((i,)) @@ -180,7 +185,75 @@ async def store_server_keys_json( desc="store_server_keys_json", ) + # invalidate takes a tuple corresponding to the params of + # _get_server_keys_json. _get_server_keys_json only takes one + # param, which is itself the 2-tuple (server_name, key_id). + self._get_server_keys_json.invalidate((((server_name, key_id),))) + + @cached() + def _get_server_keys_json( + self, server_name_and_key_id: Tuple[str, str] + ) -> FetchKeyResult: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids" + ) async def get_server_keys_json( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], FetchKeyResult]: + """ + Args: + server_name_and_key_ids: + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown + """ + keys = {} + + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = """ + SELECT server_name, key_id, key_json, ts_valid_until_ms + FROM server_keys_json WHERE 1=0 + """ + " OR (server_name=? AND key_id=?)" * len( + batch + ) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn: + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + # The entire signed JSON response is stored in server_keys_json, + # fetch out the bits needed. + key_json = json.loads(bytes(key_json_bytes)) + key_base64 = key_json["verify_keys"][key_id]["key"] + + keys[(server_name, key_id)] = FetchKeyResult( + verify_key=decode_verify_key_bytes( + key_id, decode_base64(key_base64) + ), + valid_until_ts=ts_valid_until_ms, + ) + + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return await self.db_pool.runInteraction("get_server_keys_json", _txn) + + async def get_server_keys_json_for_remote( self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: """Retrieve the key json for a list of server_keys and key ids. @@ -188,8 +261,10 @@ async def get_server_keys_json( that server, key_id, and source triplet entry will be an empty list. The JSON is returned as a byte array so that it can be efficiently used in an HTTP response. + Args: server_keys: List of (server_name, key_id, source) triplets. + Returns: A mapping from (server_name, key_id, source) triplets to a list of dicts """ diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index a1747f04ce72..b109f8c07f1e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -11,14 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.roommember import ProfileInfo +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class ProfileWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + self.db_pool.updates.register_background_index_update( + "profiles_full_user_id_key_idx", + index_name="profiles_full_user_id_key", + table="profiles", + columns=["full_user_id"], + unique=True, + ) + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( @@ -54,28 +74,36 @@ async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: desc="get_profile_avatar_url", ) - async def create_profile(self, user_localpart: str) -> None: + async def create_profile(self, user_id: UserID) -> None: + user_localpart = user_id.localpart await self.db_pool.simple_insert( - table="profiles", values={"user_id": user_localpart}, desc="create_profile" + table="profiles", + values={"user_id": user_localpart, "full_user_id": user_id.to_string()}, + desc="create_profile", ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: Optional[str] + self, user_id: UserID, new_displayname: Optional[str] ) -> None: + user_localpart = user_id.localpart await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={"displayname": new_displayname}, + values={ + "displayname": new_displayname, + "full_user_id": user_id.to_string(), + }, desc="set_profile_displayname", ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: Optional[str] + self, user_id: UserID, new_avatar_url: Optional[str] ) -> None: + user_localpart = user_id.localpart await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, desc="set_profile_avatar_url", ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 7a7c0d9c753d..efbd3e75d99e 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -428,14 +428,16 @@ def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: "partial_state_events", "partial_state_rooms_servers", "partial_state_rooms", + # Note: the _membership(s) tables have foreign keys to the `events` table + # so must be deleted first. + "local_current_membership", + "room_memberships", "events", "federation_inbound_events_staging", - "local_current_membership", "receipts_graph", "receipts_linearized", "room_aliases", "room_depth", - "room_memberships", "room_stats_state", "room_stats_current", "room_stats_earliest_token", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 717237e02496..676d03bb7e14 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2414,8 +2414,8 @@ def _register_user( # *obviously* the 'profiles' table uses localpart for user_id # while everything else uses the full mxid. txn.execute( - "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", - (user_id_obj.localpart, create_profile_with_displayname), + "INSERT INTO profiles(full_user_id, user_id, displayname) VALUES (?,?,?)", + (user_id, user_id_obj.localpart, create_profile_with_displayname), ) if self.hs.config.stats.stats_enabled: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 3955a8a9a589..4a6c6c724d33 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -172,6 +172,7 @@ async def get_relations_for_event( direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, + recurse: bool = False, ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. @@ -186,6 +187,7 @@ async def get_relations_for_event( oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. + recurse: Whether to recursively find relations. Returns: A tuple of: @@ -200,8 +202,8 @@ async def get_relations_for_event( # Ensure bad limits aren't being passed in. assert limit >= 0 - where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event.event_id, room_id] + where_clause = ["room_id = ?"] + where_args: List[Union[str, int]] = [room_id] is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: @@ -229,23 +231,52 @@ async def get_relations_for_event( if pagination_clause: where_clause.append(pagination_clause) - sql = """ - SELECT event_id, relation_type, sender, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) + # If a recursive query is requested then the filters are applied after + # recursively following relationships from the requested event to children + # up to 3-relations deep. + # + # If no recursion is needed then the event_relations table is queried + # for direct children of the requested event. + if recurse: + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relation_type, relates_to_id, 0 AS depth + FROM event_relations + WHERE relates_to_id = ? + UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.event_id = e.relates_to_id + WHERE depth <= 3 + ) + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering + FROM related_events + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ?; + """ % ( + " AND ".join(where_clause), + order, + order, + ) + else: + sql = """ + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) def _get_recent_references_for_event_txn( txn: LoggingTransaction, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - txn.execute(sql, where_args + [limit + 1]) + txn.execute(sql, [event.event_id] + where_args + [limit + 1]) events = [] topo_orderings: List[int] = [] @@ -965,7 +996,7 @@ async def get_thread_id(self, event_id: str) -> str: # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type, 0 depth + SELECT event_id, relates_to_id, relation_type, 0 AS depth FROM event_relations WHERE event_id = ? UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 @@ -1025,7 +1056,7 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str: sql = """ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type, 0 depth + SELECT event_id, relates_to_id, relation_type, 0 AS depth FROM event_relations WHERE event_id = ? UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 2a1c6fa31bc2..38b7abd8010e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -22,7 +22,7 @@ from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION from synapse.storage.types import Cursor @@ -168,7 +168,9 @@ def prepare_database( def _setup_new_database( - cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str] + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + databases: Collection[str], ) -> None: """Sets up the physical database by finding a base set of "full schemas" and then applying any necessary deltas, including schemas from the given data @@ -289,7 +291,7 @@ def _setup_new_database( def _upgrade_existing_database( - cur: Cursor, + cur: LoggingTransaction, current_schema_state: _SchemaState, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index d3103a6c7a05..1672976209d6 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 74 # remember to update the list below when updating +SCHEMA_VERSION = 76 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -91,13 +91,22 @@ - A query on `event_stream_ordering` column has now been disambiguated (i.e. the codebase can handle the `current_state_events`, `local_current_memberships` and `room_memberships` tables having an `event_stream_ordering` column). + +Changes in SCHEMA_VERSION = 75: + - The `event_stream_ordering` column in membership tables (`current_state_events`, + `local_current_membership` & `room_memberships`) is now being populated for new + rows. When the background job to populate historical rows lands this will + become the compat schema version. + +Changes in SCHEMA_VERSION = 76: + - Adds a full_user_id column to tables profiles and user_filters. """ SCHEMA_COMPAT_VERSION = ( - # The threads_id column must exist for event_push_actions, event_push_summary, - # receipts_linearized, and receipts_graph. - 73 + # Queries against `event_stream_ordering` columns in membership tables must + # be disambiguated. + 74 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/20/pushers.py b/synapse/storage/schema/main/delta/20/pushers.py index 45b846e6a7d5..08ae0efc2112 100644 --- a/synapse/storage/schema/main/delta/20/pushers.py +++ b/synapse/storage/schema/main/delta/20/pushers.py @@ -24,10 +24,13 @@ import logging +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + logger = logging.getLogger(__name__) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: logger.info("Porting pushers table...") cur.execute( """ @@ -61,8 +64,8 @@ def run_create(cur, database_engine, *args, **kwargs): """ ) count = 0 - for row in cur.fetchall(): - row = list(row) + for tuple_row in cur.fetchall(): + row = list(tuple_row) row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( @@ -81,7 +84,3 @@ def run_create(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/25/fts.py b/synapse/storage/schema/main/delta/25/fts.py index 21f57825d4ed..831f8e914d76 100644 --- a/synapse/storage/schema/main/delta/25/fts.py +++ b/synapse/storage/schema/main/delta/25/fts.py @@ -14,7 +14,8 @@ import json import logging -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -41,7 +42,7 @@ ) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -72,7 +73,3 @@ def run_create(cur, database_engine, *args, **kwargs): ) cur.execute(sql, ("event_search", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/27/ts.py b/synapse/storage/schema/main/delta/27/ts.py index 1c6058063fb6..8962afdedae0 100644 --- a/synapse/storage/schema/main/delta/27/ts.py +++ b/synapse/storage/schema/main/delta/27/ts.py @@ -14,6 +14,8 @@ import json import logging +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ ) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -51,7 +53,3 @@ def run_create(cur, database_engine, *args, **kwargs): ) cur.execute(sql, ("event_origin_server_ts", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 4b4b166e37a6..b9d8df12313c 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Dict, Iterable, List, Tuple, cast from synapse.config.appservice import load_appservices +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine logger = logging.getLogger(__name__) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -27,9 +31,13 @@ def run_create(cur, database_engine, *args, **kwargs): pass -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: cur.execute("SELECT name FROM users") - rows = cur.fetchall() + rows = cast(Iterable[Tuple[str]], cur.fetchall()) config_files = [] try: @@ -39,7 +47,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): appservices = load_appservices(config.server.server_name, config_files) - owned = {} + owned: Dict[str, List[str]] = {} for row in rows: user_id = row[0] diff --git a/synapse/storage/schema/main/delta/31/pushers.py b/synapse/storage/schema/main/delta/31/pushers_0.py similarity index 88% rename from synapse/storage/schema/main/delta/31/pushers.py rename to synapse/storage/schema/main/delta/31/pushers_0.py index 5be81c806a28..e772e2dc65a0 100644 --- a/synapse/storage/schema/main/delta/31/pushers.py +++ b/synapse/storage/schema/main/delta/31/pushers_0.py @@ -20,14 +20,17 @@ import logging +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + logger = logging.getLogger(__name__) -def token_to_stream_ordering(token): +def token_to_stream_ordering(token: str) -> int: return int(token[1:].split("_")[0]) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: logger.info("Porting pushers table, delta 31...") cur.execute( """ @@ -61,8 +64,8 @@ def run_create(cur, database_engine, *args, **kwargs): """ ) count = 0 - for row in cur.fetchall(): - row = list(row) + for tuple_row in cur.fetchall(): + row = list(tuple_row) row[12] = token_to_stream_ordering(row[12]) cur.execute( """ @@ -80,7 +83,3 @@ def run_create(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/31/search_update.py b/synapse/storage/schema/main/delta/31/search_update.py index b84c844e3af4..e20e92e454c6 100644 --- a/synapse/storage/schema/main/delta/31/search_update.py +++ b/synapse/storage/schema/main/delta/31/search_update.py @@ -14,7 +14,8 @@ import json import logging -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if not isinstance(database_engine, PostgresEngine): return @@ -56,7 +57,3 @@ def run_create(cur, database_engine, *args, **kwargs): ) cur.execute(sql, ("event_search_order", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/33/event_fields.py b/synapse/storage/schema/main/delta/33/event_fields.py index e928c66a8f2d..8d806f5b525c 100644 --- a/synapse/storage/schema/main/delta/33/event_fields.py +++ b/synapse/storage/schema/main/delta/33/event_fields.py @@ -14,6 +14,8 @@ import json import logging +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -51,7 +53,3 @@ def run_create(cur, database_engine, *args, **kwargs): ) cur.execute(sql, ("event_fields_sender_url", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/33/remote_media_ts.py b/synapse/storage/schema/main/delta/33/remote_media_ts.py index 3907189e29fc..35499e43b526 100644 --- a/synapse/storage/schema/main/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/main/delta/33/remote_media_ts.py @@ -14,14 +14,22 @@ import time +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: cur.execute(ALTER_TABLE) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: cur.execute( "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), diff --git a/synapse/storage/schema/main/delta/34/cache_stream.py b/synapse/storage/schema/main/delta/34/cache_stream.py index cf09e43e2bf2..682c86da1abd 100644 --- a/synapse/storage/schema/main/delta/34/cache_stream.py +++ b/synapse/storage/schema/main/delta/34/cache_stream.py @@ -14,7 +14,8 @@ import logging -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -34,13 +35,9 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if not isinstance(database_engine, PostgresEngine): return for statement in get_statements(CREATE_TABLE.splitlines()): cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/34/received_txn_purge.py b/synapse/storage/schema/main/delta/34/received_txn_purge.py index 67d505e68bf4..dcfe3bc45a97 100644 --- a/synapse/storage/schema/main/delta/34/received_txn_purge.py +++ b/synapse/storage/schema/main/delta/34/received_txn_purge.py @@ -14,19 +14,16 @@ import logging -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine logger = logging.getLogger(__name__) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): cur.execute("TRUNCATE received_transactions") else: cur.execute("DELETE FROM received_transactions") cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/37/remove_auth_idx.py b/synapse/storage/schema/main/delta/37/remove_auth_idx.py index a3778841699c..d672f9b43cdf 100644 --- a/synapse/storage/schema/main/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/main/delta/37/remove_auth_idx.py @@ -14,7 +14,8 @@ import logging -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: for statement in get_statements(DROP_INDICES.splitlines()): cur.execute(statement) @@ -79,7 +80,3 @@ def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(drop_constraint.splitlines()): cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/42/user_dir.py b/synapse/storage/schema/main/delta/42/user_dir.py index 506f326f4db4..7e5c307c628f 100644 --- a/synapse/storage/schema/main/delta/42/user_dir.py +++ b/synapse/storage/schema/main/delta/42/user_dir.py @@ -14,7 +14,8 @@ import logging -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: for statement in get_statements(BOTH_TABLES.splitlines()): cur.execute(statement) @@ -78,7 +79,3 @@ def run_create(cur, database_engine, *args, **kwargs): cur.execute(statement) else: raise Exception("Unrecognized database engine") - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/48/group_unique_indexes.py b/synapse/storage/schema/main/delta/48/group_unique_indexes.py index 49f5f2c00324..ad2da4c8af84 100644 --- a/synapse/storage/schema/main/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/main/delta/48/group_unique_indexes.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine + +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import get_statements FIX_INDEXES = """ @@ -34,7 +36,7 @@ """ -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" # remove duplicates from group_users & group_invites tables @@ -57,7 +59,3 @@ def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(FIX_INDEXES.splitlines()): cur.execute(statement) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/50/make_event_content_nullable.py b/synapse/storage/schema/main/delta/50/make_event_content_nullable.py index acd6ad1e1fca..3e8a348b8aad 100644 --- a/synapse/storage/schema/main/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/main/delta/50/make_event_content_nullable.py @@ -53,16 +53,13 @@ import logging -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine logger = logging.getLogger(__name__) -def run_create(cur, database_engine, *args, **kwargs): - pass - - -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): cur.execute( """ @@ -76,7 +73,9 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" ) - (oldsql,) = cur.fetchone() + row = cur.fetchone() + assert row is not None + (oldsql,) = row sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") if sql == oldsql: @@ -85,7 +84,9 @@ def run_upgrade(cur, database_engine, *args, **kwargs): logger.info("Replacing definition of 'events' with: %s", sql) cur.execute("PRAGMA schema_version") - (oldver,) = cur.fetchone() + row = cur.fetchone() + assert row is not None + (oldver,) = row cur.execute("PRAGMA writable_schema=ON") cur.execute( "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", diff --git a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py index bb7296852a61..2461f87d7727 100644 --- a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py +++ b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py @@ -1,7 +1,8 @@ import logging from io import StringIO -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -16,11 +17,7 @@ """ -def run_upgrade(cur, database_engine, *args, **kwargs): - pass - - -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): select_clause = """ SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json diff --git a/synapse/storage/schema/main/delta/57/local_current_membership.py b/synapse/storage/schema/main/delta/57/local_current_membership.py index d25093c19fde..cc0f2109bb23 100644 --- a/synapse/storage/schema/main/delta/57/local_current_membership.py +++ b/synapse/storage/schema/main/delta/57/local_current_membership.py @@ -27,7 +27,16 @@ # equivalent behaviour as if the server had remained in the room). -def run_upgrade(cur, database_engine, config, *args, **kwargs): +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + + +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: # We need to do the insert in `run_upgrade` section as we don't have access # to `config` in `run_create`. @@ -77,7 +86,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): ) -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: cur.execute( """ CREATE TABLE local_current_membership ( diff --git a/synapse/storage/schema/main/delta/58/06dlols_unique_idx.py b/synapse/storage/schema/main/delta/58/06dlols_unique_idx.py index d353f2bcb361..4eaab9e08600 100644 --- a/synapse/storage/schema/main/delta/58/06dlols_unique_idx.py +++ b/synapse/storage/schema/main/delta/58/06dlols_unique_idx.py @@ -20,18 +20,14 @@ import logging from io import StringIO +from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.prepare_database import execute_statements_from_stream -from synapse.storage.types import Cursor logger = logging.getLogger(__name__) -def run_upgrade(*args, **kwargs): - pass - - -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: # some instances might already have this index, in which case we can skip this if isinstance(database_engine, PostgresEngine): cur.execute( diff --git a/synapse/storage/schema/main/delta/58/11user_id_seq.py b/synapse/storage/schema/main/delta/58/11user_id_seq.py index 4310ec12ce1a..32f7e0a252c7 100644 --- a/synapse/storage/schema/main/delta/58/11user_id_seq.py +++ b/synapse/storage/schema/main/delta/58/11user_id_seq.py @@ -16,19 +16,16 @@ Adds a postgres SEQUENCE for generating guest user IDs. """ +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.registration import ( find_max_generated_user_id_localpart, ) -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if not isinstance(database_engine, PostgresEngine): return next_id = find_max_generated_user_id_localpart(cur) + 1 cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/59/01ignored_user.py b/synapse/storage/schema/main/delta/59/01ignored_user.py index 9e8f35c1d24b..c53e2bade25c 100644 --- a/synapse/storage/schema/main/delta/59/01ignored_user.py +++ b/synapse/storage/schema/main/delta/59/01ignored_user.py @@ -20,18 +20,14 @@ from io import StringIO from synapse.storage._base import db_to_json +from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.prepare_database import execute_statements_from_stream -from synapse.storage.types import Cursor logger = logging.getLogger(__name__) -def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): - pass - - -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: logger.info("Creating ignored_users table") execute_statements_from_stream(cur, StringIO(_create_commands)) diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py index f8d7db9f2ef3..4a06b65888df 100644 --- a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py +++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py @@ -16,11 +16,11 @@ This migration handles the process of changing the type of `room_depth.min_depth` to a BIGINT. """ +from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if not isinstance(database_engine, PostgresEngine): # this only applies to postgres - sqlite does not distinguish between big and # little ints. @@ -64,7 +64,3 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs (6103, 'replace_room_depth_min_depth', '{}', 'populate_room_depth2') """ ) - - -def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py index a2ec4fc26edb..9210026ddee9 100644 --- a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -18,11 +18,11 @@ Triggers cannot be expressed in .sql files, so we have to use a separate file. """ +from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine -from synapse.storage.types import Cursor -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: # complain if the room_id in partial_state_events doesn't match # that in `events`. We already have a fk constraint which ensures that the event # exists in `events`, so all we have to do is raise if there is a row with a diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py index 24bd4b391eee..6c112425f2f0 100644 --- a/synapse/storage/schema/main/delta/69/01as_txn_seq.py +++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py @@ -17,10 +17,11 @@ Adds a postgres SEQUENCE for generating application service transaction IDs. """ -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): # If we already have some AS TXNs we want to start from the current # maximum value. There are two potential places this is stored - the @@ -30,10 +31,12 @@ def run_create(cur, database_engine, *args, **kwargs): cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns") row = cur.fetchone() + assert row is not None txn_max = row[0] cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state") row = cur.fetchone() + assert row is not None last_txn_max = row[0] start_val = max(last_txn_max, txn_max) + 1 diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py index 55a5d092cc67..2ec1830c6ffb 100644 --- a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py +++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py @@ -14,10 +14,11 @@ import json -from synapse.storage.types import Cursor +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine -def run_create(cur: Cursor, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`""" # we know that any new events will have the columns populated (and that has been @@ -27,7 +28,9 @@ def run_create(cur: Cursor, database_engine, *args, **kwargs): # current min and max stream orderings, since that is guaranteed to include all # the events that were stored before the new columns were added. cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events") - (min_stream_ordering, max_stream_ordering) = cur.fetchone() + row = cur.fetchone() + assert row is not None + (min_stream_ordering, max_stream_ordering) = row if min_stream_ordering is None: # no rows, nothing to do. diff --git a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py index b5853d125c6a..5c3e3584a21b 100644 --- a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py +++ b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py @@ -19,9 +19,16 @@ Note the background job must still remain defined in the database class. """ +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: cur.execute("SELECT update_name FROM background_updates") rows = cur.fetchall() for row in rows: diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py index 3de0a709eba7..c7ed258e9df2 100644 --- a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py +++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py @@ -13,11 +13,11 @@ # limitations under the License. import json +from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine -from synapse.storage.types import Cursor -def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: """ Upgrade the event_search table to use the porter tokenizer if it isn't already @@ -38,6 +38,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: # Re-run the background job to re-populate the event_search table. cur.execute("SELECT MIN(stream_ordering) FROM events") row = cur.fetchone() + assert row is not None min_stream_id = row[0] # If there are not any events, nothing to do. @@ -46,6 +47,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: cur.execute("SELECT MAX(stream_ordering) FROM events") row = cur.fetchone() + assert row is not None max_stream_id = row[0] progress = { diff --git a/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.postgres new file mode 100644 index 000000000000..ceb750a9fa51 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.postgres @@ -0,0 +1,29 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which +-- we use to improve database performance by reduring JOINs. + +-- NOTE: these are set to NOT VALID to prevent locks while adding the column on large existing tables, +-- which will be validated in a later migration. For all new/updated rows the FKEY will be checked. + +ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT; +ALTER TABLE current_state_events ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID; + +ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT; +ALTER TABLE local_current_membership ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID; + +ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT; +ALTER TABLE room_memberships ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID; diff --git a/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.sqlite b/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.sqlite new file mode 100644 index 000000000000..6f6283fdb769 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/03_membership_tables_event_stream_ordering.sql.sqlite @@ -0,0 +1,23 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which +-- we use to improve database performance by reduring JOINs. + +-- NOTE: sqlite does not support ADD CONSTRAINT so we add the new columns with FK constraint as-is + +ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); +ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); +ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering); diff --git a/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py new file mode 100644 index 000000000000..2ee2bc9422a6 --- /dev/null +++ b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py @@ -0,0 +1,79 @@ +# Copyright 2022 Beeper +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the room membership tables to enforce consistency. +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine + + +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: + # Complain if the `event_stream_ordering` in membership tables doesn't match + # the `stream_ordering` row with the same `event_id` in `events`. + if isinstance(database_engine, Sqlite3Engine): + for table in ( + "current_state_events", + "local_current_membership", + "room_memberships", + ): + cur.execute( + f""" + CREATE TRIGGER IF NOT EXISTS {table}_bad_event_stream_ordering + BEFORE INSERT ON {table} + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect event_stream_ordering in {table}') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.stream_ordering != NEW.event_stream_ordering + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_event_stream_ordering() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.stream_ordering != NEW.event_stream_ordering + ) THEN + RAISE EXCEPTION 'Incorrect event_stream_ordering'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + for table in ( + "current_state_events", + "local_current_membership", + "room_memberships", + ): + cur.execute( + f""" + CREATE TRIGGER check_event_stream_ordering BEFORE INSERT OR UPDATE ON {table} + FOR EACH ROW + EXECUTE PROCEDURE check_event_stream_ordering() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql b/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql new file mode 100644 index 000000000000..517a821a561d --- /dev/null +++ b/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql @@ -0,0 +1,53 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- For MSC3970, in addition to the (room_id, user_id, token_id, txn_id) -> event_id mapping for each local event, +-- we also store the (room_id, user_id, device_id, txn_id) -> event_id mapping. +-- +-- This adds a new event_txn_id_device_id table. + +-- A map of recent events persisted with transaction IDs. Used to deduplicate +-- send event requests with the same transaction ID. +-- +-- Note: with MSC3970, transaction IDs are scoped to the +-- room ID/user ID/device ID that was used to make the request. +-- +-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the +-- event or device we don't want to try and de-duplicate the event. +CREATE TABLE IF NOT EXISTS event_txn_id_device_id ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (user_id, device_id) + REFERENCES devices (user_id, device_id) ON DELETE CASCADE +); + +-- This ensures that there is only one mapping per event_id. +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_event_id + ON event_txn_id_device_id(event_id); + +-- This ensures that there is only one mapping per (room_id, user_id, device_id, txn_id) tuple. +-- Events are usually looked up using this index. +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_txn_id + ON event_txn_id_device_id(room_id, user_id, device_id, txn_id); + +-- This table is cleaned up regularly, removing the oldest entries, hence this index. +CREATE INDEX IF NOT EXISTS event_txn_id_device_id_ts + ON event_txn_id_device_id(inserted_ts); diff --git a/synapse/storage/schema/main/delta/76/01_add_profiles_full_user_id_column.sql b/synapse/storage/schema/main/delta/76/01_add_profiles_full_user_id_column.sql new file mode 100644 index 000000000000..9cd680325ad0 --- /dev/null +++ b/synapse/storage/schema/main/delta/76/01_add_profiles_full_user_id_column.sql @@ -0,0 +1,20 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE profiles ADD COLUMN full_user_id TEXT; + +-- Make sure the column has a unique constraint, mirroring the `profiles_user_id_key` +-- constraint. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7501, 'profiles_full_user_id_key_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/76/02_add_user_filters_full_user_id_column.sql b/synapse/storage/schema/main/delta/76/02_add_user_filters_full_user_id_column.sql new file mode 100644 index 000000000000..fd231adeef9e --- /dev/null +++ b/synapse/storage/schema/main/delta/76/02_add_user_filters_full_user_id_column.sql @@ -0,0 +1,20 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE user_filters ADD COLUMN full_user_id TEXT; + +-- Add a unique index on the new column, mirroring the `user_filters_unique` unique +-- index. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7502, 'full_users_filters_unique_idx', '{}'); \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/76/03_per_user_experimental_features.sql b/synapse/storage/schema/main/delta/76/03_per_user_experimental_features.sql new file mode 100644 index 000000000000..c4ef81846ceb --- /dev/null +++ b/synapse/storage/schema/main/delta/76/03_per_user_experimental_features.sql @@ -0,0 +1,27 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table containing experimental features and whether they are enabled for a given user +CREATE TABLE per_user_experimental_features ( + -- The User ID to check/set the feature for + user_id TEXT NOT NULL, + -- Contains features to be enabled/disabled + feature TEXT NOT NULL, + -- whether the feature is enabled/disabled for a given user, defaults to disabled + enabled BOOLEAN DEFAULT FALSE, + FOREIGN KEY (user_id) REFERENCES users(name), + PRIMARY KEY (user_id, feature) +); + diff --git a/synapse/storage/schema/state/delta/47/state_group_seq.py b/synapse/storage/schema/state/delta/47/state_group_seq.py index 9fd1ccf6f792..42aff502273b 100644 --- a/synapse/storage/schema/state/delta/47/state_group_seq.py +++ b/synapse/storage/schema/state/delta/47/state_group_seq.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.engines import PostgresEngine +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -def run_create(cur, database_engine, *args, **kwargs): +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: if isinstance(database_engine, PostgresEngine): # if we already have some state groups, we want to start making new # ones with a higher id. cur.execute("SELECT max(id) FROM state_groups") row = cur.fetchone() + assert row is not None if row[0] is None: start_val = 1 @@ -28,7 +30,3 @@ def run_create(cur, database_engine, *args, **kwargs): start_val = row[0] + 1 cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6c6a9ab4b4a9..222449baac81 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -26,13 +26,15 @@ from synapse.api.filtering import Filter from synapse.api.presence import UserPresenceState from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from synapse.util.frozenutils import freeze from tests import unittest from tests.events.test_utils import MockEvent +user_id = UserID.from_string("@test_user:test") +user2_id = UserID.from_string("@test_user2:test") user_localpart = "test_user" @@ -437,7 +439,7 @@ def test_filter_presence_match(self) -> None: user_filter_json = {"presence": {"senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) presence_states = [ @@ -467,7 +469,7 @@ def test_filter_presence_no_match(self) -> None: filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart + "2", user_filter=user_filter_json + user_id=user2_id, user_filter=user_filter_json ) ) presence_states = [ @@ -495,7 +497,7 @@ def test_filter_room_state_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") @@ -514,7 +516,7 @@ def test_filter_room_state_no_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) event = MockEvent( @@ -598,7 +600,7 @@ def test_add_filter(self) -> None: filter_id = self.get_success( self.filtering.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) @@ -619,7 +621,7 @@ def test_get_filter(self) -> None: filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 7deb923a280d..15fce165b611 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -195,11 +195,11 @@ async def post_json_get_json( MISSING_KEYS = [ # Known user, known device, missing algorithm. - ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), + ("@alice:example.org", "DEVICE_2", "xyz", 1), # Known user, missing device. - ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), + ("@alice:example.org", "DEVICE_3", "signed_curve25519", 1), # Unknown user. - ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), + ("@bob:example.org", "DEVICE_4", "signed_curve25519", 1), ] claimed_keys, missing = self.get_success( @@ -207,9 +207,8 @@ async def post_json_get_json( self.service, [ # Found devices - ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), - ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), - ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), + ("@alice:example.org", "DEVICE_1", "signed_curve25519", 1), + ("@alice:example.org", "DEVICE_2", "signed_curve25519", 1), ] + MISSING_KEYS, ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 66102ab93487..7c63b2ea4c15 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -190,10 +190,23 @@ def test_verify_json_for_server(self) -> None: kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_keys_json( "server9", - int(time.time() * 1000), - {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)}, + get_key_id(key1), + from_server="test", + ts_now_ms=int(time.time() * 1000), + ts_expires_ms=1000, + # The entire response gets signed & stored, just include the bits we + # care about. + key_json_bytes=canonicaljson.encode_canonical_json( + { + "verify_keys": { + get_key_id(key1): { + "key": encode_verify_key_base64(get_verify_key(key1)) + } + } + } + ), ) self.get_success(r) @@ -280,17 +293,13 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) - kr = keyring.Keyring( - self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) - ) - key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_signature_keys( "server9", int(time.time() * 1000), # None is not a valid value in FetchKeyResult, but we're abusing this # API to insert null values into the database. The nulls get converted - # to 0 when fetched in KeyStore.get_server_verify_keys. + # to 0 when fetched in KeyStore.get_server_signature_keys. {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] ) self.get_success(r) @@ -298,27 +307,12 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: json1: JsonDict = {} signedjson.sign.sign_json(json1, "server9", key1) - # should fail immediately on an unsigned object - d = kr.verify_json_for_server("server9", {}, 0) - self.get_failure(d, SynapseError) - - # should fail on a signed object with a non-zero minimum_valid_until_ms, - # as it tries to refetch the keys and fails. - d = kr.verify_json_for_server("server9", json1, 500) - self.get_failure(d, SynapseError) - - # We expect the keyring tried to refetch the key once. - mock_fetcher.get_keys.assert_called_once_with( - "server9", [get_key_id(key1)], 500 - ) - # should succeed on a signed object with a 0 minimum_valid_until_ms - d = kr.verify_json_for_server( - "server9", - json1, - 0, + d = self.hs.get_datastores().main.get_server_signature_keys( + [("server9", get_key_id(key1))] ) - self.get_success(d) + result = self.get_success(d) + self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) def test_verify_json_dedupes_key_requests(self) -> None: """Two requests for the same key should be deduped.""" @@ -464,7 +458,9 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json_for_remote( + [lookup_triplet] + ) ) res_keys = key_json[lookup_triplet] self.assertEqual(len(res_keys), 1) @@ -582,7 +578,9 @@ def test_get_keys_from_perspectives(self) -> None: # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json_for_remote( + [lookup_triplet] + ) ) res_keys = key_json[lookup_triplet] self.assertEqual(len(res_keys), 1) @@ -703,7 +701,9 @@ def test_get_perspectives_own_key(self) -> None: # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json_for_remote( + [lookup_triplet] + ) ) res_keys = key_json[lookup_triplet] self.assertEqual(len(res_keys), 1) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 33af8770fde8..129d7cfd93f5 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -75,7 +75,7 @@ def test_join_too_large(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -106,7 +106,7 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -143,7 +143,7 @@ def test_join_too_large_once_joined(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -200,7 +200,7 @@ def test_join_too_large_no_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -230,7 +230,7 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 013b9ee5504f..72d05840613e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -160,7 +160,9 @@ def test_claim_one_time_key(self) -> None: res2 = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -203,7 +205,9 @@ def test_fallback_key(self) -> None: # key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -220,7 +224,9 @@ def test_fallback_key(self) -> None: # claiming an OTK again should return the same fallback key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -267,7 +273,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -277,7 +285,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -296,7 +306,9 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -304,6 +316,75 @@ def test_fallback_key(self) -> None: {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) + def test_fallback_key_always_returned(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + otk = {"alg1:k2": "key2"} + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # Upload a OTK & fallback key. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": otk, "fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK and requesting to always return the fallback key should + # return both. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}}, + }, + ) + + # This should not mark the key as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK again should return only the fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # And mark it as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, []) + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname @@ -971,7 +1052,7 @@ def test_query_appservice(self) -> None: # Setup a response, but only for device 2. self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) ) # we shouldn't have any unused fallback keys yet @@ -998,12 +1079,9 @@ def test_query_appservice(self) -> None: # query the fallback keys. claim_res = self.get_success( self.handler.claim_one_time_keys( - { - "one_time_keys": { - local_user: {device_id_1: "alg1", device_id_2: "alg1"} - } - }, + {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -1016,6 +1094,153 @@ def test_query_appservice(self) -> None: }, ) + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice_with_fallback(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}} + otk = {"alg1:k2": {"desc": "key2"}} + as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}} + as_otk = {"alg1:k4": {"desc": "key4"}} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) + ) + + # Claim OTKs, which will ask the appservice and do nothing else. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id_1: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **as_fallback_key}} + }, + }, + ) + + # Now upload a fallback key. + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # The appservice will return only the OTK. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_otk}}, []) + ) + + # Claim OTKs, which should return the OTK from the appservice and the + # uploaded fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id_1: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **fallback_key}} + }, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Now upload a OTK. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"one_time_keys": otk}, + ) + ) + + # Claim OTKs, which will return information only from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id_1: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}}, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Finally, return only the fallback key from the appservice. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_fallback_key}}, []) + ) + + # Claim OTKs, which will return only the fallback key from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id_1: {"alg1": 1}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: as_fallback_key}}, + }, + ) + @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) def test_query_local_devices_appservice(self) -> None: """Test that querying of appservices for keys overrides responses from the database.""" diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 7c174782da36..64a9a22afeca 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -66,9 +66,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_profile_handler() def test_get_my_name(self) -> None: - self.get_success( - self.store.set_profile_displayname(self.frank.localpart, "Frank") - ) + self.get_success(self.store.set_profile_displayname(self.frank, "Frank")) displayname = self.get_success(self.handler.get_displayname(self.frank)) @@ -121,9 +119,7 @@ def test_set_my_name_if_disabled(self) -> None: self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed - self.get_success( - self.store.set_profile_displayname(self.frank.localpart, "Frank") - ) + self.get_success(self.store.set_profile_displayname(self.frank, "Frank")) self.assertEqual( ( @@ -166,8 +162,14 @@ def test_get_other_name(self) -> None: ) def test_incoming_fed_query(self) -> None: - self.get_success(self.store.create_profile("caroline")) - self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) + self.get_success( + self.store.create_profile(UserID.from_string("@caroline:test")) + ) + self.get_success( + self.store.set_profile_displayname( + UserID.from_string("@caroline:test"), "Caroline" + ) + ) response = self.get_success( self.query_handlers["profile"]( @@ -183,9 +185,7 @@ def test_incoming_fed_query(self) -> None: def test_get_my_avatar(self) -> None: self.get_success( - self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" - ) + self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png") ) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) @@ -237,9 +237,7 @@ def test_set_my_avatar_if_disabled(self) -> None: # Setting displayname for the first time is allowed self.get_success( - self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" - ) + self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png") ) self.assertEqual( diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index fdd22a8e9437..d89a91c59d93 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -26,7 +26,7 @@ from synapse.api.errors import RequestSendFailed from synapse.http.matrixfederationclient import ( - JsonParser, + ByteParser, MatrixFederationHttpClient, MatrixFederationRequest, ) @@ -618,9 +618,9 @@ def test_too_big(self) -> None: while not test_d.called: protocol.dataReceived(b"a" * chunk_size) sent += chunk_size - self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) + self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE) - self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) + self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE) f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index a8f6436836be..645a00b4b124 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -372,3 +372,130 @@ def test_purge_history(self) -> None: self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) + + +class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/experimental_features" + + def test_enable_and_disable(self) -> None: + """ + Test basic functionality of ExperimentalFeatures endpoint + """ + # test enabling features works + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "PUT", + url, + content={ + "features": {"msc3026": True, "msc2654": True}, + }, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200) + + # list which features are enabled and ensure the ones we enabled are listed + self.assertEqual(channel.code, 200) + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + True, + channel.json_body["features"]["msc3026"], + ) + self.assertEqual( + True, + channel.json_body["features"]["msc2654"], + ) + + # test disabling a feature works + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "PUT", + url, + content={"features": {"msc3026": False}}, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200) + + # list the features enabled/disabled and ensure they are still are correct + self.assertEqual(channel.code, 200) + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + False, + channel.json_body["features"]["msc3026"], + ) + self.assertEqual( + True, + channel.json_body["features"]["msc2654"], + ) + self.assertEqual( + False, + channel.json_body["features"]["msc3881"], + ) + self.assertEqual( + False, + channel.json_body["features"]["msc3967"], + ) + + # test nothing blows up if you try to disable a feature that isn't already enabled + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "PUT", + url, + content={"features": {"msc3026": False}}, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200) + + # test trying to enable a feature without an admin access token is denied + url = f"{self.url}/f{self.other_user}" + channel = self.make_request( + "PUT", + url, + content={"features": {"msc3881": True}}, + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + {"errcode": "M_FORBIDDEN", "error": "You are not a server admin"}, + ) + + # test trying to enable a bogus msc is denied + url = f"{self.url}/{self.other_user}" + channel = self.make_request( + "PUT", + url, + content={"features": {"msc6666": True}}, + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 400) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "'msc6666' is not recognised as a valid experimental feature.", + }, + ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index b4241ceaf023..434bb56d4451 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -802,9 +802,21 @@ def test_order_by(self) -> None: # Set avatar URL to all users, that no user has a NULL value to avoid # different sort order between SQlite and PostreSQL - self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3")) - self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2")) - self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1")) + self.get_success( + self.store.set_profile_avatar_url( + UserID.from_string("@user1:test"), "mxc://url3" + ) + ) + self.get_success( + self.store.set_profile_avatar_url( + UserID.from_string("@user2:test"), "mxc://url2" + ) + ) + self.get_success( + self.store.set_profile_avatar_url( + UserID.from_string("@admin:test"), "mxc://url1" + ) + ) # order by default (name) self._order_test([self.admin_user, user1, user2], None) @@ -1127,7 +1139,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # set attributes for user self.get_success( - self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + self.store.set_profile_avatar_url( + UserID.from_string("@user:test"), "mxc://servername/mediaid" + ) ) self.get_success( self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) @@ -1257,7 +1271,9 @@ def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None: Reproduces #12257. """ # Patch `self.other_user` to have an empty string as their avatar. - self.get_success(self.store.set_profile_avatar_url("user", "")) + self.get_success( + self.store.set_profile_avatar_url(UserID.from_string("@user:test"), "") + ) # Check we can still erase them. channel = self.make_request( @@ -2311,7 +2327,9 @@ def test_deactivate_user(self) -> None: # set attributes for user self.get_success( - self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + self.store.set_profile_avatar_url( + UserID.from_string("@user:test"), "mxc://servername/mediaid" + ) ) self.get_success( self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 91678abf1311..9faa9de05076 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -17,6 +17,7 @@ from synapse.api.errors import Codes from synapse.rest.client import filter from synapse.server import HomeServer +from synapse.types import UserID from synapse.util import Clock from tests import unittest @@ -76,7 +77,8 @@ def test_add_filter_non_local_user(self) -> None: def test_get_filter(self) -> None: filter_id = self.get_success( self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER + user_id=UserID.from_string("@apple:test"), + user_filter=self.EXAMPLE_FILTER, ) ) self.reactor.advance(1) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index cc47ac560f28..75439416c175 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -950,6 +950,125 @@ def test_pagination_from_sync_and_messages(self) -> None: ) +class RecursiveRelationTestCase(BaseRelationsTestCase): + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) + def test_recursive_relations(self) -> None: + """Generate a complex, multi-level relationship tree and query it.""" + # Create a thread with a few messages in it. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + # Add annotations. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + ) + annotation_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1 + ) + annotation_2 = channel.json_body["event_id"] + + # Add a reference to part of the thread, then edit the reference and annotate it. + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2 + ) + reference_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1 + ) + annotation_3 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.test", + parent_id=reference_1, + ) + edit = channel.json_body["event_id"] + + # Also more events off the root. + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d") + annotation_4 = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + event_ids, + [ + thread_1, + thread_2, + annotation_1, + annotation_2, + reference_1, + annotation_3, + edit, + annotation_4, + ], + ) + + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) + def test_recursive_relations_with_filter(self) -> None: + """The event_type and rel_type still apply.""" + # Create a thread with a few messages in it. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_1 = channel.json_body["event_id"] + + # Add annotations. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1 + ) + annotation_1 = channel.json_body["event_id"] + + # Add a reference to part of the thread, then edit the reference and annotate it. + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1 + ) + reference_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1 + ) + annotation_2 = channel.json_body["event_id"] + + # Fetch only annotations, but recursively. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(event_ids, [annotation_1, annotation_2]) + + # Fetch only m.reactions, but recursively. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(event_ids, [annotation_1]) + + class BundledAggregationsTestCase(BaseRelationsTestCase): """ See RelationsTestCase.test_edit for a similar test for edits. diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 5901d80f26d7..5d7c13e6d04c 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,13 +37,13 @@ def decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_verify_keys(self) -> None: + def test_get_server_signature_keys(self) -> None: store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" self.get_success( - store.store_server_verify_keys( + store.store_server_signature_keys( "from_server", 10, { @@ -54,7 +54,7 @@ def test_get_server_verify_keys(self) -> None: ) res = self.get_success( - store.get_server_verify_keys( + store.get_server_signature_keys( [ ("server1", key_id_1), ("server1", key_id_2), @@ -87,7 +87,7 @@ def test_cache(self) -> None: key_id_2 = "ed25519:key2" self.get_success( - store.store_server_verify_keys( + store.store_server_signature_keys( "from_server", 0, { @@ -98,7 +98,7 @@ def test_cache(self) -> None: ) res = self.get_success( - store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) ) self.assertEqual(len(res.keys()), 2) @@ -111,20 +111,20 @@ def test_cache(self) -> None: self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit - res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)])) + res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) self.assertEqual(len(res.keys()), 1) self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) - d = store.store_server_verify_keys( + d = store.store_server_signature_keys( "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} ) self.get_success(d) res = self.get_success( - store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) ) self.assertEqual(len(res.keys()), 2) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 5806cb0e4bb7..27f450e22d1d 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -29,9 +29,9 @@ def setUp(self) -> None: def test_get_users_paginate(self) -> None: self.get_success(self.store.register_user(self.user.to_string(), "pass")) - self.get_success(self.store.create_profile(self.user.localpart)) + self.get_success(self.store.create_profile(self.user)) self.get_success( - self.store.set_profile_displayname(self.user.localpart, self.displayname) + self.store.set_profile_displayname(self.user, self.displayname) ) users, total = self.get_success( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a019d06e09c5..6ec34997ea53 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -27,11 +27,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.u_frank = UserID.from_string("@frank:test") def test_displayname(self) -> None: - self.get_success(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank)) - self.get_success( - self.store.set_profile_displayname(self.u_frank.localpart, "Frank") - ) + self.get_success(self.store.set_profile_displayname(self.u_frank, "Frank")) self.assertEqual( "Frank", @@ -43,21 +41,17 @@ def test_displayname(self) -> None: ) # test set to None - self.get_success( - self.store.set_profile_displayname(self.u_frank.localpart, None) - ) + self.get_success(self.store.set_profile_displayname(self.u_frank, None)) self.assertIsNone( self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) def test_avatar_url(self) -> None: - self.get_success(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank)) self.get_success( - self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" - ) + self.store.set_profile_avatar_url(self.u_frank, "http://my.site/here") ) self.assertEqual( @@ -70,9 +64,7 @@ def test_avatar_url(self) -> None: ) # test set to None - self.get_success( - self.store.set_profile_avatar_url(self.u_frank.localpart, None) - ) + self.get_success(self.store.set_profile_avatar_url(self.u_frank, None)) self.assertIsNone( self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 6476dfc7a1a7..ab51b2a2adca 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -268,7 +268,9 @@ def test_cross_signing_keys_retry(self) -> None: # Resync the device list. device_handler = self.hs.get_device_handler() self.get_success( - device_handler.device_list_updater.user_device_resync(remote_user_id), + device_handler.device_list_updater.multi_user_device_resync( + [remote_user_id] + ), ) # Retrieve the cross-signing keys for this user. diff --git a/tests/unittest.py b/tests/unittest.py index 93fee1c0e6e4..ee2f78ab0163 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -16,6 +16,7 @@ import gc import hashlib import hmac +import json import logging import secrets import time @@ -53,6 +54,7 @@ from synapse import events from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.config._base import Config, RootConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.crypto.event_signing import add_hashes_and_signatures @@ -67,7 +69,6 @@ ) from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer -from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree @@ -124,6 +125,63 @@ def new(*args: P.args, **kwargs: P.kwargs) -> R: return _around +_TConfig = TypeVar("_TConfig", Config, RootConfig) + + +def deepcopy_config(config: _TConfig) -> _TConfig: + new_config: _TConfig + + if isinstance(config, RootConfig): + new_config = config.__class__(config.config_files) # type: ignore[arg-type] + else: + new_config = config.__class__(config.root) + + for attr_name in config.__dict__: + if attr_name.startswith("__") or attr_name == "root": + continue + attr = getattr(config, attr_name) + if isinstance(attr, Config): + new_attr = deepcopy_config(attr) + else: + new_attr = attr + + setattr(new_config, attr_name, new_attr) + + return new_config + + +_make_homeserver_config_obj_cache: Dict[str, Union[RootConfig, Config]] = {} + + +def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig: + """Creates a :class:`HomeServerConfig` instance with the given configuration dict. + + This is equivalent to:: + + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed, + to avoid validating the whole configuration every time. + """ + cache_key = json.dumps(config) + + if cache_key in _make_homeserver_config_obj_cache: + # Cache hit: reuse the existing instance + config_obj = _make_homeserver_config_obj_cache[cache_key] + else: + # Cache miss; create the actual instance + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + # Add to the cache + _make_homeserver_config_obj_cache[cache_key] = config_obj + + assert isinstance(config_obj, RootConfig) + + return deepcopy_config(config_obj) + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -528,8 +586,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: config = kwargs["config"] # Parse the config from a config dict into a HomeServerConfig - config_obj = HomeServerConfig() - config_obj.parse_config_dict(config, "", "") + config_obj = make_homeserver_config_obj(config) kwargs["config"] = config_obj async def run_bg_updates() -> None: @@ -790,15 +847,23 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastores().main.store_server_verify_keys( + hs.get_datastores().main.store_server_keys_json( + self.OTHER_SERVER_NAME, + verify_key_id, from_server=self.OTHER_SERVER_NAME, - ts_added_ms=clock.time_msec(), - verify_keys={ - (self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult( - verify_key=verify_key, - valid_until_ts=clock.time_msec() + 10000, - ), - }, + ts_now_ms=clock.time_msec(), + ts_expires_ms=clock.time_msec() + 10000, + key_json_bytes=canonicaljson.encode_canonical_json( + { + "verify_keys": { + verify_key_id: { + "key": signedjson.key.encode_verify_key_base64( + verify_key + ) + } + } + } + ), ) ) diff --git a/tests/utils.py b/tests/utils.py index a0ac11bc5cd2..e73b46944bd9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -131,6 +131,9 @@ def default_config( # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", + # Disable trusted key servers, otherwise unit tests might try to actually + # reach out to matrix.org. + "trusted_key_servers": [], "event_cache_size": 1, "enable_registration": True, "enable_registration_captcha": False,