diff --git a/src/agent/service/nodesCrossSignClaim.ts b/src/agent/service/nodesCrossSignClaim.ts index 0e0e5a53e1..2c9793ba08 100644 --- a/src/agent/service/nodesCrossSignClaim.ts +++ b/src/agent/service/nodesCrossSignClaim.ts @@ -39,7 +39,7 @@ function nodesCrossSignClaim({ true, ); try { - await db.withTransactionF((tran) => { + await db.withTransactionF(async (tran) => { const readStatus = await genClaims.read(); // If nothing to read, end and destroy if (readStatus.done) { diff --git a/src/agent/service/nodesHolePunchMessageSend.ts b/src/agent/service/nodesHolePunchMessageSend.ts index 65738e89c3..c610d7428f 100644 --- a/src/agent/service/nodesHolePunchMessageSend.ts +++ b/src/agent/service/nodesHolePunchMessageSend.ts @@ -58,7 +58,7 @@ function nodesHolePunchMessageSend({ // Firstly, check if this node is the desired node // If so, then we want to make this node start sending hole punching packets // back to the source node. - await db.withTransactionF((tran) => { + await db.withTransactionF(async (tran) => { if (keyManager.getNodeId().equals(targetId)) { const [host, port] = networkUtils.parseAddress( call.request.getProxyAddress(), diff --git a/src/agent/service/vaultsGitInfoGet.ts b/src/agent/service/vaultsGitInfoGet.ts index 4c614203a8..0fb18c96ad 100644 --- a/src/agent/service/vaultsGitInfoGet.ts +++ b/src/agent/service/vaultsGitInfoGet.ts @@ -34,7 +34,7 @@ function vaultsGitInfoGet({ ): Promise => { const genWritable = grpcUtils.generatorWritable(call, true); try { - await db.withTransactionF((tran) => { + await db.withTransactionF(async (tran) => { const vaultIdFromName = await vaultManager.getVaultId( call.request.getVault()?.getNameOrId() as VaultName, tran, diff --git a/src/agent/service/vaultsGitPackGet.ts b/src/agent/service/vaultsGitPackGet.ts index f388f98558..f8aa5dc3d8 100644 --- a/src/agent/service/vaultsGitPackGet.ts +++ b/src/agent/service/vaultsGitPackGet.ts @@ -57,7 +57,7 @@ function vaultsGitPackGet({ const nodeId = connectionInfo.remoteNodeId; const nodeIdEncoded = nodesUtils.encodeNodeId(nodeId); const nameOrId = meta.get('vaultNameOrId').pop()!.toString(); - await db.withTransactionF((tran) => { + await db.withTransactionF(async (tran) => { const vaultIdFromName = await vaultManager.getVaultId( nameOrId as VaultName, tran, diff --git a/src/agent/service/vaultsScan.ts b/src/agent/service/vaultsScan.ts index 2fbd498610..f827191085 100644 --- a/src/agent/service/vaultsScan.ts +++ b/src/agent/service/vaultsScan.ts @@ -36,7 +36,7 @@ function vaultsScan({ } const nodeId = connectionInfo.remoteNodeId; try { - await db.withTransactionF((tran) => { + await db.withTransactionF(async (tran) => { const listResponse = vaultManager.handleScanVaults(nodeId, tran); for await (const { vaultId, diff --git a/src/discovery/Discovery.ts b/src/discovery/Discovery.ts index 3c5e57a492..138dc1163a 100644 --- a/src/discovery/Discovery.ts +++ b/src/discovery/Discovery.ts @@ -24,9 +24,7 @@ import { status, } from '@matrixai/async-init/dist/CreateDestroyStartStop'; import { IdInternal } from '@matrixai/id'; -import { Lock } from '@matrixai/async-locks'; import * as idUtils from '@matrixai/id/dist/utils'; -import * as resources from '@matrixai/resources'; import * as discoveryUtils from './utils'; import * as discoveryErrors from './errors'; import * as nodesErrors from '../nodes/errors'; @@ -91,7 +89,6 @@ class Discovery { protected discoveryProcess: Promise; protected queuePlug = promise(); protected queueDrained = promise(); - protected lock: Lock = new Lock(); public constructor({ keyManager, @@ -420,22 +417,19 @@ class Discovery { } /** - * Simple check for whether the Discovery Queue is empty. Uses a - * transaction lock to ensure consistency. + * Simple check for whether the Discovery Queue is empty. */ protected async queueIsEmpty(): Promise { - return await this.lock.withF(async () => { - let nextDiscoveryQueueId: DiscoveryQueueId | undefined; - const keyIterator = this.db.iterator(this.discoveryQueueDbPath, { - limit: 1, - values: false, - }); - for await (const [keyPath] of keyIterator) { - const key = keyPath[0] as Buffer; - nextDiscoveryQueueId = IdInternal.fromBuffer(key); - } - return nextDiscoveryQueueId == null; + let nextDiscoveryQueueId: DiscoveryQueueId | undefined; + const keyIterator = this.db.iterator(this.discoveryQueueDbPath, { + limit: 1, + values: false, }); + for await (const [keyPath] of keyIterator) { + const key = keyPath[0] as Buffer; + nextDiscoveryQueueId = IdInternal.fromBuffer(key); + } + return nextDiscoveryQueueId == null; } /** @@ -446,25 +440,23 @@ class Discovery { protected async pushKeyToDiscoveryQueue( gestaltKey: GestaltKey, ): Promise { - await resources.withF( - [this.db.transaction(), this.lock.lock()], - async ([tran]) => { - const valueIterator = tran.iterator( - this.discoveryQueueDbPath, - { valueAsBuffer: false }, - ); - for await (const [, value] of valueIterator) { - if (value === gestaltKey) { - return; - } + await this.db.withTransactionF(async (tran) => { + await tran.lock(gestaltKey); + const valueIterator = tran.iterator( + this.discoveryQueueDbPath, + { valueAsBuffer: false }, + ); + for await (const [, value] of valueIterator) { + if (value === gestaltKey) { + return; } - const discoveryQueueId = this.discoveryQueueIdGenerator(); - await tran.put( - [...this.discoveryQueueDbPath, idUtils.toBuffer(discoveryQueueId)], - gestaltKey, - ); - }, - ); + } + const discoveryQueueId = this.discoveryQueueIdGenerator(); + await tran.put( + [...this.discoveryQueueDbPath, idUtils.toBuffer(discoveryQueueId)], + gestaltKey, + ); + }); this.queuePlug.resolveP(); } @@ -476,12 +468,7 @@ class Discovery { protected async removeKeyFromDiscoveryQueue( keyId: DiscoveryQueueId, ): Promise { - await this.lock.withF(async () => { - await this.db.del([ - ...this.discoveryQueueDbPath, - idUtils.toBuffer(keyId), - ]); - }); + await this.db.del([...this.discoveryQueueDbPath, idUtils.toBuffer(keyId)]); } /** diff --git a/src/notifications/NotificationsManager.ts b/src/notifications/NotificationsManager.ts index f616431216..e2569d4eaf 100644 --- a/src/notifications/NotificationsManager.ts +++ b/src/notifications/NotificationsManager.ts @@ -1,4 +1,4 @@ -import type { DB, DBTransaction, KeyPath, LevelPath } from '@matrixai/db'; +import type { DB, DBTransaction, LevelPath } from '@matrixai/db'; import type { NotificationId, Notification, @@ -12,13 +12,11 @@ import type NodeConnectionManager from '../nodes/NodeConnectionManager'; import type { NodeId } from '../nodes/types'; import Logger from '@matrixai/logger'; import { IdInternal } from '@matrixai/id'; -import { Lock, LockBox } from '@matrixai/async-locks'; import { CreateDestroyStartStop, ready, } from '@matrixai/async-init/dist/CreateDestroyStartStop'; import { utils as idUtils } from '@matrixai/id'; -import { withF } from '@matrixai/resources'; import * as notificationsUtils from './utils'; import * as notificationsErrors from './errors'; import * as notificationsPB from '../proto/js/polykey/v1/notifications/notifications_pb'; @@ -78,7 +76,6 @@ class NotificationsManager { protected nodeManager: NodeManager; protected nodeConnectionManager: NodeConnectionManager; protected messageCap: number; - protected locks: LockBox = new LockBox(); /** * Top level stores MESSAGE_COUNT_KEY -> number (of messages) @@ -123,36 +120,30 @@ class NotificationsManager { public async start({ fresh = false, }: { fresh?: boolean } = {}): Promise { - await withF( - [ - this.db.transaction(), - this.locks.lock([ - [...this.notificationsDbPath, MESSAGE_COUNT_KEY], - Lock, - ]), - ], - async ([tran]) => { - this.logger.info(`Starting ${this.constructor.name}`); - if (fresh) { - await tran.clear(this.notificationsDbPath); - } + await this.db.withTransactionF(async (tran) => { + await tran.lock( + [...this.notificationsDbPath, MESSAGE_COUNT_KEY].toString(), + ); + this.logger.info(`Starting ${this.constructor.name}`); + if (fresh) { + await tran.clear(this.notificationsDbPath); + } - // Getting latest ID and creating ID generator - let latestId: NotificationId | undefined; - const keyIterator = tran.iterator(this.notificationsMessagesDbPath, { - limit: 1, - reverse: true, - values: false, - }); - for await (const [keyPath] of keyIterator) { - const key = keyPath[0] as Buffer; - latestId = IdInternal.fromBuffer(key); - } - this.notificationIdGenerator = - notificationsUtils.createNotificationIdGenerator(latestId); - this.logger.info(`Started ${this.constructor.name}`); - }, - ); + // Getting latest ID and creating ID generator + let latestId: NotificationId | undefined; + const keyIterator = tran.iterator(this.notificationsMessagesDbPath, { + limit: 1, + reverse: true, + values: false, + }); + for await (const [keyPath] of keyIterator) { + const key = keyPath[0] as Buffer; + latestId = IdInternal.fromBuffer(key); + } + this.notificationIdGenerator = + notificationsUtils.createNotificationIdGenerator(latestId); + this.logger.info(`Started ${this.constructor.name}`); + }); } public async stop() { @@ -168,20 +159,6 @@ class NotificationsManager { this.logger.info(`Destroyed ${this.constructor.name}`); } - @ready(new notificationsErrors.ErrorNotificationsNotRunning()) - public async withTransactionF( - ...params: [...keys: Array, f: (tran: DBTransaction) => Promise] - ): Promise { - const f = params.pop() as (tran: DBTransaction) => Promise; - const lockRequests = (params as Array).map<[KeyPath, typeof Lock]>( - (key) => [key, Lock], - ); - return withF( - [this.db.transaction(), this.locks.lock(...lockRequests)], - ([tran]) => f(tran), - ); - } - /** * Send a notification to another node * The `data` parameter must match one of the NotificationData types outlined in ./types @@ -218,10 +195,12 @@ class NotificationsManager { ): Promise { const messageCountPath = [...this.notificationsDbPath, MESSAGE_COUNT_KEY]; if (tran == null) { - return this.withTransactionF(messageCountPath, (tran) => + return this.db.withTransactionF(async (tran) => this.receiveNotification(notification, tran), ); } + + await tran.lock(messageCountPath.toString()); const nodePerms = await this.acl.getNodePerm( nodesUtils.decodeNodeId(notification.senderId)!, ); @@ -269,7 +248,7 @@ class NotificationsManager { tran?: DBTransaction; } = {}): Promise> { if (tran == null) { - return this.withTransactionF((tran) => + return this.db.withTransactionF((tran) => this.readNotifications({ unread, number, order, tran }), ); } @@ -309,7 +288,7 @@ class NotificationsManager { tran?: DBTransaction, ): Promise { if (tran == null) { - return this.withTransactionF((tran) => + return this.db.withTransactionF((tran) => this.findGestaltInvite(fromNode, tran), ); } @@ -331,10 +310,10 @@ class NotificationsManager { public async clearNotifications(tran?: DBTransaction): Promise { const messageCountPath = [...this.notificationsDbPath, MESSAGE_COUNT_KEY]; if (tran == null) { - return this.withTransactionF(messageCountPath, (tran) => - this.clearNotifications(tran), - ); + return this.db.withTransactionF((tran) => this.clearNotifications(tran)); } + + await tran.lock(messageCountPath.toString()); const notificationIds = await this.getNotificationIds('all', tran); const numMessages = await tran.get(messageCountPath); if (numMessages !== undefined) { diff --git a/src/sigchain/Sigchain.ts b/src/sigchain/Sigchain.ts index 45c0c999e3..1cba404465 100644 --- a/src/sigchain/Sigchain.ts +++ b/src/sigchain/Sigchain.ts @@ -1,4 +1,4 @@ -import type { DB, DBTransaction, KeyPath, LevelPath } from '@matrixai/db'; +import type { DB, DBTransaction, LevelPath } from '@matrixai/db'; import type { ChainDataEncoded } from './types'; import type { ClaimData, @@ -16,7 +16,6 @@ import { CreateDestroyStartStop, ready, } from '@matrixai/async-init/dist/CreateDestroyStartStop'; -import { Lock, LockBox } from '@matrixai/async-locks'; import { withF } from '@matrixai/resources'; import * as sigchainErrors from './errors'; import * as claimsUtils from '../claims/utils'; @@ -32,7 +31,6 @@ class Sigchain { protected logger: Logger; protected keyManager: KeyManager; protected db: DB; - protected locks: LockBox = new LockBox(); // Top-level database for the sigchain domain protected sigchainDbPath: LevelPath = [this.constructor.name]; // ClaimId (the lexicographic integer of the sequence number) @@ -124,20 +122,6 @@ class Sigchain { this.logger.info(`Destroyed ${this.constructor.name}`); } - @ready(new sigchainErrors.ErrorSigchainNotRunning()) - public async withTransactionF( - ...params: [...keys: Array, f: (tran: DBTransaction) => Promise] - ): Promise { - const f = params.pop() as (tran: DBTransaction) => Promise; - const lockRequests = (params as Array).map<[KeyPath, typeof Lock]>( - (key) => [key, Lock], - ); - return withF( - [this.db.transaction(), this.locks.lock(...lockRequests)], - ([tran]) => f(tran), - ); - } - /** * Helper function to create claims internally in the Sigchain class. * Wraps claims::createClaim() with the static information common to all @@ -186,10 +170,10 @@ class Sigchain { this.sequenceNumberKey, ]; if (tran == null) { - return this.withTransactionF(claimIdPath, sequenceNumberPath, (tran) => - this.addClaim(claimData, tran), - ); + return this.db.withTransactionF((tran) => this.addClaim(claimData, tran)); } + + await tran.lock(claimIdPath.toString(), sequenceNumberPath.toString()); const prevSequenceNumber = await this.getSequenceNumber(tran); const newSequenceNumber = prevSequenceNumber + 1; const claim = await this.createClaim({ @@ -223,10 +207,12 @@ class Sigchain { this.sequenceNumberKey, ]; if (tran == null) { - return this.withTransactionF(claimIdPath, sequenceNumberPath, (tran) => + return this.db.withTransactionF((tran) => this.addExistingClaim(claim, tran), ); } + + await tran.lock(claimIdPath.toString(), sequenceNumberPath.toString()); const decodedClaim = claimsUtils.decodeClaim(claim); const prevSequenceNumber = await this.getSequenceNumber(tran); const expectedSequenceNumber = prevSequenceNumber + 1; @@ -255,10 +241,12 @@ class Sigchain { this.sequenceNumberKey, ]; if (tran == null) { - return this.withTransactionF(sequenceNumberPath, (tran) => + return this.db.withTransactionF((tran) => this.createIntermediaryClaim(claimData, tran), ); } + + await tran.lock(sequenceNumberPath.toString()); const claim = await this.createClaim({ hPrev: await this.getHashPrevious(tran), seq: (await this.getSequenceNumber(tran)) + 1, @@ -279,7 +267,7 @@ class Sigchain { @ready(new sigchainErrors.ErrorSigchainNotRunning()) public async getChainData(tran?: DBTransaction): Promise { if (tran == null) { - return this.withTransactionF((tran) => this.getChainData(tran)); + return this.db.withTransactionF((tran) => this.getChainData(tran)); } const chainData: ChainDataEncoded = {}; const readIterator = tran.iterator( @@ -308,7 +296,9 @@ class Sigchain { tran?: DBTransaction, ): Promise> { if (tran == null) { - return this.withTransactionF((tran) => this.getClaims(claimType, tran)); + return this.db.withTransactionF((tran) => + this.getClaims(claimType, tran), + ); } const relevantClaims: Array = []; const readIterator = tran.iterator( @@ -374,7 +364,7 @@ class Sigchain { tran?: DBTransaction, ): Promise { if (tran == null) { - return this.withTransactionF((tran) => this.getClaim(claimId, tran)); + return this.db.withTransactionF((tran) => this.getClaim(claimId, tran)); } const claim = await tran.get([ ...this.sigchainClaimsDbPath, @@ -391,7 +381,7 @@ class Sigchain { tran?: DBTransaction, ): Promise> { if (tran == null) { - return this.withTransactionF((tran) => this.getSeqMap(tran)); + return this.db.withTransactionF((tran) => this.getSeqMap(tran)); } const map: Record = {}; const claimStream = tran.iterator(this.sigchainClaimsDbPath, { diff --git a/tests/acl/ACL.test.ts b/tests/acl/ACL.test.ts index 45e1b8bafc..f5f994b9e6 100644 --- a/tests/acl/ACL.test.ts +++ b/tests/acl/ACL.test.ts @@ -407,7 +407,7 @@ describe(ACL.name, () => { test('transactional operations', async () => { const acl = await ACL.createACL({ db, logger }); const p1 = acl.getNodePerms(); - const p2 = acl.withTransactionF(async (tran) => { + const p2 = db.withTransactionF(async (tran) => { await acl.setNodesPerm( [nodeIdG1First, nodeIdG1Second] as Array, { diff --git a/tests/sigchain/Sigchain.test.ts b/tests/sigchain/Sigchain.test.ts index 45da1b6653..01bb35d625 100644 --- a/tests/sigchain/Sigchain.test.ts +++ b/tests/sigchain/Sigchain.test.ts @@ -96,7 +96,7 @@ describe('Sigchain', () => { }); test('async start initialises the sequence number', async () => { const sigchain = await Sigchain.createSigchain({ keyManager, db, logger }); - const sequenceNumber = await sigchain.withTransactionF(async (tran) => + const sequenceNumber = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getSequenceNumber(tran), ); @@ -237,11 +237,11 @@ describe('Sigchain', () => { // Create a claim // Firstly, check that we can add an existing claim if it's the first claim // in the sigchain - const hPrev1 = await sigchain.withTransactionF(async (tran) => + const hPrev1 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getHashPrevious(tran), ); - const seq1 = await sigchain.withTransactionF(async (tran) => + const seq1 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getSequenceNumber(tran), ); @@ -259,11 +259,11 @@ describe('Sigchain', () => { kid: nodeIdAEncoded, }); await sigchain.addExistingClaim(claim1); - const hPrev2 = await sigchain.withTransactionF(async (tran) => + const hPrev2 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getHashPrevious(tran), ); - const seq2 = await sigchain.withTransactionF(async (tran) => + const seq2 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getSequenceNumber(tran), ); @@ -283,11 +283,11 @@ describe('Sigchain', () => { kid: nodeIdAEncoded, }); await sigchain.addExistingClaim(claim2); - const hPrev3 = await sigchain.withTransactionF(async (tran) => + const hPrev3 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getHashPrevious(tran), ); - const seq3 = await sigchain.withTransactionF(async (tran) => + const seq3 = await db.withTransactionF(async (tran) => // @ts-ignore - get protected method sigchain.getSequenceNumber(tran), );