diff --git a/spec/integ/crypto/crypto.spec.ts b/spec/integ/crypto/crypto.spec.ts index 66baa23ddc8..c1a665e2c06 100644 --- a/spec/integ/crypto/crypto.spec.ts +++ b/spec/integ/crypto/crypto.spec.ts @@ -2343,13 +2343,12 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, ])( "Decryption fails with withheld error if a withheld notice with code '%s' is received", (withheldCode, expectedMessage, expectedErrorCode) => { - // TODO: test arrival after the event too. - it.each(["before"])("%s the event", async (when) => { + it.each(["before", "after"])("%s the event", async (when) => { expectAliceKeyQuery({ device_keys: { "@alice:localhost": {} }, failures: {} }); await startClientAndAwaitFirstSync(); // A promise which resolves, with the MatrixEvent which wraps the event, once the decryption fails. - const awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted); + let awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted); // Send Alice an encrypted room event which looks like it was encrypted with a megolm session async function sendEncryptedEvent() { @@ -2393,6 +2392,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, await sendEncryptedEvent(); } else { await sendEncryptedEvent(); + // Make sure that the first attempt to decrypt has happened before the withheld arrives + await awaitDecryption; + awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted); await sendWithheldMessage(); } diff --git a/spec/unit/rust-crypto/rust-crypto.spec.ts b/spec/unit/rust-crypto/rust-crypto.spec.ts index 966f25a5af5..2b161c778f7 100644 --- a/spec/unit/rust-crypto/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto/rust-crypto.spec.ts @@ -95,6 +95,7 @@ describe("initRustCrypto", () => { deleteSecretsFromInbox: jest.fn(), registerReceiveSecretCallback: jest.fn(), registerDevicesUpdatedCallback: jest.fn(), + registerRoomKeysWithheldCallback: jest.fn(), outgoingRequests: jest.fn(), isBackupEnabled: jest.fn().mockResolvedValue(false), verifyBackup: jest.fn().mockResolvedValue({ trusted: jest.fn().mockReturnValue(false) }), diff --git a/src/rust-crypto/index.ts b/src/rust-crypto/index.ts index dc9a42af743..0c9e162106f 100644 --- a/src/rust-crypto/index.ts +++ b/src/rust-crypto/index.ts @@ -174,6 +174,9 @@ async function initOlmMachine( await olmMachine.registerRoomKeyUpdatedCallback((sessions: RustSdkCryptoJs.RoomKeyInfo[]) => rustCrypto.onRoomKeysUpdated(sessions), ); + await olmMachine.registerRoomKeysWithheldCallback((withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]) => + rustCrypto.onRoomKeysWithheld(withheld), + ); await olmMachine.registerUserIdentityUpdatedCallback((userId: RustSdkCryptoJs.UserId) => rustCrypto.onUserIdentityUpdated(userId), ); diff --git a/src/rust-crypto/rust-crypto.ts b/src/rust-crypto/rust-crypto.ts index c320ba9c6e9..98876482897 100644 --- a/src/rust-crypto/rust-crypto.ts +++ b/src/rust-crypto/rust-crypto.ts @@ -1486,7 +1486,7 @@ export class RustCrypto extends TypedEventEmitter { + for (const session of withheld) { + this.logger.debug(`Got withheld message for session ${session.sessionId} in ${session.roomId.toString()}`); + const pendingList = this.eventDecryptor.getEventsPendingRoomKey( + session.roomId.toString(), + session.sessionId, + ); + if (pendingList.length === 0) return; + + // The easiest way to update the status of the event is to have another go at decrypting it. + this.logger.debug( + "Retrying decryption on events:", + pendingList.map((e) => `${e.getId()}`), + ); + + for (const ev of pendingList) { + ev.attemptDecryption(this, { isRetry: true }).catch((_e) => { + // It's somewhat expected that we still can't decrypt here. + }); + } + } + } + /** * Callback for `OlmMachine.registerUserIdentityUpdatedCallback` * @@ -1683,7 +1714,7 @@ class EventDecryptor { /** * Events which we couldn't decrypt due to unknown sessions / indexes. * - * Map from senderKey to sessionId to Set of MatrixEvents + * Map from roomId to sessionId to Set of MatrixEvents */ private eventsPendingKey = new MapWithDefault>>( () => new MapWithDefault>(() => new Set()), @@ -1843,30 +1874,27 @@ class EventDecryptor { * Look for events which are waiting for a given megolm session * * Returns a list of events which were encrypted by `session` and could not be decrypted - * - * @param session - */ - public getEventsPendingRoomKey(session: RustSdkCryptoJs.RoomKeyInfo): MatrixEvent[] { - const senderPendingEvents = this.eventsPendingKey.get(session.senderKey.toBase64()); - if (!senderPendingEvents) return []; + public getEventsPendingRoomKey(roomId: string, sessionId: string): MatrixEvent[] { + const roomPendingEvents = this.eventsPendingKey.get(roomId); + if (!roomPendingEvents) return []; - const sessionPendingEvents = senderPendingEvents.get(session.sessionId); + const sessionPendingEvents = roomPendingEvents.get(sessionId); if (!sessionPendingEvents) return []; - const roomId = session.roomId.toString(); - return [...sessionPendingEvents].filter((ev) => ev.getRoomId() === roomId); + return [...sessionPendingEvents]; } /** * Add an event to the list of those awaiting their session keys. */ private addEventToPendingList(event: MatrixEvent): void { - const content = event.getWireContent(); - const senderKey = content.sender_key; - const sessionId = content.session_id; + const roomId = event.getRoomId(); + // We shouldn't have events without a room id here. + if (!roomId) return; - const senderPendingEvents = this.eventsPendingKey.getOrCreate(senderKey); - const sessionPendingEvents = senderPendingEvents.getOrCreate(sessionId); + const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId); + const sessionPendingEvents = roomPendingEvents.getOrCreate(event.getWireContent().session_id); sessionPendingEvents.add(event); } @@ -1874,23 +1902,22 @@ class EventDecryptor { * Remove an event from the list of those awaiting their session keys. */ private removeEventFromPendingList(event: MatrixEvent): void { - const content = event.getWireContent(); - const senderKey = content.sender_key; - const sessionId = content.session_id; + const roomId = event.getRoomId(); + if (!roomId) return; - const senderPendingEvents = this.eventsPendingKey.get(senderKey); - if (!senderPendingEvents) return; + const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId); + if (!roomPendingEvents) return; - const sessionPendingEvents = senderPendingEvents.get(sessionId); + const sessionPendingEvents = roomPendingEvents.get(event.getWireContent().session_id); if (!sessionPendingEvents) return; sessionPendingEvents.delete(event); // also clean up the higher-level maps if they are now empty if (sessionPendingEvents.size === 0) { - senderPendingEvents.delete(sessionId); - if (senderPendingEvents.size === 0) { - this.eventsPendingKey.delete(senderKey); + roomPendingEvents.delete(event.getWireContent().session_id); + if (roomPendingEvents.size === 0) { + this.eventsPendingKey.delete(roomId); } } }