Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust implementation of getEncryptionInfoForEvent #3718

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
IHttpOpts,
IToDeviceEvent,
MatrixClient,
MatrixEvent,
MatrixHttpApi,
TypedEventEmitter,
} from "../../../src";
Expand All @@ -38,7 +39,13 @@ import { CryptoBackend } from "../../../src/common-crypto/CryptoBackend";
import { IEventDecryptionResult } from "../../../src/@types/crypto";
import { OutgoingRequestProcessor } from "../../../src/rust-crypto/OutgoingRequestProcessor";
import { ServerSideSecretStorage } from "../../../src/secret-storage";
import { CryptoCallbacks, ImportRoomKeysOpts, VerificationRequest } from "../../../src/crypto-api";
import {
CryptoCallbacks,
EventShieldColour,
EventShieldReason,
ImportRoomKeysOpts,
VerificationRequest,
} from "../../../src/crypto-api";
import * as testData from "../../test-utils/test-data";
import { defer } from "../../../src/utils";

Expand Down Expand Up @@ -373,6 +380,111 @@ describe("RustCrypto", () => {
});
});

describe(".getEncryptionInfoForEvent", () => {
let rustCrypto: RustCrypto;
let olmMachine: Mocked<RustSdkCryptoJs.OlmMachine>;

beforeEach(() => {
olmMachine = {
getRoomEventEncryptionInfo: jest.fn(),
} as unknown as Mocked<RustSdkCryptoJs.OlmMachine>;
rustCrypto = new RustCrypto(
olmMachine,
{} as MatrixClient["http"],
TEST_USER,
TEST_DEVICE_ID,
{} as ServerSideSecretStorage,
{} as CryptoCallbacks,
);
});

afterEach(() => {
jest.restoreAllMocks();
});

async function makeEncryptedEvent(): Promise<MatrixEvent> {
const encryptedEvent = mkEvent({
event: true,
type: "m.room.encrypted",
content: { algorithm: "fake_alg" },
room: "!room:id",
});
encryptedEvent.event.event_id = "$event:id";
const mockCryptoBackend = {
decryptEvent: () =>
({
clearEvent: { content: { body: "1234" } },
} as unknown as IEventDecryptionResult),
} as unknown as CryptoBackend;
await encryptedEvent.attemptDecryption(mockCryptoBackend);
return encryptedEvent;
}

it("should handle unencrypted events", async () => {
const event = mkEvent({ event: true, type: "m.room.message", content: { body: "xyz" } });
const res = await rustCrypto.getEncryptionInfoForEvent(event);
expect(res).toBe(null);
expect(olmMachine.getRoomEventEncryptionInfo).not.toHaveBeenCalled();
});

it("passes the event into the OlmMachine", async () => {
const encryptedEvent = await makeEncryptedEvent();
const res = await rustCrypto.getEncryptionInfoForEvent(encryptedEvent);
expect(res).toBe(null);
expect(olmMachine.getRoomEventEncryptionInfo).toHaveBeenCalledTimes(1);
const [passedEvent, passedRoom] = olmMachine.getRoomEventEncryptionInfo.mock.calls[0];
expect(passedRoom.toString()).toEqual("!room:id");
expect(JSON.parse(passedEvent)).toStrictEqual(
expect.objectContaining({
event_id: "$event:id",
}),
);
});

it.each([
[RustSdkCryptoJs.ShieldColor.None, EventShieldColour.NONE],
[RustSdkCryptoJs.ShieldColor.Grey, EventShieldColour.GREY],
[RustSdkCryptoJs.ShieldColor.Red, EventShieldColour.RED],
])("gets the right shield color (%i)", async (rustShield, expectedShield) => {
const mockEncryptionInfo = {
shieldState: jest.fn().mockReturnValue({ color: rustShield, message: null }),
} as unknown as RustSdkCryptoJs.EncryptionInfo;
olmMachine.getRoomEventEncryptionInfo.mockResolvedValue(mockEncryptionInfo);

const res = await rustCrypto.getEncryptionInfoForEvent(await makeEncryptedEvent());
expect(mockEncryptionInfo.shieldState).toHaveBeenCalledWith(false);
expect(res).not.toBe(null);
expect(res!.shieldColour).toEqual(expectedShield);
});

it.each([
[null, null],
["Encrypted by an unverified user.", EventShieldReason.UNVERIFIED_IDENTITY],
["Encrypted by a device not verified by its owner.", EventShieldReason.UNSIGNED_DEVICE],
[
"The authenticity of this encrypted message can't be guaranteed on this device.",
EventShieldReason.AUTHENTICITY_NOT_GUARANTEED,
],
["Encrypted by an unknown or deleted device.", EventShieldReason.UNKNOWN_DEVICE],
["bloop", EventShieldReason.UNKNOWN],
])("gets the right shield reason (%s)", async (rustReason, expectedReason) => {
// suppress the warning from the unknown shield reason
jest.spyOn(console, "warn").mockImplementation(() => {});

const mockEncryptionInfo = {
shieldState: jest
.fn()
.mockReturnValue({ color: RustSdkCryptoJs.ShieldColor.None, message: rustReason }),
} as unknown as RustSdkCryptoJs.EncryptionInfo;
olmMachine.getRoomEventEncryptionInfo.mockResolvedValue(mockEncryptionInfo);

const res = await rustCrypto.getEncryptionInfoForEvent(await makeEncryptedEvent());
expect(mockEncryptionInfo.shieldState).toHaveBeenCalledWith(false);
expect(res).not.toBe(null);
expect(res!.shieldReason).toEqual(expectedReason);
});
});

describe("get|setTrustCrossSignedDevices", () => {
let rustCrypto: RustCrypto;

Expand Down
84 changes: 72 additions & 12 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import {
DeviceVerificationStatus,
EventEncryptionInfo,
EventShieldColour,
EventShieldReason,
GeneratedSecretStorageKey,
ImportRoomKeyProgressData,
ImportRoomKeysOpts,
Expand Down Expand Up @@ -788,10 +789,7 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
* Implementation of {@link CryptoApi.getEncryptionInfoForEvent}.
*/
public async getEncryptionInfoForEvent(event: MatrixEvent): Promise<EventEncryptionInfo | null> {
return {
shieldColour: EventShieldColour.NONE,
shieldReason: null,
};
return this.eventDecryptor.getEncryptionInfoForEvent(event);
}

/**
Expand Down Expand Up @@ -1484,14 +1482,7 @@ class EventDecryptor {

try {
const res = (await this.olmMachine.decryptRoomEvent(
JSON.stringify({
event_id: event.getId(),
type: event.getWireType(),
sender: event.getSender(),
state_key: event.getStateKey(),
content: event.getWireContent(),
origin_server_ts: event.getTs(),
}),
stringifyEvent(event),
new RustSdkCryptoJs.RoomId(event.getRoomId()!),
)) as RustSdkCryptoJs.DecryptedRoomEvent;

Expand Down Expand Up @@ -1549,6 +1540,20 @@ class EventDecryptor {
}
}

public async getEncryptionInfoForEvent(event: MatrixEvent): Promise<EventEncryptionInfo | null> {
if (!event.getClearContent()) {
// not successfully decrypted
return null;
}

const encryptionInfo = await this.olmMachine.getRoomEventEncryptionInfo(
stringifyEvent(event),
new RustSdkCryptoJs.RoomId(event.getRoomId()!),
);

return rustEncryptionInfoToJsEncryptionInfo(encryptionInfo);
}

/**
* Look for events which are waiting for a given megolm session
*
Expand Down Expand Up @@ -1606,6 +1611,61 @@ class EventDecryptor {
}
}

function stringifyEvent(event: MatrixEvent): string {
return JSON.stringify({
event_id: event.getId(),
type: event.getWireType(),
sender: event.getSender(),
state_key: event.getStateKey(),
content: event.getWireContent(),
origin_server_ts: event.getTs(),
});
}

function rustEncryptionInfoToJsEncryptionInfo(
encryptionInfo: RustSdkCryptoJs.EncryptionInfo | undefined,
): EventEncryptionInfo | null {
if (encryptionInfo === undefined) {
// not decrypted here
return null;
}

// TODO: use strict shield semantics.
const shieldState = encryptionInfo.shieldState(false);

let shieldColour: EventShieldColour;
switch (shieldState.color) {
case RustSdkCryptoJs.ShieldColor.Grey:
shieldColour = EventShieldColour.GREY;
break;
case RustSdkCryptoJs.ShieldColor.None:
shieldColour = EventShieldColour.NONE;
break;
default:
shieldColour = EventShieldColour.RED;
}

let shieldReason: EventShieldReason | null;
if (shieldState.message === null) {
shieldReason = null;
} else if (shieldState.message === "Encrypted by an unverified user.") {
shieldReason = EventShieldReason.UNVERIFIED_IDENTITY;
} else if (shieldState.message === "Encrypted by a device not verified by its owner.") {
shieldReason = EventShieldReason.UNSIGNED_DEVICE;
} else if (
shieldState.message === "The authenticity of this encrypted message can't be guaranteed on this device."
) {
shieldReason = EventShieldReason.AUTHENTICITY_NOT_GUARANTEED;
} else if (shieldState.message === "Encrypted by an unknown or deleted device.") {
shieldReason = EventShieldReason.UNKNOWN_DEVICE;
} else {
logger.warn(`Unknown shield state message '${shieldState.message}'`);
shieldReason = EventShieldReason.UNKNOWN;
}

return { shieldColour, shieldReason };
}

type RustCryptoEvents =
| CryptoEvent.VerificationRequestReceived
| CryptoEvent.UserTrustStatusChanged
Expand Down
Loading