Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle late-arriving m.room_key.withheld messages #4310

Merged
merged 5 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions spec/integ/crypto/crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2343,13 +2343,12 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
])(
"Decryption fails with withheld error if a withheld notice with code '%s' is received",
(withheldCode, expectedMessage, expectedErrorCode) => {
// TODO: test arrival after the event too.
it.each(["before"])("%s the event", async (when) => {
it.each(["before", "after"])("%s the event", async (when) => {
expectAliceKeyQuery({ device_keys: { "@alice:localhost": {} }, failures: {} });
await startClientAndAwaitFirstSync();

// A promise which resolves, with the MatrixEvent which wraps the event, once the decryption fails.
const awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
let awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);

// Send Alice an encrypted room event which looks like it was encrypted with a megolm session
async function sendEncryptedEvent() {
Expand Down Expand Up @@ -2393,6 +2392,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
await sendEncryptedEvent();
} else {
await sendEncryptedEvent();
// Make sure that the first attempt to decrypt has happened before the withheld arrives
await awaitDecryption;
awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
await sendWithheldMessage();
}

Expand Down
1 change: 1 addition & 0 deletions spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ describe("initRustCrypto", () => {
deleteSecretsFromInbox: jest.fn(),
registerReceiveSecretCallback: jest.fn(),
registerDevicesUpdatedCallback: jest.fn(),
registerRoomKeysWithheldCallback: jest.fn(),
outgoingRequests: jest.fn(),
isBackupEnabled: jest.fn().mockResolvedValue(false),
verifyBackup: jest.fn().mockResolvedValue({ trusted: jest.fn().mockReturnValue(false) }),
Expand Down
3 changes: 3 additions & 0 deletions src/rust-crypto/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ async function initOlmMachine(
await olmMachine.registerRoomKeyUpdatedCallback((sessions: RustSdkCryptoJs.RoomKeyInfo[]) =>
rustCrypto.onRoomKeysUpdated(sessions),
);
await olmMachine.registerRoomKeysWithheldCallback((withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]) =>
rustCrypto.onRoomKeysWithheld(withheld),
);
await olmMachine.registerUserIdentityUpdatedCallback((userId: RustSdkCryptoJs.UserId) =>
rustCrypto.onUserIdentityUpdated(userId),
);
Expand Down
75 changes: 51 additions & 24 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,7 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
this.logger.debug(
`Got update for session ${key.senderKey.toBase64()}|${key.sessionId} in ${key.roomId.toString()}`,
);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key.roomId.toString(), key.sessionId);
if (pendingList.length === 0) return;

this.logger.debug(
Expand All @@ -1507,6 +1507,37 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
}
}

/**
* Callback for `OlmMachine.registerRoomKeyWithheldCallback`.
*
* Called by the rust sdk whenever we are told that a key has been withheld. We see if we had any events that
* failed to decrypt for the given session, and update their status if so.
*
* @param withheld - Details of the withheld sessions.
*/
public async onRoomKeysWithheld(withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]): Promise<void> {
for (const session of withheld) {
this.logger.debug(`Got withheld message for session ${session.sessionId} in ${session.roomId.toString()}`);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(
session.roomId.toString(),
session.sessionId,
);
if (pendingList.length === 0) return;

// The easiest way to update the status of the event is to have another go at decrypting it.
this.logger.debug(
"Retrying decryption on events:",
pendingList.map((e) => `${e.getId()}`),
);

for (const ev of pendingList) {
ev.attemptDecryption(this, { isRetry: true }).catch((_e) => {
// It's somewhat expected that we still can't decrypt here.
});
}
}
}

/**
* Callback for `OlmMachine.registerUserIdentityUpdatedCallback`
*
Expand Down Expand Up @@ -1683,7 +1714,7 @@ class EventDecryptor {
/**
* Events which we couldn't decrypt due to unknown sessions / indexes.
*
* Map from senderKey to sessionId to Set of MatrixEvents
* Map from roomId to sessionId to Set of MatrixEvents
*/
private eventsPendingKey = new MapWithDefault<string, MapWithDefault<string, Set<MatrixEvent>>>(
() => new MapWithDefault<string, Set<MatrixEvent>>(() => new Set()),
Expand Down Expand Up @@ -1843,54 +1874,50 @@ class EventDecryptor {
* Look for events which are waiting for a given megolm session
*
* Returns a list of events which were encrypted by `session` and could not be decrypted
*
* @param session -
*/
public getEventsPendingRoomKey(session: RustSdkCryptoJs.RoomKeyInfo): MatrixEvent[] {
const senderPendingEvents = this.eventsPendingKey.get(session.senderKey.toBase64());
if (!senderPendingEvents) return [];
public getEventsPendingRoomKey(roomId: string, sessionId: string): MatrixEvent[] {
const roomPendingEvents = this.eventsPendingKey.get(roomId);
if (!roomPendingEvents) return [];

const sessionPendingEvents = senderPendingEvents.get(session.sessionId);
const sessionPendingEvents = roomPendingEvents.get(sessionId);
if (!sessionPendingEvents) return [];

const roomId = session.roomId.toString();
return [...sessionPendingEvents].filter((ev) => ev.getRoomId() === roomId);
return [...sessionPendingEvents];
}

/**
* Add an event to the list of those awaiting their session keys.
*/
private addEventToPendingList(event: MatrixEvent): void {
const content = event.getWireContent();
const senderKey = content.sender_key;
const sessionId = content.session_id;
const roomId = event.getRoomId();
// We shouldn't have events without a room id here.
if (!roomId) return;

const senderPendingEvents = this.eventsPendingKey.getOrCreate(senderKey);
const sessionPendingEvents = senderPendingEvents.getOrCreate(sessionId);
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
const sessionPendingEvents = roomPendingEvents.getOrCreate(event.getWireContent().session_id);
sessionPendingEvents.add(event);
}

/**
* Remove an event from the list of those awaiting their session keys.
*/
private removeEventFromPendingList(event: MatrixEvent): void {
const content = event.getWireContent();
const senderKey = content.sender_key;
const sessionId = content.session_id;
const roomId = event.getRoomId();
if (!roomId) return;

const senderPendingEvents = this.eventsPendingKey.get(senderKey);
if (!senderPendingEvents) return;
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
if (!roomPendingEvents) return;

const sessionPendingEvents = senderPendingEvents.get(sessionId);
const sessionPendingEvents = roomPendingEvents.get(event.getWireContent().session_id);
if (!sessionPendingEvents) return;

sessionPendingEvents.delete(event);

// also clean up the higher-level maps if they are now empty
if (sessionPendingEvents.size === 0) {
senderPendingEvents.delete(sessionId);
if (senderPendingEvents.size === 0) {
this.eventsPendingKey.delete(senderKey);
roomPendingEvents.delete(event.getWireContent().session_id);
if (roomPendingEvents.size === 0) {
this.eventsPendingKey.delete(roomId);
}
}
}
Expand Down
Loading