diff --git a/jest.config.js b/jest.config.js index 9ccd6d24..18525aff 100644 --- a/jest.config.js +++ b/jest.config.js @@ -41,10 +41,10 @@ module.exports = { // An object that configures minimum threshold enforcement for coverage results coverageThreshold: { global: { - branches: 77.77, - functions: 84.84, - lines: 79.31, - statements: 79.31, + branches: 84.21, + functions: 94.11, + lines: 95, + statements: 95, }, }, diff --git a/src/DeferredPromise.test.ts b/src/DeferredPromise.test.ts new file mode 100644 index 00000000..b5917447 --- /dev/null +++ b/src/DeferredPromise.test.ts @@ -0,0 +1,28 @@ +import { DeferredPromise } from './DeferredPromise'; +import { ensureDefined } from './util'; + +describe('DeferredPromise', () => { + it('should not fail to create a DeferredPromise', () => { + expect(() => new DeferredPromise()).not.toThrow(); + }); + + it('should define resolve and reject fields', () => { + const promise = new DeferredPromise(); + expect(promise.resolve).toBeDefined(); + expect(promise.reject).toBeDefined(); + }); + + it('resolves with the correct value', async () => { + const deferred = new DeferredPromise(); + ensureDefined(deferred.resolve); + deferred.resolve('hello'); + expect(await deferred.promise).toBe('hello'); + }); + + it('rejects with the correct reason', async () => { + const deferred = new DeferredPromise(); + ensureDefined(deferred.reject); + deferred.reject('error'); + await expect(deferred.promise).rejects.toBe('error'); + }); +}); diff --git a/src/DeferredPromise.ts b/src/DeferredPromise.ts new file mode 100644 index 00000000..529c82b7 --- /dev/null +++ b/src/DeferredPromise.ts @@ -0,0 +1,23 @@ +/** + * A deferred promise can be resolved by a caller different from the one who + * created it. + * + * Example: + * - "A" creates a deferred promise "P", adds it to a list, and awaits it + * - "B" gets "P" from the list and resolves it + * - "A" gets the resolved value + */ +export class DeferredPromise { + promise: Promise; + + resolve?: (value: T | PromiseLike) => void; + + reject?: (reason?: any) => void; + + constructor() { + this.promise = new Promise((resolve, reject) => { + this.resolve = resolve; + this.reject = reject; + }); + } +} diff --git a/src/snap-keyring.test.ts b/src/SnapKeyring.test.ts similarity index 80% rename from src/snap-keyring.test.ts rename to src/SnapKeyring.test.ts index 31efd464..2aadfbca 100644 --- a/src/snap-keyring.test.ts +++ b/src/SnapKeyring.test.ts @@ -58,6 +58,25 @@ describe('SnapKeyring', () => { }), ).rejects.toThrow('Method not supported: invalid'); }); + + it('should submit an async request and return the result', async () => { + mockSnapController.handleRequest.mockResolvedValue({ pending: true }); + const requestPromise = keyring.signPersonalMessage( + accounts[0].address, + 'hello', + ); + + const { calls } = mockSnapController.handleRequest.mock; + const requestId = calls[calls.length - 1][0].request.params.request.id; + await keyring.handleKeyringSnapMessage(snapId, { + method: 'submitResponse', + params: { + id: requestId, + result: '0x123', + }, + }); + expect(await requestPromise).toBe('0x123'); + }); }); describe('getAccounts', () => { @@ -226,4 +245,30 @@ describe('SnapKeyring', () => { ); }); }); + + describe('removeAccount', () => { + it('should throw an error if the account is not found', async () => { + await expect(keyring.removeAccount('0x0')).rejects.toThrow( + 'Account address not found: 0x0', + ); + }); + + it('should remove an account', async () => { + mockSnapController.handleRequest.mockResolvedValue(null); + await keyring.removeAccount(accounts[0].address); + expect(await keyring.getAccounts()).toStrictEqual([accounts[1].address]); + }); + + it('should remove the account and warn if snap fails', async () => { + const spy = jest.spyOn(console, 'error').mockImplementation(); + mockSnapController.handleRequest.mockRejectedValue('error'); + await keyring.removeAccount(accounts[0].address); + expect(await keyring.getAccounts()).toStrictEqual([accounts[1].address]); + expect(console.error).toHaveBeenCalledWith( + 'Account "0xC728514Df8A7F9271f4B7a4dd2Aa6d2D723d3eE3" may not have been removed from snap "local:snap.mock":', + 'error', + ); + spy.mockRestore(); + }); + }); }); diff --git a/src/snap-keyring.ts b/src/SnapKeyring.ts similarity index 91% rename from src/snap-keyring.ts rename to src/SnapKeyring.ts index 202ca732..6d70e8df 100644 --- a/src/snap-keyring.ts +++ b/src/SnapKeyring.ts @@ -12,8 +12,9 @@ import { assert, object, string, record, Infer } from 'superstruct'; import { v4 as uuid } from 'uuid'; import { CaseInsensitiveMap } from './CaseInsensitiveMap'; +import { DeferredPromise } from './DeferredPromise'; import { SnapMessage, SnapMessageStruct } from './types'; -import { DeferredPromise, strictMask, toJson, unique } from './util'; +import { strictMask, toJson, unique } from './util'; export const SNAP_KEYRING_TYPE = 'Snap Keyring'; @@ -74,7 +75,7 @@ export class SnapKeyring extends EventEmitter { // Don't call the snap back to list the accounts. The main use case for // this method is to allow the snap to verify if the keyring's state is // in sync with the snap's state. - return Array.from(this.#addressToAccount.values()).filter( + return [...this.#addressToAccount.values()].filter( (account) => this.#addressToSnapId.get(account.address) === snapId, ); } @@ -82,7 +83,7 @@ export class SnapKeyring extends EventEmitter { case 'submitResponse': { const { id, result } = params as any; // FIXME: add a struct for this this.#resolveRequest(id, result); - return true; + return null; } default: @@ -132,7 +133,7 @@ export class SnapKeyring extends EventEmitter { // Do not call the snap here. This method is called by the UI, keep it // _fast_. return unique( - Array.from(this.#addressToAccount.values(), (account) => account.address), + [...this.#addressToAccount.values()].map((account) => account.address), ); } @@ -151,6 +152,12 @@ export class SnapKeyring extends EventEmitter { ): Promise { const { account, snapId } = this.#resolveAddress(address); const id = uuid(); + + // Create the promise before calling the snap to prevent a race condition + // where the snap responds before we have a chance to create it. + const promise = new DeferredPromise(); + this.#pendingRequests.set(id, promise); + const response = await this.#snapClient.withSnapId(snapId).submitRequest({ account: account.id, scope: '', // Chain ID in CAIP-2 format. @@ -162,12 +169,13 @@ export class SnapKeyring extends EventEmitter { }, }); + // The snap can respond immediately if the request is not async. In that + // case we should delete the promise to prevent a leak. if (!response.pending) { + this.#pendingRequests.delete(id); return response.result; } - const promise = new DeferredPromise(); - this.#pendingRequests.set(id, promise); return promise.promise; } @@ -182,7 +190,7 @@ export class SnapKeyring extends EventEmitter { address: string, transaction: TypedTransaction, _opts = {}, - ) { + ): Promise { const tx = toJson({ ...transaction.toJSON(), type: transaction.type, @@ -280,9 +288,10 @@ export class SnapKeyring extends EventEmitter { * @param address - Address of the account to remove. */ async removeAccount(address: string): Promise { + const { account, snapId } = this.#resolveAddress(address); + // Always remove the account from the maps, even if the snap is going to // fail to delete it. - const { account, snapId } = this.#resolveAddress(address); this.#removeAccountFromMaps(account); try { @@ -292,7 +301,7 @@ export class SnapKeyring extends EventEmitter { // with the account deletion, otherwise the account will be stuck in the // keyring. console.error( - `Cannot talk to snap "${snapId}", continuing with the deletion of account ${address}.`, + `Account "${address}" may not have been removed from snap "${snapId}":`, error, ); } @@ -304,14 +313,12 @@ export class SnapKeyring extends EventEmitter { * @param extraSnapIds - Extra snap IDs to sync accounts for. */ async #syncAllSnapsAccounts(...extraSnapIds: string[]): Promise { - const snapIds = unique( - Array.from(this.#addressToSnapId.values()).concat(extraSnapIds), - ); - - for (const snapId of snapIds) { + const snapIds = [...this.#addressToSnapId.values()].concat(extraSnapIds); + for (const snapId of unique(snapIds)) { try { await this.#syncSnapAccounts(snapId); } catch (error) { + // Log the error and continue with the other snaps. console.error(`Failed to sync accounts for snap "${snapId}":`, error); } } @@ -325,12 +332,12 @@ export class SnapKeyring extends EventEmitter { async #syncSnapAccounts(snapId: string): Promise { // Get new accounts first, before removing the old ones. This way, if // something goes wrong, we don't lose the old accounts. + const oldAccounts = this.#getAccountsBySnapId(snapId); const newAccounts = await this.#snapClient .withSnapId(snapId) .listAccounts(); // Remove the old accounts from the maps. - const oldAccounts = this.#getAccountsBySnapId(snapId); for (const account of oldAccounts) { this.#removeAccountFromMaps(account); } @@ -367,14 +374,13 @@ export class SnapKeyring extends EventEmitter { * @param result - Result of the request. */ #resolveRequest(id: string, result: any): void { - const signingPromise = this.#pendingRequests.get(id); - if (signingPromise?.resolve === undefined) { - console.warn(`No pending request found for ID: ${id}`); - return; + const promise = this.#pendingRequests.get(id); + if (promise?.resolve === undefined) { + throw new Error(`No pending request found for ID: ${id}`); } this.#pendingRequests.delete(id); - signingPromise.resolve(result); + promise.resolve(result); } /** diff --git a/src/index.ts b/src/index.ts index a3be23ca..d01d020e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,2 +1,2 @@ export * from './types'; -export * from './snap-keyring'; +export * from './SnapKeyring'; diff --git a/src/util.test.ts b/src/util.test.ts index df18eba9..d60c7ccd 100644 --- a/src/util.test.ts +++ b/src/util.test.ts @@ -1,15 +1,54 @@ -import { DeferredPromise } from './util'; - -describe('DeferredPromise', () => { - describe('constructor', () => { - it('should not fail to create a DeferredPromise', () => { - expect(() => new DeferredPromise()).not.toThrow(); - }); - - it('should define resolve and reject fields', () => { - const promise = new DeferredPromise(); - expect(promise.resolve).toBeDefined(); - expect(promise.reject).toBeDefined(); - }); +import { ensureDefined, toJson, unique } from './util'; + +describe('unique', () => { + it('returns an empty array when given an empty array', () => { + const arr: number[] = []; + const result = unique(arr); + expect(result).toStrictEqual([]); + }); + + it('returns an array with unique elements', () => { + const arr = [1, 2, 2, 3, 3, 3]; + const result = unique(arr); + expect(result).toStrictEqual([1, 2, 3]); + }); + + it('returns an array with unique objects', () => { + const obj1 = { name: 'John' }; + const obj2 = { name: 'Jane' }; + const arr = [obj1, obj1, obj2]; + const result = unique(arr); + expect(result).toStrictEqual([{ name: 'John' }, { name: 'Jane' }]); + }); +}); + +describe('toJson', () => { + it('correctly serializes an object to JSON', () => { + const obj = { name: 'John', age: 30 }; + const json = toJson(obj); + expect(json).toStrictEqual(obj); + }); + + it('correctly serializes an array to JSON', () => { + const arr = [1, 2, 3]; + const json = toJson(arr); + expect(json).toStrictEqual(arr); + }); + + it('correctly serializes an object with defined and non-undefined fields to JSON', () => { + const obj = { name: 'John', age: undefined }; + const expectedJson = { name: 'John' }; + const json = toJson(obj); + expect(json).toStrictEqual(expectedJson); + }); +}); + +describe('ensureDefined', () => { + it('does not throw an error when value is defined', () => { + expect(() => ensureDefined('hello')).not.toThrow(); + }); + + it('throws an error when value is undefined', () => { + expect(() => ensureDefined(undefined)).toThrow('Argument is undefined'); }); }); diff --git a/src/util.ts b/src/util.ts index f938b86e..44b186ea 100644 --- a/src/util.ts +++ b/src/util.ts @@ -1,30 +1,6 @@ import type { Json } from '@metamask/utils'; import { Struct, assert } from 'superstruct'; -/** - * A deferred promise can be resolved by a caller different from the one who - * created it. - * - * Example: - * - "A" creates a deferred promise "P", adds it to a list, and awaits it - * - "B" gets "P" from the list and resolves it - * - "A" gets the resolved value - */ -export class DeferredPromise { - promise: Promise; - - resolve?: (value: T | PromiseLike) => void; - - reject?: (reason?: any) => void; - - constructor() { - this.promise = new Promise((resolve, reject) => { - this.resolve = resolve; - this.reject = reject; - }); - } -} - /** * Assert that a value is valid according to a struct. * @@ -68,3 +44,14 @@ export function unique(array: T[]): T[] { export function toJson(value: any): T { return JSON.parse(JSON.stringify(value)) as T; } + +/** + * Asserts that the given value is defined. + * + * @param value - Value to check. + */ +export function ensureDefined(value: T | undefined): asserts value is T { + if (value === undefined) { + throw new Error('Argument is undefined'); + } +}