diff --git a/packages/core/src/agent/Agent.ts b/packages/core/src/agent/Agent.ts index d1c446e6f1..2e016f9025 100644 --- a/packages/core/src/agent/Agent.ts +++ b/packages/core/src/agent/Agent.ts @@ -2,13 +2,13 @@ import type { Logger } from '../logger' import type { InboundTransport } from '../transport/InboundTransport' import type { OutboundTransport } from '../transport/OutboundTransport' import type { InitConfig } from '../types' -import type { Wallet } from '../wallet/Wallet' import type { AgentDependencies } from './AgentDependencies' import type { AgentMessageReceivedEvent } from './Events' import type { TransportSession } from './TransportService' import type { Subscription } from 'rxjs' import type { DependencyContainer } from 'tsyringe' +import { Subject } from 'rxjs' import { concatMap, takeUntil } from 'rxjs/operators' import { container as baseContainer } from 'tsyringe' @@ -42,6 +42,7 @@ import { WalletModule } from '../wallet/WalletModule' import { WalletError } from '../wallet/error' import { AgentConfig } from './AgentConfig' +import { AgentContext } from './AgentContext' import { Dispatcher } from './Dispatcher' import { EnvelopeService } from './EnvelopeService' import { EventEmitter } from './EventEmitter' @@ -60,8 +61,9 @@ export class Agent { protected messageSender: MessageSender private _isInitialized = false public messageSubscription: Subscription - private walletService: Wallet private routingService: RoutingService + private agentContext: AgentContext + private stop$ = new Subject() public readonly connections: ConnectionsModule public readonly proofs: ProofsModule @@ -112,8 +114,8 @@ export class Agent { this.messageSender = this.dependencyManager.resolve(MessageSender) this.messageReceiver = this.dependencyManager.resolve(MessageReceiver) this.transportService = this.dependencyManager.resolve(TransportService) - this.walletService = this.dependencyManager.resolve(InjectionSymbols.Wallet) this.routingService = this.dependencyManager.resolve(RoutingService) + this.agentContext = this.dependencyManager.resolve(AgentContext) // We set the modules in the constructor because that allows to set them as read-only this.connections = this.dependencyManager.resolve(ConnectionsModule) @@ -134,8 +136,12 @@ export class Agent { this.messageSubscription = this.eventEmitter .observable(AgentEventTypes.AgentMessageReceived) .pipe( - takeUntil(this.agentConfig.stop$), - concatMap((e) => this.messageReceiver.receiveMessage(e.payload.message, { connection: e.payload.connection })) + takeUntil(this.stop$), + concatMap((e) => + this.messageReceiver.receiveMessage(this.agentContext, e.payload.message, { + connection: e.payload.connection, + }) + ) ) .subscribe() } @@ -185,7 +191,7 @@ export class Agent { // Make sure the storage is up to date const storageUpdateService = this.dependencyManager.resolve(StorageUpdateService) - const isStorageUpToDate = await storageUpdateService.isUpToDate() + const isStorageUpToDate = await storageUpdateService.isUpToDate(this.agentContext) this.logger.info(`Agent storage is ${isStorageUpToDate ? '' : 'not '}up to date.`) if (!isStorageUpToDate && this.agentConfig.autoUpdateStorageOnStartup) { @@ -194,7 +200,7 @@ export class Agent { await updateAssistant.initialize() await updateAssistant.update() } else if (!isStorageUpToDate) { - const currentVersion = await storageUpdateService.getCurrentStorageVersion() + const currentVersion = await storageUpdateService.getCurrentStorageVersion(this.agentContext) // Close wallet to prevent un-initialized agent with initialized wallet await this.wallet.close() throw new AriesFrameworkError( @@ -208,9 +214,11 @@ export class Agent { if (publicDidSeed) { // If an agent has publicDid it will be used as routing key. - await this.walletService.initPublicDid({ seed: publicDidSeed }) + await this.agentContext.wallet.initPublicDid({ seed: publicDidSeed }) } + // set the pools on the ledger. + this.ledger.setPools(this.agentContext.config.indyLedgers) // As long as value isn't false we will async connect to all genesis pools on startup if (connectToIndyLedgersOnStartup) { this.ledger.connectToPools().catch((error) => { @@ -243,7 +251,7 @@ export class Agent { public async shutdown() { // All observables use takeUntil with the stop$ observable // this means all observables will stop running if a value is emitted on this observable - this.agentConfig.stop$.next(true) + this.stop$.next(true) // Stop transports const allTransports = [...this.inboundTransports, ...this.outboundTransports] @@ -258,11 +266,11 @@ export class Agent { } public get publicDid() { - return this.walletService.publicDid + return this.agentContext.wallet.publicDid } public async receiveMessage(inboundMessage: unknown, session?: TransportSession) { - return await this.messageReceiver.receiveMessage(inboundMessage, { session }) + return await this.messageReceiver.receiveMessage(this.agentContext, inboundMessage, { session }) } public get injectionContainer() { @@ -273,6 +281,10 @@ export class Agent { return this.agentConfig } + public get context() { + return this.agentContext + } + private async getMediationConnection(mediatorInvitationUrl: string) { const outOfBandInvitation = this.oob.parseInvitation(mediatorInvitationUrl) const outOfBandRecord = await this.oob.findByInvitationId(outOfBandInvitation.id) @@ -281,7 +293,7 @@ export class Agent { if (!connection) { this.logger.debug('Mediation connection does not exist, creating connection') // We don't want to use the current default mediator when connecting to another mediator - const routing = await this.routingService.getRouting({ useDefaultMediator: false }) + const routing = await this.routingService.getRouting(this.agentContext, { useDefaultMediator: false }) this.logger.debug('Routing created', routing) const { connectionRecord: newConnection } = await this.oob.receiveInvitation(outOfBandInvitation, { @@ -303,7 +315,7 @@ export class Agent { } private registerDependencies(dependencyManager: DependencyManager) { - dependencyManager.registerInstance(AgentConfig, this.agentConfig) + const dependencies = this.agentConfig.agentDependencies // Register internal dependencies dependencyManager.registerSingleton(EventEmitter) @@ -318,11 +330,14 @@ export class Agent { dependencyManager.registerSingleton(StorageVersionRepository) dependencyManager.registerSingleton(StorageUpdateService) + dependencyManager.registerInstance(AgentConfig, this.agentConfig) + dependencyManager.registerInstance(InjectionSymbols.AgentDependencies, dependencies) + dependencyManager.registerInstance(InjectionSymbols.FileSystem, new dependencies.FileSystem()) + dependencyManager.registerInstance(InjectionSymbols.Stop$, this.stop$) + // Register possibly already defined services if (!dependencyManager.isRegistered(InjectionSymbols.Wallet)) { - this.dependencyManager.registerSingleton(IndyWallet) - const wallet = this.dependencyManager.resolve(IndyWallet) - dependencyManager.registerInstance(InjectionSymbols.Wallet, wallet) + dependencyManager.registerContextScoped(InjectionSymbols.Wallet, IndyWallet) } if (!dependencyManager.isRegistered(InjectionSymbols.Logger)) { dependencyManager.registerInstance(InjectionSymbols.Logger, this.logger) @@ -352,5 +367,7 @@ export class Agent { IndyModule, W3cVcModule ) + + dependencyManager.registerInstance(AgentContext, new AgentContext({ dependencyManager })) } } diff --git a/packages/core/src/agent/AgentConfig.ts b/packages/core/src/agent/AgentConfig.ts index bb8ca24b56..e43b17c183 100644 --- a/packages/core/src/agent/AgentConfig.ts +++ b/packages/core/src/agent/AgentConfig.ts @@ -1,10 +1,7 @@ import type { Logger } from '../logger' -import type { FileSystem } from '../storage/FileSystem' import type { InitConfig } from '../types' import type { AgentDependencies } from './AgentDependencies' -import { Subject } from 'rxjs' - import { DID_COMM_TRANSPORT_QUEUE } from '../constants' import { AriesFrameworkError } from '../error' import { ConsoleLogger, LogLevel } from '../logger' @@ -17,17 +14,12 @@ export class AgentConfig { public label: string public logger: Logger public readonly agentDependencies: AgentDependencies - public readonly fileSystem: FileSystem - - // $stop is used for agent shutdown signal - public readonly stop$ = new Subject() public constructor(initConfig: InitConfig, agentDependencies: AgentDependencies) { this.initConfig = initConfig this.label = initConfig.label this.logger = initConfig.logger ?? new ConsoleLogger(LogLevel.off) this.agentDependencies = agentDependencies - this.fileSystem = new agentDependencies.FileSystem() const { mediatorConnectionsInvite, clearDefaultMediator, defaultMediatorId } = this.initConfig diff --git a/packages/core/src/agent/AgentContext.ts b/packages/core/src/agent/AgentContext.ts new file mode 100644 index 0000000000..a8e176d67f --- /dev/null +++ b/packages/core/src/agent/AgentContext.ts @@ -0,0 +1,32 @@ +import type { DependencyManager } from '../plugins' +import type { Wallet } from '../wallet' + +import { InjectionSymbols } from '../constants' + +import { AgentConfig } from './AgentConfig' + +export class AgentContext { + /** + * Dependency manager holds all dependencies for the current context. Possibly a child of a parent dependency manager, + * in which case all singleton dependencies from the parent context are also available to this context. + */ + public readonly dependencyManager: DependencyManager + + public constructor({ dependencyManager }: { dependencyManager: DependencyManager }) { + this.dependencyManager = dependencyManager + } + + /** + * Convenience method to access the agent config for the current context. + */ + public get config() { + return this.dependencyManager.resolve(AgentConfig) + } + + /** + * Convenience method to access the wallet for the current context. + */ + public get wallet() { + return this.dependencyManager.resolve(InjectionSymbols.Wallet) + } +} diff --git a/packages/core/src/agent/Dispatcher.ts b/packages/core/src/agent/Dispatcher.ts index d659da8f44..e55a324f85 100644 --- a/packages/core/src/agent/Dispatcher.ts +++ b/packages/core/src/agent/Dispatcher.ts @@ -1,13 +1,13 @@ -import type { Logger } from '../logger' import type { OutboundMessage, OutboundServiceMessage } from '../types' import type { AgentMessage } from './AgentMessage' import type { AgentMessageProcessedEvent } from './Events' import type { Handler } from './Handler' import type { InboundMessageContext } from './models/InboundMessageContext' -import { AgentConfig } from '../agent/AgentConfig' +import { InjectionSymbols } from '../constants' import { AriesFrameworkError } from '../error/AriesFrameworkError' -import { injectable } from '../plugins' +import { Logger } from '../logger' +import { injectable, inject } from '../plugins' import { canHandleMessageType, parseMessageType } from '../utils/messageType' import { ProblemReportMessage } from './../modules/problem-reports/messages/ProblemReportMessage' @@ -23,10 +23,14 @@ class Dispatcher { private eventEmitter: EventEmitter private logger: Logger - public constructor(messageSender: MessageSender, eventEmitter: EventEmitter, agentConfig: AgentConfig) { + public constructor( + messageSender: MessageSender, + eventEmitter: EventEmitter, + @inject(InjectionSymbols.Logger) logger: Logger + ) { this.messageSender = messageSender this.eventEmitter = eventEmitter - this.logger = agentConfig.logger + this.logger = logger } public registerHandler(handler: Handler) { @@ -70,7 +74,7 @@ class Dispatcher { } if (outboundMessage && isOutboundServiceMessage(outboundMessage)) { - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(messageContext.agentContext, { message: outboundMessage.payload, service: outboundMessage.service, senderKey: outboundMessage.senderKey, @@ -78,11 +82,11 @@ class Dispatcher { }) } else if (outboundMessage) { outboundMessage.sessionId = messageContext.sessionId - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(messageContext.agentContext, outboundMessage) } // Emit event that allows to hook into received messages - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: AgentEventTypes.AgentMessageProcessed, payload: { message: messageContext.message, diff --git a/packages/core/src/agent/EnvelopeService.ts b/packages/core/src/agent/EnvelopeService.ts index de6e7e9e5b..d2ca8e4e51 100644 --- a/packages/core/src/agent/EnvelopeService.ts +++ b/packages/core/src/agent/EnvelopeService.ts @@ -1,14 +1,12 @@ -import type { Logger } from '../logger' import type { EncryptedMessage, PlaintextMessage } from '../types' +import type { AgentContext } from './AgentContext' import type { AgentMessage } from './AgentMessage' import { InjectionSymbols } from '../constants' import { Key, KeyType } from '../crypto' +import { Logger } from '../logger' import { ForwardMessage } from '../modules/routing/messages' import { inject, injectable } from '../plugins' -import { Wallet } from '../wallet/Wallet' - -import { AgentConfig } from './AgentConfig' export interface EnvelopeKeys { recipientKeys: Key[] @@ -18,28 +16,28 @@ export interface EnvelopeKeys { @injectable() export class EnvelopeService { - private wallet: Wallet private logger: Logger - private config: AgentConfig - public constructor(@inject(InjectionSymbols.Wallet) wallet: Wallet, agentConfig: AgentConfig) { - this.wallet = wallet - this.logger = agentConfig.logger - this.config = agentConfig + public constructor(@inject(InjectionSymbols.Logger) logger: Logger) { + this.logger = logger } - public async packMessage(payload: AgentMessage, keys: EnvelopeKeys): Promise { + public async packMessage( + agentContext: AgentContext, + payload: AgentMessage, + keys: EnvelopeKeys + ): Promise { const { recipientKeys, routingKeys, senderKey } = keys let recipientKeysBase58 = recipientKeys.map((key) => key.publicKeyBase58) const routingKeysBase58 = routingKeys.map((key) => key.publicKeyBase58) const senderKeyBase58 = senderKey && senderKey.publicKeyBase58 // pass whether we want to use legacy did sov prefix - const message = payload.toJSON({ useLegacyDidSovPrefix: this.config.useLegacyDidSovPrefix }) + const message = payload.toJSON({ useLegacyDidSovPrefix: agentContext.config.useLegacyDidSovPrefix }) this.logger.debug(`Pack outbound message ${message['@type']}`) - let encryptedMessage = await this.wallet.pack(message, recipientKeysBase58, senderKeyBase58 ?? undefined) + let encryptedMessage = await agentContext.wallet.pack(message, recipientKeysBase58, senderKeyBase58 ?? undefined) // If the message has routing keys (mediator) pack for each mediator for (const routingKeyBase58 of routingKeysBase58) { @@ -51,17 +49,20 @@ export class EnvelopeService { recipientKeysBase58 = [routingKeyBase58] this.logger.debug('Forward message created', forwardMessage) - const forwardJson = forwardMessage.toJSON({ useLegacyDidSovPrefix: this.config.useLegacyDidSovPrefix }) + const forwardJson = forwardMessage.toJSON({ useLegacyDidSovPrefix: agentContext.config.useLegacyDidSovPrefix }) // Forward messages are anon packed - encryptedMessage = await this.wallet.pack(forwardJson, [routingKeyBase58], undefined) + encryptedMessage = await agentContext.wallet.pack(forwardJson, [routingKeyBase58], undefined) } return encryptedMessage } - public async unpackMessage(encryptedMessage: EncryptedMessage): Promise { - const decryptedMessage = await this.wallet.unpack(encryptedMessage) + public async unpackMessage( + agentContext: AgentContext, + encryptedMessage: EncryptedMessage + ): Promise { + const decryptedMessage = await agentContext.wallet.unpack(encryptedMessage) const { recipientKey, senderKey, plaintextMessage } = decryptedMessage return { recipientKey: recipientKey ? Key.fromPublicKeyBase58(recipientKey, KeyType.Ed25519) : undefined, diff --git a/packages/core/src/agent/EventEmitter.ts b/packages/core/src/agent/EventEmitter.ts index 62caae137c..284dcc1709 100644 --- a/packages/core/src/agent/EventEmitter.ts +++ b/packages/core/src/agent/EventEmitter.ts @@ -1,24 +1,30 @@ +import type { AgentContext } from './AgentContext' import type { BaseEvent } from './Events' import type { EventEmitter as NativeEventEmitter } from 'events' -import { fromEventPattern } from 'rxjs' +import { fromEventPattern, Subject } from 'rxjs' import { takeUntil } from 'rxjs/operators' -import { injectable } from '../plugins' +import { InjectionSymbols } from '../constants' +import { injectable, inject } from '../plugins' -import { AgentConfig } from './AgentConfig' +import { AgentDependencies } from './AgentDependencies' @injectable() export class EventEmitter { - private agentConfig: AgentConfig private eventEmitter: NativeEventEmitter - - public constructor(agentConfig: AgentConfig) { - this.agentConfig = agentConfig - this.eventEmitter = new agentConfig.agentDependencies.EventEmitterClass() + private stop$: Subject + + public constructor( + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies, + @inject(InjectionSymbols.Stop$) stop$: Subject + ) { + this.eventEmitter = new agentDependencies.EventEmitterClass() + this.stop$ = stop$ } - public emit(data: T) { + // agentContext is currently not used, but already making required as it will be used soon + public emit(agentContext: AgentContext, data: T) { this.eventEmitter.emit(data.type, data) } @@ -34,6 +40,6 @@ export class EventEmitter { return fromEventPattern( (handler) => this.on(event, handler), (handler) => this.off(event, handler) - ).pipe(takeUntil(this.agentConfig.stop$)) + ).pipe(takeUntil(this.stop$)) } } diff --git a/packages/core/src/agent/MessageReceiver.ts b/packages/core/src/agent/MessageReceiver.ts index 78f18caca5..31282e0252 100644 --- a/packages/core/src/agent/MessageReceiver.ts +++ b/packages/core/src/agent/MessageReceiver.ts @@ -1,20 +1,21 @@ -import type { Logger } from '../logger' import type { ConnectionRecord } from '../modules/connections' import type { InboundTransport } from '../transport' -import type { PlaintextMessage, EncryptedMessage } from '../types' +import type { EncryptedMessage, PlaintextMessage } from '../types' +import type { AgentContext } from './AgentContext' import type { AgentMessage } from './AgentMessage' import type { DecryptedMessageContext } from './EnvelopeService' import type { TransportSession } from './TransportService' +import { InjectionSymbols } from '../constants' import { AriesFrameworkError } from '../error' -import { ConnectionsModule } from '../modules/connections' +import { Logger } from '../logger' +import { ConnectionService } from '../modules/connections' import { ProblemReportError, ProblemReportMessage, ProblemReportReason } from '../modules/problem-reports' -import { injectable } from '../plugins' +import { injectable, inject } from '../plugins' import { isValidJweStructure } from '../utils/JWE' import { JsonTransformer } from '../utils/JsonTransformer' import { canHandleMessageType, parseMessageType, replaceLegacyDidSovPrefixOnMessage } from '../utils/messageType' -import { AgentConfig } from './AgentConfig' import { Dispatcher } from './Dispatcher' import { EnvelopeService } from './EnvelopeService' import { MessageSender } from './MessageSender' @@ -24,30 +25,28 @@ import { InboundMessageContext } from './models/InboundMessageContext' @injectable() export class MessageReceiver { - private config: AgentConfig private envelopeService: EnvelopeService private transportService: TransportService private messageSender: MessageSender private dispatcher: Dispatcher private logger: Logger - private connectionsModule: ConnectionsModule + private connectionService: ConnectionService public readonly inboundTransports: InboundTransport[] = [] public constructor( - config: AgentConfig, envelopeService: EnvelopeService, transportService: TransportService, messageSender: MessageSender, - connectionsModule: ConnectionsModule, - dispatcher: Dispatcher + connectionService: ConnectionService, + dispatcher: Dispatcher, + @inject(InjectionSymbols.Logger) logger: Logger ) { - this.config = config this.envelopeService = envelopeService this.transportService = transportService this.messageSender = messageSender - this.connectionsModule = connectionsModule + this.connectionService = connectionService this.dispatcher = dispatcher - this.logger = this.config.logger + this.logger = logger } public registerInboundTransport(inboundTransport: InboundTransport) { @@ -61,27 +60,36 @@ export class MessageReceiver { * @param inboundMessage the message to receive and handle */ public async receiveMessage( + agentContext: AgentContext, inboundMessage: unknown, { session, connection }: { session?: TransportSession; connection?: ConnectionRecord } ) { - this.logger.debug(`Agent ${this.config.label} received message`) + this.logger.debug(`Agent ${agentContext.config.label} received message`) if (this.isEncryptedMessage(inboundMessage)) { - await this.receiveEncryptedMessage(inboundMessage as EncryptedMessage, session) + await this.receiveEncryptedMessage(agentContext, inboundMessage as EncryptedMessage, session) } else if (this.isPlaintextMessage(inboundMessage)) { - await this.receivePlaintextMessage(inboundMessage, connection) + await this.receivePlaintextMessage(agentContext, inboundMessage, connection) } else { throw new AriesFrameworkError('Unable to parse incoming message: unrecognized format') } } - private async receivePlaintextMessage(plaintextMessage: PlaintextMessage, connection?: ConnectionRecord) { - const message = await this.transformAndValidate(plaintextMessage) - const messageContext = new InboundMessageContext(message, { connection }) + private async receivePlaintextMessage( + agentContext: AgentContext, + plaintextMessage: PlaintextMessage, + connection?: ConnectionRecord + ) { + const message = await this.transformAndValidate(agentContext, plaintextMessage) + const messageContext = new InboundMessageContext(message, { connection, agentContext }) await this.dispatcher.dispatch(messageContext) } - private async receiveEncryptedMessage(encryptedMessage: EncryptedMessage, session?: TransportSession) { - const decryptedMessage = await this.decryptMessage(encryptedMessage) + private async receiveEncryptedMessage( + agentContext: AgentContext, + encryptedMessage: EncryptedMessage, + session?: TransportSession + ) { + const decryptedMessage = await this.decryptMessage(agentContext, encryptedMessage) const { plaintextMessage, senderKey, recipientKey } = decryptedMessage this.logger.info( @@ -89,9 +97,9 @@ export class MessageReceiver { plaintextMessage ) - const connection = await this.findConnectionByMessageKeys(decryptedMessage) + const connection = await this.findConnectionByMessageKeys(agentContext, decryptedMessage) - const message = await this.transformAndValidate(plaintextMessage, connection) + const message = await this.transformAndValidate(agentContext, plaintextMessage, connection) const messageContext = new InboundMessageContext(message, { // Only make the connection available in message context if the connection is ready @@ -100,6 +108,7 @@ export class MessageReceiver { connection: connection?.isReady ? connection : undefined, senderKey, recipientKey, + agentContext, }) // We want to save a session if there is a chance of returning outbound message via inbound transport. @@ -133,9 +142,12 @@ export class MessageReceiver { * * @param message the received inbound message to decrypt */ - private async decryptMessage(message: EncryptedMessage): Promise { + private async decryptMessage( + agentContext: AgentContext, + message: EncryptedMessage + ): Promise { try { - return await this.envelopeService.unpackMessage(message) + return await this.envelopeService.unpackMessage(agentContext, message) } catch (error) { this.logger.error('Error while decrypting message', { error, @@ -160,6 +172,7 @@ export class MessageReceiver { } private async transformAndValidate( + agentContext: AgentContext, plaintextMessage: PlaintextMessage, connection?: ConnectionRecord | null ): Promise { @@ -167,21 +180,21 @@ export class MessageReceiver { try { message = await this.transformMessage(plaintextMessage) } catch (error) { - if (connection) await this.sendProblemReportMessage(error.message, connection, plaintextMessage) + if (connection) await this.sendProblemReportMessage(agentContext, error.message, connection, plaintextMessage) throw error } return message } - private async findConnectionByMessageKeys({ - recipientKey, - senderKey, - }: DecryptedMessageContext): Promise { + private async findConnectionByMessageKeys( + agentContext: AgentContext, + { recipientKey, senderKey }: DecryptedMessageContext + ): Promise { // We only fetch connections that are sent in AuthCrypt mode if (!recipientKey || !senderKey) return null // Try to find the did records that holds the sender and recipient keys - return this.connectionsModule.findByKeys({ + return this.connectionService.findByKeys(agentContext, { senderKey, recipientKey, }) @@ -228,6 +241,7 @@ export class MessageReceiver { * @param plaintextMessage received inbound message */ private async sendProblemReportMessage( + agentContext: AgentContext, message: string, connection: ConnectionRecord, plaintextMessage: PlaintextMessage @@ -247,7 +261,7 @@ export class MessageReceiver { }) const outboundMessage = createOutboundMessage(connection, problemReportMessage) if (outboundMessage) { - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(agentContext, outboundMessage) } } } diff --git a/packages/core/src/agent/MessageSender.ts b/packages/core/src/agent/MessageSender.ts index 389be6eca9..e763afb2b6 100644 --- a/packages/core/src/agent/MessageSender.ts +++ b/packages/core/src/agent/MessageSender.ts @@ -4,6 +4,7 @@ import type { DidDocument } from '../modules/dids' import type { OutOfBandRecord } from '../modules/oob/repository' import type { OutboundTransport } from '../transport/OutboundTransport' import type { OutboundMessage, OutboundPackage, EncryptedMessage } from '../types' +import type { AgentContext } from './AgentContext' import type { AgentMessage } from './AgentMessage' import type { EnvelopeKeys } from './EnvelopeService' import type { TransportSession } from './TransportService' @@ -65,16 +66,19 @@ export class MessageSender { this.outboundTransports.push(outboundTransport) } - public async packMessage({ - keys, - message, - endpoint, - }: { - keys: EnvelopeKeys - message: AgentMessage - endpoint: string - }): Promise { - const encryptedMessage = await this.envelopeService.packMessage(message, keys) + public async packMessage( + agentContext: AgentContext, + { + keys, + message, + endpoint, + }: { + keys: EnvelopeKeys + message: AgentMessage + endpoint: string + } + ): Promise { + const encryptedMessage = await this.envelopeService.packMessage(agentContext, message, keys) return { payload: encryptedMessage, @@ -83,24 +87,27 @@ export class MessageSender { } } - private async sendMessageToSession(session: TransportSession, message: AgentMessage) { + private async sendMessageToSession(agentContext: AgentContext, session: TransportSession, message: AgentMessage) { this.logger.debug(`Existing ${session.type} transport session has been found.`) if (!session.keys) { throw new AriesFrameworkError(`There are no keys for the given ${session.type} transport session.`) } - const encryptedMessage = await this.envelopeService.packMessage(message, session.keys) + const encryptedMessage = await this.envelopeService.packMessage(agentContext, message, session.keys) await session.send(encryptedMessage) } - public async sendPackage({ - connection, - encryptedMessage, - options, - }: { - connection: ConnectionRecord - encryptedMessage: EncryptedMessage - options?: { transportPriority?: TransportPriorityOptions } - }) { + public async sendPackage( + agentContext: AgentContext, + { + connection, + encryptedMessage, + options, + }: { + connection: ConnectionRecord + encryptedMessage: EncryptedMessage + options?: { transportPriority?: TransportPriorityOptions } + } + ) { const errors: Error[] = [] // Try to send to already open session @@ -116,7 +123,11 @@ export class MessageSender { } // Retrieve DIDComm services - const { services, queueService } = await this.retrieveServicesByConnection(connection, options?.transportPriority) + const { services, queueService } = await this.retrieveServicesByConnection( + agentContext, + connection, + options?.transportPriority + ) if (this.outboundTransports.length === 0 && !queueService) { throw new AriesFrameworkError('Agent has no outbound transport!') @@ -167,6 +178,7 @@ export class MessageSender { } public async sendMessage( + agentContext: AgentContext, outboundMessage: OutboundMessage, options?: { transportPriority?: TransportPriorityOptions @@ -193,7 +205,7 @@ export class MessageSender { if (session?.inboundMessage?.hasReturnRouting(payload.threadId)) { this.logger.debug(`Found session with return routing for message '${payload.id}' (connection '${connection.id}'`) try { - await this.sendMessageToSession(session, payload) + await this.sendMessageToSession(agentContext, session, payload) return } catch (error) { errors.push(error) @@ -203,6 +215,7 @@ export class MessageSender { // Retrieve DIDComm services const { services, queueService } = await this.retrieveServicesByConnection( + agentContext, connection, options?.transportPriority, outOfBand @@ -215,7 +228,7 @@ export class MessageSender { ) } - const ourDidDocument = await this.didResolverService.resolveDidDocument(connection.did) + const ourDidDocument = await this.didResolverService.resolveDidDocument(agentContext, connection.did) const ourAuthenticationKeys = getAuthenticationKeys(ourDidDocument) // TODO We're selecting just the first authentication key. Is it ok? @@ -234,7 +247,7 @@ export class MessageSender { for await (const service of services) { try { // Enable return routing if the our did document does not have any inbound endpoint for given sender key - await this.sendMessageToService({ + await this.sendMessageToService(agentContext, { message: payload, service, senderKey: firstOurAuthenticationKey, @@ -265,7 +278,7 @@ export class MessageSender { senderKey: firstOurAuthenticationKey, } - const encryptedMessage = await this.envelopeService.packMessage(payload, keys) + const encryptedMessage = await this.envelopeService.packMessage(agentContext, payload, keys) this.messageRepository.add(connection.id, encryptedMessage) return } @@ -279,19 +292,22 @@ export class MessageSender { throw new AriesFrameworkError(`Message is undeliverable to connection ${connection.id} (${connection.theirLabel})`) } - public async sendMessageToService({ - message, - service, - senderKey, - returnRoute, - connectionId, - }: { - message: AgentMessage - service: ResolvedDidCommService - senderKey: Key - returnRoute?: boolean - connectionId?: string - }) { + public async sendMessageToService( + agentContext: AgentContext, + { + message, + service, + senderKey, + returnRoute, + connectionId, + }: { + message: AgentMessage + service: ResolvedDidCommService + senderKey: Key + returnRoute?: boolean + connectionId?: string + } + ) { if (this.outboundTransports.length === 0) { throw new AriesFrameworkError('Agent has no outbound transport!') } @@ -326,7 +342,7 @@ export class MessageSender { throw error } - const outboundPackage = await this.packMessage({ message, keys, endpoint: service.serviceEndpoint }) + const outboundPackage = await this.packMessage(agentContext, { message, keys, endpoint: service.serviceEndpoint }) outboundPackage.endpoint = service.serviceEndpoint outboundPackage.connectionId = connectionId for (const transport of this.outboundTransports) { @@ -341,9 +357,9 @@ export class MessageSender { throw new AriesFrameworkError(`Unable to send message to service: ${service.serviceEndpoint}`) } - private async retrieveServicesFromDid(did: string) { + private async retrieveServicesFromDid(agentContext: AgentContext, did: string) { this.logger.debug(`Resolving services for did ${did}.`) - const didDocument = await this.didResolverService.resolveDidDocument(did) + const didDocument = await this.didResolverService.resolveDidDocument(agentContext, did) const didCommServices: ResolvedDidCommService[] = [] @@ -362,7 +378,7 @@ export class MessageSender { // Resolve dids to DIDDocs to retrieve routingKeys const routingKeys = [] for (const routingKey of didCommService.routingKeys ?? []) { - const routingDidDocument = await this.didResolverService.resolveDidDocument(routingKey) + const routingDidDocument = await this.didResolverService.resolveDidDocument(agentContext, routingKey) routingKeys.push(keyReferenceToKey(routingDidDocument, routingKey)) } @@ -385,6 +401,7 @@ export class MessageSender { } private async retrieveServicesByConnection( + agentContext: AgentContext, connection: ConnectionRecord, transportPriority?: TransportPriorityOptions, outOfBand?: OutOfBandRecord @@ -398,14 +415,14 @@ export class MessageSender { if (connection.theirDid) { this.logger.debug(`Resolving services for connection theirDid ${connection.theirDid}.`) - didCommServices = await this.retrieveServicesFromDid(connection.theirDid) + didCommServices = await this.retrieveServicesFromDid(agentContext, connection.theirDid) } else if (outOfBand) { this.logger.debug(`Resolving services from out-of-band record ${outOfBand?.id}.`) if (connection.isRequester) { for (const service of outOfBand.outOfBandInvitation.services) { // Resolve dids to DIDDocs to retrieve services if (typeof service === 'string') { - didCommServices = await this.retrieveServicesFromDid(service) + didCommServices = await this.retrieveServicesFromDid(agentContext, service) } else { // Out of band inline service contains keys encoded as did:key references didCommServices.push({ diff --git a/packages/core/src/agent/__tests__/Agent.test.ts b/packages/core/src/agent/__tests__/Agent.test.ts index 653066b9fe..558f267e14 100644 --- a/packages/core/src/agent/__tests__/Agent.test.ts +++ b/packages/core/src/agent/__tests__/Agent.test.ts @@ -1,5 +1,3 @@ -import type { Wallet } from '../../wallet/Wallet' - import { getBaseConfig } from '../../../tests/helpers' import { InjectionSymbols } from '../../constants' import { BasicMessageRepository, BasicMessageService } from '../../modules/basic-messages' @@ -23,7 +21,6 @@ import { } from '../../modules/routing' import { InMemoryMessageRepository } from '../../storage/InMemoryMessageRepository' import { IndyStorageService } from '../../storage/IndyStorageService' -import { IndyWallet } from '../../wallet/IndyWallet' import { WalletError } from '../../wallet/error' import { Agent } from '../Agent' import { Dispatcher } from '../Dispatcher' @@ -38,7 +35,7 @@ describe('Agent', () => { let agent: Agent afterEach(async () => { - const wallet = agent.dependencyManager.resolve(InjectionSymbols.Wallet) + const wallet = agent.context.wallet if (wallet.isInitialized) { await wallet.delete() @@ -59,7 +56,7 @@ describe('Agent', () => { expect.assertions(4) agent = new Agent(config, dependencies) - const wallet = agent.dependencyManager.resolve(InjectionSymbols.Wallet) + const wallet = agent.context.wallet expect(agent.isInitialized).toBe(false) expect(wallet.isInitialized).toBe(false) @@ -139,7 +136,6 @@ describe('Agent', () => { expect(container.resolve(IndyLedgerService)).toBeInstanceOf(IndyLedgerService) // Symbols, interface based - expect(container.resolve(InjectionSymbols.Wallet)).toBeInstanceOf(IndyWallet) expect(container.resolve(InjectionSymbols.Logger)).toBe(config.logger) expect(container.resolve(InjectionSymbols.MessageRepository)).toBeInstanceOf(InMemoryMessageRepository) expect(container.resolve(InjectionSymbols.StorageService)).toBeInstanceOf(IndyStorageService) @@ -182,7 +178,6 @@ describe('Agent', () => { expect(container.resolve(IndyLedgerService)).toBe(container.resolve(IndyLedgerService)) // Symbols, interface based - expect(container.resolve(InjectionSymbols.Wallet)).toBe(container.resolve(InjectionSymbols.Wallet)) expect(container.resolve(InjectionSymbols.Logger)).toBe(container.resolve(InjectionSymbols.Logger)) expect(container.resolve(InjectionSymbols.MessageRepository)).toBe( container.resolve(InjectionSymbols.MessageRepository) diff --git a/packages/core/src/agent/__tests__/Dispatcher.test.ts b/packages/core/src/agent/__tests__/Dispatcher.test.ts index ec5f60160f..5a735449c6 100644 --- a/packages/core/src/agent/__tests__/Dispatcher.test.ts +++ b/packages/core/src/agent/__tests__/Dispatcher.test.ts @@ -1,6 +1,8 @@ import type { Handler } from '../Handler' -import { getAgentConfig } from '../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext } from '../../../tests/helpers' import { parseMessageType } from '../../utils/messageType' import { AgentMessage } from '../AgentMessage' import { Dispatcher } from '../Dispatcher' @@ -48,8 +50,9 @@ class TestHandler implements Handler { describe('Dispatcher', () => { const agentConfig = getAgentConfig('DispatcherTest') + const agentContext = getAgentContext() const MessageSenderMock = MessageSender as jest.Mock - const eventEmitter = new EventEmitter(agentConfig) + const eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) const fakeProtocolHandler = new TestHandler([CustomProtocolMessage]) const connectionHandler = new TestHandler([ ConnectionInvitationTestMessage, @@ -57,7 +60,7 @@ describe('Dispatcher', () => { ConnectionResponseTestMessage, ]) - const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig) + const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig.logger) dispatcher.registerHandler(connectionHandler) dispatcher.registerHandler(new TestHandler([NotificationAckTestMessage])) @@ -138,9 +141,9 @@ describe('Dispatcher', () => { describe('dispatch()', () => { it('calls the handle method of the handler', async () => { - const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig) + const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig.logger) const customProtocolMessage = new CustomProtocolMessage() - const inboundMessageContext = new InboundMessageContext(customProtocolMessage) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() dispatcher.registerHandler({ supportedMessages: [CustomProtocolMessage], handle: mockHandle }) @@ -151,9 +154,9 @@ describe('Dispatcher', () => { }) it('throws an error if no handler for the message could be found', async () => { - const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig) + const dispatcher = new Dispatcher(new MessageSenderMock(), eventEmitter, agentConfig.logger) const customProtocolMessage = new CustomProtocolMessage() - const inboundMessageContext = new InboundMessageContext(customProtocolMessage) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() dispatcher.registerHandler({ supportedMessages: [], handle: mockHandle }) diff --git a/packages/core/src/agent/__tests__/MessageSender.test.ts b/packages/core/src/agent/__tests__/MessageSender.test.ts index d7158a9f47..96adad3bdd 100644 --- a/packages/core/src/agent/__tests__/MessageSender.test.ts +++ b/packages/core/src/agent/__tests__/MessageSender.test.ts @@ -6,7 +6,7 @@ import type { OutboundMessage, EncryptedMessage } from '../../types' import type { ResolvedDidCommService } from '../MessageSender' import { TestMessage } from '../../../tests/TestMessage' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../tests/helpers' +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../tests/helpers' import testLogger from '../../../tests/logger' import { Key, KeyType } from '../../crypto' import { ReturnRouteTypes } from '../../decorators/transport/TransportDecorator' @@ -117,6 +117,8 @@ describe('MessageSender', () => { let messageRepository: MessageRepository let connection: ConnectionRecord let outboundMessage: OutboundMessage + const agentConfig = getAgentConfig('MessageSender') + const agentContext = getAgentContext() describe('sendMessage', () => { beforeEach(() => { @@ -124,7 +126,7 @@ describe('MessageSender', () => { DidResolverServiceMock.mockClear() outboundTransport = new DummyHttpOutboundTransport() - messageRepository = new InMemoryMessageRepository(getAgentConfig('MessageSender')) + messageRepository = new InMemoryMessageRepository(agentConfig.logger) messageSender = new MessageSender( enveloperService, transportService, @@ -154,7 +156,9 @@ describe('MessageSender', () => { }) test('throw error when there is no outbound transport', async () => { - await expect(messageSender.sendMessage(outboundMessage)).rejects.toThrow(/Message is undeliverable to connection/) + await expect(messageSender.sendMessage(agentContext, outboundMessage)).rejects.toThrow( + /Message is undeliverable to connection/ + ) }) test('throw error when there is no service or queue', async () => { @@ -162,7 +166,7 @@ describe('MessageSender', () => { didResolverServiceResolveMock.mockResolvedValue(getMockDidDocument({ service: [] })) - await expect(messageSender.sendMessage(outboundMessage)).rejects.toThrow( + await expect(messageSender.sendMessage(agentContext, outboundMessage)).rejects.toThrow( `Message is undeliverable to connection test-123 (Test 123)` ) }) @@ -175,7 +179,7 @@ describe('MessageSender', () => { messageSender.registerOutboundTransport(outboundTransport) const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') - await messageSender.sendMessage(outboundMessage) + await messageSender.sendMessage(agentContext, outboundMessage) expect(sendMessageSpy).toHaveBeenCalledWith({ connectionId: 'test-123', @@ -191,9 +195,9 @@ describe('MessageSender', () => { const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') - await messageSender.sendMessage(outboundMessage) + await messageSender.sendMessage(agentContext, outboundMessage) - expect(didResolverServiceResolveMock).toHaveBeenCalledWith(connection.theirDid) + expect(didResolverServiceResolveMock).toHaveBeenCalledWith(agentContext, connection.theirDid) expect(sendMessageSpy).toHaveBeenCalledWith({ connectionId: 'test-123', payload: encryptedMessage, @@ -210,7 +214,7 @@ describe('MessageSender', () => { new Error(`Unable to resolve did document for did '${connection.theirDid}': notFound`) ) - await expect(messageSender.sendMessage(outboundMessage)).rejects.toThrowError( + await expect(messageSender.sendMessage(agentContext, outboundMessage)).rejects.toThrowError( `Unable to resolve did document for did '${connection.theirDid}': notFound` ) }) @@ -222,7 +226,7 @@ describe('MessageSender', () => { messageSender.registerOutboundTransport(outboundTransport) const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') - await messageSender.sendMessage(outboundMessage) + await messageSender.sendMessage(agentContext, outboundMessage) expect(sendMessageSpy).toHaveBeenCalledWith({ connectionId: 'test-123', @@ -239,7 +243,7 @@ describe('MessageSender', () => { const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') const sendMessageToServiceSpy = jest.spyOn(messageSender, 'sendMessageToService') - await messageSender.sendMessage({ ...outboundMessage, sessionId: 'session-123' }) + await messageSender.sendMessage(agentContext, { ...outboundMessage, sessionId: 'session-123' }) expect(session.send).toHaveBeenCalledTimes(1) expect(session.send).toHaveBeenNthCalledWith(1, encryptedMessage) @@ -253,9 +257,9 @@ describe('MessageSender', () => { const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') const sendMessageToServiceSpy = jest.spyOn(messageSender, 'sendMessageToService') - await messageSender.sendMessage(outboundMessage) + await messageSender.sendMessage(agentContext, outboundMessage) - const [[sendMessage]] = sendMessageToServiceSpy.mock.calls + const [[, sendMessage]] = sendMessageToServiceSpy.mock.calls expect(sendMessage).toMatchObject({ connectionId: 'test-123', @@ -283,9 +287,9 @@ describe('MessageSender', () => { // Simulate the case when the first call fails sendMessageSpy.mockRejectedValueOnce(new Error()) - await messageSender.sendMessage(outboundMessage) + await messageSender.sendMessage(agentContext, outboundMessage) - const [, [sendMessage]] = sendMessageToServiceSpy.mock.calls + const [, [, sendMessage]] = sendMessageToServiceSpy.mock.calls expect(sendMessage).toMatchObject({ connectionId: 'test-123', message: outboundMessage.payload, @@ -306,7 +310,9 @@ describe('MessageSender', () => { test('throw error when message endpoint is not supported by outbound transport schemes', async () => { messageSender.registerOutboundTransport(new DummyWsOutboundTransport()) - await expect(messageSender.sendMessage(outboundMessage)).rejects.toThrow(/Message is undeliverable to connection/) + await expect(messageSender.sendMessage(agentContext, outboundMessage)).rejects.toThrow( + /Message is undeliverable to connection/ + ) }) }) @@ -324,7 +330,7 @@ describe('MessageSender', () => { messageSender = new MessageSender( enveloperService, transportService, - new InMemoryMessageRepository(getAgentConfig('MessageSenderTest')), + new InMemoryMessageRepository(agentConfig.logger), logger, didResolverService ) @@ -338,7 +344,7 @@ describe('MessageSender', () => { test('throws error when there is no outbound transport', async () => { await expect( - messageSender.sendMessageToService({ + messageSender.sendMessageToService(agentContext, { message: new TestMessage(), senderKey, service, @@ -350,7 +356,7 @@ describe('MessageSender', () => { messageSender.registerOutboundTransport(outboundTransport) const sendMessageSpy = jest.spyOn(outboundTransport, 'sendMessage') - await messageSender.sendMessageToService({ + await messageSender.sendMessageToService(agentContext, { message: new TestMessage(), senderKey, service, @@ -371,7 +377,7 @@ describe('MessageSender', () => { const message = new TestMessage() message.setReturnRouting(ReturnRouteTypes.all) - await messageSender.sendMessageToService({ + await messageSender.sendMessageToService(agentContext, { message, senderKey, service, @@ -388,7 +394,7 @@ describe('MessageSender', () => { test('throw error when message endpoint is not supported by outbound transport schemes', async () => { messageSender.registerOutboundTransport(new DummyWsOutboundTransport()) await expect( - messageSender.sendMessageToService({ + messageSender.sendMessageToService(agentContext, { message: new TestMessage(), senderKey, service, @@ -400,7 +406,7 @@ describe('MessageSender', () => { describe('packMessage', () => { beforeEach(() => { outboundTransport = new DummyHttpOutboundTransport() - messageRepository = new InMemoryMessageRepository(getAgentConfig('PackMessage')) + messageRepository = new InMemoryMessageRepository(agentConfig.logger) messageSender = new MessageSender( enveloperService, transportService, @@ -426,7 +432,7 @@ describe('MessageSender', () => { routingKeys: [], senderKey: senderKey, } - const result = await messageSender.packMessage({ message, keys, endpoint }) + const result = await messageSender.packMessage(agentContext, { message, keys, endpoint }) expect(result).toEqual({ payload: encryptedMessage, diff --git a/packages/core/src/agent/index.ts b/packages/core/src/agent/index.ts new file mode 100644 index 0000000000..615455eb43 --- /dev/null +++ b/packages/core/src/agent/index.ts @@ -0,0 +1 @@ +export * from './AgentContext' diff --git a/packages/core/src/agent/models/InboundMessageContext.ts b/packages/core/src/agent/models/InboundMessageContext.ts index be7e1d4eb9..a31d7a8614 100644 --- a/packages/core/src/agent/models/InboundMessageContext.ts +++ b/packages/core/src/agent/models/InboundMessageContext.ts @@ -1,5 +1,6 @@ import type { Key } from '../../crypto' import type { ConnectionRecord } from '../../modules/connections' +import type { AgentContext } from '../AgentContext' import type { AgentMessage } from '../AgentMessage' import { AriesFrameworkError } from '../../error' @@ -9,6 +10,7 @@ export interface MessageContextParams { sessionId?: string senderKey?: Key recipientKey?: Key + agentContext: AgentContext } export class InboundMessageContext { @@ -17,13 +19,15 @@ export class InboundMessageContext { public sessionId?: string public senderKey?: Key public recipientKey?: Key + public readonly agentContext: AgentContext - public constructor(message: T, context: MessageContextParams = {}) { + public constructor(message: T, context: MessageContextParams) { this.message = message this.recipientKey = context.recipientKey this.senderKey = context.senderKey this.connection = context.connection this.sessionId = context.sessionId + this.agentContext = context.agentContext } /** diff --git a/packages/core/src/cache/PersistedLruCache.ts b/packages/core/src/cache/PersistedLruCache.ts index d680a2dca1..ab00e0d14e 100644 --- a/packages/core/src/cache/PersistedLruCache.ts +++ b/packages/core/src/cache/PersistedLruCache.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../agent' import type { CacheRepository } from './CacheRepository' import { LRUMap } from 'lru_map' @@ -16,22 +17,22 @@ export class PersistedLruCache { this.cacheRepository = cacheRepository } - public async get(key: string) { - const cache = await this.getCache() + public async get(agentContext: AgentContext, key: string) { + const cache = await this.getCache(agentContext) return cache.get(key) } - public async set(key: string, value: CacheValue) { - const cache = await this.getCache() + public async set(agentContext: AgentContext, key: string, value: CacheValue) { + const cache = await this.getCache(agentContext) cache.set(key, value) - await this.persistCache() + await this.persistCache(agentContext) } - private async getCache() { + private async getCache(agentContext: AgentContext) { if (!this._cache) { - const cacheRecord = await this.fetchCacheRecord() + const cacheRecord = await this.fetchCacheRecord(agentContext) this._cache = this.lruFromRecord(cacheRecord) } @@ -45,8 +46,8 @@ export class PersistedLruCache { ) } - private async fetchCacheRecord() { - let cacheRecord = await this.cacheRepository.findById(this.cacheId) + private async fetchCacheRecord(agentContext: AgentContext) { + let cacheRecord = await this.cacheRepository.findById(agentContext, this.cacheId) if (!cacheRecord) { cacheRecord = new CacheRecord({ @@ -54,16 +55,17 @@ export class PersistedLruCache { entries: [], }) - await this.cacheRepository.save(cacheRecord) + await this.cacheRepository.save(agentContext, cacheRecord) } return cacheRecord } - private async persistCache() { - const cache = await this.getCache() + private async persistCache(agentContext: AgentContext) { + const cache = await this.getCache(agentContext) await this.cacheRepository.update( + agentContext, new CacheRecord({ entries: cache.toJSON(), id: this.cacheId, diff --git a/packages/core/src/cache/__tests__/PersistedLruCache.test.ts b/packages/core/src/cache/__tests__/PersistedLruCache.test.ts index dc75ce6c1f..c7b893108d 100644 --- a/packages/core/src/cache/__tests__/PersistedLruCache.test.ts +++ b/packages/core/src/cache/__tests__/PersistedLruCache.test.ts @@ -1,4 +1,4 @@ -import { mockFunction } from '../../../tests/helpers' +import { getAgentContext, mockFunction } from '../../../tests/helpers' import { CacheRecord } from '../CacheRecord' import { CacheRepository } from '../CacheRepository' import { PersistedLruCache } from '../PersistedLruCache' @@ -6,6 +6,8 @@ import { PersistedLruCache } from '../PersistedLruCache' jest.mock('../CacheRepository') const CacheRepositoryMock = CacheRepository as jest.Mock +const agentContext = getAgentContext() + describe('PersistedLruCache', () => { let cacheRepository: CacheRepository let cache: PersistedLruCache @@ -30,42 +32,42 @@ describe('PersistedLruCache', () => { }) ) - expect(await cache.get('doesnotexist')).toBeUndefined() - expect(await cache.get('test')).toBe('somevalue') - expect(findMock).toHaveBeenCalledWith('cacheId') + expect(await cache.get(agentContext, 'doesnotexist')).toBeUndefined() + expect(await cache.get(agentContext, 'test')).toBe('somevalue') + expect(findMock).toHaveBeenCalledWith(agentContext, 'cacheId') }) it('should set the value in the persisted record', async () => { const updateMock = mockFunction(cacheRepository.update).mockResolvedValue() - await cache.set('test', 'somevalue') - const [[cacheRecord]] = updateMock.mock.calls + await cache.set(agentContext, 'test', 'somevalue') + const [[, cacheRecord]] = updateMock.mock.calls expect(cacheRecord.entries.length).toBe(1) expect(cacheRecord.entries[0].key).toBe('test') expect(cacheRecord.entries[0].value).toBe('somevalue') - expect(await cache.get('test')).toBe('somevalue') + expect(await cache.get(agentContext, 'test')).toBe('somevalue') }) it('should remove least recently used entries if entries are added that exceed the limit', async () => { // Set first value in cache, resolves fine - await cache.set('one', 'valueone') - expect(await cache.get('one')).toBe('valueone') + await cache.set(agentContext, 'one', 'valueone') + expect(await cache.get(agentContext, 'one')).toBe('valueone') // Set two more entries in the cache. Third item // exceeds limit, so first item gets removed - await cache.set('two', 'valuetwo') - await cache.set('three', 'valuethree') - expect(await cache.get('one')).toBeUndefined() - expect(await cache.get('two')).toBe('valuetwo') - expect(await cache.get('three')).toBe('valuethree') + await cache.set(agentContext, 'two', 'valuetwo') + await cache.set(agentContext, 'three', 'valuethree') + expect(await cache.get(agentContext, 'one')).toBeUndefined() + expect(await cache.get(agentContext, 'two')).toBe('valuetwo') + expect(await cache.get(agentContext, 'three')).toBe('valuethree') // Get two from the cache, meaning three will be removed first now // because it is not recently used - await cache.get('two') - await cache.set('four', 'valuefour') - expect(await cache.get('three')).toBeUndefined() - expect(await cache.get('two')).toBe('valuetwo') + await cache.get(agentContext, 'two') + await cache.set(agentContext, 'four', 'valuefour') + expect(await cache.get(agentContext, 'three')).toBeUndefined() + expect(await cache.get(agentContext, 'two')).toBe('valuetwo') }) }) diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index 4b2eb6f0ea..9d7fdcbc61 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -1,8 +1,11 @@ export const InjectionSymbols = { - Wallet: Symbol('Wallet'), MessageRepository: Symbol('MessageRepository'), StorageService: Symbol('StorageService'), Logger: Symbol('Logger'), + AgentDependencies: Symbol('AgentDependencies'), + Stop$: Symbol('Stop$'), + FileSystem: Symbol('FileSystem'), + Wallet: Symbol('Wallet'), } export const DID_COMM_TRANSPORT_QUEUE = 'didcomm:transport/queue' diff --git a/packages/core/src/crypto/JwsService.ts b/packages/core/src/crypto/JwsService.ts index 8e631d4185..29a7f390e0 100644 --- a/packages/core/src/crypto/JwsService.ts +++ b/packages/core/src/crypto/JwsService.ts @@ -1,11 +1,10 @@ +import type { AgentContext } from '../agent' import type { Buffer } from '../utils' import type { Jws, JwsGeneralFormat } from './JwsTypes' -import { InjectionSymbols } from '../constants' import { AriesFrameworkError } from '../error' -import { inject, injectable } from '../plugins' +import { injectable } from '../plugins' import { JsonEncoder, TypedArrayEncoder } from '../utils' -import { Wallet } from '../wallet' import { WalletError } from '../wallet/error' import { Key } from './Key' @@ -18,19 +17,16 @@ const JWS_ALG = 'EdDSA' @injectable() export class JwsService { - private wallet: Wallet - - public constructor(@inject(InjectionSymbols.Wallet) wallet: Wallet) { - this.wallet = wallet - } - - public async createJws({ payload, verkey, header }: CreateJwsOptions): Promise { + public async createJws( + agentContext: AgentContext, + { payload, verkey, header }: CreateJwsOptions + ): Promise { const base64Payload = TypedArrayEncoder.toBase64URL(payload) const base64Protected = JsonEncoder.toBase64URL(this.buildProtected(verkey)) const key = Key.fromPublicKeyBase58(verkey, KeyType.Ed25519) const signature = TypedArrayEncoder.toBase64URL( - await this.wallet.sign({ data: TypedArrayEncoder.fromString(`${base64Protected}.${base64Payload}`), key }) + await agentContext.wallet.sign({ data: TypedArrayEncoder.fromString(`${base64Protected}.${base64Payload}`), key }) ) return { @@ -43,7 +39,7 @@ export class JwsService { /** * Verify a JWS */ - public async verifyJws({ jws, payload }: VerifyJwsOptions): Promise { + public async verifyJws(agentContext: AgentContext, { jws, payload }: VerifyJwsOptions): Promise { const base64Payload = TypedArrayEncoder.toBase64URL(payload) const signatures = 'signatures' in jws ? jws.signatures : [jws] @@ -71,7 +67,7 @@ export class JwsService { signerVerkeys.push(verkey) try { - const isValid = await this.wallet.verify({ key, data, signature }) + const isValid = await agentContext.wallet.verify({ key, data, signature }) if (!isValid) { return { diff --git a/packages/core/src/crypto/__tests__/JwsService.test.ts b/packages/core/src/crypto/__tests__/JwsService.test.ts index 87ced7bd95..d3371200ff 100644 --- a/packages/core/src/crypto/__tests__/JwsService.test.ts +++ b/packages/core/src/crypto/__tests__/JwsService.test.ts @@ -1,6 +1,7 @@ +import type { AgentContext } from '../../agent' import type { Wallet } from '@aries-framework/core' -import { getAgentConfig } from '../../../tests/helpers' +import { getAgentConfig, getAgentContext } from '../../../tests/helpers' import { DidKey } from '../../modules/dids' import { Buffer, JsonEncoder } from '../../utils' import { IndyWallet } from '../../wallet/IndyWallet' @@ -13,15 +14,19 @@ import * as didJwsz6Mkv from './__fixtures__/didJwsz6Mkv' describe('JwsService', () => { let wallet: Wallet + let agentContext: AgentContext let jwsService: JwsService beforeAll(async () => { const config = getAgentConfig('JwsService') - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) + agentContext = getAgentContext({ + wallet, + }) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) - jwsService = new JwsService(wallet) + jwsService = new JwsService() }) afterAll(async () => { @@ -36,7 +41,7 @@ describe('JwsService', () => { const key = Key.fromPublicKeyBase58(verkey, KeyType.Ed25519) const kid = new DidKey(key).did - const jws = await jwsService.createJws({ + const jws = await jwsService.createJws(agentContext, { payload, verkey, header: { kid }, @@ -50,7 +55,7 @@ describe('JwsService', () => { it('returns true if the jws signature matches the payload', async () => { const payload = JsonEncoder.toBuffer(didJwsz6Mkf.DATA_JSON) - const { isValid, signerVerkeys } = await jwsService.verifyJws({ + const { isValid, signerVerkeys } = await jwsService.verifyJws(agentContext, { payload, jws: didJwsz6Mkf.JWS_JSON, }) @@ -62,7 +67,7 @@ describe('JwsService', () => { it('returns all verkeys that signed the jws', async () => { const payload = JsonEncoder.toBuffer(didJwsz6Mkf.DATA_JSON) - const { isValid, signerVerkeys } = await jwsService.verifyJws({ + const { isValid, signerVerkeys } = await jwsService.verifyJws(agentContext, { payload, jws: { signatures: [didJwsz6Mkf.JWS_JSON, didJwsz6Mkv.JWS_JSON] }, }) @@ -74,7 +79,7 @@ describe('JwsService', () => { it('returns false if the jws signature does not match the payload', async () => { const payload = JsonEncoder.toBuffer({ ...didJwsz6Mkf.DATA_JSON, did: 'another_did' }) - const { isValid, signerVerkeys } = await jwsService.verifyJws({ + const { isValid, signerVerkeys } = await jwsService.verifyJws(agentContext, { payload, jws: didJwsz6Mkf.JWS_JSON, }) @@ -85,7 +90,7 @@ describe('JwsService', () => { it('throws an error if the jws signatures array does not contain a JWS', async () => { await expect( - jwsService.verifyJws({ + jwsService.verifyJws(agentContext, { payload: new Buffer([]), jws: { signatures: [] }, }) diff --git a/packages/core/src/decorators/signature/SignatureDecoratorUtils.test.ts b/packages/core/src/decorators/signature/SignatureDecoratorUtils.test.ts index 749332603f..0f216a372b 100644 --- a/packages/core/src/decorators/signature/SignatureDecoratorUtils.test.ts +++ b/packages/core/src/decorators/signature/SignatureDecoratorUtils.test.ts @@ -41,7 +41,7 @@ describe('Decorators | Signature | SignatureDecoratorUtils', () => { beforeAll(async () => { const config = getAgentConfig('SignatureDecoratorUtilsTest') - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) }) diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 036a70959b..f13ef7691c 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,6 +1,7 @@ // reflect-metadata used for class-transformer + class-validator import 'reflect-metadata' +export { AgentContext } from './agent/AgentContext' export { Agent } from './agent/Agent' export { EventEmitter } from './agent/EventEmitter' export { Handler, HandlerInboundMessage } from './agent/Handler' diff --git a/packages/core/src/modules/basic-messages/BasicMessagesModule.ts b/packages/core/src/modules/basic-messages/BasicMessagesModule.ts index 8d38643c4b..796ffa2334 100644 --- a/packages/core/src/modules/basic-messages/BasicMessagesModule.ts +++ b/packages/core/src/modules/basic-messages/BasicMessagesModule.ts @@ -1,6 +1,7 @@ import type { DependencyManager } from '../../plugins' import type { BasicMessageTags } from './repository/BasicMessageRecord' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' @@ -17,29 +18,32 @@ export class BasicMessagesModule { private basicMessageService: BasicMessageService private messageSender: MessageSender private connectionService: ConnectionService + private agentContext: AgentContext public constructor( dispatcher: Dispatcher, basicMessageService: BasicMessageService, messageSender: MessageSender, - connectionService: ConnectionService + connectionService: ConnectionService, + agentContext: AgentContext ) { this.basicMessageService = basicMessageService this.messageSender = messageSender this.connectionService = connectionService + this.agentContext = agentContext this.registerHandlers(dispatcher) } public async sendMessage(connectionId: string, message: string) { - const connection = await this.connectionService.getById(connectionId) + const connection = await this.connectionService.getById(this.agentContext, connectionId) - const basicMessage = await this.basicMessageService.createMessage(message, connection) + const basicMessage = await this.basicMessageService.createMessage(this.agentContext, message, connection) const outboundMessage = createOutboundMessage(connection, basicMessage) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) } public async findAllByQuery(query: Partial) { - return this.basicMessageService.findAllByQuery(query) + return this.basicMessageService.findAllByQuery(this.agentContext, query) } private registerHandlers(dispatcher: Dispatcher) { diff --git a/packages/core/src/modules/basic-messages/__tests__/BasicMessageService.test.ts b/packages/core/src/modules/basic-messages/__tests__/BasicMessageService.test.ts index 8b64f2e50c..ad2fbfa547 100644 --- a/packages/core/src/modules/basic-messages/__tests__/BasicMessageService.test.ts +++ b/packages/core/src/modules/basic-messages/__tests__/BasicMessageService.test.ts @@ -1,66 +1,69 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' -import type { StorageService } from '../../../storage/StorageService' -import type { BasicMessageStateChangedEvent } from '../BasicMessageEvents' - -import { getAgentConfig, getMockConnection } from '../../../../tests/helpers' +import { getAgentContext, getMockConnection } from '../../../../tests/helpers' import { EventEmitter } from '../../../agent/EventEmitter' import { InboundMessageContext } from '../../../agent/models/InboundMessageContext' -import { IndyStorageService } from '../../../storage/IndyStorageService' -import { Repository } from '../../../storage/Repository' -import { IndyWallet } from '../../../wallet/IndyWallet' -import { BasicMessageEventTypes } from '../BasicMessageEvents' import { BasicMessageRole } from '../BasicMessageRole' import { BasicMessage } from '../messages' import { BasicMessageRecord } from '../repository/BasicMessageRecord' +import { BasicMessageRepository } from '../repository/BasicMessageRepository' import { BasicMessageService } from '../services' +jest.mock('../repository/BasicMessageRepository') +const BasicMessageRepositoryMock = BasicMessageRepository as jest.Mock +const basicMessageRepository = new BasicMessageRepositoryMock() + +jest.mock('../../../agent/EventEmitter') +const EventEmitterMock = EventEmitter as jest.Mock +const eventEmitter = new EventEmitterMock() + +const agentContext = getAgentContext() + describe('BasicMessageService', () => { + let basicMessageService: BasicMessageService const mockConnectionRecord = getMockConnection({ id: 'd3849ac3-c981-455b-a1aa-a10bea6cead8', did: 'did:sov:C2SsBf5QUQpqSAQfhu3sd2', }) - let wallet: IndyWallet - let storageService: StorageService - let agentConfig: AgentConfig - - beforeAll(async () => { - agentConfig = getAgentConfig('BasicMessageServiceTest') - wallet = new IndyWallet(agentConfig) - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await wallet.createAndOpen(agentConfig.walletConfig!) - storageService = new IndyStorageService(wallet, agentConfig) + beforeEach(() => { + basicMessageService = new BasicMessageService(basicMessageRepository, eventEmitter) }) - afterAll(async () => { - await wallet.delete() - }) + describe('createMessage', () => { + it(`creates message and record, and emits message and basic message record`, async () => { + const message = await basicMessageService.createMessage(agentContext, 'hello', mockConnectionRecord) - describe('save', () => { - let basicMessageRepository: Repository - let basicMessageService: BasicMessageService - let eventEmitter: EventEmitter + expect(message.content).toBe('hello') - beforeEach(() => { - eventEmitter = new EventEmitter(agentConfig) - basicMessageRepository = new Repository(BasicMessageRecord, storageService, eventEmitter) - basicMessageService = new BasicMessageService(basicMessageRepository, eventEmitter) + expect(basicMessageRepository.save).toHaveBeenCalledWith(agentContext, expect.any(BasicMessageRecord)) + expect(eventEmitter.emit).toHaveBeenCalledWith(agentContext, { + type: 'BasicMessageStateChanged', + payload: { + basicMessageRecord: expect.objectContaining({ + connectionId: mockConnectionRecord.id, + id: expect.any(String), + sentTime: expect.any(String), + content: 'hello', + role: BasicMessageRole.Sender, + }), + message, + }, + }) }) + }) - it(`emits newMessage with message and basic message record`, async () => { - const eventListenerMock = jest.fn() - eventEmitter.on(BasicMessageEventTypes.BasicMessageStateChanged, eventListenerMock) - + describe('save', () => { + it(`stores record and emits message and basic message record`, async () => { const basicMessage = new BasicMessage({ id: '123', content: 'message', }) - const messageContext = new InboundMessageContext(basicMessage) + const messageContext = new InboundMessageContext(basicMessage, { agentContext }) await basicMessageService.save(messageContext, mockConnectionRecord) - expect(eventListenerMock).toHaveBeenCalledWith({ + expect(basicMessageRepository.save).toHaveBeenCalledWith(agentContext, expect.any(BasicMessageRecord)) + expect(eventEmitter.emit).toHaveBeenCalledWith(agentContext, { type: 'BasicMessageStateChanged', payload: { basicMessageRecord: expect.objectContaining({ diff --git a/packages/core/src/modules/basic-messages/services/BasicMessageService.ts b/packages/core/src/modules/basic-messages/services/BasicMessageService.ts index 749258deda..dff23b0f7e 100644 --- a/packages/core/src/modules/basic-messages/services/BasicMessageService.ts +++ b/packages/core/src/modules/basic-messages/services/BasicMessageService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../agent' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' import type { ConnectionRecord } from '../../connections/repository/ConnectionRecord' import type { BasicMessageStateChangedEvent } from '../BasicMessageEvents' @@ -21,8 +22,7 @@ export class BasicMessageService { this.eventEmitter = eventEmitter } - public async createMessage(message: string, connectionRecord: ConnectionRecord) { - connectionRecord.assertReady() + public async createMessage(agentContext: AgentContext, message: string, connectionRecord: ConnectionRecord) { const basicMessage = new BasicMessage({ content: message }) const basicMessageRecord = new BasicMessageRecord({ @@ -32,8 +32,8 @@ export class BasicMessageService { role: BasicMessageRole.Sender, }) - await this.basicMessageRepository.save(basicMessageRecord) - this.emitStateChangedEvent(basicMessageRecord, basicMessage) + await this.basicMessageRepository.save(agentContext, basicMessageRecord) + this.emitStateChangedEvent(agentContext, basicMessageRecord, basicMessage) return basicMessage } @@ -41,7 +41,7 @@ export class BasicMessageService { /** * @todo use connection from message context */ - public async save({ message }: InboundMessageContext, connection: ConnectionRecord) { + public async save({ message, agentContext }: InboundMessageContext, connection: ConnectionRecord) { const basicMessageRecord = new BasicMessageRecord({ sentTime: message.sentTime.toISOString(), content: message.content, @@ -49,19 +49,23 @@ export class BasicMessageService { role: BasicMessageRole.Receiver, }) - await this.basicMessageRepository.save(basicMessageRecord) - this.emitStateChangedEvent(basicMessageRecord, message) + await this.basicMessageRepository.save(agentContext, basicMessageRecord) + this.emitStateChangedEvent(agentContext, basicMessageRecord, message) } - private emitStateChangedEvent(basicMessageRecord: BasicMessageRecord, basicMessage: BasicMessage) { + private emitStateChangedEvent( + agentContext: AgentContext, + basicMessageRecord: BasicMessageRecord, + basicMessage: BasicMessage + ) { const clonedBasicMessageRecord = JsonTransformer.clone(basicMessageRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: BasicMessageEventTypes.BasicMessageStateChanged, payload: { message: basicMessage, basicMessageRecord: clonedBasicMessageRecord }, }) } - public async findAllByQuery(query: Partial) { - return this.basicMessageRepository.findByQuery(query) + public async findAllByQuery(agentContext: AgentContext, query: Partial) { + return this.basicMessageRepository.findByQuery(agentContext, query) } } diff --git a/packages/core/src/modules/connections/ConnectionsModule.ts b/packages/core/src/modules/connections/ConnectionsModule.ts index 8fc16406c2..d81459cdad 100644 --- a/packages/core/src/modules/connections/ConnectionsModule.ts +++ b/packages/core/src/modules/connections/ConnectionsModule.ts @@ -4,7 +4,7 @@ import type { OutOfBandRecord } from '../oob/repository' import type { ConnectionRecord } from './repository/ConnectionRecord' import type { Routing } from './services' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' @@ -34,7 +34,6 @@ import { TrustPingService } from './services/TrustPingService' @module() @injectable() export class ConnectionsModule { - private agentConfig: AgentConfig private didExchangeProtocol: DidExchangeProtocol private connectionService: ConnectionService private outOfBandService: OutOfBandService @@ -43,10 +42,10 @@ export class ConnectionsModule { private routingService: RoutingService private didRepository: DidRepository private didResolverService: DidResolverService + private agentContext: AgentContext public constructor( dispatcher: Dispatcher, - agentConfig: AgentConfig, didExchangeProtocol: DidExchangeProtocol, connectionService: ConnectionService, outOfBandService: OutOfBandService, @@ -54,9 +53,9 @@ export class ConnectionsModule { routingService: RoutingService, didRepository: DidRepository, didResolverService: DidResolverService, - messageSender: MessageSender + messageSender: MessageSender, + agentContext: AgentContext ) { - this.agentConfig = agentConfig this.didExchangeProtocol = didExchangeProtocol this.connectionService = connectionService this.outOfBandService = outOfBandService @@ -65,6 +64,8 @@ export class ConnectionsModule { this.didRepository = didRepository this.messageSender = messageSender this.didResolverService = didResolverService + this.agentContext = agentContext + this.registerHandlers(dispatcher) } @@ -81,18 +82,20 @@ export class ConnectionsModule { ) { const { protocol, label, alias, imageUrl, autoAcceptConnection } = config - const routing = config.routing || (await this.routingService.getRouting({ mediatorId: outOfBandRecord.mediatorId })) + const routing = + config.routing || + (await this.routingService.getRouting(this.agentContext, { mediatorId: outOfBandRecord.mediatorId })) let result if (protocol === HandshakeProtocol.DidExchange) { - result = await this.didExchangeProtocol.createRequest(outOfBandRecord, { + result = await this.didExchangeProtocol.createRequest(this.agentContext, outOfBandRecord, { label, alias, routing, autoAcceptConnection, }) } else if (protocol === HandshakeProtocol.Connections) { - result = await this.connectionService.createRequest(outOfBandRecord, { + result = await this.connectionService.createRequest(this.agentContext, outOfBandRecord, { label, alias, imageUrl, @@ -105,7 +108,7 @@ export class ConnectionsModule { const { message, connectionRecord } = result const outboundMessage = createOutboundMessage(connectionRecord, message, outOfBandRecord) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return connectionRecord } @@ -117,7 +120,7 @@ export class ConnectionsModule { * @returns connection record */ public async acceptRequest(connectionId: string): Promise { - const connectionRecord = await this.connectionService.findById(connectionId) + const connectionRecord = await this.connectionService.findById(this.agentContext, connectionId) if (!connectionRecord) { throw new AriesFrameworkError(`Connection record ${connectionId} not found.`) } @@ -125,21 +128,29 @@ export class ConnectionsModule { throw new AriesFrameworkError(`Connection record ${connectionId} does not have out-of-band record.`) } - const outOfBandRecord = await this.outOfBandService.findById(connectionRecord.outOfBandId) + const outOfBandRecord = await this.outOfBandService.findById(this.agentContext, connectionRecord.outOfBandId) if (!outOfBandRecord) { throw new AriesFrameworkError(`Out-of-band record ${connectionRecord.outOfBandId} not found.`) } let outboundMessage if (connectionRecord.protocol === HandshakeProtocol.DidExchange) { - const message = await this.didExchangeProtocol.createResponse(connectionRecord, outOfBandRecord) + const message = await this.didExchangeProtocol.createResponse( + this.agentContext, + connectionRecord, + outOfBandRecord + ) outboundMessage = createOutboundMessage(connectionRecord, message) } else { - const { message } = await this.connectionService.createResponse(connectionRecord, outOfBandRecord) + const { message } = await this.connectionService.createResponse( + this.agentContext, + connectionRecord, + outOfBandRecord + ) outboundMessage = createOutboundMessage(connectionRecord, message) } - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return connectionRecord } @@ -151,34 +162,38 @@ export class ConnectionsModule { * @returns connection record */ public async acceptResponse(connectionId: string): Promise { - const connectionRecord = await this.connectionService.getById(connectionId) + const connectionRecord = await this.connectionService.getById(this.agentContext, connectionId) let outboundMessage if (connectionRecord.protocol === HandshakeProtocol.DidExchange) { if (!connectionRecord.outOfBandId) { throw new AriesFrameworkError(`Connection ${connectionRecord.id} does not have outOfBandId!`) } - const outOfBandRecord = await this.outOfBandService.findById(connectionRecord.outOfBandId) + const outOfBandRecord = await this.outOfBandService.findById(this.agentContext, connectionRecord.outOfBandId) if (!outOfBandRecord) { throw new AriesFrameworkError( `OutOfBand record for connection ${connectionRecord.id} with outOfBandId ${connectionRecord.outOfBandId} not found!` ) } - const message = await this.didExchangeProtocol.createComplete(connectionRecord, outOfBandRecord) + const message = await this.didExchangeProtocol.createComplete( + this.agentContext, + connectionRecord, + outOfBandRecord + ) outboundMessage = createOutboundMessage(connectionRecord, message) } else { - const { message } = await this.connectionService.createTrustPing(connectionRecord, { + const { message } = await this.connectionService.createTrustPing(this.agentContext, connectionRecord, { responseRequested: false, }) outboundMessage = createOutboundMessage(connectionRecord, message) } - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return connectionRecord } public async returnWhenIsConnected(connectionId: string, options?: { timeoutMs: number }): Promise { - return this.connectionService.returnWhenIsConnected(connectionId, options?.timeoutMs) + return this.connectionService.returnWhenIsConnected(this.agentContext, connectionId, options?.timeoutMs) } /** @@ -187,7 +202,7 @@ export class ConnectionsModule { * @returns List containing all connection records */ public getAll() { - return this.connectionService.getAll() + return this.connectionService.getAll(this.agentContext) } /** @@ -199,7 +214,7 @@ export class ConnectionsModule { * */ public getById(connectionId: string): Promise { - return this.connectionService.getById(connectionId) + return this.connectionService.getById(this.agentContext, connectionId) } /** @@ -209,7 +224,7 @@ export class ConnectionsModule { * @returns The connection record or null if not found */ public findById(connectionId: string): Promise { - return this.connectionService.findById(connectionId) + return this.connectionService.findById(this.agentContext, connectionId) } /** @@ -218,31 +233,11 @@ export class ConnectionsModule { * @param connectionId the connection record id */ public async deleteById(connectionId: string) { - return this.connectionService.deleteById(connectionId) - } - - public async findByKeys({ senderKey, recipientKey }: { senderKey: Key; recipientKey: Key }) { - const theirDidRecord = await this.didRepository.findByRecipientKey(senderKey) - if (theirDidRecord) { - const ourDidRecord = await this.didRepository.findByRecipientKey(recipientKey) - if (ourDidRecord) { - const connectionRecord = await this.connectionService.findSingleByQuery({ - did: ourDidRecord.id, - theirDid: theirDidRecord.id, - }) - if (connectionRecord && connectionRecord.isReady) return connectionRecord - } - } - - this.agentConfig.logger.debug( - `No connection record found for encrypted message with recipient key ${recipientKey.fingerprint} and sender key ${senderKey.fingerprint}` - ) - - return null + return this.connectionService.deleteById(this.agentContext, connectionId) } public async findAllByOutOfBandId(outOfBandId: string) { - return this.connectionService.findAllByOutOfBandId(outOfBandId) + return this.connectionService.findAllByOutOfBandId(this.agentContext, outOfBandId) } /** @@ -254,21 +249,20 @@ export class ConnectionsModule { * @returns The connection record */ public getByThreadId(threadId: string): Promise { - return this.connectionService.getByThreadId(threadId) + return this.connectionService.getByThreadId(this.agentContext, threadId) } public async findByDid(did: string): Promise { - return this.connectionService.findByTheirDid(did) + return this.connectionService.findByTheirDid(this.agentContext, did) } public async findByInvitationDid(invitationDid: string): Promise { - return this.connectionService.findByInvitationDid(invitationDid) + return this.connectionService.findByInvitationDid(this.agentContext, invitationDid) } private registerHandlers(dispatcher: Dispatcher) { dispatcher.registerHandler( new ConnectionRequestHandler( - this.agentConfig, this.connectionService, this.outOfBandService, this.routingService, @@ -276,12 +270,7 @@ export class ConnectionsModule { ) ) dispatcher.registerHandler( - new ConnectionResponseHandler( - this.agentConfig, - this.connectionService, - this.outOfBandService, - this.didResolverService - ) + new ConnectionResponseHandler(this.connectionService, this.outOfBandService, this.didResolverService) ) dispatcher.registerHandler(new AckMessageHandler(this.connectionService)) dispatcher.registerHandler(new TrustPingMessageHandler(this.trustPingService, this.connectionService)) @@ -289,7 +278,6 @@ export class ConnectionsModule { dispatcher.registerHandler( new DidExchangeRequestHandler( - this.agentConfig, this.didExchangeProtocol, this.outOfBandService, this.routingService, @@ -299,7 +287,6 @@ export class ConnectionsModule { dispatcher.registerHandler( new DidExchangeResponseHandler( - this.agentConfig, this.didExchangeProtocol, this.outOfBandService, this.connectionService, diff --git a/packages/core/src/modules/connections/DidExchangeProtocol.ts b/packages/core/src/modules/connections/DidExchangeProtocol.ts index e5e8554a9f..a1a865ccd2 100644 --- a/packages/core/src/modules/connections/DidExchangeProtocol.ts +++ b/packages/core/src/modules/connections/DidExchangeProtocol.ts @@ -1,18 +1,19 @@ +import type { AgentContext } from '../../agent' import type { ResolvedDidCommService } from '../../agent/MessageSender' import type { InboundMessageContext } from '../../agent/models/InboundMessageContext' -import type { Logger } from '../../logger' import type { ParsedMessageType } from '../../utils/messageType' import type { OutOfBandDidCommService } from '../oob/domain/OutOfBandDidCommService' import type { OutOfBandRecord } from '../oob/repository' import type { ConnectionRecord } from './repository' import type { Routing } from './services/ConnectionService' -import { AgentConfig } from '../../agent/AgentConfig' +import { InjectionSymbols } from '../../constants' import { Key, KeyType } from '../../crypto' import { JwsService } from '../../crypto/JwsService' import { Attachment, AttachmentData } from '../../decorators/attachment/Attachment' import { AriesFrameworkError } from '../../error' -import { injectable } from '../../plugins' +import { Logger } from '../../logger' +import { inject, injectable } from '../../plugins' import { JsonEncoder } from '../../utils/JsonEncoder' import { JsonTransformer } from '../../utils/JsonTransformer' import { DidDocument } from '../dids' @@ -23,7 +24,7 @@ import { didKeyToInstanceOfKey } from '../dids/helpers' import { DidKey } from '../dids/methods/key/DidKey' import { getNumAlgoFromPeerDid, PeerDidNumAlgo } from '../dids/methods/peer/didPeer' import { didDocumentJsonToNumAlgo1Did } from '../dids/methods/peer/peerDidNumAlgo1' -import { DidRepository, DidRecord } from '../dids/repository' +import { DidRecord, DidRepository } from '../dids/repository' import { OutOfBandRole } from '../oob/domain/OutOfBandRole' import { OutOfBandState } from '../oob/domain/OutOfBandState' @@ -32,7 +33,7 @@ import { DidExchangeProblemReportError, DidExchangeProblemReportReason } from '. import { DidExchangeCompleteMessage } from './messages/DidExchangeCompleteMessage' import { DidExchangeRequestMessage } from './messages/DidExchangeRequestMessage' import { DidExchangeResponseMessage } from './messages/DidExchangeResponseMessage' -import { HandshakeProtocol, DidExchangeRole, DidExchangeState } from './models' +import { DidExchangeRole, DidExchangeState, HandshakeProtocol } from './models' import { ConnectionService } from './services' interface DidExchangeRequestParams { @@ -46,26 +47,25 @@ interface DidExchangeRequestParams { @injectable() export class DidExchangeProtocol { - private config: AgentConfig private connectionService: ConnectionService private jwsService: JwsService private didRepository: DidRepository private logger: Logger public constructor( - config: AgentConfig, connectionService: ConnectionService, didRepository: DidRepository, - jwsService: JwsService + jwsService: JwsService, + @inject(InjectionSymbols.Logger) logger: Logger ) { - this.config = config this.connectionService = connectionService this.didRepository = didRepository this.jwsService = jwsService - this.logger = config.logger + this.logger = logger } public async createRequest( + agentContext: AgentContext, outOfBandRecord: OutOfBandRecord, params: DidExchangeRequestParams ): Promise<{ message: DidExchangeRequestMessage; connectionRecord: ConnectionRecord }> { @@ -81,7 +81,7 @@ export class DidExchangeProtocol { // We take just the first one for now. const [invitationDid] = outOfBandInvitation.invitationDids - const connectionRecord = await this.connectionService.createConnection({ + const connectionRecord = await this.connectionService.createConnection(agentContext, { protocol: HandshakeProtocol.DidExchange, role: DidExchangeRole.Requester, alias, @@ -96,15 +96,17 @@ export class DidExchangeProtocol { DidExchangeStateMachine.assertCreateMessageState(DidExchangeRequestMessage.type, connectionRecord) // Create message - const label = params.label ?? this.config.label - const didDocument = await this.createPeerDidDoc(this.routingToServices(routing)) + const label = params.label ?? agentContext.config.label + const didDocument = await this.createPeerDidDoc(agentContext, this.routingToServices(routing)) const parentThreadId = outOfBandInvitation.id const message = new DidExchangeRequestMessage({ label, parentThreadId, did: didDocument.id, goal, goalCode }) // Create sign attachment containing didDoc if (getNumAlgoFromPeerDid(didDocument.id) === PeerDidNumAlgo.GenesisDoc) { - const didDocAttach = await this.createSignedAttachment(didDocument, [routing.recipientKey.publicKeyBase58]) + const didDocAttach = await this.createSignedAttachment(agentContext, didDocument, [ + routing.recipientKey.publicKeyBase58, + ]) message.didDoc = didDocAttach } @@ -115,7 +117,7 @@ export class DidExchangeProtocol { connectionRecord.autoAcceptConnection = autoAcceptConnection } - await this.updateState(DidExchangeRequestMessage.type, connectionRecord) + await this.updateState(agentContext, DidExchangeRequestMessage.type, connectionRecord) this.logger.debug(`Create message ${DidExchangeRequestMessage.type.messageTypeUri} end`, { connectionRecord, message, @@ -163,7 +165,7 @@ export class DidExchangeProtocol { ) } - const didDocument = await this.extractDidDocument(message) + const didDocument = await this.extractDidDocument(messageContext.agentContext, message) const didRecord = new DidRecord({ id: message.did, role: DidDocumentRole.Received, @@ -184,9 +186,9 @@ export class DidExchangeProtocol { didDocument: 'omitted...', }) - await this.didRepository.save(didRecord) + await this.didRepository.save(messageContext.agentContext, didRecord) - const connectionRecord = await this.connectionService.createConnection({ + const connectionRecord = await this.connectionService.createConnection(messageContext.agentContext, { protocol: HandshakeProtocol.DidExchange, role: DidExchangeRole.Responder, state: DidExchangeState.RequestReceived, @@ -198,12 +200,13 @@ export class DidExchangeProtocol { outOfBandId: outOfBandRecord.id, }) - await this.updateState(DidExchangeRequestMessage.type, connectionRecord) + await this.updateState(messageContext.agentContext, DidExchangeRequestMessage.type, connectionRecord) this.logger.debug(`Process message ${DidExchangeRequestMessage.type.messageTypeUri} end`, connectionRecord) return connectionRecord } public async createResponse( + agentContext: AgentContext, connectionRecord: ConnectionRecord, outOfBandRecord: OutOfBandRecord, routing?: Routing @@ -233,11 +236,12 @@ export class DidExchangeProtocol { })) } - const didDocument = await this.createPeerDidDoc(services) + const didDocument = await this.createPeerDidDoc(agentContext, services) const message = new DidExchangeResponseMessage({ did: didDocument.id, threadId }) if (getNumAlgoFromPeerDid(didDocument.id) === PeerDidNumAlgo.GenesisDoc) { const didDocAttach = await this.createSignedAttachment( + agentContext, didDocument, Array.from( new Set( @@ -253,7 +257,7 @@ export class DidExchangeProtocol { connectionRecord.did = didDocument.id - await this.updateState(DidExchangeResponseMessage.type, connectionRecord) + await this.updateState(agentContext, DidExchangeResponseMessage.type, connectionRecord) this.logger.debug(`Create message ${DidExchangeResponseMessage.type.messageTypeUri} end`, { connectionRecord, message, @@ -299,6 +303,7 @@ export class DidExchangeProtocol { } const didDocument = await this.extractDidDocument( + messageContext.agentContext, message, outOfBandRecord.outOfBandInvitation.getRecipientKeys().map((key) => key.publicKeyBase58) ) @@ -320,16 +325,17 @@ export class DidExchangeProtocol { didDocument: 'omitted...', }) - await this.didRepository.save(didRecord) + await this.didRepository.save(messageContext.agentContext, didRecord) connectionRecord.theirDid = message.did - await this.updateState(DidExchangeResponseMessage.type, connectionRecord) + await this.updateState(messageContext.agentContext, DidExchangeResponseMessage.type, connectionRecord) this.logger.debug(`Process message ${DidExchangeResponseMessage.type.messageTypeUri} end`, connectionRecord) return connectionRecord } public async createComplete( + agentContext: AgentContext, connectionRecord: ConnectionRecord, outOfBandRecord: OutOfBandRecord ): Promise { @@ -351,7 +357,7 @@ export class DidExchangeProtocol { const message = new DidExchangeCompleteMessage({ threadId, parentThreadId }) - await this.updateState(DidExchangeCompleteMessage.type, connectionRecord) + await this.updateState(agentContext, DidExchangeCompleteMessage.type, connectionRecord) this.logger.debug(`Create message ${DidExchangeCompleteMessage.type.messageTypeUri} end`, { connectionRecord, message, @@ -384,18 +390,22 @@ export class DidExchangeProtocol { }) } - await this.updateState(DidExchangeCompleteMessage.type, connectionRecord) + await this.updateState(messageContext.agentContext, DidExchangeCompleteMessage.type, connectionRecord) this.logger.debug(`Process message ${DidExchangeCompleteMessage.type.messageTypeUri} end`, { connectionRecord }) return connectionRecord } - private async updateState(messageType: ParsedMessageType, connectionRecord: ConnectionRecord) { + private async updateState( + agentContext: AgentContext, + messageType: ParsedMessageType, + connectionRecord: ConnectionRecord + ) { this.logger.debug(`Updating state`, { connectionRecord }) const nextState = DidExchangeStateMachine.nextState(messageType, connectionRecord) - return this.connectionService.updateState(connectionRecord, nextState) + return this.connectionService.updateState(agentContext, connectionRecord, nextState) } - private async createPeerDidDoc(services: ResolvedDidCommService[]) { + private async createPeerDidDoc(agentContext: AgentContext, services: ResolvedDidCommService[]) { const didDocument = createDidDocumentFromServices(services) const peerDid = didDocumentJsonToNumAlgo1Did(didDocument.toJSON()) @@ -419,12 +429,12 @@ export class DidExchangeProtocol { didDocument: 'omitted...', }) - await this.didRepository.save(didRecord) + await this.didRepository.save(agentContext, didRecord) this.logger.debug('Did record created.', didRecord) return didDocument } - private async createSignedAttachment(didDoc: DidDocument, verkeys: string[]) { + private async createSignedAttachment(agentContext: AgentContext, didDoc: DidDocument, verkeys: string[]) { const didDocAttach = new Attachment({ mimeType: 'application/json', data: new AttachmentData({ @@ -438,7 +448,7 @@ export class DidExchangeProtocol { const kid = new DidKey(key).did const payload = JsonEncoder.toBuffer(didDoc) - const jws = await this.jwsService.createJws({ + const jws = await this.jwsService.createJws(agentContext, { payload, verkey, header: { @@ -460,6 +470,7 @@ export class DidExchangeProtocol { * @returns verified DID document content from message attachment */ private async extractDidDocument( + agentContext: AgentContext, message: DidExchangeRequestMessage | DidExchangeResponseMessage, invitationKeysBase58: string[] = [] ): Promise { @@ -485,7 +496,7 @@ export class DidExchangeProtocol { this.logger.trace('DidDocument JSON', json) const payload = JsonEncoder.toBuffer(json) - const { isValid, signerVerkeys } = await this.jwsService.verifyJws({ jws, payload }) + const { isValid, signerVerkeys } = await this.jwsService.verifyJws(agentContext, { jws, payload }) const didDocument = JsonTransformer.fromJSON(json, DidDocument) const didDocumentKeysBase58 = didDocument.authentication diff --git a/packages/core/src/modules/connections/__tests__/ConnectionService.test.ts b/packages/core/src/modules/connections/__tests__/ConnectionService.test.ts index b27ec3fed1..9f16191403 100644 --- a/packages/core/src/modules/connections/__tests__/ConnectionService.test.ts +++ b/packages/core/src/modules/connections/__tests__/ConnectionService.test.ts @@ -1,7 +1,16 @@ +import type { AgentContext } from '../../../agent' import type { Wallet } from '../../../wallet/Wallet' import type { Routing } from '../services/ConnectionService' -import { getAgentConfig, getMockConnection, getMockOutOfBand, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { + getAgentConfig, + getAgentContext, + getMockConnection, + getMockOutOfBand, + mockFunction, +} from '../../../../tests/helpers' import { AgentMessage } from '../../../agent/AgentMessage' import { EventEmitter } from '../../../agent/EventEmitter' import { InboundMessageContext } from '../../../agent/models/InboundMessageContext' @@ -39,21 +48,23 @@ const DidRepositoryMock = DidRepository as jest.Mock const connectionImageUrl = 'https://example.com/image.png' -describe('ConnectionService', () => { - const agentConfig = getAgentConfig('ConnectionServiceTest', { - endpoints: ['http://agent.com:8080'], - connectionImageUrl, - }) +const agentConfig = getAgentConfig('ConnectionServiceTest', { + endpoints: ['http://agent.com:8080'], + connectionImageUrl, +}) +describe('ConnectionService', () => { let wallet: Wallet let connectionRepository: ConnectionRepository let didRepository: DidRepository let connectionService: ConnectionService let eventEmitter: EventEmitter let myRouting: Routing + let agentContext: AgentContext beforeAll(async () => { - wallet = new IndyWallet(agentConfig) + wallet = new IndyWallet(agentConfig.agentDependencies, agentConfig.logger) + agentContext = getAgentContext({ wallet, agentConfig }) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(agentConfig.walletConfig!) }) @@ -63,10 +74,10 @@ describe('ConnectionService', () => { }) beforeEach(async () => { - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) connectionRepository = new ConnectionRepositoryMock() didRepository = new DidRepositoryMock() - connectionService = new ConnectionService(wallet, agentConfig, connectionRepository, didRepository, eventEmitter) + connectionService = new ConnectionService(agentConfig.logger, connectionRepository, didRepository, eventEmitter) myRouting = { recipientKey: Key.fromFingerprint('z6MkwFkSP4uv5PhhKJCGehtjuZedkotC7VF64xtMsxuM8R3W'), endpoints: agentConfig.endpoints ?? [], @@ -82,7 +93,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ state: OutOfBandState.PrepareResponse }) const config = { routing: myRouting } - const { connectionRecord, message } = await connectionService.createRequest(outOfBand, config) + const { connectionRecord, message } = await connectionService.createRequest(agentContext, outOfBand, config) expect(connectionRecord.state).toBe(DidExchangeState.RequestSent) expect(message.label).toBe(agentConfig.label) @@ -119,7 +130,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ state: OutOfBandState.PrepareResponse }) const config = { label: 'Custom label', routing: myRouting } - const { message } = await connectionService.createRequest(outOfBand, config) + const { message } = await connectionService.createRequest(agentContext, outOfBand, config) expect(message.label).toBe('Custom label') }) @@ -130,7 +141,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ state: OutOfBandState.PrepareResponse, imageUrl: connectionImageUrl }) const config = { label: 'Custom label', routing: myRouting } - const { connectionRecord } = await connectionService.createRequest(outOfBand, config) + const { connectionRecord } = await connectionService.createRequest(agentContext, outOfBand, config) expect(connectionRecord.imageUrl).toBe(connectionImageUrl) }) @@ -141,7 +152,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ state: OutOfBandState.PrepareResponse }) const config = { imageUrl: 'custom-image-url', routing: myRouting } - const { message } = await connectionService.createRequest(outOfBand, config) + const { message } = await connectionService.createRequest(agentContext, outOfBand, config) expect(message.imageUrl).toBe('custom-image-url') }) @@ -152,7 +163,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ role: OutOfBandRole.Sender, state: OutOfBandState.PrepareResponse }) const config = { routing: myRouting } - return expect(connectionService.createRequest(outOfBand, config)).rejects.toThrowError( + return expect(connectionService.createRequest(agentContext, outOfBand, config)).rejects.toThrowError( `Invalid out-of-band record role ${OutOfBandRole.Sender}, expected is ${OutOfBandRole.Receiver}.` ) }) @@ -166,7 +177,7 @@ describe('ConnectionService', () => { const outOfBand = getMockOutOfBand({ state }) const config = { routing: myRouting } - return expect(connectionService.createRequest(outOfBand, config)).rejects.toThrowError( + return expect(connectionService.createRequest(agentContext, outOfBand, config)).rejects.toThrowError( `Invalid out-of-band record state ${state}, valid states are: ${OutOfBandState.PrepareResponse}.` ) } @@ -208,6 +219,7 @@ describe('ConnectionService', () => { }) const messageContext = new InboundMessageContext(connectionRequest, { + agentContext, senderKey: theirKey, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), }) @@ -265,6 +277,7 @@ describe('ConnectionService', () => { }) const messageContext = new InboundMessageContext(connectionRequest, { + agentContext, connection: connectionRecord, senderKey: theirKey, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), @@ -297,6 +310,7 @@ describe('ConnectionService', () => { }) const messageContext = new InboundMessageContext(connectionRequest, { + agentContext, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), senderKey: Key.fromPublicKeyBase58('79CXkde3j8TNuMXxPdV7nLUrT2g7JAEjH5TreyVY7GEZ', KeyType.Ed25519), }) @@ -312,6 +326,7 @@ describe('ConnectionService', () => { expect.assertions(1) const inboundMessage = new InboundMessageContext(jest.fn()(), { + agentContext, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), senderKey: Key.fromPublicKeyBase58('79CXkde3j8TNuMXxPdV7nLUrT2g7JAEjH5TreyVY7GEZ', KeyType.Ed25519), }) @@ -329,7 +344,7 @@ describe('ConnectionService', () => { (state) => { expect.assertions(1) - const inboundMessage = new InboundMessageContext(jest.fn()(), {}) + const inboundMessage = new InboundMessageContext(jest.fn()(), { agentContext }) const outOfBand = getMockOutOfBand({ role: OutOfBandRole.Sender, state }) return expect(connectionService.processRequest(inboundMessage, outOfBand)).rejects.toThrowError( @@ -376,6 +391,7 @@ describe('ConnectionService', () => { }) const { message, connectionRecord: connectionRecord } = await connectionService.createResponse( + agentContext, mockConnection, outOfBand ) @@ -398,7 +414,7 @@ describe('ConnectionService', () => { state: DidExchangeState.RequestReceived, }) const outOfBand = getMockOutOfBand() - return expect(connectionService.createResponse(connection, outOfBand)).rejects.toThrowError( + return expect(connectionService.createResponse(agentContext, connection, outOfBand)).rejects.toThrowError( `Connection record has invalid role ${DidExchangeRole.Requester}. Expected role ${DidExchangeRole.Responder}.` ) }) @@ -420,7 +436,7 @@ describe('ConnectionService', () => { const connection = getMockConnection({ state }) const outOfBand = getMockOutOfBand() - return expect(connectionService.createResponse(connection, outOfBand)).rejects.toThrowError( + return expect(connectionService.createResponse(agentContext, connection, outOfBand)).rejects.toThrowError( `Connection record is in invalid state ${state}. Valid states are: ${DidExchangeState.RequestReceived}.` ) } @@ -478,6 +494,7 @@ describe('ConnectionService', () => { recipientKeys: [new DidKey(theirKey).did], }) const messageContext = new InboundMessageContext(connectionResponse, { + agentContext, connection: connectionRecord, senderKey: theirKey, recipientKey: Key.fromPublicKeyBase58(verkey, KeyType.Ed25519), @@ -501,6 +518,7 @@ describe('ConnectionService', () => { state: DidExchangeState.RequestSent, }) const messageContext = new InboundMessageContext(jest.fn()(), { + agentContext, connection: connectionRecord, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), senderKey: Key.fromPublicKeyBase58('79CXkde3j8TNuMXxPdV7nLUrT2g7JAEjH5TreyVY7GEZ', KeyType.Ed25519), @@ -561,6 +579,7 @@ describe('ConnectionService', () => { recipientKeys: [new DidKey(Key.fromPublicKeyBase58(verkey, KeyType.Ed25519)).did], }) const messageContext = new InboundMessageContext(connectionResponse, { + agentContext, connection: connectionRecord, senderKey: theirKey, recipientKey: Key.fromPublicKeyBase58(verkey, KeyType.Ed25519), @@ -594,6 +613,7 @@ describe('ConnectionService', () => { const outOfBandRecord = getMockOutOfBand({ recipientKeys: [new DidKey(theirKey).did] }) const messageContext = new InboundMessageContext(connectionResponse, { + agentContext, connection: connectionRecord, recipientKey: Key.fromPublicKeyBase58('8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K', KeyType.Ed25519), senderKey: Key.fromPublicKeyBase58('79CXkde3j8TNuMXxPdV7nLUrT2g7JAEjH5TreyVY7GEZ', KeyType.Ed25519), @@ -611,7 +631,10 @@ describe('ConnectionService', () => { const mockConnection = getMockConnection({ state: DidExchangeState.ResponseReceived }) - const { message, connectionRecord: connectionRecord } = await connectionService.createTrustPing(mockConnection) + const { message, connectionRecord: connectionRecord } = await connectionService.createTrustPing( + agentContext, + mockConnection + ) expect(connectionRecord.state).toBe(DidExchangeState.Completed) expect(message).toEqual(expect.any(TrustPingMessage)) @@ -632,7 +655,7 @@ describe('ConnectionService', () => { expect.assertions(1) const connection = getMockConnection({ state }) - return expect(connectionService.createTrustPing(connection)).rejects.toThrowError( + return expect(connectionService.createTrustPing(agentContext, connection)).rejects.toThrowError( `Connection record is in invalid state ${state}. Valid states are: ${DidExchangeState.ResponseReceived}, ${DidExchangeState.Completed}.` ) } @@ -648,7 +671,7 @@ describe('ConnectionService', () => { threadId: 'thread-id', }) - const messageContext = new InboundMessageContext(ack, {}) + const messageContext = new InboundMessageContext(ack, { agentContext }) return expect(connectionService.processAck(messageContext)).rejects.toThrowError( 'Unable to process connection ack: connection for recipient key undefined not found' @@ -668,7 +691,7 @@ describe('ConnectionService', () => { threadId: 'thread-id', }) - const messageContext = new InboundMessageContext(ack, { connection }) + const messageContext = new InboundMessageContext(ack, { agentContext, connection }) const updatedConnection = await connectionService.processAck(messageContext) @@ -688,7 +711,7 @@ describe('ConnectionService', () => { threadId: 'thread-id', }) - const messageContext = new InboundMessageContext(ack, { connection }) + const messageContext = new InboundMessageContext(ack, { agentContext, connection }) const updatedConnection = await connectionService.processAck(messageContext) @@ -701,6 +724,7 @@ describe('ConnectionService', () => { expect.assertions(1) const messageContext = new InboundMessageContext(new AgentMessage(), { + agentContext, connection: getMockConnection({ state: DidExchangeState.Completed }), }) @@ -711,6 +735,7 @@ describe('ConnectionService', () => { expect.assertions(1) const messageContext = new InboundMessageContext(new AgentMessage(), { + agentContext, connection: getMockConnection({ state: DidExchangeState.InvitationReceived }), }) @@ -728,7 +753,7 @@ describe('ConnectionService', () => { serviceEndpoint: '', routingKeys: [], }) - const messageContext = new InboundMessageContext(message) + const messageContext = new InboundMessageContext(message, { agentContext }) expect(() => connectionService.assertConnectionOrServiceDecorator(messageContext)).not.toThrow() }) @@ -759,7 +784,7 @@ describe('ConnectionService', () => { serviceEndpoint: '', routingKeys: [], }) - const messageContext = new InboundMessageContext(message, { recipientKey, senderKey }) + const messageContext = new InboundMessageContext(message, { agentContext, recipientKey, senderKey }) expect(() => connectionService.assertConnectionOrServiceDecorator(messageContext, { @@ -780,7 +805,7 @@ describe('ConnectionService', () => { }) const message = new AgentMessage() - const messageContext = new InboundMessageContext(message) + const messageContext = new InboundMessageContext(message, { agentContext }) expect(() => connectionService.assertConnectionOrServiceDecorator(messageContext, { @@ -802,7 +827,7 @@ describe('ConnectionService', () => { }) const message = new AgentMessage() - const messageContext = new InboundMessageContext(message, { recipientKey }) + const messageContext = new InboundMessageContext(message, { agentContext, recipientKey }) expect(() => connectionService.assertConnectionOrServiceDecorator(messageContext, { @@ -824,7 +849,7 @@ describe('ConnectionService', () => { }) const message = new AgentMessage() - const messageContext = new InboundMessageContext(message) + const messageContext = new InboundMessageContext(message, { agentContext }) expect(() => connectionService.assertConnectionOrServiceDecorator(messageContext, { @@ -847,6 +872,7 @@ describe('ConnectionService', () => { const message = new AgentMessage() const messageContext = new InboundMessageContext(message, { + agentContext, senderKey: Key.fromPublicKeyBase58(senderKey, KeyType.Ed25519), }) @@ -864,8 +890,8 @@ describe('ConnectionService', () => { it('getById should return value from connectionRepository.getById', async () => { const expected = getMockConnection() mockFunction(connectionRepository.getById).mockReturnValue(Promise.resolve(expected)) - const result = await connectionService.getById(expected.id) - expect(connectionRepository.getById).toBeCalledWith(expected.id) + const result = await connectionService.getById(agentContext, expected.id) + expect(connectionRepository.getById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -873,8 +899,8 @@ describe('ConnectionService', () => { it('getByThreadId should return value from connectionRepository.getSingleByQuery', async () => { const expected = getMockConnection() mockFunction(connectionRepository.getByThreadId).mockReturnValue(Promise.resolve(expected)) - const result = await connectionService.getByThreadId('threadId') - expect(connectionRepository.getByThreadId).toBeCalledWith('threadId') + const result = await connectionService.getByThreadId(agentContext, 'threadId') + expect(connectionRepository.getByThreadId).toBeCalledWith(agentContext, 'threadId') expect(result).toBe(expected) }) @@ -882,8 +908,8 @@ describe('ConnectionService', () => { it('findById should return value from connectionRepository.findById', async () => { const expected = getMockConnection() mockFunction(connectionRepository.findById).mockReturnValue(Promise.resolve(expected)) - const result = await connectionService.findById(expected.id) - expect(connectionRepository.findById).toBeCalledWith(expected.id) + const result = await connectionService.findById(agentContext, expected.id) + expect(connectionRepository.findById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -892,8 +918,8 @@ describe('ConnectionService', () => { const expected = [getMockConnection(), getMockConnection()] mockFunction(connectionRepository.getAll).mockReturnValue(Promise.resolve(expected)) - const result = await connectionService.getAll() - expect(connectionRepository.getAll).toBeCalledWith() + const result = await connectionService.getAll(agentContext) + expect(connectionRepository.getAll).toBeCalledWith(agentContext) expect(result).toEqual(expect.arrayContaining(expected)) }) diff --git a/packages/core/src/modules/connections/handlers/ConnectionRequestHandler.ts b/packages/core/src/modules/connections/handlers/ConnectionRequestHandler.ts index b9197814c1..1f55bea49a 100644 --- a/packages/core/src/modules/connections/handlers/ConnectionRequestHandler.ts +++ b/packages/core/src/modules/connections/handlers/ConnectionRequestHandler.ts @@ -1,4 +1,3 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { DidRepository } from '../../dids/repository' import type { OutOfBandService } from '../../oob/OutOfBandService' @@ -10,7 +9,6 @@ import { AriesFrameworkError } from '../../../error/AriesFrameworkError' import { ConnectionRequestMessage } from '../messages' export class ConnectionRequestHandler implements Handler { - private agentConfig: AgentConfig private connectionService: ConnectionService private outOfBandService: OutOfBandService private routingService: RoutingService @@ -18,13 +16,11 @@ export class ConnectionRequestHandler implements Handler { public supportedMessages = [ConnectionRequestMessage] public constructor( - agentConfig: AgentConfig, connectionService: ConnectionService, outOfBandService: OutOfBandService, routingService: RoutingService, didRepository: DidRepository ) { - this.agentConfig = agentConfig this.connectionService = connectionService this.outOfBandService = outOfBandService this.routingService = routingService @@ -38,7 +34,7 @@ export class ConnectionRequestHandler implements Handler { throw new AriesFrameworkError('Unable to process connection request without senderVerkey or recipientKey') } - const outOfBandRecord = await this.outOfBandService.findByRecipientKey(recipientKey) + const outOfBandRecord = await this.outOfBandService.findByRecipientKey(messageContext.agentContext, recipientKey) if (!outOfBandRecord) { throw new AriesFrameworkError(`Out-of-band record for recipient key ${recipientKey.fingerprint} was not found.`) @@ -50,18 +46,25 @@ export class ConnectionRequestHandler implements Handler { ) } - const didRecord = await this.didRepository.findByRecipientKey(senderKey) + const didRecord = await this.didRepository.findByRecipientKey(messageContext.agentContext, senderKey) if (didRecord) { throw new AriesFrameworkError(`Did record for sender key ${senderKey.fingerprint} already exists.`) } const connectionRecord = await this.connectionService.processRequest(messageContext, outOfBandRecord) - if (connectionRecord?.autoAcceptConnection ?? this.agentConfig.autoAcceptConnections) { + if (connectionRecord?.autoAcceptConnection ?? messageContext.agentContext.config.autoAcceptConnections) { // TODO: Allow rotation of keys used in the invitation for new ones not only when out-of-band is reusable - const routing = outOfBandRecord.reusable ? await this.routingService.getRouting() : undefined + const routing = outOfBandRecord.reusable + ? await this.routingService.getRouting(messageContext.agentContext) + : undefined - const { message } = await this.connectionService.createResponse(connectionRecord, outOfBandRecord, routing) + const { message } = await this.connectionService.createResponse( + messageContext.agentContext, + connectionRecord, + outOfBandRecord, + routing + ) return createOutboundMessage(connectionRecord, message, outOfBandRecord) } } diff --git a/packages/core/src/modules/connections/handlers/ConnectionResponseHandler.ts b/packages/core/src/modules/connections/handlers/ConnectionResponseHandler.ts index 6bac8d929c..6024fb7973 100644 --- a/packages/core/src/modules/connections/handlers/ConnectionResponseHandler.ts +++ b/packages/core/src/modules/connections/handlers/ConnectionResponseHandler.ts @@ -1,4 +1,3 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { DidResolverService } from '../../dids' import type { OutOfBandService } from '../../oob/OutOfBandService' @@ -9,7 +8,6 @@ import { AriesFrameworkError } from '../../../error' import { ConnectionResponseMessage } from '../messages' export class ConnectionResponseHandler implements Handler { - private agentConfig: AgentConfig private connectionService: ConnectionService private outOfBandService: OutOfBandService private didResolverService: DidResolverService @@ -17,12 +15,10 @@ export class ConnectionResponseHandler implements Handler { public supportedMessages = [ConnectionResponseMessage] public constructor( - agentConfig: AgentConfig, connectionService: ConnectionService, outOfBandService: OutOfBandService, didResolverService: DidResolverService ) { - this.agentConfig = agentConfig this.connectionService = connectionService this.outOfBandService = outOfBandService this.didResolverService = didResolverService @@ -35,7 +31,7 @@ export class ConnectionResponseHandler implements Handler { throw new AriesFrameworkError('Unable to process connection response without senderKey or recipientKey') } - const connectionRecord = await this.connectionService.getByThreadId(message.threadId) + const connectionRecord = await this.connectionService.getByThreadId(messageContext.agentContext, message.threadId) if (!connectionRecord) { throw new AriesFrameworkError(`Connection for thread ID ${message.threadId} not found!`) } @@ -44,7 +40,10 @@ export class ConnectionResponseHandler implements Handler { throw new AriesFrameworkError(`Connection record ${connectionRecord.id} has no 'did'`) } - const ourDidDocument = await this.didResolverService.resolveDidDocument(connectionRecord.did) + const ourDidDocument = await this.didResolverService.resolveDidDocument( + messageContext.agentContext, + connectionRecord.did + ) if (!ourDidDocument) { throw new AriesFrameworkError(`Did document for did ${connectionRecord.did} was not resolved!`) } @@ -58,7 +57,8 @@ export class ConnectionResponseHandler implements Handler { } const outOfBandRecord = - connectionRecord.outOfBandId && (await this.outOfBandService.findById(connectionRecord.outOfBandId)) + connectionRecord.outOfBandId && + (await this.outOfBandService.findById(messageContext.agentContext, connectionRecord.outOfBandId)) if (!outOfBandRecord) { throw new AriesFrameworkError(`Out-of-band record ${connectionRecord.outOfBandId} was not found.`) @@ -71,8 +71,10 @@ export class ConnectionResponseHandler implements Handler { // TODO: should we only send ping message in case of autoAcceptConnection or always? // In AATH we have a separate step to send the ping. So for now we'll only do it // if auto accept is enable - if (connection.autoAcceptConnection ?? this.agentConfig.autoAcceptConnections) { - const { message } = await this.connectionService.createTrustPing(connection, { responseRequested: false }) + if (connection.autoAcceptConnection ?? messageContext.agentContext.config.autoAcceptConnections) { + const { message } = await this.connectionService.createTrustPing(messageContext.agentContext, connection, { + responseRequested: false, + }) return createOutboundMessage(connection, message) } } diff --git a/packages/core/src/modules/connections/handlers/DidExchangeCompleteHandler.ts b/packages/core/src/modules/connections/handlers/DidExchangeCompleteHandler.ts index d3f4a6eae6..e138dfc49e 100644 --- a/packages/core/src/modules/connections/handlers/DidExchangeCompleteHandler.ts +++ b/packages/core/src/modules/connections/handlers/DidExchangeCompleteHandler.ts @@ -35,14 +35,17 @@ export class DidExchangeCompleteHandler implements Handler { if (!message.thread?.parentThreadId) { throw new AriesFrameworkError(`Message does not contain pthid attribute`) } - const outOfBandRecord = await this.outOfBandService.findByInvitationId(message.thread?.parentThreadId) + const outOfBandRecord = await this.outOfBandService.findByInvitationId( + messageContext.agentContext, + message.thread?.parentThreadId + ) if (!outOfBandRecord) { throw new AriesFrameworkError(`OutOfBand record for message ID ${message.thread?.parentThreadId} not found!`) } if (!outOfBandRecord.reusable) { - await this.outOfBandService.updateState(outOfBandRecord, OutOfBandState.Done) + await this.outOfBandService.updateState(messageContext.agentContext, outOfBandRecord, OutOfBandState.Done) } await this.didExchangeProtocol.processComplete(messageContext, outOfBandRecord) } diff --git a/packages/core/src/modules/connections/handlers/DidExchangeRequestHandler.ts b/packages/core/src/modules/connections/handlers/DidExchangeRequestHandler.ts index c7fdc5699c..3c18a8dc84 100644 --- a/packages/core/src/modules/connections/handlers/DidExchangeRequestHandler.ts +++ b/packages/core/src/modules/connections/handlers/DidExchangeRequestHandler.ts @@ -1,4 +1,3 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { DidRepository } from '../../dids/repository' import type { OutOfBandService } from '../../oob/OutOfBandService' @@ -13,19 +12,16 @@ import { DidExchangeRequestMessage } from '../messages' export class DidExchangeRequestHandler implements Handler { private didExchangeProtocol: DidExchangeProtocol private outOfBandService: OutOfBandService - private agentConfig: AgentConfig private routingService: RoutingService private didRepository: DidRepository public supportedMessages = [DidExchangeRequestMessage] public constructor( - agentConfig: AgentConfig, didExchangeProtocol: DidExchangeProtocol, outOfBandService: OutOfBandService, routingService: RoutingService, didRepository: DidRepository ) { - this.agentConfig = agentConfig this.didExchangeProtocol = didExchangeProtocol this.outOfBandService = outOfBandService this.routingService = routingService @@ -42,7 +38,10 @@ export class DidExchangeRequestHandler implements Handler { if (!message.thread?.parentThreadId) { throw new AriesFrameworkError(`Message does not contain 'pthid' attribute`) } - const outOfBandRecord = await this.outOfBandService.findByInvitationId(message.thread.parentThreadId) + const outOfBandRecord = await this.outOfBandService.findByInvitationId( + messageContext.agentContext, + message.thread.parentThreadId + ) if (!outOfBandRecord) { throw new AriesFrameworkError(`OutOfBand record for message ID ${message.thread?.parentThreadId} not found!`) @@ -54,7 +53,7 @@ export class DidExchangeRequestHandler implements Handler { ) } - const didRecord = await this.didRepository.findByRecipientKey(senderKey) + const didRecord = await this.didRepository.findByRecipientKey(messageContext.agentContext, senderKey) if (didRecord) { throw new AriesFrameworkError(`Did record for sender key ${senderKey.fingerprint} already exists.`) } @@ -69,12 +68,19 @@ export class DidExchangeRequestHandler implements Handler { const connectionRecord = await this.didExchangeProtocol.processRequest(messageContext, outOfBandRecord) - if (connectionRecord?.autoAcceptConnection ?? this.agentConfig.autoAcceptConnections) { + if (connectionRecord?.autoAcceptConnection ?? messageContext.agentContext.config.autoAcceptConnections) { // TODO We should add an option to not pass routing and therefore do not rotate keys and use the keys from the invitation // TODO: Allow rotation of keys used in the invitation for new ones not only when out-of-band is reusable - const routing = outOfBandRecord.reusable ? await this.routingService.getRouting() : undefined + const routing = outOfBandRecord.reusable + ? await this.routingService.getRouting(messageContext.agentContext) + : undefined - const message = await this.didExchangeProtocol.createResponse(connectionRecord, outOfBandRecord, routing) + const message = await this.didExchangeProtocol.createResponse( + messageContext.agentContext, + connectionRecord, + outOfBandRecord, + routing + ) return createOutboundMessage(connectionRecord, message, outOfBandRecord) } } diff --git a/packages/core/src/modules/connections/handlers/DidExchangeResponseHandler.ts b/packages/core/src/modules/connections/handlers/DidExchangeResponseHandler.ts index fea2841bb0..b6e3fedbfd 100644 --- a/packages/core/src/modules/connections/handlers/DidExchangeResponseHandler.ts +++ b/packages/core/src/modules/connections/handlers/DidExchangeResponseHandler.ts @@ -1,4 +1,3 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { DidResolverService } from '../../dids' import type { OutOfBandService } from '../../oob/OutOfBandService' @@ -12,7 +11,6 @@ import { DidExchangeResponseMessage } from '../messages' import { HandshakeProtocol } from '../models' export class DidExchangeResponseHandler implements Handler { - private agentConfig: AgentConfig private didExchangeProtocol: DidExchangeProtocol private outOfBandService: OutOfBandService private connectionService: ConnectionService @@ -20,13 +18,11 @@ export class DidExchangeResponseHandler implements Handler { public supportedMessages = [DidExchangeResponseMessage] public constructor( - agentConfig: AgentConfig, didExchangeProtocol: DidExchangeProtocol, outOfBandService: OutOfBandService, connectionService: ConnectionService, didResolverService: DidResolverService ) { - this.agentConfig = agentConfig this.didExchangeProtocol = didExchangeProtocol this.outOfBandService = outOfBandService this.connectionService = connectionService @@ -40,7 +36,7 @@ export class DidExchangeResponseHandler implements Handler { throw new AriesFrameworkError('Unable to process connection response without sender key or recipient key') } - const connectionRecord = await this.connectionService.getByThreadId(message.threadId) + const connectionRecord = await this.connectionService.getByThreadId(messageContext.agentContext, message.threadId) if (!connectionRecord) { throw new AriesFrameworkError(`Connection for thread ID ${message.threadId} not found!`) } @@ -49,7 +45,10 @@ export class DidExchangeResponseHandler implements Handler { throw new AriesFrameworkError(`Connection record ${connectionRecord.id} has no 'did'`) } - const ourDidDocument = await this.didResolverService.resolveDidDocument(connectionRecord.did) + const ourDidDocument = await this.didResolverService.resolveDidDocument( + messageContext.agentContext, + connectionRecord.did + ) if (!ourDidDocument) { throw new AriesFrameworkError(`Did document for did ${connectionRecord.did} was not resolved`) } @@ -73,7 +72,10 @@ export class DidExchangeResponseHandler implements Handler { throw new AriesFrameworkError(`Connection ${connectionRecord.id} does not have outOfBandId!`) } - const outOfBandRecord = await this.outOfBandService.findById(connectionRecord.outOfBandId) + const outOfBandRecord = await this.outOfBandService.findById( + messageContext.agentContext, + connectionRecord.outOfBandId + ) if (!outOfBandRecord) { throw new AriesFrameworkError( @@ -94,11 +96,15 @@ export class DidExchangeResponseHandler implements Handler { // TODO: should we only send complete message in case of autoAcceptConnection or always? // In AATH we have a separate step to send the complete. So for now we'll only do it - // if auto accept is enable - if (connection.autoAcceptConnection ?? this.agentConfig.autoAcceptConnections) { - const message = await this.didExchangeProtocol.createComplete(connection, outOfBandRecord) + // if auto accept is enabled + if (connection.autoAcceptConnection ?? messageContext.agentContext.config.autoAcceptConnections) { + const message = await this.didExchangeProtocol.createComplete( + messageContext.agentContext, + connection, + outOfBandRecord + ) if (!outOfBandRecord.reusable) { - await this.outOfBandService.updateState(outOfBandRecord, OutOfBandState.Done) + await this.outOfBandService.updateState(messageContext.agentContext, outOfBandRecord, OutOfBandState.Done) } return createOutboundMessage(connection, message) } diff --git a/packages/core/src/modules/connections/handlers/TrustPingMessageHandler.ts b/packages/core/src/modules/connections/handlers/TrustPingMessageHandler.ts index 6a37fee4b6..aec2f74ea5 100644 --- a/packages/core/src/modules/connections/handlers/TrustPingMessageHandler.ts +++ b/packages/core/src/modules/connections/handlers/TrustPingMessageHandler.ts @@ -25,7 +25,7 @@ export class TrustPingMessageHandler implements Handler { // TODO: This is better addressed in a middleware of some kind because // any message can transition the state to complete, not just an ack or trust ping if (connection.state === DidExchangeState.ResponseSent) { - await this.connectionService.updateState(connection, DidExchangeState.Completed) + await this.connectionService.updateState(messageContext.agentContext, connection, DidExchangeState.Completed) } return this.trustPingService.processPing(messageContext, connection) diff --git a/packages/core/src/modules/connections/repository/ConnectionRepository.ts b/packages/core/src/modules/connections/repository/ConnectionRepository.ts index 2ede9851c7..504b9ea655 100644 --- a/packages/core/src/modules/connections/repository/ConnectionRepository.ts +++ b/packages/core/src/modules/connections/repository/ConnectionRepository.ts @@ -1,6 +1,8 @@ +import type { AgentContext } from '../../../agent' + import { EventEmitter } from '../../../agent/EventEmitter' import { InjectionSymbols } from '../../../constants' -import { inject, injectable } from '../../../plugins' +import { injectable, inject } from '../../../plugins' import { Repository } from '../../../storage/Repository' import { StorageService } from '../../../storage/StorageService' @@ -15,14 +17,14 @@ export class ConnectionRepository extends Repository { super(ConnectionRecord, storageService, eventEmitter) } - public async findByDids({ ourDid, theirDid }: { ourDid: string; theirDid: string }) { - return this.findSingleByQuery({ + public async findByDids(agentContext: AgentContext, { ourDid, theirDid }: { ourDid: string; theirDid: string }) { + return this.findSingleByQuery(agentContext, { did: ourDid, theirDid, }) } - public getByThreadId(threadId: string): Promise { - return this.getSingleByQuery({ threadId }) + public getByThreadId(agentContext: AgentContext, threadId: string): Promise { + return this.getSingleByQuery(agentContext, { threadId }) } } diff --git a/packages/core/src/modules/connections/services/ConnectionService.ts b/packages/core/src/modules/connections/services/ConnectionService.ts index 3986fae4b6..51e2d7870a 100644 --- a/packages/core/src/modules/connections/services/ConnectionService.ts +++ b/packages/core/src/modules/connections/services/ConnectionService.ts @@ -1,6 +1,6 @@ +import type { AgentContext } from '../../../agent' import type { AgentMessage } from '../../../agent/AgentMessage' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' -import type { Logger } from '../../../logger' import type { AckMessage } from '../../common' import type { OutOfBandDidCommService } from '../../oob/domain/OutOfBandDidCommService' import type { OutOfBandRecord } from '../../oob/repository' @@ -11,21 +11,20 @@ import type { ConnectionRecordProps } from '../repository/ConnectionRecord' import { firstValueFrom, ReplaySubject } from 'rxjs' import { first, map, timeout } from 'rxjs/operators' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' import { InjectionSymbols } from '../../../constants' import { Key } from '../../../crypto' import { signData, unpackAndVerifySignatureDecorator } from '../../../decorators/signature/SignatureDecoratorUtils' import { AriesFrameworkError } from '../../../error' +import { Logger } from '../../../logger' import { inject, injectable } from '../../../plugins' import { JsonTransformer } from '../../../utils/JsonTransformer' import { indyDidFromPublicKeyBase58 } from '../../../utils/did' -import { Wallet } from '../../../wallet/Wallet' import { DidKey, IndyAgentService } from '../../dids' import { DidDocumentRole } from '../../dids/domain/DidDocumentRole' import { didKeyToVerkey } from '../../dids/helpers' import { didDocumentJsonToNumAlgo1Did } from '../../dids/methods/peer/peerDidNumAlgo1' -import { DidRepository, DidRecord } from '../../dids/repository' +import { DidRecord, DidRepository } from '../../dids/repository' import { DidRecordMetadataKeys } from '../../dids/repository/didRecordMetadataTypes' import { OutOfBandRole } from '../../oob/domain/OutOfBandRole' import { OutOfBandState } from '../../oob/domain/OutOfBandState' @@ -33,14 +32,14 @@ import { ConnectionEventTypes } from '../ConnectionEvents' import { ConnectionProblemReportError, ConnectionProblemReportReason } from '../errors' import { ConnectionRequestMessage, ConnectionResponseMessage, TrustPingMessage } from '../messages' import { - DidExchangeRole, - DidExchangeState, + authenticationTypes, Connection, DidDoc, + DidExchangeRole, + DidExchangeState, Ed25119Sig2018, HandshakeProtocol, ReferencedAuthentication, - authenticationTypes, } from '../models' import { ConnectionRecord } from '../repository/ConnectionRecord' import { ConnectionRepository } from '../repository/ConnectionRepository' @@ -57,26 +56,21 @@ export interface ConnectionRequestParams { @injectable() export class ConnectionService { - private wallet: Wallet - private config: AgentConfig private connectionRepository: ConnectionRepository private didRepository: DidRepository private eventEmitter: EventEmitter private logger: Logger public constructor( - @inject(InjectionSymbols.Wallet) wallet: Wallet, - config: AgentConfig, + @inject(InjectionSymbols.Logger) logger: Logger, connectionRepository: ConnectionRepository, didRepository: DidRepository, eventEmitter: EventEmitter ) { - this.wallet = wallet - this.config = config this.connectionRepository = connectionRepository this.didRepository = didRepository this.eventEmitter = eventEmitter - this.logger = config.logger + this.logger = logger } /** @@ -87,6 +81,7 @@ export class ConnectionService { * @returns outbound message containing connection request */ public async createRequest( + agentContext: AgentContext, outOfBandRecord: OutOfBandRecord, config: ConnectionRequestParams ): Promise> { @@ -105,12 +100,12 @@ export class ConnectionService { // We take just the first one for now. const [invitationDid] = outOfBandInvitation.invitationDids - const { did: peerDid } = await this.createDid({ + const { did: peerDid } = await this.createDid(agentContext, { role: DidDocumentRole.Created, didDoc, }) - const connectionRecord = await this.createConnection({ + const connectionRecord = await this.createConnection(agentContext, { protocol: HandshakeProtocol.Connections, role: DidExchangeRole.Requester, state: DidExchangeState.InvitationReceived, @@ -127,10 +122,10 @@ export class ConnectionService { const { label, imageUrl, autoAcceptConnection } = config const connectionRequest = new ConnectionRequestMessage({ - label: label ?? this.config.label, + label: label ?? agentContext.config.label, did: didDoc.id, didDoc, - imageUrl: imageUrl ?? this.config.connectionImageUrl, + imageUrl: imageUrl ?? agentContext.config.connectionImageUrl, }) if (autoAcceptConnection !== undefined || autoAcceptConnection !== null) { @@ -138,7 +133,7 @@ export class ConnectionService { } connectionRecord.threadId = connectionRequest.id - await this.updateState(connectionRecord, DidExchangeState.RequestSent) + await this.updateState(agentContext, connectionRecord, DidExchangeState.RequestSent) return { connectionRecord, @@ -150,7 +145,9 @@ export class ConnectionService { messageContext: InboundMessageContext, outOfBandRecord: OutOfBandRecord ): Promise { - this.logger.debug(`Process message ${ConnectionRequestMessage.type} start`, messageContext) + this.logger.debug(`Process message ${ConnectionRequestMessage.type} start`, { + message: messageContext.message, + }) outOfBandRecord.assertRole(OutOfBandRole.Sender) outOfBandRecord.assertState(OutOfBandState.AwaitResponse) @@ -163,12 +160,12 @@ export class ConnectionService { }) } - const { did: peerDid } = await this.createDid({ + const { did: peerDid } = await this.createDid(messageContext.agentContext, { role: DidDocumentRole.Received, didDoc: message.connection.didDoc, }) - const connectionRecord = await this.createConnection({ + const connectionRecord = await this.createConnection(messageContext.agentContext, { protocol: HandshakeProtocol.Connections, role: DidExchangeRole.Responder, state: DidExchangeState.RequestReceived, @@ -181,8 +178,8 @@ export class ConnectionService { autoAcceptConnection: outOfBandRecord.autoAcceptConnection, }) - await this.connectionRepository.update(connectionRecord) - this.emitStateChangedEvent(connectionRecord, null) + await this.connectionRepository.update(messageContext.agentContext, connectionRecord) + this.emitStateChangedEvent(messageContext.agentContext, connectionRecord, null) this.logger.debug(`Process message ${ConnectionRequestMessage.type} end`, connectionRecord) return connectionRecord @@ -195,6 +192,7 @@ export class ConnectionService { * @returns outbound message containing connection response */ public async createResponse( + agentContext: AgentContext, connectionRecord: ConnectionRecord, outOfBandRecord: OutOfBandRecord, routing?: Routing @@ -211,7 +209,7 @@ export class ConnectionService { ) ) - const { did: peerDid } = await this.createDid({ + const { did: peerDid } = await this.createDid(agentContext, { role: DidDocumentRole.Created, didDoc, }) @@ -231,11 +229,11 @@ export class ConnectionService { const connectionResponse = new ConnectionResponseMessage({ threadId: connectionRecord.threadId, - connectionSig: await signData(connectionJson, this.wallet, signingKey), + connectionSig: await signData(connectionJson, agentContext.wallet, signingKey), }) connectionRecord.did = peerDid - await this.updateState(connectionRecord, DidExchangeState.ResponseSent) + await this.updateState(agentContext, connectionRecord, DidExchangeState.ResponseSent) this.logger.debug(`Create message ${ConnectionResponseMessage.type.messageTypeUri} end`, { connectionRecord, @@ -260,7 +258,9 @@ export class ConnectionService { messageContext: InboundMessageContext, outOfBandRecord: OutOfBandRecord ): Promise { - this.logger.debug(`Process message ${ConnectionResponseMessage.type} start`, messageContext) + this.logger.debug(`Process message ${ConnectionResponseMessage.type} start`, { + message: messageContext.message, + }) const { connection: connectionRecord, message, recipientKey, senderKey } = messageContext if (!recipientKey || !senderKey) { @@ -276,7 +276,10 @@ export class ConnectionService { let connectionJson = null try { - connectionJson = await unpackAndVerifySignatureDecorator(message.connectionSig, this.wallet) + connectionJson = await unpackAndVerifySignatureDecorator( + message.connectionSig, + messageContext.agentContext.wallet + ) } catch (error) { if (error instanceof AriesFrameworkError) { throw new ConnectionProblemReportError(error.message, { @@ -305,7 +308,7 @@ export class ConnectionService { throw new AriesFrameworkError('DID Document is missing.') } - const { did: peerDid } = await this.createDid({ + const { did: peerDid } = await this.createDid(messageContext.agentContext, { role: DidDocumentRole.Received, didDoc: connection.didDoc, }) @@ -313,7 +316,7 @@ export class ConnectionService { connectionRecord.theirDid = peerDid connectionRecord.threadId = message.threadId - await this.updateState(connectionRecord, DidExchangeState.ResponseReceived) + await this.updateState(messageContext.agentContext, connectionRecord, DidExchangeState.ResponseReceived) return connectionRecord } @@ -328,6 +331,7 @@ export class ConnectionService { * @returns outbound message containing trust ping message */ public async createTrustPing( + agentContext: AgentContext, connectionRecord: ConnectionRecord, config: { responseRequested?: boolean; comment?: string } = {} ): Promise> { @@ -340,7 +344,7 @@ export class ConnectionService { // Only update connection record and emit an event if the state is not already 'Complete' if (connectionRecord.state !== DidExchangeState.Completed) { - await this.updateState(connectionRecord, DidExchangeState.Completed) + await this.updateState(agentContext, connectionRecord, DidExchangeState.Completed) } return { @@ -368,7 +372,7 @@ export class ConnectionService { // TODO: This is better addressed in a middleware of some kind because // any message can transition the state to complete, not just an ack or trust ping if (connection.state === DidExchangeState.ResponseSent && connection.role === DidExchangeRole.Responder) { - await this.updateState(connection, DidExchangeState.Completed) + await this.updateState(messageContext.agentContext, connection, DidExchangeState.Completed) } return connection @@ -393,9 +397,9 @@ export class ConnectionService { } let connectionRecord - const ourDidRecords = await this.didRepository.findAllByRecipientKey(recipientKey) + const ourDidRecords = await this.didRepository.findAllByRecipientKey(messageContext.agentContext, recipientKey) for (const ourDidRecord of ourDidRecords) { - connectionRecord = await this.findByOurDid(ourDidRecord.id) + connectionRecord = await this.findByOurDid(messageContext.agentContext, ourDidRecord.id) } if (!connectionRecord) { @@ -404,7 +408,9 @@ export class ConnectionService { ) } - const theirDidRecord = connectionRecord.theirDid && (await this.didRepository.findById(connectionRecord.theirDid)) + const theirDidRecord = + connectionRecord.theirDid && + (await this.didRepository.findById(messageContext.agentContext, connectionRecord.theirDid)) if (!theirDidRecord) { throw new AriesFrameworkError(`Did record with id ${connectionRecord.theirDid} not found.`) } @@ -416,7 +422,7 @@ export class ConnectionService { } connectionRecord.errorMessage = `${connectionProblemReportMessage.description.code} : ${connectionProblemReportMessage.description.en}` - await this.update(connectionRecord) + await this.update(messageContext.agentContext, connectionRecord) return connectionRecord } @@ -494,19 +500,23 @@ export class ConnectionService { } } - public async updateState(connectionRecord: ConnectionRecord, newState: DidExchangeState) { + public async updateState(agentContext: AgentContext, connectionRecord: ConnectionRecord, newState: DidExchangeState) { const previousState = connectionRecord.state connectionRecord.state = newState - await this.connectionRepository.update(connectionRecord) + await this.connectionRepository.update(agentContext, connectionRecord) - this.emitStateChangedEvent(connectionRecord, previousState) + this.emitStateChangedEvent(agentContext, connectionRecord, previousState) } - private emitStateChangedEvent(connectionRecord: ConnectionRecord, previousState: DidExchangeState | null) { + private emitStateChangedEvent( + agentContext: AgentContext, + connectionRecord: ConnectionRecord, + previousState: DidExchangeState | null + ) { // Connection record in event should be static const clonedConnection = JsonTransformer.clone(connectionRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: ConnectionEventTypes.ConnectionStateChanged, payload: { connectionRecord: clonedConnection, @@ -515,8 +525,8 @@ export class ConnectionService { }) } - public update(connectionRecord: ConnectionRecord) { - return this.connectionRepository.update(connectionRecord) + public update(agentContext: AgentContext, connectionRecord: ConnectionRecord) { + return this.connectionRepository.update(agentContext, connectionRecord) } /** @@ -524,8 +534,8 @@ export class ConnectionService { * * @returns List containing all connection records */ - public getAll() { - return this.connectionRepository.getAll() + public getAll(agentContext: AgentContext) { + return this.connectionRepository.getAll(agentContext) } /** @@ -536,8 +546,8 @@ export class ConnectionService { * @return The connection record * */ - public getById(connectionId: string): Promise { - return this.connectionRepository.getById(connectionId) + public getById(agentContext: AgentContext, connectionId: string): Promise { + return this.connectionRepository.getById(agentContext, connectionId) } /** @@ -546,8 +556,8 @@ export class ConnectionService { * @param connectionId the connection record id * @returns The connection record or null if not found */ - public findById(connectionId: string): Promise { - return this.connectionRepository.findById(connectionId) + public findById(agentContext: AgentContext, connectionId: string): Promise { + return this.connectionRepository.findById(agentContext, connectionId) } /** @@ -555,13 +565,13 @@ export class ConnectionService { * * @param connectionId the connection record id */ - public async deleteById(connectionId: string) { - const connectionRecord = await this.getById(connectionId) - return this.connectionRepository.delete(connectionRecord) + public async deleteById(agentContext: AgentContext, connectionId: string) { + const connectionRecord = await this.getById(agentContext, connectionId) + return this.connectionRepository.delete(agentContext, connectionRecord) } - public async findSingleByQuery(query: { did: string; theirDid: string }) { - return this.connectionRepository.findSingleByQuery(query) + public async findByDids(agentContext: AgentContext, query: { ourDid: string; theirDid: string }) { + return this.connectionRepository.findByDids(agentContext, query) } /** @@ -572,33 +582,56 @@ export class ConnectionService { * @throws {RecordDuplicateError} If multiple records are found * @returns The connection record */ - public getByThreadId(threadId: string): Promise { - return this.connectionRepository.getByThreadId(threadId) + public async getByThreadId(agentContext: AgentContext, threadId: string): Promise { + return this.connectionRepository.getByThreadId(agentContext, threadId) } - public async findByTheirDid(did: string): Promise { - return this.connectionRepository.findSingleByQuery({ theirDid: did }) + public async findByTheirDid(agentContext: AgentContext, theirDid: string): Promise { + return this.connectionRepository.findSingleByQuery(agentContext, { theirDid }) } - public async findByOurDid(did: string): Promise { - return this.connectionRepository.findSingleByQuery({ did }) + public async findByOurDid(agentContext: AgentContext, ourDid: string): Promise { + return this.connectionRepository.findSingleByQuery(agentContext, { did: ourDid }) } - public async findAllByOutOfBandId(outOfBandId: string) { - return this.connectionRepository.findByQuery({ outOfBandId }) + public async findAllByOutOfBandId(agentContext: AgentContext, outOfBandId: string) { + return this.connectionRepository.findByQuery(agentContext, { outOfBandId }) } - public async findByInvitationDid(invitationDid: string) { - return this.connectionRepository.findByQuery({ invitationDid }) + public async findByInvitationDid(agentContext: AgentContext, invitationDid: string) { + return this.connectionRepository.findByQuery(agentContext, { invitationDid }) } - public async createConnection(options: ConnectionRecordProps): Promise { + public async findByKeys( + agentContext: AgentContext, + { senderKey, recipientKey }: { senderKey: Key; recipientKey: Key } + ) { + const theirDidRecord = await this.didRepository.findByRecipientKey(agentContext, senderKey) + if (theirDidRecord) { + const ourDidRecord = await this.didRepository.findByRecipientKey(agentContext, recipientKey) + if (ourDidRecord) { + const connectionRecord = await this.findByDids(agentContext, { + ourDid: ourDidRecord.id, + theirDid: theirDidRecord.id, + }) + if (connectionRecord && connectionRecord.isReady) return connectionRecord + } + } + + this.logger.debug( + `No connection record found for encrypted message with recipient key ${recipientKey.fingerprint} and sender key ${senderKey.fingerprint}` + ) + + return null + } + + public async createConnection(agentContext: AgentContext, options: ConnectionRecordProps): Promise { const connectionRecord = new ConnectionRecord(options) - await this.connectionRepository.save(connectionRecord) + await this.connectionRepository.save(agentContext, connectionRecord) return connectionRecord } - private async createDid({ role, didDoc }: { role: DidDocumentRole; didDoc: DidDoc }) { + private async createDid(agentContext: AgentContext, { role, didDoc }: { role: DidDocumentRole; didDoc: DidDoc }) { // Convert the legacy did doc to a new did document const didDocument = convertToNewDidDocument(didDoc) @@ -629,7 +662,7 @@ export class ConnectionService { didDocument: 'omitted...', }) - await this.didRepository.save(didRecord) + await this.didRepository.save(agentContext, didRecord) this.logger.debug('Did record created.', didRecord) return { did: peerDid, didDocument } } @@ -700,7 +733,11 @@ export class ConnectionService { }) } - public async returnWhenIsConnected(connectionId: string, timeoutMs = 20000): Promise { + public async returnWhenIsConnected( + agentContext: AgentContext, + connectionId: string, + timeoutMs = 20000 + ): Promise { const isConnected = (connection: ConnectionRecord) => { return connection.id === connectionId && connection.state === DidExchangeState.Completed } @@ -718,7 +755,7 @@ export class ConnectionService { ) .subscribe(subject) - const connection = await this.getById(connectionId) + const connection = await this.getById(agentContext, connectionId) if (isConnected(connection)) { subject.next(connection) } diff --git a/packages/core/src/modules/credentials/CredentialsModule.ts b/packages/core/src/modules/credentials/CredentialsModule.ts index 6c74598b0b..5aefd21ee2 100644 --- a/packages/core/src/modules/credentials/CredentialsModule.ts +++ b/packages/core/src/modules/credentials/CredentialsModule.ts @@ -1,5 +1,4 @@ import type { AgentMessage } from '../../agent/AgentMessage' -import type { Logger } from '../../logger' import type { DependencyManager } from '../../plugins' import type { DeleteCredentialOptions } from './CredentialServiceOptions' import type { @@ -7,30 +6,32 @@ import type { AcceptOfferOptions, AcceptProposalOptions, AcceptRequestOptions, - NegotiateOfferOptions, - NegotiateProposalOptions, - OfferCredentialOptions, - ProposeCredentialOptions, - ServiceMap, CreateOfferOptions, - FindOfferMessageReturn, - FindRequestMessageReturn, FindCredentialMessageReturn, + FindOfferMessageReturn, FindProposalMessageReturn, + FindRequestMessageReturn, GetFormatDataReturn, + NegotiateOfferOptions, + NegotiateProposalOptions, + OfferCredentialOptions, + ProposeCredentialOptions, SendProblemReportOptions, + ServiceMap, } from './CredentialsModuleOptions' import type { CredentialFormat } from './formats' import type { IndyCredentialFormat } from './formats/indy/IndyCredentialFormat' import type { CredentialExchangeRecord } from './repository/CredentialExchangeRecord' import type { CredentialService } from './services/CredentialService' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' +import { InjectionSymbols } from '../../constants' import { ServiceDecorator } from '../../decorators/service/ServiceDecorator' import { AriesFrameworkError } from '../../error' -import { injectable, module } from '../../plugins' +import { Logger } from '../../logger' +import { inject, injectable, module } from '../../plugins' import { DidCommMessageRole } from '../../storage' import { DidCommMessageRepository } from '../../storage/didcomm/DidCommMessageRepository' import { ConnectionService } from '../connections/services' @@ -38,10 +39,10 @@ import { RoutingService } from '../routing/services/RoutingService' import { IndyCredentialFormatService } from './formats' import { CredentialState } from './models/CredentialState' +import { RevocationNotificationService } from './protocol/revocation-notification/services' import { V1CredentialService } from './protocol/v1/V1CredentialService' import { V2CredentialService } from './protocol/v2/V2CredentialService' import { CredentialRepository } from './repository/CredentialRepository' -import { RevocationNotificationService } from './services' export interface CredentialsModule[]> { // Proposal methods @@ -95,7 +96,7 @@ export class CredentialsModule< private connectionService: ConnectionService private messageSender: MessageSender private credentialRepository: CredentialRepository - private agentConfig: AgentConfig + private agentContext: AgentContext private didCommMessageRepo: DidCommMessageRepository private routingService: RoutingService private logger: Logger @@ -104,7 +105,8 @@ export class CredentialsModule< public constructor( messageSender: MessageSender, connectionService: ConnectionService, - agentConfig: AgentConfig, + agentContext: AgentContext, + @inject(InjectionSymbols.Logger) logger: Logger, credentialRepository: CredentialRepository, mediationRecipientService: RoutingService, didCommMessageRepository: DidCommMessageRepository, @@ -117,10 +119,10 @@ export class CredentialsModule< this.messageSender = messageSender this.connectionService = connectionService this.credentialRepository = credentialRepository - this.agentConfig = agentConfig this.routingService = mediationRecipientService + this.agentContext = agentContext this.didCommMessageRepo = didCommMessageRepository - this.logger = agentConfig.logger + this.logger = logger // Dynamically build service map. This will be extracted once services are registered dynamically this.serviceMap = [v1Service, v2Service].reduce( @@ -131,7 +133,7 @@ export class CredentialsModule< {} ) as ServiceMap - this.logger.debug(`Initializing Credentials Module for agent ${this.agentConfig.label}`) + this.logger.debug(`Initializing Credentials Module for agent ${this.agentContext.config.label}`) } public getService(protocolVersion: PVT): CredentialService { @@ -155,10 +157,10 @@ export class CredentialsModule< this.logger.debug(`Got a CredentialService object for version ${options.protocolVersion}`) - const connection = await this.connectionService.getById(options.connectionId) + const connection = await this.connectionService.getById(this.agentContext, options.connectionId) // will get back a credential record -> map to Credential Exchange Record - const { credentialRecord, message } = await service.createProposal({ + const { credentialRecord, message } = await service.createProposal(this.agentContext, { connection, credentialFormats: options.credentialFormats, comment: options.comment, @@ -171,7 +173,7 @@ export class CredentialsModule< const outbound = createOutboundMessage(connection, message) this.logger.debug('In proposeCredential: Send Proposal to Issuer') - await this.messageSender.sendMessage(outbound) + await this.messageSender.sendMessage(this.agentContext, outbound) return credentialRecord } @@ -196,7 +198,7 @@ export class CredentialsModule< const service = this.getService(credentialRecord.protocolVersion) // will get back a credential record -> map to Credential Exchange Record - const { message } = await service.acceptProposal({ + const { message } = await service.acceptProposal(this.agentContext, { credentialRecord, credentialFormats: options.credentialFormats, comment: options.comment, @@ -204,9 +206,9 @@ export class CredentialsModule< }) // send the message - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const outbound = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outbound) + await this.messageSender.sendMessage(this.agentContext, outbound) return credentialRecord } @@ -231,16 +233,16 @@ export class CredentialsModule< // with version we can get the Service const service = this.getService(credentialRecord.protocolVersion) - const { message } = await service.negotiateProposal({ + const { message } = await service.negotiateProposal(this.agentContext, { credentialRecord, credentialFormats: options.credentialFormats, comment: options.comment, autoAcceptCredential: options.autoAcceptCredential, }) - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -253,12 +255,12 @@ export class CredentialsModule< * @returns Credential exchange record associated with the sent credential offer message */ public async offerCredential(options: OfferCredentialOptions): Promise { - const connection = await this.connectionService.getById(options.connectionId) + const connection = await this.connectionService.getById(this.agentContext, options.connectionId) const service = this.getService(options.protocolVersion) this.logger.debug(`Got a CredentialService object for version ${options.protocolVersion}`) - const { message, credentialRecord } = await service.createOffer({ + const { message, credentialRecord } = await service.createOffer(this.agentContext, { credentialFormats: options.credentialFormats, autoAcceptCredential: options.autoAcceptCredential, comment: options.comment, @@ -267,7 +269,7 @@ export class CredentialsModule< this.logger.debug('Offer Message successfully created; message= ', message) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -285,13 +287,13 @@ export class CredentialsModule< const service = this.getService(credentialRecord.protocolVersion) this.logger.debug(`Got a CredentialService object for this version; version = ${service.version}`) - const offerMessage = await service.findOfferMessage(credentialRecord.id) + const offerMessage = await service.findOfferMessage(this.agentContext, credentialRecord.id) // Use connection if present if (credentialRecord.connectionId) { - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) - const { message } = await service.acceptOffer({ + const { message } = await service.acceptOffer(this.agentContext, { credentialRecord, credentialFormats: options.credentialFormats, comment: options.comment, @@ -299,14 +301,14 @@ export class CredentialsModule< }) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } // Use ~service decorator otherwise else if (offerMessage?.service) { // Create ~service decorator - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(this.agentContext) const ourService = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -314,7 +316,7 @@ export class CredentialsModule< }) const recipientService = offerMessage.service - const { message } = await service.acceptOffer({ + const { message } = await service.acceptOffer(this.agentContext, { credentialRecord, credentialFormats: options.credentialFormats, comment: options.comment, @@ -323,13 +325,13 @@ export class CredentialsModule< // Set and save ~service decorator to record (to remember our verkey) message.service = ourService - await this.didCommMessageRepo.saveOrUpdateAgentMessage({ + await this.didCommMessageRepo.saveOrUpdateAgentMessage(this.agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, }) - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(this.agentContext, { message, service: recipientService.resolvedDidCommService, senderKey: ourService.resolvedDidCommService.recipientKeys[0], @@ -352,7 +354,7 @@ export class CredentialsModule< // with version we can get the Service const service = this.getService(credentialRecord.protocolVersion) - await service.updateState(credentialRecord, CredentialState.Declined) + await service.updateState(this.agentContext, credentialRecord, CredentialState.Declined) return credentialRecord } @@ -361,7 +363,7 @@ export class CredentialsModule< const credentialRecord = await this.getById(options.credentialRecordId) const service = this.getService(credentialRecord.protocolVersion) - const { message } = await service.negotiateOffer({ + const { message } = await service.negotiateOffer(this.agentContext, { credentialFormats: options.credentialFormats, credentialRecord, comment: options.comment, @@ -374,9 +376,9 @@ export class CredentialsModule< ) } - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -394,7 +396,7 @@ export class CredentialsModule< const service = this.getService(options.protocolVersion) this.logger.debug(`Got a CredentialService object for version ${options.protocolVersion}`) - const { message, credentialRecord } = await service.createOffer({ + const { message, credentialRecord } = await service.createOffer(this.agentContext, { credentialFormats: options.credentialFormats, comment: options.comment, autoAcceptCredential: options.autoAcceptCredential, @@ -420,7 +422,7 @@ export class CredentialsModule< this.logger.debug(`Got a CredentialService object for version ${credentialRecord.protocolVersion}`) - const { message } = await service.acceptRequest({ + const { message } = await service.acceptRequest(this.agentContext, { credentialRecord, credentialFormats: options.credentialFormats, comment: options.comment, @@ -428,14 +430,14 @@ export class CredentialsModule< }) this.logger.debug('We have a credential message (sending outbound): ', message) - const requestMessage = await service.findRequestMessage(credentialRecord.id) - const offerMessage = await service.findOfferMessage(credentialRecord.id) + const requestMessage = await service.findRequestMessage(this.agentContext, credentialRecord.id) + const offerMessage = await service.findOfferMessage(this.agentContext, credentialRecord.id) // Use connection if present if (credentialRecord.connectionId) { - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -445,13 +447,13 @@ export class CredentialsModule< const ourService = offerMessage.service message.service = ourService - await this.didCommMessageRepo.saveOrUpdateAgentMessage({ + await this.didCommMessageRepo.saveOrUpdateAgentMessage(this.agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, }) - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(this.agentContext, { message, service: recipientService.resolvedDidCommService, senderKey: ourService.resolvedDidCommService.recipientKeys[0], @@ -484,18 +486,18 @@ export class CredentialsModule< this.logger.debug(`Got a CredentialService object for version ${credentialRecord.protocolVersion}`) - const { message } = await service.acceptCredential({ + const { message } = await service.acceptCredential(this.agentContext, { credentialRecord, }) - const requestMessage = await service.findRequestMessage(credentialRecord.id) - const credentialMessage = await service.findCredentialMessage(credentialRecord.id) + const requestMessage = await service.findRequestMessage(this.agentContext, credentialRecord.id) + const credentialMessage = await service.findCredentialMessage(this.agentContext, credentialRecord.id) if (credentialRecord.connectionId) { - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -504,7 +506,7 @@ export class CredentialsModule< const recipientService = credentialMessage.service const ourService = requestMessage.service - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(this.agentContext, { message, service: recipientService.resolvedDidCommService, senderKey: ourService.resolvedDidCommService.recipientKeys[0], @@ -532,15 +534,15 @@ export class CredentialsModule< if (!credentialRecord.connectionId) { throw new AriesFrameworkError(`No connectionId found for credential record '${credentialRecord.id}'.`) } - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, credentialRecord.connectionId) const service = this.getService(credentialRecord.protocolVersion) - const problemReportMessage = service.createProblemReport({ message: options.message }) + const problemReportMessage = service.createProblemReport(this.agentContext, { message: options.message }) problemReportMessage.setThread({ threadId: credentialRecord.threadId, }) const outboundMessage = createOutboundMessage(connection, problemReportMessage) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return credentialRecord } @@ -549,7 +551,7 @@ export class CredentialsModule< const credentialRecord = await this.getById(credentialRecordId) const service = this.getService(credentialRecord.protocolVersion) - return service.getFormatData(credentialRecordId) + return service.getFormatData(this.agentContext, credentialRecordId) } /** @@ -561,7 +563,7 @@ export class CredentialsModule< * */ public getById(credentialRecordId: string): Promise { - return this.credentialRepository.getById(credentialRecordId) + return this.credentialRepository.getById(this.agentContext, credentialRecordId) } /** @@ -570,7 +572,7 @@ export class CredentialsModule< * @returns List containing all credential records */ public getAll(): Promise { - return this.credentialRepository.getAll() + return this.credentialRepository.getAll(this.agentContext) } /** @@ -580,7 +582,7 @@ export class CredentialsModule< * @returns The credential record or null if not found */ public findById(credentialRecordId: string): Promise { - return this.credentialRepository.findById(credentialRecordId) + return this.credentialRepository.findById(this.agentContext, credentialRecordId) } /** @@ -592,31 +594,31 @@ export class CredentialsModule< public async deleteById(credentialId: string, options?: DeleteCredentialOptions) { const credentialRecord = await this.getById(credentialId) const service = this.getService(credentialRecord.protocolVersion) - return service.delete(credentialRecord, options) + return service.delete(this.agentContext, credentialRecord, options) } public async findProposalMessage(credentialExchangeId: string): Promise> { const service = await this.getServiceForCredentialExchangeId(credentialExchangeId) - return service.findProposalMessage(credentialExchangeId) + return service.findProposalMessage(this.agentContext, credentialExchangeId) } public async findOfferMessage(credentialExchangeId: string): Promise> { const service = await this.getServiceForCredentialExchangeId(credentialExchangeId) - return service.findOfferMessage(credentialExchangeId) + return service.findOfferMessage(this.agentContext, credentialExchangeId) } public async findRequestMessage(credentialExchangeId: string): Promise> { const service = await this.getServiceForCredentialExchangeId(credentialExchangeId) - return service.findRequestMessage(credentialExchangeId) + return service.findRequestMessage(this.agentContext, credentialExchangeId) } public async findCredentialMessage(credentialExchangeId: string): Promise> { const service = await this.getServiceForCredentialExchangeId(credentialExchangeId) - return service.findCredentialMessage(credentialExchangeId) + return service.findCredentialMessage(this.agentContext, credentialExchangeId) } private async getServiceForCredentialExchangeId(credentialExchangeId: string) { diff --git a/packages/core/src/modules/credentials/formats/CredentialFormatService.ts b/packages/core/src/modules/credentials/formats/CredentialFormatService.ts index bd2e8fc34c..9d3f6f5da9 100644 --- a/packages/core/src/modules/credentials/formats/CredentialFormatService.ts +++ b/packages/core/src/modules/credentials/formats/CredentialFormatService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../agent' import type { EventEmitter } from '../../../agent/EventEmitter' import type { CredentialRepository } from '../repository' import type { CredentialFormat } from './CredentialFormat' @@ -34,30 +35,48 @@ export abstract class CredentialFormatService): Promise - abstract processProposal(options: FormatProcessOptions): Promise - abstract acceptProposal(options: FormatAcceptProposalOptions): Promise + abstract createProposal( + agentContext: AgentContext, + options: FormatCreateProposalOptions + ): Promise + abstract processProposal(agentContext: AgentContext, options: FormatProcessOptions): Promise + abstract acceptProposal( + agentContext: AgentContext, + options: FormatAcceptProposalOptions + ): Promise // offer methods - abstract createOffer(options: FormatCreateOfferOptions): Promise - abstract processOffer(options: FormatProcessOptions): Promise - abstract acceptOffer(options: FormatAcceptOfferOptions): Promise + abstract createOffer( + agentContext: AgentContext, + options: FormatCreateOfferOptions + ): Promise + abstract processOffer(agentContext: AgentContext, options: FormatProcessOptions): Promise + abstract acceptOffer(agentContext: AgentContext, options: FormatAcceptOfferOptions): Promise // request methods - abstract createRequest(options: FormatCreateRequestOptions): Promise - abstract processRequest(options: FormatProcessOptions): Promise - abstract acceptRequest(options: FormatAcceptRequestOptions): Promise + abstract createRequest( + agentContext: AgentContext, + options: FormatCreateRequestOptions + ): Promise + abstract processRequest(agentContext: AgentContext, options: FormatProcessOptions): Promise + abstract acceptRequest( + agentContext: AgentContext, + options: FormatAcceptRequestOptions + ): Promise // credential methods - abstract processCredential(options: FormatProcessOptions): Promise + abstract processCredential(agentContext: AgentContext, options: FormatProcessOptions): Promise // auto accept methods - abstract shouldAutoRespondToProposal(options: FormatAutoRespondProposalOptions): boolean - abstract shouldAutoRespondToOffer(options: FormatAutoRespondOfferOptions): boolean - abstract shouldAutoRespondToRequest(options: FormatAutoRespondRequestOptions): boolean - abstract shouldAutoRespondToCredential(options: FormatAutoRespondCredentialOptions): boolean + abstract shouldAutoRespondToProposal(agentContext: AgentContext, options: FormatAutoRespondProposalOptions): boolean + abstract shouldAutoRespondToOffer(agentContext: AgentContext, options: FormatAutoRespondOfferOptions): boolean + abstract shouldAutoRespondToRequest(agentContext: AgentContext, options: FormatAutoRespondRequestOptions): boolean + abstract shouldAutoRespondToCredential( + agentContext: AgentContext, + options: FormatAutoRespondCredentialOptions + ): boolean - abstract deleteCredentialById(credentialId: string): Promise + abstract deleteCredentialById(agentContext: AgentContext, credentialId: string): Promise abstract supportsFormat(format: string): boolean diff --git a/packages/core/src/modules/credentials/formats/indy/IndyCredentialFormatService.ts b/packages/core/src/modules/credentials/formats/indy/IndyCredentialFormatService.ts index fea9cb1bda..8f9e380b3c 100644 --- a/packages/core/src/modules/credentials/formats/indy/IndyCredentialFormatService.ts +++ b/packages/core/src/modules/credentials/formats/indy/IndyCredentialFormatService.ts @@ -1,36 +1,35 @@ +import type { AgentContext } from '../../../../agent' import type { Attachment } from '../../../../decorators/attachment/Attachment' -import type { Logger } from '../../../../logger' import type { LinkedAttachment } from '../../../../utils/LinkedAttachment' import type { CredentialPreviewAttributeOptions } from '../../models/CredentialPreviewAttribute' import type { CredentialExchangeRecord } from '../../repository/CredentialExchangeRecord' import type { - FormatAutoRespondCredentialOptions, FormatAcceptOfferOptions, FormatAcceptProposalOptions, FormatAcceptRequestOptions, + FormatAutoRespondCredentialOptions, + FormatAutoRespondOfferOptions, + FormatAutoRespondProposalOptions, + FormatAutoRespondRequestOptions, FormatCreateOfferOptions, FormatCreateOfferReturn, FormatCreateProposalOptions, FormatCreateProposalReturn, FormatCreateReturn, FormatProcessOptions, - FormatAutoRespondOfferOptions, - FormatAutoRespondProposalOptions, - FormatAutoRespondRequestOptions, } from '../CredentialFormatServiceOptions' import type { IndyCredentialFormat } from './IndyCredentialFormat' import type * as Indy from 'indy-sdk' -import { AgentConfig } from '../../../../agent/AgentConfig' import { EventEmitter } from '../../../../agent/EventEmitter' import { InjectionSymbols } from '../../../../constants' import { AriesFrameworkError } from '../../../../error' +import { Logger } from '../../../../logger' import { inject, injectable } from '../../../../plugins' import { JsonTransformer } from '../../../../utils/JsonTransformer' import { MessageValidator } from '../../../../utils/MessageValidator' import { getIndyDidFromVerificationMethod } from '../../../../utils/did' import { uuid } from '../../../../utils/uuid' -import { Wallet } from '../../../../wallet/Wallet' import { ConnectionService } from '../../../connections' import { DidResolverService, findVerificationMethodByKeyType } from '../../../dids' import { IndyHolderService, IndyIssuerService } from '../../../indy' @@ -57,7 +56,6 @@ export class IndyCredentialFormatService extends CredentialFormatService): Promise { + public async createProposal( + agentContext: AgentContext, + { credentialFormats, credentialRecord }: FormatCreateProposalOptions + ): Promise { const format = new CredentialFormatSpec({ format: INDY_CRED_FILTER, }) @@ -133,7 +129,7 @@ export class IndyCredentialFormatService extends CredentialFormatService { + public async processProposal(agentContext: AgentContext, { attachment }: FormatProcessOptions): Promise { const credProposalJson = attachment.getDataAsJson() if (!credProposalJson) { @@ -144,12 +140,15 @@ export class IndyCredentialFormatService extends CredentialFormatService): Promise { + public async acceptProposal( + agentContext: AgentContext, + { + attachId, + credentialFormats, + credentialRecord, + proposalAttachment, + }: FormatAcceptProposalOptions + ): Promise { const indyFormat = credentialFormats?.indy const credentialProposal = JsonTransformer.fromJSON(proposalAttachment.getDataAsJson(), IndyCredPropose) @@ -167,7 +166,7 @@ export class IndyCredentialFormatService extends CredentialFormatService): Promise { + public async createOffer( + agentContext: AgentContext, + { credentialFormats, credentialRecord, attachId }: FormatCreateOfferOptions + ): Promise { const indyFormat = credentialFormats.indy if (!indyFormat) { throw new AriesFrameworkError('Missing indy credentialFormat data') } - const { format, attachment, previewAttributes } = await this.createIndyOffer({ + const { format, attachment, previewAttributes } = await this.createIndyOffer(agentContext, { credentialRecord, attachId, attributes: indyFormat.attributes, @@ -208,7 +206,7 @@ export class IndyCredentialFormatService extends CredentialFormatService() @@ -220,24 +218,28 @@ export class IndyCredentialFormatService extends CredentialFormatService): Promise { + public async acceptOffer( + agentContext: AgentContext, + { credentialFormats, credentialRecord, attachId, offerAttachment }: FormatAcceptOfferOptions + ): Promise { const indyFormat = credentialFormats?.indy - const holderDid = indyFormat?.holderDid ?? (await this.getIndyHolderDid(credentialRecord)) + const holderDid = indyFormat?.holderDid ?? (await this.getIndyHolderDid(agentContext, credentialRecord)) const credentialOffer = offerAttachment.getDataAsJson() - const credentialDefinition = await this.indyLedgerService.getCredentialDefinition(credentialOffer.cred_def_id) + const credentialDefinition = await this.indyLedgerService.getCredentialDefinition( + agentContext, + credentialOffer.cred_def_id + ) - const [credentialRequest, credentialRequestMetadata] = await this.indyHolderService.createCredentialRequest({ - holderDid, - credentialOffer, - credentialDefinition, - }) + const [credentialRequest, credentialRequestMetadata] = await this.indyHolderService.createCredentialRequest( + agentContext, + { + holderDid, + credentialOffer, + credentialDefinition, + } + ) credentialRecord.metadata.set(CredentialMetadataKeys.IndyRequest, credentialRequestMetadata) credentialRecord.metadata.set(CredentialMetadataKeys.IndyCredential, { credentialDefinitionId: credentialOffer.cred_def_id, @@ -264,16 +266,14 @@ export class IndyCredentialFormatService extends CredentialFormatService { + public async processRequest(agentContext: AgentContext, options: FormatProcessOptions): Promise { // not needed for Indy } - public async acceptRequest({ - credentialRecord, - attachId, - offerAttachment, - requestAttachment, - }: FormatAcceptRequestOptions): Promise { + public async acceptRequest( + agentContext: AgentContext, + { credentialRecord, attachId, offerAttachment, requestAttachment }: FormatAcceptRequestOptions + ): Promise { // Assert credential attributes const credentialAttributes = credentialRecord.credentialAttributes if (!credentialAttributes) { @@ -290,7 +290,7 @@ export class IndyCredentialFormatService extends CredentialFormatService { + public async processCredential( + agentContext: AgentContext, + { credentialRecord, attachment }: FormatProcessOptions + ): Promise { const credentialRequestMetadata = credentialRecord.metadata.get(CredentialMetadataKeys.IndyRequest) if (!credentialRequestMetadata) { @@ -328,9 +331,12 @@ export class IndyCredentialFormatService extends CredentialFormatService() - const credentialDefinition = await this.indyLedgerService.getCredentialDefinition(indyCredential.cred_def_id) + const credentialDefinition = await this.indyLedgerService.getCredentialDefinition( + agentContext, + indyCredential.cred_def_id + ) const revocationRegistry = indyCredential.rev_reg_id - ? await this.indyLedgerService.getRevocationRegistryDefinition(indyCredential.rev_reg_id) + ? await this.indyLedgerService.getRevocationRegistryDefinition(agentContext, indyCredential.rev_reg_id) : null if (!credentialRecord.credentialAttributes) { @@ -343,7 +349,7 @@ export class IndyCredentialFormatService extends CredentialFormatService { - await this.indyHolderService.deleteCredential(credentialRecordId) + public async deleteCredentialById(agentContext: AgentContext, credentialRecordId: string): Promise { + await this.indyHolderService.deleteCredential(agentContext, credentialRecordId) } - public shouldAutoRespondToProposal({ offerAttachment, proposalAttachment }: FormatAutoRespondProposalOptions) { + public shouldAutoRespondToProposal( + agentContext: AgentContext, + { offerAttachment, proposalAttachment }: FormatAutoRespondProposalOptions + ) { const credentialProposalJson = proposalAttachment.getDataAsJson() const credentialProposal = JsonTransformer.fromJSON(credentialProposalJson, IndyCredPropose) @@ -406,7 +415,10 @@ export class IndyCredentialFormatService extends CredentialFormatService() const credentialRequestJson = requestAttachment.getDataAsJson() return credentialOfferJson.cred_def_id == credentialRequestJson.cred_def_id } - public shouldAutoRespondToCredential({ - credentialRecord, - requestAttachment, - credentialAttachment, - }: FormatAutoRespondCredentialOptions) { + public shouldAutoRespondToCredential( + agentContext: AgentContext, + { credentialRecord, requestAttachment, credentialAttachment }: FormatAutoRespondCredentialOptions + ) { const credentialJson = credentialAttachment.getDataAsJson() const credentialRequestJson = requestAttachment.getDataAsJson() @@ -444,33 +458,36 @@ export class IndyCredentialFormatService extends CredentialFormatService { + private async createIndyOffer( + agentContext: AgentContext, + { + credentialRecord, + attachId, + credentialDefinitionId, + attributes, + linkedAttachments, + }: { + credentialDefinitionId: string + credentialRecord: CredentialExchangeRecord + attachId?: string + attributes: CredentialPreviewAttributeOptions[] + linkedAttachments?: LinkedAttachment[] + } + ): Promise { // if the proposal has an attachment Id use that, otherwise the generated id of the formats object const format = new CredentialFormatSpec({ attachId: attachId, format: INDY_CRED_ABSTRACT, }) - const offer = await this.indyIssuerService.createCredentialOffer(credentialDefinitionId) + const offer = await this.indyIssuerService.createCredentialOffer(agentContext, credentialDefinitionId) const { previewAttributes } = this.getCredentialLinkedAttachments(attributes, linkedAttachments) if (!previewAttributes) { throw new AriesFrameworkError('Missing required preview attributes for indy offer') } - await this.assertPreviewAttributesMatchSchemaAttributes(offer, previewAttributes) + await this.assertPreviewAttributesMatchSchemaAttributes(agentContext, offer, previewAttributes) credentialRecord.metadata.set(CredentialMetadataKeys.IndyCredential, { schemaId: offer.schema_id, @@ -483,22 +500,23 @@ export class IndyCredentialFormatService extends CredentialFormatService { - const schema = await this.indyLedgerService.getSchema(offer.schema_id) + const schema = await this.indyLedgerService.getSchema(agentContext, offer.schema_id) IndyCredentialUtils.checkAttributesMatch(schema, attributes) } - private async getIndyHolderDid(credentialRecord: CredentialExchangeRecord) { + private async getIndyHolderDid(agentContext: AgentContext, credentialRecord: CredentialExchangeRecord) { // If we have a connection id we try to extract the did from the connection did document. if (credentialRecord.connectionId) { - const connection = await this.connectionService.getById(credentialRecord.connectionId) + const connection = await this.connectionService.getById(agentContext, credentialRecord.connectionId) if (!connection.did) { throw new AriesFrameworkError(`Connection record ${connection.id} has no 'did'`) } - const resolved = await this.didResolver.resolve(connection.did) + const resolved = await this.didResolver.resolve(agentContext, connection.did) if (resolved.didDocument) { const verificationMethod = await findVerificationMethodByKeyType( @@ -515,7 +533,7 @@ export class IndyCredentialFormatService extends CredentialFormatService({ + this.eventEmitter.emit(agentContext, { type: CredentialEventTypes.RevocationNotificationReceived, payload: { credentialRecord: clonedCredentialRecord, @@ -91,6 +93,7 @@ export class RevocationNotificationService { const connection = messageContext.assertReadyConnection() await this.processRevocationNotification( + messageContext.agentContext, indyRevocationRegistryId, indyCredentialRevocationId, connection, @@ -132,6 +135,7 @@ export class RevocationNotificationService { const comment = messageContext.message.comment const connection = messageContext.assertReadyConnection() await this.processRevocationNotification( + messageContext.agentContext, indyRevocationRegistryId, indyCredentialRevocationId, connection, diff --git a/packages/core/src/modules/credentials/protocol/revocation-notification/services/__tests__/RevocationNotificationService.test.ts b/packages/core/src/modules/credentials/protocol/revocation-notification/services/__tests__/RevocationNotificationService.test.ts index bea53b40e1..9222b3fdcf 100644 --- a/packages/core/src/modules/credentials/protocol/revocation-notification/services/__tests__/RevocationNotificationService.test.ts +++ b/packages/core/src/modules/credentials/protocol/revocation-notification/services/__tests__/RevocationNotificationService.test.ts @@ -1,7 +1,10 @@ +import type { AgentContext } from '../../../../../../agent' import type { RevocationNotificationReceivedEvent } from '../../../../CredentialEvents' +import { Subject } from 'rxjs' + import { CredentialExchangeRecord, CredentialState, InboundMessageContext } from '../../../../../..' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../../../../tests/helpers' +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../../../../tests/helpers' import { Dispatcher } from '../../../../../../agent/Dispatcher' import { EventEmitter } from '../../../../../../agent/EventEmitter' import { DidExchangeState } from '../../../../../connections' @@ -25,6 +28,7 @@ const connection = getMockConnection({ describe('RevocationNotificationService', () => { let revocationNotificationService: RevocationNotificationService + let agentContext: AgentContext let eventEmitter: EventEmitter beforeEach(() => { @@ -32,12 +36,14 @@ describe('RevocationNotificationService', () => { indyLedgers: [], }) - eventEmitter = new EventEmitter(agentConfig) + agentContext = getAgentContext() + + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) revocationNotificationService = new RevocationNotificationService( credentialRepository, eventEmitter, - agentConfig, - dispatcher + dispatcher, + agentConfig.logger ) }) @@ -82,6 +88,7 @@ describe('RevocationNotificationService', () => { }) const messageContext = new InboundMessageContext(revocationNotificationMessage, { connection, + agentContext, }) await revocationNotificationService.v1ProcessRevocationNotification(messageContext) @@ -123,7 +130,7 @@ describe('RevocationNotificationService', () => { issueThread: revocationNotificationThreadId, comment: 'Credential has been revoked', }) - const messageContext = new InboundMessageContext(revocationNotificationMessage, { connection }) + const messageContext = new InboundMessageContext(revocationNotificationMessage, { connection, agentContext }) await revocationNotificationService.v1ProcessRevocationNotification(messageContext) @@ -143,7 +150,7 @@ describe('RevocationNotificationService', () => { issueThread: revocationNotificationThreadId, comment: 'Credential has been revoked', }) - const messageContext = new InboundMessageContext(revocationNotificationMessage) + const messageContext = new InboundMessageContext(revocationNotificationMessage, { agentContext }) await revocationNotificationService.v1ProcessRevocationNotification(messageContext) @@ -187,9 +194,7 @@ describe('RevocationNotificationService', () => { revocationFormat: 'indy-anoncreds', comment: 'Credential has been revoked', }) - const messageContext = new InboundMessageContext(revocationNotificationMessage, { - connection, - }) + const messageContext = new InboundMessageContext(revocationNotificationMessage, { agentContext, connection }) await revocationNotificationService.v2ProcessRevocationNotification(messageContext) @@ -231,7 +236,7 @@ describe('RevocationNotificationService', () => { revocationFormat: 'indy-anoncreds', comment: 'Credential has been revoked', }) - const messageContext = new InboundMessageContext(revocationNotificationMessage, { connection }) + const messageContext = new InboundMessageContext(revocationNotificationMessage, { connection, agentContext }) await revocationNotificationService.v2ProcessRevocationNotification(messageContext) @@ -252,7 +257,7 @@ describe('RevocationNotificationService', () => { revocationFormat: 'indy-anoncreds', comment: 'Credential has been revoked', }) - const messageContext = new InboundMessageContext(revocationNotificationMessage) + const messageContext = new InboundMessageContext(revocationNotificationMessage, { agentContext }) await revocationNotificationService.v2ProcessRevocationNotification(messageContext) diff --git a/packages/core/src/modules/credentials/protocol/v1/V1CredentialService.ts b/packages/core/src/modules/credentials/protocol/v1/V1CredentialService.ts index 1d52e4b197..938dbf7dd6 100644 --- a/packages/core/src/modules/credentials/protocol/v1/V1CredentialService.ts +++ b/packages/core/src/modules/credentials/protocol/v1/V1CredentialService.ts @@ -1,5 +1,5 @@ +import type { AgentContext } from '../../../../agent' import type { AgentMessage } from '../../../../agent/AgentMessage' -import type { HandlerInboundMessage } from '../../../../agent/Handler' import type { InboundMessageContext } from '../../../../agent/models/InboundMessageContext' import type { ProblemReportMessage } from '../../../problem-reports' import type { @@ -18,12 +18,13 @@ import type { GetFormatDataReturn } from '../../CredentialsModuleOptions' import type { CredentialFormat } from '../../formats' import type { IndyCredentialFormat } from '../../formats/indy/IndyCredentialFormat' -import { AgentConfig } from '../../../../agent/AgentConfig' import { Dispatcher } from '../../../../agent/Dispatcher' import { EventEmitter } from '../../../../agent/EventEmitter' +import { InjectionSymbols } from '../../../../constants' import { Attachment, AttachmentData } from '../../../../decorators/attachment/Attachment' import { AriesFrameworkError } from '../../../../error' -import { injectable } from '../../../../plugins' +import { Logger } from '../../../../logger' +import { inject, injectable } from '../../../../plugins' import { DidCommMessageRepository, DidCommMessageRole } from '../../../../storage' import { JsonTransformer } from '../../../../utils' import { isLinkedAttachment } from '../../../../utils/attachment' @@ -36,7 +37,7 @@ import { IndyCredentialFormatService } from '../../formats/indy/IndyCredentialFo import { IndyCredPropose } from '../../formats/indy/models' import { AutoAcceptCredential } from '../../models/CredentialAutoAcceptType' import { CredentialState } from '../../models/CredentialState' -import { CredentialRepository, CredentialExchangeRecord } from '../../repository' +import { CredentialExchangeRecord, CredentialRepository } from '../../repository' import { CredentialService } from '../../services' import { composeAutoAccept } from '../../util/composeAutoAccept' import { arePreviewAttributesEqual } from '../../util/previewAttributes' @@ -71,17 +72,16 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat public constructor( connectionService: ConnectionService, didCommMessageRepository: DidCommMessageRepository, - agentConfig: AgentConfig, + @inject(InjectionSymbols.Logger) logger: Logger, routingService: RoutingService, dispatcher: Dispatcher, eventEmitter: EventEmitter, credentialRepository: CredentialRepository, formatService: IndyCredentialFormatService ) { - super(credentialRepository, didCommMessageRepository, eventEmitter, dispatcher, agentConfig) + super(credentialRepository, didCommMessageRepository, eventEmitter, dispatcher, logger) this.connectionService = connectionService this.formatService = formatService - this.didCommMessageRepository = didCommMessageRepository this.routingService = routingService this.registerHandlers() @@ -110,12 +110,10 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Object containing proposal message and associated credential record * */ - public async createProposal({ - connection, - credentialFormats, - comment, - autoAcceptCredential, - }: CreateProposalOptions<[IndyCredentialFormat]>): Promise> { + public async createProposal( + agentContext: AgentContext, + { connection, credentialFormats, comment, autoAcceptCredential }: CreateProposalOptions<[IndyCredentialFormat]> + ): Promise> { this.assertOnlyIndyFormat(credentialFormats) if (!credentialFormats.indy) { @@ -137,7 +135,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat }) // call create proposal for validation of the proposal and addition of linked attachments - const { previewAttributes, attachment } = await this.formatService.createProposal({ + const { previewAttributes, attachment } = await this.formatService.createProposal(agentContext, { credentialFormats, credentialRecord, }) @@ -159,15 +157,15 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat comment, }) - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, }) credentialRecord.credentialAttributes = previewAttributes - await this.credentialRepository.save(credentialRecord) - this.emitStateChangedEvent(credentialRecord, null) + await this.credentialRepository.save(agentContext, credentialRecord) + this.emitStateChangedEvent(agentContext, credentialRecord, null) return { credentialRecord, message } } @@ -189,7 +187,11 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat this.logger.debug(`Processing credential proposal with message id ${proposalMessage.id}`) - let credentialRecord = await this.findByThreadAndConnectionId(proposalMessage.threadId, connection?.id) + let credentialRecord = await this.findByThreadAndConnectionId( + messageContext.agentContext, + proposalMessage.threadId, + connection?.id + ) // Credential record already exists, this is a response to an earlier message sent by us if (credentialRecord) { @@ -199,11 +201,14 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.OfferSent) - const proposalCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ - associatedRecordId: credentialRecord.id, - messageClass: V1ProposeCredentialMessage, - }) - const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ + const proposalCredentialMessage = await this.didCommMessageRepository.findAgentMessage( + messageContext.agentContext, + { + associatedRecordId: credentialRecord.id, + messageClass: V1ProposeCredentialMessage, + } + ) + const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) @@ -213,7 +218,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat previousSentMessage: offerCredentialMessage ?? undefined, }) - await this.formatService.processProposal({ + await this.formatService.processProposal(messageContext.agentContext, { credentialRecord, attachment: new Attachment({ data: new AttachmentData({ @@ -223,8 +228,8 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat }) // Update record - await this.updateState(credentialRecord, CredentialState.ProposalReceived) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.updateState(messageContext.agentContext, credentialRecord, CredentialState.ProposalReceived) + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: proposalMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, @@ -244,10 +249,10 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat this.connectionService.assertConnectionOrServiceDecorator(messageContext) // Save record - await this.credentialRepository.save(credentialRecord) - this.emitStateChangedEvent(credentialRecord, null) + await this.credentialRepository.save(messageContext.agentContext, credentialRecord) + this.emitStateChangedEvent(messageContext.agentContext, credentialRecord, null) - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(messageContext.agentContext, { agentMessage: proposalMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, @@ -261,20 +266,21 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @param options The object containing config options * @returns Object containing proposal message and associated credential record */ - public async acceptProposal({ - credentialRecord, - credentialFormats, - comment, - autoAcceptCredential, - }: AcceptProposalOptions<[IndyCredentialFormat]>): Promise< - CredentialProtocolMsgReturnType - > { + public async acceptProposal( + agentContext: AgentContext, + { + credentialRecord, + credentialFormats, + comment, + autoAcceptCredential, + }: AcceptProposalOptions<[IndyCredentialFormat]> + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.ProposalReceived) if (credentialFormats) this.assertOnlyIndyFormat(credentialFormats) - const proposalMessage = await this.didCommMessageRepository.getAgentMessage({ + const proposalMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1ProposeCredentialMessage, }) @@ -284,7 +290,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat // if the user provided other attributes in the credentialFormats array. credentialRecord.credentialAttributes = proposalMessage.credentialPreview?.attributes - const { attachment, previewAttributes } = await this.formatService.acceptProposal({ + const { attachment, previewAttributes } = await this.formatService.acceptProposal(agentContext, { attachId: INDY_CREDENTIAL_OFFER_ATTACHMENT_ID, credentialFormats, credentialRecord, @@ -312,9 +318,9 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat credentialRecord.credentialAttributes = previewAttributes credentialRecord.autoAcceptCredential = autoAcceptCredential ?? credentialRecord.autoAcceptCredential - await this.updateState(credentialRecord, CredentialState.OfferSent) + await this.updateState(agentContext, credentialRecord, CredentialState.OfferSent) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -331,20 +337,21 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Credential record associated with the credential offer and the corresponding new offer message * */ - public async negotiateProposal({ - credentialFormats, - credentialRecord, - comment, - autoAcceptCredential, - }: NegotiateProposalOptions<[IndyCredentialFormat]>): Promise< - CredentialProtocolMsgReturnType - > { + public async negotiateProposal( + agentContext: AgentContext, + { + credentialFormats, + credentialRecord, + comment, + autoAcceptCredential, + }: NegotiateProposalOptions<[IndyCredentialFormat]> + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.ProposalReceived) if (credentialFormats) this.assertOnlyIndyFormat(credentialFormats) - const { attachment, previewAttributes } = await this.formatService.createOffer({ + const { attachment, previewAttributes } = await this.formatService.createOffer(agentContext, { attachId: INDY_CREDENTIAL_OFFER_ATTACHMENT_ID, credentialFormats, credentialRecord, @@ -366,9 +373,9 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat credentialRecord.credentialAttributes = previewAttributes credentialRecord.autoAcceptCredential = autoAcceptCredential ?? credentialRecord.autoAcceptCredential - await this.updateState(credentialRecord, CredentialState.OfferSent) + await this.updateState(agentContext, credentialRecord, CredentialState.OfferSent) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -385,12 +392,10 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Object containing offer message and associated credential record * */ - public async createOffer({ - credentialFormats, - autoAcceptCredential, - comment, - connection, - }: CreateOfferOptions<[IndyCredentialFormat]>): Promise> { + public async createOffer( + agentContext: AgentContext, + { credentialFormats, autoAcceptCredential, comment, connection }: CreateOfferOptions<[IndyCredentialFormat]> + ): Promise> { // Assert if (credentialFormats) this.assertOnlyIndyFormat(credentialFormats) @@ -410,7 +415,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat protocolVersion: 'v1', }) - const { attachment, previewAttributes } = await this.formatService.createOffer({ + const { attachment, previewAttributes } = await this.formatService.createOffer(agentContext, { attachId: INDY_CREDENTIAL_OFFER_ATTACHMENT_ID, credentialFormats, credentialRecord, @@ -431,15 +436,15 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat attachments: credentialFormats.indy.linkedAttachments?.map((linkedAttachments) => linkedAttachments.attachment), }) - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, agentMessage: message, role: DidCommMessageRole.Sender, }) credentialRecord.credentialAttributes = previewAttributes - await this.credentialRepository.save(credentialRecord) - this.emitStateChangedEvent(credentialRecord, null) + await this.credentialRepository.save(agentContext, credentialRecord) + this.emitStateChangedEvent(agentContext, credentialRecord, null) return { message, credentialRecord } } @@ -455,13 +460,17 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * */ public async processOffer( - messageContext: HandlerInboundMessage + messageContext: InboundMessageContext ): Promise { const { message: offerMessage, connection } = messageContext this.logger.debug(`Processing credential offer with id ${offerMessage.id}`) - let credentialRecord = await this.findByThreadAndConnectionId(offerMessage.threadId, connection?.id) + let credentialRecord = await this.findByThreadAndConnectionId( + messageContext.agentContext, + offerMessage.threadId, + connection?.id + ) const offerAttachment = offerMessage.getOfferAttachmentById(INDY_CREDENTIAL_OFFER_ATTACHMENT_ID) if (!offerAttachment) { @@ -471,11 +480,14 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat } if (credentialRecord) { - const proposalCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ - associatedRecordId: credentialRecord.id, - messageClass: V1ProposeCredentialMessage, - }) - const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ + const proposalCredentialMessage = await this.didCommMessageRepository.findAgentMessage( + messageContext.agentContext, + { + associatedRecordId: credentialRecord.id, + messageClass: V1ProposeCredentialMessage, + } + ) + const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) @@ -488,17 +500,17 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat previousSentMessage: proposalCredentialMessage ?? undefined, }) - await this.formatService.processOffer({ + await this.formatService.processOffer(messageContext.agentContext, { credentialRecord, attachment: offerAttachment, }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: offerMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) - await this.updateState(credentialRecord, CredentialState.OfferReceived) + await this.updateState(messageContext.agentContext, credentialRecord, CredentialState.OfferReceived) return credentialRecord } else { @@ -513,19 +525,19 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat // Assert this.connectionService.assertConnectionOrServiceDecorator(messageContext) - await this.formatService.processOffer({ + await this.formatService.processOffer(messageContext.agentContext, { credentialRecord, attachment: offerAttachment, }) // Save in repository - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(messageContext.agentContext, { agentMessage: offerMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) - await this.credentialRepository.save(credentialRecord) - this.emitStateChangedEvent(credentialRecord, null) + await this.credentialRepository.save(messageContext.agentContext, credentialRecord) + this.emitStateChangedEvent(messageContext.agentContext, credentialRecord, null) return credentialRecord } @@ -538,17 +550,15 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Object containing request message and associated credential record * */ - public async acceptOffer({ - credentialRecord, - credentialFormats, - comment, - autoAcceptCredential, - }: AcceptOfferOptions<[IndyCredentialFormat]>): Promise> { + public async acceptOffer( + agentContext: AgentContext, + { credentialRecord, credentialFormats, comment, autoAcceptCredential }: AcceptOfferOptions<[IndyCredentialFormat]> + ): Promise> { // Assert credential credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.OfferReceived) - const offerMessage = await this.didCommMessageRepository.getAgentMessage({ + const offerMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) @@ -560,7 +570,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat ) } - const { attachment } = await this.formatService.acceptOffer({ + const { attachment } = await this.formatService.acceptOffer(agentContext, { credentialRecord, credentialFormats, attachId: INDY_CREDENTIAL_REQUEST_ATTACHMENT_ID, @@ -580,12 +590,12 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat isLinkedAttachment(attachment) ) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: requestMessage, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, }) - await this.updateState(credentialRecord, CredentialState.RequestSent) + await this.updateState(agentContext, credentialRecord, CredentialState.RequestSent) return { message: requestMessage, credentialRecord } } @@ -600,12 +610,15 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns credential record associated with the credential request message * */ - public async negotiateOffer({ - credentialFormats, - credentialRecord, - autoAcceptCredential, - comment, - }: NegotiateOfferOptions<[IndyCredentialFormat]>): Promise> { + public async negotiateOffer( + agentContext: AgentContext, + { + credentialFormats, + credentialRecord, + autoAcceptCredential, + comment, + }: NegotiateOfferOptions<[IndyCredentialFormat]> + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.OfferReceived) @@ -624,7 +637,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat // call create proposal for validation of the proposal and addition of linked attachments // As the format is different for v1 of the issue credential protocol we won't be using the attachment - const { previewAttributes, attachment } = await this.formatService.createProposal({ + const { previewAttributes, attachment } = await this.formatService.createProposal(agentContext, { credentialFormats, credentialRecord, }) @@ -647,7 +660,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -657,7 +670,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat credentialRecord.credentialAttributes = previewAttributes credentialRecord.linkedAttachments = linkedAttachments?.map((attachment) => attachment.attachment) credentialRecord.autoAcceptCredential = autoAcceptCredential ?? credentialRecord.autoAcceptCredential - await this.updateState(credentialRecord, CredentialState.ProposalSent) + await this.updateState(agentContext, credentialRecord, CredentialState.ProposalSent) return { credentialRecord, message } } @@ -688,14 +701,18 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat this.logger.debug(`Processing credential request with id ${requestMessage.id}`) - const credentialRecord = await this.getByThreadAndConnectionId(requestMessage.threadId, connection?.id) + const credentialRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + requestMessage.threadId, + connection?.id + ) this.logger.trace('Credential record found when processing credential request', credentialRecord) - const proposalMessage = await this.didCommMessageRepository.findAgentMessage({ + const proposalMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1ProposeCredentialMessage, }) - const offerMessage = await this.didCommMessageRepository.findAgentMessage({ + const offerMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) @@ -716,18 +733,18 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat ) } - await this.formatService.processRequest({ + await this.formatService.processRequest(messageContext.agentContext, { credentialRecord, attachment: requestAttachment, }) - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(messageContext.agentContext, { agentMessage: requestMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) - await this.updateState(credentialRecord, CredentialState.RequestReceived) + await this.updateState(messageContext.agentContext, credentialRecord, CredentialState.RequestReceived) return credentialRecord } @@ -738,21 +755,19 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Object containing issue credential message and associated credential record * */ - public async acceptRequest({ - credentialRecord, - credentialFormats, - comment, - autoAcceptCredential, - }: AcceptRequestOptions<[IndyCredentialFormat]>): Promise> { + public async acceptRequest( + agentContext: AgentContext, + { credentialRecord, credentialFormats, comment, autoAcceptCredential }: AcceptRequestOptions<[IndyCredentialFormat]> + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.RequestReceived) - const offerMessage = await this.didCommMessageRepository.getAgentMessage({ + const offerMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) - const requestMessage = await this.didCommMessageRepository.getAgentMessage({ + const requestMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1RequestCredentialMessage, }) @@ -766,7 +781,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat ) } - const { attachment: credentialsAttach } = await this.formatService.acceptRequest({ + const { attachment: credentialsAttach } = await this.formatService.acceptRequest(agentContext, { credentialRecord, requestAttachment, offerAttachment, @@ -783,14 +798,14 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat issueMessage.setThread({ threadId: credentialRecord.threadId }) issueMessage.setPleaseAck() - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(agentContext, { agentMessage: issueMessage, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, }) credentialRecord.autoAcceptCredential = autoAcceptCredential ?? credentialRecord.autoAcceptCredential - await this.updateState(credentialRecord, CredentialState.CredentialIssued) + await this.updateState(agentContext, credentialRecord, CredentialState.CredentialIssued) return { message: issueMessage, credentialRecord } } @@ -809,13 +824,17 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat this.logger.debug(`Processing credential with id ${issueMessage.id}`) - const credentialRecord = await this.getByThreadAndConnectionId(issueMessage.threadId, connection?.id) + const credentialRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + issueMessage.threadId, + connection?.id + ) - const requestCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ + const requestCredentialMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1RequestCredentialMessage, }) - const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage({ + const offerCredentialMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1OfferCredentialMessage, }) @@ -833,18 +852,18 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat throw new AriesFrameworkError('Missing indy credential attachment in processCredential') } - await this.formatService.processCredential({ + await this.formatService.processCredential(messageContext.agentContext, { attachment: issueAttachment, credentialRecord, }) - await this.didCommMessageRepository.saveAgentMessage({ + await this.didCommMessageRepository.saveAgentMessage(messageContext.agentContext, { agentMessage: issueMessage, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) - await this.updateState(credentialRecord, CredentialState.CredentialReceived) + await this.updateState(messageContext.agentContext, credentialRecord, CredentialState.CredentialReceived) return credentialRecord } @@ -856,9 +875,10 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns Object containing credential acknowledgement message and associated credential record * */ - public async acceptCredential({ - credentialRecord, - }: AcceptCredentialOptions): Promise> { + public async acceptCredential( + agentContext: AgentContext, + { credentialRecord }: AcceptCredentialOptions + ): Promise> { credentialRecord.assertProtocolVersion('v1') credentialRecord.assertState(CredentialState.CredentialReceived) @@ -868,7 +888,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat threadId: credentialRecord.threadId, }) - await this.updateState(credentialRecord, CredentialState.Done) + await this.updateState(agentContext, credentialRecord, CredentialState.Done) return { message: ackMessage, credentialRecord } } @@ -887,13 +907,17 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat this.logger.debug(`Processing credential ack with id ${ackMessage.id}`) - const credentialRecord = await this.getByThreadAndConnectionId(ackMessage.threadId, connection?.id) + const credentialRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + ackMessage.threadId, + connection?.id + ) - const requestCredentialMessage = await this.didCommMessageRepository.getAgentMessage({ + const requestCredentialMessage = await this.didCommMessageRepository.getAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1RequestCredentialMessage, }) - const issueCredentialMessage = await this.didCommMessageRepository.getAgentMessage({ + const issueCredentialMessage = await this.didCommMessageRepository.getAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1IssueCredentialMessage, }) @@ -907,7 +931,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat }) // Update record - await this.updateState(credentialRecord, CredentialState.Done) + await this.updateState(messageContext.agentContext, credentialRecord, CredentialState.Done) return credentialRecord } @@ -919,7 +943,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat * @returns a {@link V1CredentialProblemReportMessage} * */ - public createProblemReport(options: CreateProblemReportOptions): ProblemReportMessage { + public createProblemReport(agentContext: AgentContext, options: CreateProblemReportOptions): ProblemReportMessage { return new V1CredentialProblemReportMessage({ description: { en: options.message, @@ -929,18 +953,24 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat } // AUTO RESPOND METHODS - public async shouldAutoRespondToProposal(options: { - credentialRecord: CredentialExchangeRecord - proposalMessage: V1ProposeCredentialMessage - }): Promise { + public async shouldAutoRespondToProposal( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + proposalMessage: V1ProposeCredentialMessage + } + ): Promise { const { credentialRecord, proposalMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) // Do not auto accept if missing properties if (!offerMessage || !offerMessage.credentialPreview) return false @@ -959,18 +989,24 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat ) } - public async shouldAutoRespondToOffer(options: { - credentialRecord: CredentialExchangeRecord - offerMessage: V1OfferCredentialMessage - }) { + public async shouldAutoRespondToOffer( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + offerMessage: V1OfferCredentialMessage + } + ) { const { credentialRecord, offerMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const proposalMessage = await this.findProposalMessage(credentialRecord.id) + const proposalMessage = await this.findProposalMessage(agentContext, credentialRecord.id) // Do not auto accept if missing properties if (!offerMessage.credentialPreview) return false @@ -989,18 +1025,24 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat ) } - public async shouldAutoRespondToRequest(options: { - credentialRecord: CredentialExchangeRecord - requestMessage: V1RequestCredentialMessage - }) { + public async shouldAutoRespondToRequest( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + requestMessage: V1RequestCredentialMessage + } + ) { const { credentialRecord, requestMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) if (!offerMessage) return false const offerAttachment = offerMessage.getOfferAttachmentById(INDY_CREDENTIAL_OFFER_ATTACHMENT_ID) @@ -1008,26 +1050,32 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat if (!offerAttachment || !requestAttachment) return false - return this.formatService.shouldAutoRespondToRequest({ + return this.formatService.shouldAutoRespondToRequest(agentContext, { credentialRecord, offerAttachment, requestAttachment, }) } - public async shouldAutoRespondToCredential(options: { - credentialRecord: CredentialExchangeRecord - credentialMessage: V1IssueCredentialMessage - }) { + public async shouldAutoRespondToCredential( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + credentialMessage: V1IssueCredentialMessage + } + ) { const { credentialRecord, credentialMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const requestMessage = await this.findRequestMessage(credentialRecord.id) - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const requestMessage = await this.findRequestMessage(agentContext, credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) const credentialAttachment = credentialMessage.getCredentialAttachmentById(INDY_CREDENTIAL_ATTACHMENT_ID) if (!credentialAttachment) return false @@ -1037,7 +1085,7 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat const offerAttachment = offerMessage?.getOfferAttachmentById(INDY_CREDENTIAL_OFFER_ATTACHMENT_ID) - return this.formatService.shouldAutoRespondToCredential({ + return this.formatService.shouldAutoRespondToCredential(agentContext, { credentialRecord, credentialAttachment, requestAttachment, @@ -1045,41 +1093,44 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat }) } - public async findProposalMessage(credentialExchangeId: string) { - return await this.didCommMessageRepository.findAgentMessage({ + public async findProposalMessage(agentContext: AgentContext, credentialExchangeId: string) { + return await this.didCommMessageRepository.findAgentMessage(agentContext, { associatedRecordId: credentialExchangeId, messageClass: V1ProposeCredentialMessage, }) } - public async findOfferMessage(credentialExchangeId: string) { - return await this.didCommMessageRepository.findAgentMessage({ + public async findOfferMessage(agentContext: AgentContext, credentialExchangeId: string) { + return await this.didCommMessageRepository.findAgentMessage(agentContext, { associatedRecordId: credentialExchangeId, messageClass: V1OfferCredentialMessage, }) } - public async findRequestMessage(credentialExchangeId: string) { - return await this.didCommMessageRepository.findAgentMessage({ + public async findRequestMessage(agentContext: AgentContext, credentialExchangeId: string) { + return await this.didCommMessageRepository.findAgentMessage(agentContext, { associatedRecordId: credentialExchangeId, messageClass: V1RequestCredentialMessage, }) } - public async findCredentialMessage(credentialExchangeId: string) { - return await this.didCommMessageRepository.findAgentMessage({ + public async findCredentialMessage(agentContext: AgentContext, credentialExchangeId: string) { + return await this.didCommMessageRepository.findAgentMessage(agentContext, { associatedRecordId: credentialExchangeId, messageClass: V1IssueCredentialMessage, }) } - public async getFormatData(credentialExchangeId: string): Promise> { + public async getFormatData( + agentContext: AgentContext, + credentialExchangeId: string + ): Promise> { // TODO: we could looking at fetching all record using a single query and then filtering based on the type of the message. const [proposalMessage, offerMessage, requestMessage, credentialMessage] = await Promise.all([ - this.findProposalMessage(credentialExchangeId), - this.findOfferMessage(credentialExchangeId), - this.findRequestMessage(credentialExchangeId), - this.findCredentialMessage(credentialExchangeId), + this.findProposalMessage(agentContext, credentialExchangeId), + this.findOfferMessage(agentContext, credentialExchangeId), + this.findRequestMessage(agentContext, credentialExchangeId), + this.findCredentialMessage(agentContext, credentialExchangeId), ]) const indyProposal = proposalMessage @@ -1117,14 +1168,12 @@ export class V1CredentialService extends CredentialService<[IndyCredentialFormat } protected registerHandlers() { - this.dispatcher.registerHandler(new V1ProposeCredentialHandler(this, this.agentConfig)) - this.dispatcher.registerHandler( - new V1OfferCredentialHandler(this, this.agentConfig, this.routingService, this.didCommMessageRepository) - ) + this.dispatcher.registerHandler(new V1ProposeCredentialHandler(this, this.logger)) this.dispatcher.registerHandler( - new V1RequestCredentialHandler(this, this.agentConfig, this.didCommMessageRepository) + new V1OfferCredentialHandler(this, this.routingService, this.didCommMessageRepository, this.logger) ) - this.dispatcher.registerHandler(new V1IssueCredentialHandler(this, this.agentConfig, this.didCommMessageRepository)) + this.dispatcher.registerHandler(new V1RequestCredentialHandler(this, this.didCommMessageRepository, this.logger)) + this.dispatcher.registerHandler(new V1IssueCredentialHandler(this, this.didCommMessageRepository, this.logger)) this.dispatcher.registerHandler(new V1CredentialAckHandler(this)) this.dispatcher.registerHandler(new V1CredentialProblemReportHandler(this)) } diff --git a/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceCred.test.ts b/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceCred.test.ts index e0188dd352..7f2760a501 100644 --- a/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceCred.test.ts +++ b/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceCred.test.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../../agent' import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { GetAgentMessageOptions } from '../../../../../storage/didcomm/DidCommMessageRepository' import type { CredentialStateChangedEvent } from '../../../CredentialEvents' @@ -5,7 +6,9 @@ import type { IndyCredentialViewMetadata } from '../../../formats/indy/models' import type { CredentialPreviewAttribute } from '../../../models' import type { CustomCredentialTags } from '../../../repository/CredentialExchangeRecord' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../../../tests/helpers' import { Dispatcher } from '../../../../../agent/Dispatcher' import { EventEmitter } from '../../../../../agent/EventEmitter' import { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' @@ -35,12 +38,12 @@ import { INDY_CREDENTIAL_OFFER_ATTACHMENT_ID, INDY_CREDENTIAL_REQUEST_ATTACHMENT_ID, V1CredentialAckMessage, - V1CredentialPreview, V1CredentialProblemReportMessage, V1IssueCredentialMessage, V1OfferCredentialMessage, V1ProposeCredentialMessage, V1RequestCredentialMessage, + V1CredentialPreview, } from '../messages' // Mock classes @@ -132,7 +135,7 @@ const didCommMessageRecord = new DidCommMessageRecord({ }) // eslint-disable-next-line @typescript-eslint/no-explicit-any -const getAgentMessageMock = async (options: GetAgentMessageOptions) => { +const getAgentMessageMock = async (agentContext: AgentContext, options: GetAgentMessageOptions) => { if (options.messageClass === V1ProposeCredentialMessage) { return credentialProposalMessage } @@ -218,12 +221,14 @@ const mockCredentialRecord = ({ describe('V1CredentialService', () => { let eventEmitter: EventEmitter let agentConfig: AgentConfig + let agentContext: AgentContext let credentialService: V1CredentialService beforeEach(async () => { // real objects agentConfig = getAgentConfig('V1CredentialServiceCredTest') - eventEmitter = new EventEmitter(agentConfig) + agentContext = getAgentContext() + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) // mock function implementations mockFunction(connectionService.getById).mockResolvedValue(connection) @@ -238,7 +243,7 @@ describe('V1CredentialService', () => { credentialService = new V1CredentialService( connectionService, didCommMessageRepository, - agentConfig, + agentConfig.logger, routingService, dispatcher, eventEmitter, @@ -276,7 +281,7 @@ describe('V1CredentialService', () => { }) // when - const { message } = await credentialService.acceptOffer({ + const { message } = await credentialService.acceptOffer(agentContext, { comment: 'hello', autoAcceptCredential: AutoAcceptCredential.Never, credentialRecord, @@ -299,7 +304,7 @@ describe('V1CredentialService', () => { 'requests~attach': [JsonTransformer.toJSON(requestAttachment)], }) expect(credentialRepository.update).toHaveBeenCalledTimes(1) - expect(indyCredentialFormatService.acceptOffer).toHaveBeenCalledWith({ + expect(indyCredentialFormatService.acceptOffer).toHaveBeenCalledWith(agentContext, { credentialRecord, attachId: INDY_CREDENTIAL_REQUEST_ATTACHMENT_ID, offerAttachment, @@ -309,7 +314,7 @@ describe('V1CredentialService', () => { }, }, }) - expect(didCommMessageRepository.saveOrUpdateAgentMessage).toHaveBeenCalledWith({ + expect(didCommMessageRepository.saveOrUpdateAgentMessage).toHaveBeenCalledWith(agentContext, { agentMessage: message, associatedRecordId: '84353745-8bd9-42e1-8d81-238ca77c29d2', role: DidCommMessageRole.Sender, @@ -333,12 +338,12 @@ describe('V1CredentialService', () => { }) // when - await credentialService.acceptOffer({ + await credentialService.acceptOffer(agentContext, { credentialRecord, }) // then - expect(updateStateSpy).toHaveBeenCalledWith(credentialRecord, CredentialState.RequestSent) + expect(updateStateSpy).toHaveBeenCalledWith(agentContext, credentialRecord, CredentialState.RequestSent) }) const validState = CredentialState.OfferReceived @@ -347,7 +352,7 @@ describe('V1CredentialService', () => { await Promise.all( invalidCredentialStates.map(async (state) => { await expect( - credentialService.acceptOffer({ credentialRecord: mockCredentialRecord({ state }) }) + credentialService.acceptOffer(agentContext, { credentialRecord: mockCredentialRecord({ state }) }) ).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`) }) ) @@ -366,6 +371,7 @@ describe('V1CredentialService', () => { }) credentialRequest.setThread({ threadId: 'somethreadid' }) messageContext = new InboundMessageContext(credentialRequest, { + agentContext, connection, }) }) @@ -380,7 +386,7 @@ describe('V1CredentialService', () => { const returnedCredentialRecord = await credentialService.processRequest(messageContext) // then - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) @@ -397,7 +403,7 @@ describe('V1CredentialService', () => { const returnedCredentialRecord = await credentialService.processRequest(messageContext) // then - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) @@ -438,10 +444,11 @@ describe('V1CredentialService', () => { }) // when - await credentialService.acceptRequest({ credentialRecord }) + await credentialService.acceptRequest(agentContext, { credentialRecord }) // then expect(credentialRepository.update).toHaveBeenCalledWith( + agentContext, expect.objectContaining({ state: CredentialState.CredentialIssued, }) @@ -470,7 +477,7 @@ describe('V1CredentialService', () => { eventEmitter.on(CredentialEventTypes.CredentialStateChanged, eventListenerMock) // when - await credentialService.acceptRequest({ credentialRecord }) + await credentialService.acceptRequest(agentContext, { credentialRecord }) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -504,7 +511,7 @@ describe('V1CredentialService', () => { }) // when - const { message } = await credentialService.acceptRequest({ credentialRecord, comment }) + const { message } = await credentialService.acceptRequest(agentContext, { credentialRecord, comment }) // then expect(message.toJSON()).toMatchObject({ @@ -518,7 +525,7 @@ describe('V1CredentialService', () => { '~please_ack': expect.any(Object), }) - expect(indyCredentialFormatService.acceptRequest).toHaveBeenCalledWith({ + expect(indyCredentialFormatService.acceptRequest).toHaveBeenCalledWith(agentContext, { credentialRecord, requestAttachment, offerAttachment, @@ -538,9 +545,7 @@ describe('V1CredentialService', () => { credentialAttachments: [credentialAttachment], }) credentialResponse.setThread({ threadId: 'somethreadid' }) - const messageContext = new InboundMessageContext(credentialResponse, { - connection, - }) + const messageContext = new InboundMessageContext(credentialResponse, { agentContext, connection }) mockFunction(credentialRepository.getSingleByQuery).mockResolvedValue(credentialRecord) @@ -548,18 +553,18 @@ describe('V1CredentialService', () => { await credentialService.processCredential(messageContext) // then - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) - expect(didCommMessageRepository.saveAgentMessage).toHaveBeenCalledWith({ + expect(didCommMessageRepository.saveAgentMessage).toHaveBeenCalledWith(agentContext, { agentMessage: credentialResponse, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) - expect(indyCredentialFormatService.processCredential).toHaveBeenNthCalledWith(1, { + expect(indyCredentialFormatService.processCredential).toHaveBeenNthCalledWith(1, agentContext, { attachment: credentialAttachment, credentialRecord, }) @@ -583,11 +588,11 @@ describe('V1CredentialService', () => { const repositoryUpdateSpy = jest.spyOn(credentialRepository, 'update') // when - await credentialService.acceptCredential({ credentialRecord: credential }) + await credentialService.acceptCredential(agentContext, { credentialRecord: credential }) // then expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1) - const [[updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls + const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls expect(updatedCredentialRecord).toMatchObject({ state: CredentialState.Done, }) @@ -598,7 +603,7 @@ describe('V1CredentialService', () => { eventEmitter.on(CredentialEventTypes.CredentialStateChanged, eventListenerMock) // when - await credentialService.acceptCredential({ credentialRecord: credential }) + await credentialService.acceptCredential(agentContext, { credentialRecord: credential }) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -617,7 +622,9 @@ describe('V1CredentialService', () => { mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(credential)) // when - const { message: ackMessage } = await credentialService.acceptCredential({ credentialRecord: credential }) + const { message: ackMessage } = await credentialService.acceptCredential(agentContext, { + credentialRecord: credential, + }) // then expect(ackMessage.toJSON()).toMatchObject({ @@ -635,7 +642,7 @@ describe('V1CredentialService', () => { await Promise.all( invalidCredentialStates.map(async (state) => { await expect( - credentialService.acceptCredential({ + credentialService.acceptCredential(agentContext, { credentialRecord: mockCredentialRecord({ state, threadId, @@ -661,9 +668,7 @@ describe('V1CredentialService', () => { status: AckStatus.OK, threadId: 'somethreadid', }) - messageContext = new InboundMessageContext(credentialRequest, { - connection, - }) + messageContext = new InboundMessageContext(credentialRequest, { agentContext, connection }) }) test(`updates state to ${CredentialState.Done} and returns credential record`, async () => { @@ -679,12 +684,12 @@ describe('V1CredentialService', () => { const expectedCredentialRecord = { state: CredentialState.Done, } - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1) - const [[updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls + const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls expect(updatedCredentialRecord).toMatchObject(expectedCredentialRecord) expect(returnedCredentialRecord).toMatchObject(expectedCredentialRecord) }) @@ -708,7 +713,7 @@ describe('V1CredentialService', () => { mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(credential)) // when - const credentialProblemReportMessage = credentialService.createProblemReport({ message }) + const credentialProblemReportMessage = credentialService.createProblemReport(agentContext, { message }) credentialProblemReportMessage.setThread({ threadId }) // then @@ -742,9 +747,7 @@ describe('V1CredentialService', () => { }, }) credentialProblemReportMessage.setThread({ threadId: 'somethreadid' }) - messageContext = new InboundMessageContext(credentialProblemReportMessage, { - connection, - }) + messageContext = new InboundMessageContext(credentialProblemReportMessage, { agentContext, connection }) }) test(`updates problem report error message and returns credential record`, async () => { @@ -760,12 +763,12 @@ describe('V1CredentialService', () => { const expectedCredentialRecord = { errorMessage: 'issuance-abandoned: Indy error', } - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1) - const [[updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls + const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls expect(updatedCredentialRecord).toMatchObject(expectedCredentialRecord) expect(returnedCredentialRecord).toMatchObject(expectedCredentialRecord) }) @@ -775,8 +778,8 @@ describe('V1CredentialService', () => { it('getById should return value from credentialRepository.getById', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getById(expected.id) - expect(credentialRepository.getById).toBeCalledWith(expected.id) + const result = await credentialService.getById(agentContext, expected.id) + expect(credentialRepository.getById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -784,8 +787,8 @@ describe('V1CredentialService', () => { it('getById should return value from credentialRepository.getSingleByQuery', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getByThreadAndConnectionId('threadId', 'connectionId') - expect(credentialRepository.getSingleByQuery).toBeCalledWith({ + const result = await credentialService.getByThreadAndConnectionId(agentContext, 'threadId', 'connectionId') + expect(credentialRepository.getSingleByQuery).toBeCalledWith(agentContext, { threadId: 'threadId', connectionId: 'connectionId', }) @@ -796,8 +799,8 @@ describe('V1CredentialService', () => { it('findById should return value from credentialRepository.findById', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.findById).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.findById(expected.id) - expect(credentialRepository.findById).toBeCalledWith(expected.id) + const result = await credentialService.findById(agentContext, expected.id) + expect(credentialRepository.findById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -806,8 +809,8 @@ describe('V1CredentialService', () => { const expected = [mockCredentialRecord(), mockCredentialRecord()] mockFunction(credentialRepository.getAll).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getAll() - expect(credentialRepository.getAll).toBeCalledWith() + const result = await credentialService.getAll(agentContext) + expect(credentialRepository.getAll).toBeCalledWith(agentContext) expect(result).toEqual(expect.arrayContaining(expected)) }) @@ -819,8 +822,8 @@ describe('V1CredentialService', () => { mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(credentialRecord)) const repositoryDeleteSpy = jest.spyOn(credentialRepository, 'delete') - await credentialService.delete(credentialRecord) - expect(repositoryDeleteSpy).toHaveBeenNthCalledWith(1, credentialRecord) + await credentialService.delete(agentContext, credentialRecord) + expect(repositoryDeleteSpy).toHaveBeenNthCalledWith(1, agentContext, credentialRecord) }) it('should call deleteCredentialById in indyCredentialFormatService if deleteAssociatedCredential is true', async () => { @@ -829,12 +832,16 @@ describe('V1CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord, { + await credentialService.delete(agentContext, credentialRecord, { deleteAssociatedCredentials: true, deleteAssociatedDidCommMessages: false, }) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) }) it('should not call deleteCredentialById in indyCredentialFormatService if deleteAssociatedCredential is false', async () => { @@ -843,7 +850,7 @@ describe('V1CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord, { + await credentialService.delete(agentContext, credentialRecord, { deleteAssociatedCredentials: false, deleteAssociatedDidCommMessages: false, }) @@ -857,9 +864,13 @@ describe('V1CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord) + await credentialService.delete(agentContext, credentialRecord) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) }) it('deleteAssociatedDidCommMessages should default to true', async () => { const deleteCredentialMock = mockFunction(indyCredentialFormatService.deleteCredentialById) @@ -867,9 +878,13 @@ describe('V1CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord) + await credentialService.delete(agentContext, credentialRecord) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) expect(didCommMessageRepository.delete).toHaveBeenCalledTimes(3) }) }) @@ -890,14 +905,18 @@ describe('V1CredentialService', () => { const repositoryUpdateSpy = jest.spyOn(credentialRepository, 'update') // when - await credentialService.declineOffer(credential) + await credentialService.declineOffer(agentContext, credential) // then const expectedCredentialState = { state: CredentialState.Declined, } expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1) - expect(repositoryUpdateSpy).toHaveBeenNthCalledWith(1, expect.objectContaining(expectedCredentialState)) + expect(repositoryUpdateSpy).toHaveBeenNthCalledWith( + 1, + agentContext, + expect.objectContaining(expectedCredentialState) + ) }) test(`emits stateChange event from ${CredentialState.OfferReceived} to ${CredentialState.Declined}`, async () => { @@ -908,7 +927,7 @@ describe('V1CredentialService', () => { mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(Promise.resolve(credential)) // when - await credentialService.declineOffer(credential) + await credentialService.declineOffer(agentContext, credential) // then expect(eventListenerMock).toHaveBeenCalledTimes(1) @@ -930,7 +949,7 @@ describe('V1CredentialService', () => { await Promise.all( invalidCredentialStates.map(async (state) => { await expect( - credentialService.declineOffer(mockCredentialRecord({ state, tags: { threadId } })) + credentialService.declineOffer(agentContext, mockCredentialRecord({ state, tags: { threadId } })) ).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`) }) ) diff --git a/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceProposeOffer.test.ts b/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceProposeOffer.test.ts index 92c3434bd8..7ebe9467aa 100644 --- a/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceProposeOffer.test.ts +++ b/packages/core/src/modules/credentials/protocol/v1/__tests__/V1CredentialServiceProposeOffer.test.ts @@ -1,9 +1,10 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { CredentialStateChangedEvent } from '../../../CredentialEvents' import type { CreateOfferOptions, CreateProposalOptions } from '../../../CredentialServiceOptions' import type { IndyCredentialFormat } from '../../../formats/indy/IndyCredentialFormat' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../../../tests/helpers' import { Dispatcher } from '../../../../../agent/Dispatcher' import { EventEmitter } from '../../../../../agent/EventEmitter' import { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' @@ -51,6 +52,9 @@ const indyCredentialFormatService = new IndyCredentialFormatServiceMock() const dispatcher = new DispatcherMock() const connectionService = new ConnectionServiceMock() +const agentConfig = getAgentConfig('V1CredentialServiceProposeOfferTest') +const agentContext = getAgentContext() + // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore indyCredentialFormatService.credentialRecordType = 'indy' @@ -89,13 +93,11 @@ const proposalAttachment = new Attachment({ describe('V1CredentialServiceProposeOffer', () => { let eventEmitter: EventEmitter - let agentConfig: AgentConfig + let credentialService: V1CredentialService beforeEach(async () => { - // real objects - agentConfig = getAgentConfig('V1CredentialServiceProposeOfferTest') - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) // mock function implementations mockFunction(connectionService.getById).mockResolvedValue(connection) @@ -107,7 +109,7 @@ describe('V1CredentialServiceProposeOffer', () => { credentialService = new V1CredentialService( connectionService, didCommMessageRepository, - agentConfig, + agentConfig.logger, routingService, dispatcher, eventEmitter, @@ -148,11 +150,12 @@ describe('V1CredentialServiceProposeOffer', () => { }), }) - await credentialService.createProposal(proposeOptions) + await credentialService.createProposal(agentContext, proposeOptions) // then expect(repositorySaveSpy).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ type: CredentialExchangeRecord.type, id: expect.any(String), @@ -175,7 +178,7 @@ describe('V1CredentialServiceProposeOffer', () => { }), }) - await credentialService.createProposal(proposeOptions) + await credentialService.createProposal(agentContext, proposeOptions) expect(eventListenerMock).toHaveBeenCalledWith({ type: 'CredentialStateChanged', @@ -198,7 +201,7 @@ describe('V1CredentialServiceProposeOffer', () => { previewAttributes: credentialPreview.attributes, }) - const { message } = await credentialService.createProposal(proposeOptions) + const { message } = await credentialService.createProposal(agentContext, proposeOptions) expect(message.toJSON()).toMatchObject({ '@id': expect.any(String), @@ -253,12 +256,12 @@ describe('V1CredentialServiceProposeOffer', () => { const repositorySaveSpy = jest.spyOn(credentialRepository, 'save') - await credentialService.createOffer(offerOptions) + await credentialService.createOffer(agentContext, offerOptions) // then expect(repositorySaveSpy).toHaveBeenCalledTimes(1) - const [[createdCredentialRecord]] = repositorySaveSpy.mock.calls + const [[, createdCredentialRecord]] = repositorySaveSpy.mock.calls expect(createdCredentialRecord).toMatchObject({ type: CredentialExchangeRecord.type, id: expect.any(String), @@ -282,7 +285,7 @@ describe('V1CredentialServiceProposeOffer', () => { previewAttributes: credentialPreview.attributes, }) - await credentialService.createOffer(offerOptions) + await credentialService.createOffer(agentContext, offerOptions) expect(eventListenerMock).toHaveBeenCalledWith({ type: 'CredentialStateChanged', @@ -304,7 +307,7 @@ describe('V1CredentialServiceProposeOffer', () => { }), }) - await expect(credentialService.createOffer(offerOptions)).rejects.toThrowError( + await expect(credentialService.createOffer(agentContext, offerOptions)).rejects.toThrowError( 'Missing required credential preview from indy format service' ) }) @@ -319,7 +322,7 @@ describe('V1CredentialServiceProposeOffer', () => { previewAttributes: credentialPreview.attributes, }) - const { message: credentialOffer } = await credentialService.createOffer(offerOptions) + const { message: credentialOffer } = await credentialService.createOffer(agentContext, offerOptions) expect(credentialOffer.toJSON()).toMatchObject({ '@id': expect.any(String), '@type': 'https://didcomm.org/issue-credential/1.0/offer-credential', @@ -350,9 +353,7 @@ describe('V1CredentialServiceProposeOffer', () => { credentialPreview: credentialPreview, offerAttachments: [offerAttachment], }) - const messageContext = new InboundMessageContext(credentialOfferMessage, { - connection, - }) + const messageContext = new InboundMessageContext(credentialOfferMessage, { agentContext, connection }) test(`creates and return credential record in ${CredentialState.OfferReceived} state with offer, thread ID`, async () => { // when @@ -361,6 +362,7 @@ describe('V1CredentialServiceProposeOffer', () => { // then expect(credentialRepository.save).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ type: CredentialExchangeRecord.type, id: expect.any(String), diff --git a/packages/core/src/modules/credentials/protocol/v1/__tests__/v1-credentials.e2e.test.ts b/packages/core/src/modules/credentials/protocol/v1/__tests__/v1-credentials.e2e.test.ts index 6ee984bca7..1d25498a3a 100644 --- a/packages/core/src/modules/credentials/protocol/v1/__tests__/v1-credentials.e2e.test.ts +++ b/packages/core/src/modules/credentials/protocol/v1/__tests__/v1-credentials.e2e.test.ts @@ -96,7 +96,7 @@ describe('v1 credentials', () => { }) const didCommMessageRepository = faberAgent.dependencyManager.resolve(DidCommMessageRepository) - const offerMessageRecord = await didCommMessageRepository.findAgentMessage({ + const offerMessageRecord = await didCommMessageRepository.findAgentMessage(faberAgent.context, { associatedRecordId: faberCredentialRecord.id, messageClass: V1OfferCredentialMessage, }) diff --git a/packages/core/src/modules/credentials/protocol/v1/handlers/V1IssueCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v1/handlers/V1IssueCredentialHandler.ts index 0a213e9c30..bf12db449a 100644 --- a/packages/core/src/modules/credentials/protocol/v1/handlers/V1IssueCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v1/handlers/V1IssueCredentialHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' +import type { Logger } from '../../../../../logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' import type { V1CredentialService } from '../V1CredentialService' @@ -9,24 +9,24 @@ import { V1IssueCredentialMessage, V1RequestCredentialMessage } from '../message export class V1IssueCredentialHandler implements Handler { private credentialService: V1CredentialService - private agentConfig: AgentConfig private didCommMessageRepository: DidCommMessageRepository + private logger: Logger public supportedMessages = [V1IssueCredentialMessage] public constructor( credentialService: V1CredentialService, - agentConfig: AgentConfig, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig this.didCommMessageRepository = didCommMessageRepository + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const credentialRecord = await this.credentialService.processCredential(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToCredential({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToCredential(messageContext.agentContext, { credentialRecord, credentialMessage: messageContext.message, }) @@ -40,14 +40,14 @@ export class V1IssueCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending acknowledgement with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending acknowledgement with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) - const { message } = await this.credentialService.acceptCredential({ + const { message } = await this.credentialService.acceptCredential(messageContext.agentContext, { credentialRecord, }) - const requestMessage = await this.didCommMessageRepository.getAgentMessage({ + const requestMessage = await this.didCommMessageRepository.getAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V1RequestCredentialMessage, }) @@ -65,6 +65,6 @@ export class V1IssueCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential ack`) + this.logger.error(`Could not automatically create credential ack`) } } diff --git a/packages/core/src/modules/credentials/protocol/v1/handlers/V1OfferCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v1/handlers/V1OfferCredentialHandler.ts index fab851c7df..207cbff379 100644 --- a/packages/core/src/modules/credentials/protocol/v1/handlers/V1OfferCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v1/handlers/V1OfferCredentialHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' +import type { Logger } from '../../../../../logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { RoutingService } from '../../../../routing/services/RoutingService' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' @@ -12,27 +12,27 @@ import { V1OfferCredentialMessage } from '../messages' export class V1OfferCredentialHandler implements Handler { private credentialService: V1CredentialService - private agentConfig: AgentConfig private routingService: RoutingService private didCommMessageRepository: DidCommMessageRepository + private logger: Logger public supportedMessages = [V1OfferCredentialMessage] public constructor( credentialService: V1CredentialService, - agentConfig: AgentConfig, routingService: RoutingService, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig this.routingService = routingService this.didCommMessageRepository = didCommMessageRepository + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const credentialRecord = await this.credentialService.processOffer(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToOffer({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToOffer(messageContext.agentContext, { credentialRecord, offerMessage: messageContext.message, }) @@ -46,15 +46,15 @@ export class V1OfferCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending request with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending request with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) if (messageContext.connection) { - const { message } = await this.credentialService.acceptOffer({ credentialRecord }) + const { message } = await this.credentialService.acceptOffer(messageContext.agentContext, { credentialRecord }) return createOutboundMessage(messageContext.connection, message) } else if (messageContext.message.service) { - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(messageContext.agentContext) const ourService = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -62,7 +62,7 @@ export class V1OfferCredentialHandler implements Handler { }) const recipientService = messageContext.message.service - const { message } = await this.credentialService.acceptOffer({ + const { message } = await this.credentialService.acceptOffer(messageContext.agentContext, { credentialRecord, credentialFormats: { indy: { @@ -73,7 +73,7 @@ export class V1OfferCredentialHandler implements Handler { // Set and save ~service decorator to record (to remember our verkey) message.service = ourService - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -86,6 +86,6 @@ export class V1OfferCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential request`) + this.logger.error(`Could not automatically create credential request`) } } diff --git a/packages/core/src/modules/credentials/protocol/v1/handlers/V1ProposeCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v1/handlers/V1ProposeCredentialHandler.ts index ed5992e136..38c32018d7 100644 --- a/packages/core/src/modules/credentials/protocol/v1/handlers/V1ProposeCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v1/handlers/V1ProposeCredentialHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' +import type { Logger } from '../../../../../logger' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' import type { V1CredentialService } from '../V1CredentialService' @@ -8,21 +8,24 @@ import { V1ProposeCredentialMessage } from '../messages' export class V1ProposeCredentialHandler implements Handler { private credentialService: V1CredentialService - private agentConfig: AgentConfig + private logger: Logger public supportedMessages = [V1ProposeCredentialMessage] - public constructor(credentialService: V1CredentialService, agentConfig: AgentConfig) { + public constructor(credentialService: V1CredentialService, logger: Logger) { this.credentialService = credentialService - this.agentConfig = agentConfig + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const credentialRecord = await this.credentialService.processProposal(messageContext) - const shouldAutoAcceptProposal = await this.credentialService.shouldAutoRespondToProposal({ - credentialRecord, - proposalMessage: messageContext.message, - }) + const shouldAutoAcceptProposal = await this.credentialService.shouldAutoRespondToProposal( + messageContext.agentContext, + { + credentialRecord, + proposalMessage: messageContext.message, + } + ) if (shouldAutoAcceptProposal) { return await this.acceptProposal(credentialRecord, messageContext) @@ -33,16 +36,16 @@ export class V1ProposeCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending offer with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending offer with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) if (!messageContext.connection) { - this.agentConfig.logger.error('No connection on the messageContext, aborting auto accept') + this.logger.error('No connection on the messageContext, aborting auto accept') return } - const { message } = await this.credentialService.acceptProposal({ + const { message } = await this.credentialService.acceptProposal(messageContext.agentContext, { credentialRecord, }) diff --git a/packages/core/src/modules/credentials/protocol/v1/handlers/V1RequestCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v1/handlers/V1RequestCredentialHandler.ts index 7e91d61e25..a5eb94ad41 100644 --- a/packages/core/src/modules/credentials/protocol/v1/handlers/V1RequestCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v1/handlers/V1RequestCredentialHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' +import type { Logger } from '../../../../../logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' import type { V1CredentialService } from '../V1CredentialService' @@ -9,25 +9,25 @@ import { DidCommMessageRole } from '../../../../../storage' import { V1RequestCredentialMessage } from '../messages' export class V1RequestCredentialHandler implements Handler { - private agentConfig: AgentConfig private credentialService: V1CredentialService private didCommMessageRepository: DidCommMessageRepository + private logger: Logger public supportedMessages = [V1RequestCredentialMessage] public constructor( credentialService: V1CredentialService, - agentConfig: AgentConfig, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig + this.logger = logger this.didCommMessageRepository = didCommMessageRepository } public async handle(messageContext: HandlerInboundMessage) { const credentialRecord = await this.credentialService.processRequest(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToRequest({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToRequest(messageContext.agentContext, { credentialRecord, requestMessage: messageContext.message, }) @@ -41,13 +41,13 @@ export class V1RequestCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending credential with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending credential with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) - const offerMessage = await this.credentialService.findOfferMessage(credentialRecord.id) + const offerMessage = await this.credentialService.findOfferMessage(messageContext.agentContext, credentialRecord.id) - const { message } = await this.credentialService.acceptRequest({ + const { message } = await this.credentialService.acceptRequest(messageContext.agentContext, { credentialRecord, }) @@ -60,7 +60,7 @@ export class V1RequestCredentialHandler implements Handler { // Set ~service, update message in record (for later use) message.setService(ourService) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -73,6 +73,6 @@ export class V1RequestCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential request`) + this.logger.error(`Could not automatically create credential request`) } } diff --git a/packages/core/src/modules/credentials/protocol/v2/CredentialFormatCoordinator.ts b/packages/core/src/modules/credentials/protocol/v2/CredentialFormatCoordinator.ts index b63635ffe0..0ee7ea021c 100644 --- a/packages/core/src/modules/credentials/protocol/v2/CredentialFormatCoordinator.ts +++ b/packages/core/src/modules/credentials/protocol/v2/CredentialFormatCoordinator.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../agent' import type { Attachment } from '../../../../decorators/attachment/Attachment' import type { DidCommMessageRepository } from '../../../../storage' import type { CredentialFormat, CredentialFormatPayload, CredentialFormatService } from '../../formats' @@ -28,24 +29,27 @@ export class CredentialFormatCoordinator { * @returns The created {@link V2ProposeCredentialMessage} * */ - public async createProposal({ - credentialFormats, - formatServices, - credentialRecord, - comment, - }: { - formatServices: CredentialFormatService[] - credentialFormats: CredentialFormatPayload - credentialRecord: CredentialExchangeRecord - comment?: string - }): Promise { + public async createProposal( + agentContext: AgentContext, + { + credentialFormats, + formatServices, + credentialRecord, + comment, + }: { + formatServices: CredentialFormatService[] + credentialFormats: CredentialFormatPayload + credentialRecord: CredentialExchangeRecord + comment?: string + } + ): Promise { // create message. there are two arrays in each message, one for formats the other for attachments const formats: CredentialFormatSpec[] = [] const proposalAttachments: Attachment[] = [] let credentialPreview: V2CredentialPreview | undefined for (const formatService of formatServices) { - const { format, attachment, previewAttributes } = await formatService.createProposal({ + const { format, attachment, previewAttributes } = await formatService.createProposal(agentContext, { credentialFormats, credentialRecord, }) @@ -72,7 +76,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -81,48 +85,54 @@ export class CredentialFormatCoordinator { return message } - public async processProposal({ - credentialRecord, - message, - formatServices, - }: { - credentialRecord: CredentialExchangeRecord - message: V2ProposeCredentialMessage - formatServices: CredentialFormatService[] - }) { + public async processProposal( + agentContext: AgentContext, + { + credentialRecord, + message, + formatServices, + }: { + credentialRecord: CredentialExchangeRecord + message: V2ProposeCredentialMessage + formatServices: CredentialFormatService[] + } + ) { for (const formatService of formatServices) { const attachment = this.getAttachmentForService(formatService, message.formats, message.proposalAttachments) - await formatService.processProposal({ + await formatService.processProposal(agentContext, { attachment, credentialRecord, }) } - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) } - public async acceptProposal({ - credentialRecord, - credentialFormats, - formatServices, - comment, - }: { - credentialRecord: CredentialExchangeRecord - credentialFormats?: CredentialFormatPayload - formatServices: CredentialFormatService[] - comment?: string - }) { + public async acceptProposal( + agentContext: AgentContext, + { + credentialRecord, + credentialFormats, + formatServices, + comment, + }: { + credentialRecord: CredentialExchangeRecord + credentialFormats?: CredentialFormatPayload + formatServices: CredentialFormatService[] + comment?: string + } + ) { // create message. there are two arrays in each message, one for formats the other for attachments const formats: CredentialFormatSpec[] = [] const offerAttachments: Attachment[] = [] let credentialPreview: V2CredentialPreview | undefined - const proposalMessage = await this.didCommMessageRepository.getAgentMessage({ + const proposalMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2ProposeCredentialMessage, }) @@ -139,7 +149,7 @@ export class CredentialFormatCoordinator { proposalMessage.proposalAttachments ) - const { attachment, format, previewAttributes } = await formatService.acceptProposal({ + const { attachment, format, previewAttributes } = await formatService.acceptProposal(agentContext, { credentialRecord, credentialFormats, proposalAttachment, @@ -174,7 +184,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, @@ -190,24 +200,27 @@ export class CredentialFormatCoordinator { * @returns The created {@link V2OfferCredentialMessage} * */ - public async createOffer({ - credentialFormats, - formatServices, - credentialRecord, - comment, - }: { - formatServices: CredentialFormatService[] - credentialFormats: CredentialFormatPayload - credentialRecord: CredentialExchangeRecord - comment?: string - }): Promise { + public async createOffer( + agentContext: AgentContext, + { + credentialFormats, + formatServices, + credentialRecord, + comment, + }: { + formatServices: CredentialFormatService[] + credentialFormats: CredentialFormatPayload + credentialRecord: CredentialExchangeRecord + comment?: string + } + ): Promise { // create message. there are two arrays in each message, one for formats the other for attachments const formats: CredentialFormatSpec[] = [] const offerAttachments: Attachment[] = [] let credentialPreview: V2CredentialPreview | undefined for (const formatService of formatServices) { - const { format, attachment, previewAttributes } = await formatService.createOffer({ + const { format, attachment, previewAttributes } = await formatService.createOffer(agentContext, { credentialFormats, credentialRecord, }) @@ -241,7 +254,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -250,43 +263,49 @@ export class CredentialFormatCoordinator { return message } - public async processOffer({ - credentialRecord, - message, - formatServices, - }: { - credentialRecord: CredentialExchangeRecord - message: V2OfferCredentialMessage - formatServices: CredentialFormatService[] - }) { + public async processOffer( + agentContext: AgentContext, + { + credentialRecord, + message, + formatServices, + }: { + credentialRecord: CredentialExchangeRecord + message: V2OfferCredentialMessage + formatServices: CredentialFormatService[] + } + ) { for (const formatService of formatServices) { const attachment = this.getAttachmentForService(formatService, message.formats, message.offerAttachments) - await formatService.processOffer({ + await formatService.processOffer(agentContext, { attachment, credentialRecord, }) } - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) } - public async acceptOffer({ - credentialRecord, - credentialFormats, - formatServices, - comment, - }: { - credentialRecord: CredentialExchangeRecord - credentialFormats?: CredentialFormatPayload - formatServices: CredentialFormatService[] - comment?: string - }) { - const offerMessage = await this.didCommMessageRepository.getAgentMessage({ + public async acceptOffer( + agentContext: AgentContext, + { + credentialRecord, + credentialFormats, + formatServices, + comment, + }: { + credentialRecord: CredentialExchangeRecord + credentialFormats?: CredentialFormatPayload + formatServices: CredentialFormatService[] + comment?: string + } + ) { + const offerMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2OfferCredentialMessage, }) @@ -302,7 +321,7 @@ export class CredentialFormatCoordinator { offerMessage.offerAttachments ) - const { attachment, format } = await formatService.acceptOffer({ + const { attachment, format } = await formatService.acceptOffer(agentContext, { offerAttachment, credentialRecord, credentialFormats, @@ -322,7 +341,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, @@ -338,23 +357,26 @@ export class CredentialFormatCoordinator { * @returns The created {@link V2RequestCredentialMessage} * */ - public async createRequest({ - credentialFormats, - formatServices, - credentialRecord, - comment, - }: { - formatServices: CredentialFormatService[] - credentialFormats: CredentialFormatPayload - credentialRecord: CredentialExchangeRecord - comment?: string - }): Promise { + public async createRequest( + agentContext: AgentContext, + { + credentialFormats, + formatServices, + credentialRecord, + comment, + }: { + formatServices: CredentialFormatService[] + credentialFormats: CredentialFormatPayload + credentialRecord: CredentialExchangeRecord + comment?: string + } + ): Promise { // create message. there are two arrays in each message, one for formats the other for attachments const formats: CredentialFormatSpec[] = [] const requestAttachments: Attachment[] = [] for (const formatService of formatServices) { - const { format, attachment } = await formatService.createRequest({ + const { format, attachment } = await formatService.createRequest(agentContext, { credentialFormats, credentialRecord, }) @@ -371,7 +393,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -380,48 +402,54 @@ export class CredentialFormatCoordinator { return message } - public async processRequest({ - credentialRecord, - message, - formatServices, - }: { - credentialRecord: CredentialExchangeRecord - message: V2RequestCredentialMessage - formatServices: CredentialFormatService[] - }) { + public async processRequest( + agentContext: AgentContext, + { + credentialRecord, + message, + formatServices, + }: { + credentialRecord: CredentialExchangeRecord + message: V2RequestCredentialMessage + formatServices: CredentialFormatService[] + } + ) { for (const formatService of formatServices) { const attachment = this.getAttachmentForService(formatService, message.formats, message.requestAttachments) - await formatService.processRequest({ + await formatService.processRequest(agentContext, { attachment, credentialRecord, }) } - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, }) } - public async acceptRequest({ - credentialRecord, - credentialFormats, - formatServices, - comment, - }: { - credentialRecord: CredentialExchangeRecord - credentialFormats?: CredentialFormatPayload - formatServices: CredentialFormatService[] - comment?: string - }) { - const requestMessage = await this.didCommMessageRepository.getAgentMessage({ + public async acceptRequest( + agentContext: AgentContext, + { + credentialRecord, + credentialFormats, + formatServices, + comment, + }: { + credentialRecord: CredentialExchangeRecord + credentialFormats?: CredentialFormatPayload + formatServices: CredentialFormatService[] + comment?: string + } + ) { + const requestMessage = await this.didCommMessageRepository.getAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2RequestCredentialMessage, }) - const offerMessage = await this.didCommMessageRepository.findAgentMessage({ + const offerMessage = await this.didCommMessageRepository.findAgentMessage(agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2OfferCredentialMessage, }) @@ -441,7 +469,7 @@ export class CredentialFormatCoordinator { ? this.getAttachmentForService(formatService, offerMessage.formats, offerMessage.offerAttachments) : undefined - const { attachment, format } = await formatService.acceptRequest({ + const { attachment, format } = await formatService.acceptRequest(agentContext, { requestAttachment, offerAttachment, credentialRecord, @@ -461,7 +489,7 @@ export class CredentialFormatCoordinator { message.setThread({ threadId: credentialRecord.threadId }) message.setPleaseAck() - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, @@ -470,25 +498,28 @@ export class CredentialFormatCoordinator { return message } - public async processCredential({ - credentialRecord, - message, - formatServices, - }: { - credentialRecord: CredentialExchangeRecord - message: V2IssueCredentialMessage - formatServices: CredentialFormatService[] - }) { + public async processCredential( + agentContext: AgentContext, + { + credentialRecord, + message, + formatServices, + }: { + credentialRecord: CredentialExchangeRecord + message: V2IssueCredentialMessage + formatServices: CredentialFormatService[] + } + ) { for (const formatService of formatServices) { const attachment = this.getAttachmentForService(formatService, message.formats, message.credentialAttachments) - await formatService.processCredential({ + await formatService.processCredential(agentContext, { attachment, credentialRecord, }) } - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(agentContext, { agentMessage: message, role: DidCommMessageRole.Receiver, associatedRecordId: credentialRecord.id, diff --git a/packages/core/src/modules/credentials/protocol/v2/V2CredentialService.ts b/packages/core/src/modules/credentials/protocol/v2/V2CredentialService.ts index 0cea805e08..75023e3ddc 100644 --- a/packages/core/src/modules/credentials/protocol/v2/V2CredentialService.ts +++ b/packages/core/src/modules/credentials/protocol/v2/V2CredentialService.ts @@ -1,21 +1,22 @@ +import type { AgentContext } from '../../../../agent' import type { AgentMessage } from '../../../../agent/AgentMessage' import type { HandlerInboundMessage } from '../../../../agent/Handler' import type { InboundMessageContext } from '../../../../agent/models/InboundMessageContext' import type { ProblemReportMessage } from '../../../problem-reports' import type { - CreateProposalOptions, - CredentialProtocolMsgReturnType, + AcceptCredentialOptions, + AcceptOfferOptions, AcceptProposalOptions, - NegotiateProposalOptions, + AcceptRequestOptions, CreateOfferOptions, - AcceptOfferOptions, - NegotiateOfferOptions, + CreateProposalOptions, CreateRequestOptions, - AcceptRequestOptions, - AcceptCredentialOptions, - GetFormatDataReturn, + CredentialProtocolMsgReturnType, FormatDataMessagePayload, CreateProblemReportOptions, + GetFormatDataReturn, + NegotiateOfferOptions, + NegotiateProposalOptions, } from '../../CredentialServiceOptions' import type { CredentialFormat, @@ -25,11 +26,12 @@ import type { } from '../../formats' import type { CredentialFormatSpec } from '../../models' -import { AgentConfig } from '../../../../agent/AgentConfig' import { Dispatcher } from '../../../../agent/Dispatcher' import { EventEmitter } from '../../../../agent/EventEmitter' +import { InjectionSymbols } from '../../../../constants' import { AriesFrameworkError } from '../../../../error' -import { injectable } from '../../../../plugins' +import { Logger } from '../../../../logger' +import { injectable, inject } from '../../../../plugins' import { DidCommMessageRepository } from '../../../../storage' import { uuid } from '../../../../utils/uuid' import { AckStatus } from '../../../common' @@ -37,8 +39,8 @@ import { ConnectionService } from '../../../connections' import { RoutingService } from '../../../routing/services/RoutingService' import { CredentialProblemReportReason } from '../../errors' import { IndyCredentialFormatService } from '../../formats/indy/IndyCredentialFormatService' -import { CredentialState, AutoAcceptCredential } from '../../models' -import { CredentialRepository, CredentialExchangeRecord } from '../../repository' +import { AutoAcceptCredential, CredentialState } from '../../models' +import { CredentialExchangeRecord, CredentialRepository } from '../../repository' import { CredentialService } from '../../services/CredentialService' import { composeAutoAccept } from '../../util/composeAutoAccept' import { arePreviewAttributesEqual } from '../../util/previewAttributes' @@ -65,23 +67,21 @@ import { export class V2CredentialService extends CredentialService { private connectionService: ConnectionService private credentialFormatCoordinator: CredentialFormatCoordinator - protected didCommMessageRepository: DidCommMessageRepository private routingService: RoutingService private formatServiceMap: { [key: string]: CredentialFormatService } public constructor( connectionService: ConnectionService, didCommMessageRepository: DidCommMessageRepository, - agentConfig: AgentConfig, routingService: RoutingService, dispatcher: Dispatcher, eventEmitter: EventEmitter, credentialRepository: CredentialRepository, - indyCredentialFormatService: IndyCredentialFormatService + indyCredentialFormatService: IndyCredentialFormatService, + @inject(InjectionSymbols.Logger) logger: Logger ) { - super(credentialRepository, didCommMessageRepository, eventEmitter, dispatcher, agentConfig) + super(credentialRepository, didCommMessageRepository, eventEmitter, dispatcher, logger) this.connectionService = connectionService - this.didCommMessageRepository = didCommMessageRepository this.routingService = routingService this.credentialFormatCoordinator = new CredentialFormatCoordinator(didCommMessageRepository) @@ -121,12 +121,10 @@ export class V2CredentialService): Promise> { + public async createProposal( + agentContext: AgentContext, + { connection, credentialFormats, comment, autoAcceptCredential }: CreateProposalOptions + ): Promise> { this.logger.debug('Get the Format Service and Create Proposal Message') const formatServices = this.getFormatServices(credentialFormats) @@ -142,7 +140,7 @@ export class V2CredentialService): Promise> { + public async acceptProposal( + agentContext: AgentContext, + { credentialRecord, credentialFormats, autoAcceptCredential, comment }: AcceptProposalOptions + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.ProposalReceived) @@ -246,7 +249,7 @@ export class V2CredentialService): Promise> { + public async negotiateProposal( + agentContext: AgentContext, + { credentialRecord, credentialFormats, autoAcceptCredential, comment }: NegotiateProposalOptions + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.ProposalReceived) @@ -304,7 +305,7 @@ export class V2CredentialService): Promise> { + public async createOffer( + agentContext: AgentContext, + { credentialFormats, autoAcceptCredential, comment, connection }: CreateOfferOptions + ): Promise> { const formatServices = this.getFormatServices(credentialFormats) if (formatServices.length === 0) { throw new AriesFrameworkError(`Unable to create offer. No supported formats`) @@ -345,7 +344,7 @@ export class V2CredentialService) { + public async acceptOffer( + agentContext: AgentContext, + { credentialRecord, autoAcceptCredential, comment, credentialFormats }: AcceptOfferOptions + ) { // Assert credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.OfferReceived) @@ -449,7 +453,7 @@ export class V2CredentialService): Promise> { + public async negotiateOffer( + agentContext: AgentContext, + { credentialRecord, credentialFormats, autoAcceptCredential, comment }: NegotiateOfferOptions + ): Promise> { // Assert credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.OfferReceived) @@ -507,7 +509,7 @@ export class V2CredentialService): Promise> { + public async createRequest( + agentContext: AgentContext, + { credentialFormats, autoAcceptCredential, comment, connection }: CreateRequestOptions + ): Promise> { const formatServices = this.getFormatServices(credentialFormats) if (formatServices.length === 0) { throw new AriesFrameworkError(`Unable to create request. No supported formats`) @@ -544,7 +544,7 @@ export class V2CredentialService) { + public async acceptRequest( + agentContext: AgentContext, + { credentialRecord, autoAcceptCredential, comment, credentialFormats }: AcceptRequestOptions + ) { // Assert credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.RequestReceived) @@ -654,7 +656,7 @@ export class V2CredentialService> { + public async acceptCredential( + agentContext: AgentContext, + { credentialRecord }: AcceptCredentialOptions + ): Promise> { credentialRecord.assertProtocolVersion('v2') credentialRecord.assertState(CredentialState.CredentialReceived) @@ -755,7 +762,7 @@ export class V2CredentialService { + public async shouldAutoRespondToProposal( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + proposalMessage: V2ProposeCredentialMessage + } + ): Promise { const { credentialRecord, proposalMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) if (!offerMessage) return false // NOTE: we take the formats from the offerMessage so we always check all services that we last sent @@ -850,7 +867,7 @@ export class V2CredentialService { + public async shouldAutoRespondToOffer( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + offerMessage: V2OfferCredentialMessage + } + ): Promise { const { credentialRecord, offerMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const proposalMessage = await this.findProposalMessage(credentialRecord.id) + const proposalMessage = await this.findProposalMessage(agentContext, credentialRecord.id) if (!proposalMessage) return false // NOTE: we take the formats from the proposalMessage so we always check all services that we last sent @@ -908,7 +931,7 @@ export class V2CredentialService { + public async shouldAutoRespondToRequest( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + requestMessage: V2RequestCredentialMessage + } + ): Promise { const { credentialRecord, requestMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const proposalMessage = await this.findProposalMessage(credentialRecord.id) + const proposalMessage = await this.findProposalMessage(agentContext, credentialRecord.id) - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) if (!offerMessage) return false // NOTE: we take the formats from the offerMessage so we always check all services that we last sent @@ -976,7 +1005,7 @@ export class V2CredentialService { + public async shouldAutoRespondToCredential( + agentContext: AgentContext, + options: { + credentialRecord: CredentialExchangeRecord + credentialMessage: V2IssueCredentialMessage + } + ): Promise { const { credentialRecord, credentialMessage } = options - const autoAccept = composeAutoAccept(credentialRecord.autoAcceptCredential, this.agentConfig.autoAcceptCredentials) + const autoAccept = composeAutoAccept( + credentialRecord.autoAcceptCredential, + agentContext.config.autoAcceptCredentials + ) // Handle always / never cases if (autoAccept === AutoAcceptCredential.Always) return true if (autoAccept === AutoAcceptCredential.Never) return false - const proposalMessage = await this.findProposalMessage(credentialRecord.id) - const offerMessage = await this.findOfferMessage(credentialRecord.id) + const proposalMessage = await this.findProposalMessage(agentContext, credentialRecord.id) + const offerMessage = await this.findOfferMessage(agentContext, credentialRecord.id) - const requestMessage = await this.findRequestMessage(credentialRecord.id) + const requestMessage = await this.findRequestMessage(agentContext, credentialRecord.id) if (!requestMessage) return false // NOTE: we take the formats from the requestMessage so we always check all services that we last sent @@ -1041,7 +1076,7 @@ export class V2CredentialService { + public async getFormatData(agentContext: AgentContext, credentialExchangeId: string): Promise { // TODO: we could looking at fetching all record using a single query and then filtering based on the type of the message. const [proposalMessage, offerMessage, requestMessage, credentialMessage] = await Promise.all([ - this.findProposalMessage(credentialExchangeId), - this.findOfferMessage(credentialExchangeId), - this.findRequestMessage(credentialExchangeId), - this.findCredentialMessage(credentialExchangeId), + this.findProposalMessage(agentContext, credentialExchangeId), + this.findOfferMessage(agentContext, credentialExchangeId), + this.findRequestMessage(agentContext, credentialExchangeId), + this.findCredentialMessage(agentContext, credentialExchangeId), ]) // Create object with the keys and the message formats/attachments. We can then loop over this in a generic @@ -1134,17 +1169,15 @@ export class V2CredentialService) => { +const getAgentMessageMock = async (agentContext: AgentContext, options: GetAgentMessageOptions) => { if (options.messageClass === V2ProposeCredentialMessage) { return credentialProposalMessage } @@ -231,12 +236,11 @@ const mockCredentialRecord = ({ describe('CredentialService', () => { let eventEmitter: EventEmitter - let agentConfig: AgentConfig + let credentialService: V2CredentialService beforeEach(async () => { - agentConfig = getAgentConfig('V2CredentialServiceCredTest') - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) // mock function implementations mockFunction(connectionService.getById).mockResolvedValue(connection) @@ -251,12 +255,12 @@ describe('CredentialService', () => { credentialService = new V2CredentialService( connectionService, didCommMessageRepository, - agentConfig, routingService, dispatcher, eventEmitter, credentialRepository, - indyCredentialFormatService + indyCredentialFormatService, + agentConfig.logger ) }) @@ -279,7 +283,7 @@ describe('CredentialService', () => { }) // when - await credentialService.acceptOffer({ + await credentialService.acceptOffer(agentContext, { credentialRecord, credentialFormats: { indy: { @@ -292,6 +296,7 @@ describe('CredentialService', () => { // then expect(credentialRepository.update).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ state: CredentialState.RequestSent, }) @@ -315,7 +320,10 @@ describe('CredentialService', () => { }) // when - const { message: credentialRequest } = await credentialService.acceptOffer({ credentialRecord, comment }) + const { message: credentialRequest } = await credentialService.acceptOffer(agentContext, { + credentialRecord, + comment, + }) // then expect(credentialRequest.toJSON()).toMatchObject({ @@ -336,7 +344,7 @@ describe('CredentialService', () => { await Promise.all( invalidCredentialStates.map(async (state) => { await expect( - credentialService.acceptOffer({ credentialRecord: mockCredentialRecord({ state }) }) + credentialService.acceptOffer(agentContext, { credentialRecord: mockCredentialRecord({ state }) }) ).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`) }) ) @@ -350,6 +358,7 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord({ state: CredentialState.OfferSent }) const messageContext = new InboundMessageContext(credentialRequestMessage, { connection, + agentContext, }) // given @@ -359,7 +368,7 @@ describe('CredentialService', () => { const returnedCredentialRecord = await credentialService.processRequest(messageContext) // then - expect(credentialRepository.findSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.findSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) @@ -373,6 +382,7 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord({ state: CredentialState.OfferSent }) const messageContext = new InboundMessageContext(credentialRequestMessage, { connection, + agentContext, }) const eventListenerMock = jest.fn() @@ -383,10 +393,15 @@ describe('CredentialService', () => { const returnedCredentialRecord = await credentialService.processRequest(messageContext) // then - expect(credentialRepository.findSingleByQuery).toHaveBeenNthCalledWith(1, { - threadId: 'somethreadid', - connectionId: connection.id, - }) + expect(credentialRepository.findSingleByQuery).toHaveBeenNthCalledWith( + 1, + agentContext, + + { + threadId: 'somethreadid', + connectionId: connection.id, + } + ) expect(eventListenerMock).toHaveBeenCalled() expect(returnedCredentialRecord.state).toEqual(CredentialState.RequestReceived) }) @@ -398,6 +413,7 @@ describe('CredentialService', () => { const messageContext = new InboundMessageContext(credentialRequestMessage, { connection, + agentContext, }) await Promise.all( @@ -426,7 +442,7 @@ describe('CredentialService', () => { connectionId: 'b1e2f039-aa39-40be-8643-6ce2797b5190', }) - await credentialService.acceptRequest({ + await credentialService.acceptRequest(agentContext, { credentialRecord, comment: 'credential response comment', }) @@ -434,6 +450,7 @@ describe('CredentialService', () => { // then expect(credentialRepository.update).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ state: CredentialState.CredentialIssued, }) @@ -459,7 +476,7 @@ describe('CredentialService', () => { eventEmitter.on(CredentialEventTypes.CredentialStateChanged, eventListenerMock) // when - await credentialService.acceptRequest({ + await credentialService.acceptRequest(agentContext, { credentialRecord, comment: 'credential response comment', }) @@ -493,7 +510,7 @@ describe('CredentialService', () => { const comment = 'credential response comment' // when - const { message: credentialResponse } = await credentialService.acceptRequest({ + const { message: credentialResponse } = await credentialService.acceptRequest(agentContext, { comment: 'credential response comment', credentialRecord, }) @@ -523,6 +540,7 @@ describe('CredentialService', () => { const messageContext = new InboundMessageContext(credentialIssueMessage, { connection, + agentContext, }) // given @@ -544,11 +562,12 @@ describe('CredentialService', () => { }) // when - await credentialService.acceptCredential({ credentialRecord }) + await credentialService.acceptCredential(agentContext, { credentialRecord }) // then expect(credentialRepository.update).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ state: CredentialState.Done, }) @@ -566,7 +585,7 @@ describe('CredentialService', () => { eventEmitter.on(CredentialEventTypes.CredentialStateChanged, eventListenerMock) // when - await credentialService.acceptCredential({ credentialRecord }) + await credentialService.acceptCredential(agentContext, { credentialRecord }) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -591,7 +610,7 @@ describe('CredentialService', () => { mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) // when - const { message: ackMessage } = await credentialService.acceptCredential({ credentialRecord }) + const { message: ackMessage } = await credentialService.acceptCredential(agentContext, { credentialRecord }) // then expect(ackMessage.toJSON()).toMatchObject({ @@ -609,7 +628,7 @@ describe('CredentialService', () => { await Promise.all( invalidCredentialStates.map(async (state) => { await expect( - credentialService.acceptCredential({ + credentialService.acceptCredential(agentContext, { credentialRecord: mockCredentialRecord({ state, threadId: 'somethreadid', @@ -627,9 +646,7 @@ describe('CredentialService', () => { status: AckStatus.OK, threadId: 'somethreadid', }) - const messageContext = new InboundMessageContext(credentialRequest, { - connection, - }) + const messageContext = new InboundMessageContext(credentialRequest, { agentContext, connection }) test(`updates state to ${CredentialState.Done} and returns credential record`, async () => { const credentialRecord = mockCredentialRecord({ @@ -642,7 +659,7 @@ describe('CredentialService', () => { // when const returnedCredentialRecord = await credentialService.processAck(messageContext) - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) @@ -663,7 +680,7 @@ describe('CredentialService', () => { mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) // when - const credentialProblemReportMessage = credentialService.createProblemReport({ message }) + const credentialProblemReportMessage = credentialService.createProblemReport(agentContext, { message }) credentialProblemReportMessage.setThread({ threadId: 'somethreadid' }) // then @@ -691,6 +708,7 @@ describe('CredentialService', () => { credentialProblemReportMessage.setThread({ threadId: 'somethreadid' }) const messageContext = new InboundMessageContext(credentialProblemReportMessage, { connection, + agentContext, }) test(`updates problem report error message and returns credential record`, async () => { @@ -706,7 +724,7 @@ describe('CredentialService', () => { // then - expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) @@ -719,8 +737,8 @@ describe('CredentialService', () => { it('getById should return value from credentialRepository.getById', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getById(expected.id) - expect(credentialRepository.getById).toBeCalledWith(expected.id) + const result = await credentialService.getById(agentContext, expected.id) + expect(credentialRepository.getById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -728,8 +746,8 @@ describe('CredentialService', () => { it('getById should return value from credentialRepository.getSingleByQuery', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getByThreadAndConnectionId('threadId', 'connectionId') - expect(credentialRepository.getSingleByQuery).toBeCalledWith({ + const result = await credentialService.getByThreadAndConnectionId(agentContext, 'threadId', 'connectionId') + expect(credentialRepository.getSingleByQuery).toBeCalledWith(agentContext, { threadId: 'threadId', connectionId: 'connectionId', }) @@ -740,8 +758,8 @@ describe('CredentialService', () => { it('findById should return value from credentialRepository.findById', async () => { const expected = mockCredentialRecord() mockFunction(credentialRepository.findById).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.findById(expected.id) - expect(credentialRepository.findById).toBeCalledWith(expected.id) + const result = await credentialService.findById(agentContext, expected.id) + expect(credentialRepository.findById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -750,8 +768,8 @@ describe('CredentialService', () => { const expected = [mockCredentialRecord(), mockCredentialRecord()] mockFunction(credentialRepository.getAll).mockReturnValue(Promise.resolve(expected)) - const result = await credentialService.getAll() - expect(credentialRepository.getAll).toBeCalledWith() + const result = await credentialService.getAll(agentContext) + expect(credentialRepository.getAll).toBeCalledWith(agentContext) expect(result).toEqual(expect.arrayContaining(expected)) }) @@ -763,8 +781,8 @@ describe('CredentialService', () => { mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(credentialRecord)) const repositoryDeleteSpy = jest.spyOn(credentialRepository, 'delete') - await credentialService.delete(credentialRecord) - expect(repositoryDeleteSpy).toHaveBeenNthCalledWith(1, credentialRecord) + await credentialService.delete(agentContext, credentialRecord) + expect(repositoryDeleteSpy).toHaveBeenNthCalledWith(1, agentContext, credentialRecord) }) it('should call deleteCredentialById in indyCredentialFormatService if deleteAssociatedCredential is true', async () => { @@ -773,12 +791,16 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord, { + await credentialService.delete(agentContext, credentialRecord, { deleteAssociatedCredentials: true, deleteAssociatedDidCommMessages: false, }) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) }) it('should not call deleteCredentialById in indyCredentialFormatService if deleteAssociatedCredential is false', async () => { @@ -787,7 +809,7 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord, { + await credentialService.delete(agentContext, credentialRecord, { deleteAssociatedCredentials: false, deleteAssociatedDidCommMessages: false, }) @@ -801,9 +823,13 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord) + await credentialService.delete(agentContext, credentialRecord) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) }) it('deleteAssociatedDidCommMessages should default to true', async () => { const deleteCredentialMock = mockFunction(indyCredentialFormatService.deleteCredentialById) @@ -811,9 +837,13 @@ describe('CredentialService', () => { const credentialRecord = mockCredentialRecord() mockFunction(credentialRepository.getById).mockResolvedValue(credentialRecord) - await credentialService.delete(credentialRecord) + await credentialService.delete(agentContext, credentialRecord) - expect(deleteCredentialMock).toHaveBeenNthCalledWith(1, credentialRecord.credentials[0].credentialRecordId) + expect(deleteCredentialMock).toHaveBeenNthCalledWith( + 1, + agentContext, + credentialRecord.credentials[0].credentialRecordId + ) expect(didCommMessageRepository.delete).toHaveBeenCalledTimes(3) }) }) @@ -825,12 +855,13 @@ describe('CredentialService', () => { }) // when - await credentialService.declineOffer(credentialRecord) + await credentialService.declineOffer(agentContext, credentialRecord) // then expect(credentialRepository.update).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ state: CredentialState.Declined, }) @@ -849,7 +880,7 @@ describe('CredentialService', () => { mockFunction(credentialRepository.getSingleByQuery).mockResolvedValue(credentialRecord) // when - await credentialService.declineOffer(credentialRecord) + await credentialService.declineOffer(agentContext, credentialRecord) // then expect(eventListenerMock).toHaveBeenCalledTimes(1) @@ -870,9 +901,9 @@ describe('CredentialService', () => { test(`throws an error when state transition is invalid`, async () => { await Promise.all( invalidCredentialStates.map(async (state) => { - await expect(credentialService.declineOffer(mockCredentialRecord({ state }))).rejects.toThrowError( - `Credential record is in invalid state ${state}. Valid states are: ${validState}.` - ) + await expect( + credentialService.declineOffer(agentContext, mockCredentialRecord({ state })) + ).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`) }) ) }) diff --git a/packages/core/src/modules/credentials/protocol/v2/__tests__/V2CredentialServiceOffer.test.ts b/packages/core/src/modules/credentials/protocol/v2/__tests__/V2CredentialServiceOffer.test.ts index 0a2aaf7aef..89670711fe 100644 --- a/packages/core/src/modules/credentials/protocol/v2/__tests__/V2CredentialServiceOffer.test.ts +++ b/packages/core/src/modules/credentials/protocol/v2/__tests__/V2CredentialServiceOffer.test.ts @@ -1,9 +1,10 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { CredentialStateChangedEvent } from '../../../CredentialEvents' import type { CreateOfferOptions } from '../../../CredentialServiceOptions' import type { IndyCredentialFormat } from '../../../formats/indy/IndyCredentialFormat' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../../../tests/helpers' import { Dispatcher } from '../../../../../agent/Dispatcher' import { EventEmitter } from '../../../../../agent/EventEmitter' import { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' @@ -55,6 +56,9 @@ const connectionService = new ConnectionServiceMock() // @ts-ignore indyCredentialFormatService.formatKey = 'indy' +const agentConfig = getAgentConfig('V2CredentialServiceOfferTest') +const agentContext = getAgentContext() + const connection = getMockConnection({ id: '123', state: DidExchangeState.Completed, @@ -80,13 +84,11 @@ const offerAttachment = new Attachment({ describe('V2CredentialServiceOffer', () => { let eventEmitter: EventEmitter - let agentConfig: AgentConfig let credentialService: V2CredentialService beforeEach(async () => { // real objects - agentConfig = getAgentConfig('V2CredentialServiceOfferTest') - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) // mock function implementations mockFunction(connectionService.getById).mockResolvedValue(connection) @@ -96,12 +98,12 @@ describe('V2CredentialServiceOffer', () => { credentialService = new V2CredentialService( connectionService, didCommMessageRepository, - agentConfig, routingService, dispatcher, eventEmitter, credentialRepository, - indyCredentialFormatService + indyCredentialFormatService, + agentConfig.logger ) }) @@ -129,11 +131,12 @@ describe('V2CredentialServiceOffer', () => { }) // when - await credentialService.createOffer(offerOptions) + await credentialService.createOffer(agentContext, offerOptions) // then expect(credentialRepository.save).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ type: CredentialExchangeRecord.type, id: expect.any(String), @@ -154,7 +157,7 @@ describe('V2CredentialServiceOffer', () => { const eventListenerMock = jest.fn() eventEmitter.on(CredentialEventTypes.CredentialStateChanged, eventListenerMock) - await credentialService.createOffer(offerOptions) + await credentialService.createOffer(agentContext, offerOptions) expect(eventListenerMock).toHaveBeenCalledWith({ type: 'CredentialStateChanged', @@ -175,7 +178,7 @@ describe('V2CredentialServiceOffer', () => { previewAttributes: credentialPreview.attributes, }) - const { message: credentialOffer } = await credentialService.createOffer(offerOptions) + const { message: credentialOffer } = await credentialService.createOffer(agentContext, offerOptions) expect(credentialOffer.toJSON()).toMatchObject({ '@id': expect.any(String), @@ -210,9 +213,7 @@ describe('V2CredentialServiceOffer', () => { offerAttachments: [offerAttachment], }) - const messageContext = new InboundMessageContext(credentialOfferMessage, { - connection, - }) + const messageContext = new InboundMessageContext(credentialOfferMessage, { agentContext, connection }) test(`creates and return credential record in ${CredentialState.OfferReceived} state with offer, thread ID`, async () => { mockFunction(indyCredentialFormatService.supportsFormat).mockReturnValue(true) @@ -223,6 +224,7 @@ describe('V2CredentialServiceOffer', () => { // then expect(credentialRepository.save).toHaveBeenNthCalledWith( 1, + agentContext, expect.objectContaining({ type: CredentialExchangeRecord.type, id: expect.any(String), diff --git a/packages/core/src/modules/credentials/protocol/v2/__tests__/v2-credentials.e2e.test.ts b/packages/core/src/modules/credentials/protocol/v2/__tests__/v2-credentials.e2e.test.ts index 0fa0e2c9f2..c584f7fca6 100644 --- a/packages/core/src/modules/credentials/protocol/v2/__tests__/v2-credentials.e2e.test.ts +++ b/packages/core/src/modules/credentials/protocol/v2/__tests__/v2-credentials.e2e.test.ts @@ -115,7 +115,7 @@ describe('v2 credentials', () => { }) const didCommMessageRepository = faberAgent.dependencyManager.resolve(DidCommMessageRepository) - const offerMessage = await didCommMessageRepository.findAgentMessage({ + const offerMessage = await didCommMessageRepository.findAgentMessage(faberAgent.context, { associatedRecordId: faberCredentialRecord.id, messageClass: V2OfferCredentialMessage, }) @@ -227,7 +227,11 @@ describe('v2 credentials', () => { deleteAssociatedCredentials: true, deleteAssociatedDidCommMessages: true, }) - expect(deleteCredentialSpy).toHaveBeenNthCalledWith(1, holderCredential.credentials[0].credentialRecordId) + expect(deleteCredentialSpy).toHaveBeenNthCalledWith( + 1, + aliceAgent.context, + holderCredential.credentials[0].credentialRecordId + ) return expect(aliceAgent.credentials.getById(holderCredential.id)).rejects.toThrowError( `CredentialRecord: record with id ${holderCredential.id} not found.` diff --git a/packages/core/src/modules/credentials/protocol/v2/handlers/V2IssueCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v2/handlers/V2IssueCredentialHandler.ts index 402d6c6047..9329fa298a 100644 --- a/packages/core/src/modules/credentials/protocol/v2/handlers/V2IssueCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v2/handlers/V2IssueCredentialHandler.ts @@ -1,6 +1,6 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' import type { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' +import type { Logger } from '../../../../../logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' import type { V2CredentialService } from '../V2CredentialService' @@ -11,24 +11,25 @@ import { V2RequestCredentialMessage } from '../messages/V2RequestCredentialMessa export class V2IssueCredentialHandler implements Handler { private credentialService: V2CredentialService - private agentConfig: AgentConfig private didCommMessageRepository: DidCommMessageRepository + private logger: Logger public supportedMessages = [V2IssueCredentialMessage] public constructor( credentialService: V2CredentialService, - agentConfig: AgentConfig, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig this.didCommMessageRepository = didCommMessageRepository + this.logger = logger } + public async handle(messageContext: InboundMessageContext) { const credentialRecord = await this.credentialService.processCredential(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToCredential({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToCredential(messageContext.agentContext, { credentialRecord, credentialMessage: messageContext.message, }) @@ -42,16 +43,16 @@ export class V2IssueCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending acknowledgement with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending acknowledgement with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) - const requestMessage = await this.didCommMessageRepository.findAgentMessage({ + const requestMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2RequestCredentialMessage, }) - const { message } = await this.credentialService.acceptCredential({ + const { message } = await this.credentialService.acceptCredential(messageContext.agentContext, { credentialRecord, }) @@ -68,6 +69,6 @@ export class V2IssueCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential ack`) + this.logger.error(`Could not automatically create credential ack`) } } diff --git a/packages/core/src/modules/credentials/protocol/v2/handlers/V2OfferCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v2/handlers/V2OfferCredentialHandler.ts index d938f2a19b..7d3c3b6419 100644 --- a/packages/core/src/modules/credentials/protocol/v2/handlers/V2OfferCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v2/handlers/V2OfferCredentialHandler.ts @@ -1,6 +1,6 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' import type { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' +import type { Logger } from '../../../../../logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { RoutingService } from '../../../../routing/services/RoutingService' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' @@ -13,27 +13,28 @@ import { V2OfferCredentialMessage } from '../messages/V2OfferCredentialMessage' export class V2OfferCredentialHandler implements Handler { private credentialService: V2CredentialService - private agentConfig: AgentConfig private routingService: RoutingService + private logger: Logger + public supportedMessages = [V2OfferCredentialMessage] private didCommMessageRepository: DidCommMessageRepository public constructor( credentialService: V2CredentialService, - agentConfig: AgentConfig, routingService: RoutingService, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig this.routingService = routingService this.didCommMessageRepository = didCommMessageRepository + this.logger = logger } public async handle(messageContext: InboundMessageContext) { const credentialRecord = await this.credentialService.processOffer(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToOffer({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToOffer(messageContext.agentContext, { credentialRecord, offerMessage: messageContext.message, }) @@ -48,17 +49,17 @@ export class V2OfferCredentialHandler implements Handler { messageContext: HandlerInboundMessage, offerMessage?: V2OfferCredentialMessage ) { - this.agentConfig.logger.info( - `Automatically sending request with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending request with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) if (messageContext.connection) { - const { message } = await this.credentialService.acceptOffer({ + const { message } = await this.credentialService.acceptOffer(messageContext.agentContext, { credentialRecord, }) return createOutboundMessage(messageContext.connection, message) } else if (offerMessage?.service) { - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(messageContext.agentContext) const ourService = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -66,14 +67,14 @@ export class V2OfferCredentialHandler implements Handler { }) const recipientService = offerMessage.service - const { message } = await this.credentialService.acceptOffer({ + const { message } = await this.credentialService.acceptOffer(messageContext.agentContext, { credentialRecord, }) // Set and save ~service decorator to record (to remember our verkey) message.service = ourService - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: message, role: DidCommMessageRole.Sender, associatedRecordId: credentialRecord.id, @@ -86,6 +87,6 @@ export class V2OfferCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential request`) + this.logger.error(`Could not automatically create credential request`) } } diff --git a/packages/core/src/modules/credentials/protocol/v2/handlers/V2ProposeCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v2/handlers/V2ProposeCredentialHandler.ts index 27a181ed67..9c63943302 100644 --- a/packages/core/src/modules/credentials/protocol/v2/handlers/V2ProposeCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v2/handlers/V2ProposeCredentialHandler.ts @@ -1,6 +1,6 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../../../agent/Handler' import type { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' +import type { Logger } from '../../../../../logger' import type { CredentialExchangeRecord } from '../../../repository/CredentialExchangeRecord' import type { V2CredentialService } from '../V2CredentialService' @@ -9,19 +9,19 @@ import { V2ProposeCredentialMessage } from '../messages/V2ProposeCredentialMessa export class V2ProposeCredentialHandler implements Handler { private credentialService: V2CredentialService - private agentConfig: AgentConfig + private logger: Logger public supportedMessages = [V2ProposeCredentialMessage] - public constructor(credentialService: V2CredentialService, agentConfig: AgentConfig) { + public constructor(credentialService: V2CredentialService, logger: Logger) { this.credentialService = credentialService - this.agentConfig = agentConfig + this.logger = logger } public async handle(messageContext: InboundMessageContext) { const credentialRecord = await this.credentialService.processProposal(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToProposal({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToProposal(messageContext.agentContext, { credentialRecord, proposalMessage: messageContext.message, }) @@ -35,16 +35,16 @@ export class V2ProposeCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending offer with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending offer with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) if (!messageContext.connection) { - this.agentConfig.logger.error('No connection on the messageContext, aborting auto accept') + this.logger.error('No connection on the messageContext, aborting auto accept') return } - const { message } = await this.credentialService.acceptProposal({ credentialRecord }) + const { message } = await this.credentialService.acceptProposal(messageContext.agentContext, { credentialRecord }) return createOutboundMessage(messageContext.connection, message) } diff --git a/packages/core/src/modules/credentials/protocol/v2/handlers/V2RequestCredentialHandler.ts b/packages/core/src/modules/credentials/protocol/v2/handlers/V2RequestCredentialHandler.ts index 7b137d8955..6f5145dedd 100644 --- a/packages/core/src/modules/credentials/protocol/v2/handlers/V2RequestCredentialHandler.ts +++ b/packages/core/src/modules/credentials/protocol/v2/handlers/V2RequestCredentialHandler.ts @@ -1,6 +1,6 @@ -import type { AgentConfig } from '../../../../../agent/AgentConfig' import type { Handler } from '../../../../../agent/Handler' import type { InboundMessageContext } from '../../../../../agent/models/InboundMessageContext' +import type { Logger } from '../../../../../logger/Logger' import type { DidCommMessageRepository } from '../../../../../storage' import type { CredentialExchangeRecord } from '../../../repository' import type { V2CredentialService } from '../V2CredentialService' @@ -12,24 +12,25 @@ import { V2RequestCredentialMessage } from '../messages/V2RequestCredentialMessa export class V2RequestCredentialHandler implements Handler { private credentialService: V2CredentialService - private agentConfig: AgentConfig private didCommMessageRepository: DidCommMessageRepository + private logger: Logger + public supportedMessages = [V2RequestCredentialMessage] public constructor( credentialService: V2CredentialService, - agentConfig: AgentConfig, - didCommMessageRepository: DidCommMessageRepository + didCommMessageRepository: DidCommMessageRepository, + logger: Logger ) { this.credentialService = credentialService - this.agentConfig = agentConfig this.didCommMessageRepository = didCommMessageRepository + this.logger = logger } public async handle(messageContext: InboundMessageContext) { const credentialRecord = await this.credentialService.processRequest(messageContext) - const shouldAutoRespond = await this.credentialService.shouldAutoRespondToRequest({ + const shouldAutoRespond = await this.credentialService.shouldAutoRespondToRequest(messageContext.agentContext, { credentialRecord, requestMessage: messageContext.message, }) @@ -43,16 +44,16 @@ export class V2RequestCredentialHandler implements Handler { credentialRecord: CredentialExchangeRecord, messageContext: InboundMessageContext ) { - this.agentConfig.logger.info( - `Automatically sending credential with autoAccept on ${this.agentConfig.autoAcceptCredentials}` + this.logger.info( + `Automatically sending credential with autoAccept on ${messageContext.agentContext.config.autoAcceptCredentials}` ) - const offerMessage = await this.didCommMessageRepository.findAgentMessage({ + const offerMessage = await this.didCommMessageRepository.findAgentMessage(messageContext.agentContext, { associatedRecordId: credentialRecord.id, messageClass: V2OfferCredentialMessage, }) - const { message } = await this.credentialService.acceptRequest({ + const { message } = await this.credentialService.acceptRequest(messageContext.agentContext, { credentialRecord, }) @@ -64,7 +65,7 @@ export class V2RequestCredentialHandler implements Handler { // Set ~service, update message in record (for later use) message.setService(ourService) - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(messageContext.agentContext, { agentMessage: message, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, @@ -77,6 +78,6 @@ export class V2RequestCredentialHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create credential request`) + this.logger.error(`Could not automatically create credential request`) } } diff --git a/packages/core/src/modules/credentials/services/CredentialService.ts b/packages/core/src/modules/credentials/services/CredentialService.ts index 2e305864ef..7642e4c4c5 100644 --- a/packages/core/src/modules/credentials/services/CredentialService.ts +++ b/packages/core/src/modules/credentials/services/CredentialService.ts @@ -1,4 +1,4 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' +import type { AgentContext } from '../../../agent' import type { AgentMessage } from '../../../agent/AgentMessage' import type { Dispatcher } from '../../../agent/Dispatcher' import type { EventEmitter } from '../../../agent/EventEmitter' @@ -35,7 +35,6 @@ export abstract class CredentialService // methods for proposal - abstract createProposal(options: CreateProposalOptions): Promise> + abstract createProposal( + agentContext: AgentContext, + options: CreateProposalOptions + ): Promise> abstract processProposal(messageContext: InboundMessageContext): Promise - abstract acceptProposal(options: AcceptProposalOptions): Promise> + abstract acceptProposal( + agentContext: AgentContext, + options: AcceptProposalOptions + ): Promise> abstract negotiateProposal( + agentContext: AgentContext, options: NegotiateProposalOptions ): Promise> // methods for offer - abstract createOffer(options: CreateOfferOptions): Promise> + abstract createOffer( + agentContext: AgentContext, + options: CreateOfferOptions + ): Promise> abstract processOffer(messageContext: InboundMessageContext): Promise - abstract acceptOffer(options: AcceptOfferOptions): Promise> - abstract negotiateOffer(options: NegotiateOfferOptions): Promise> + abstract acceptOffer( + agentContext: AgentContext, + options: AcceptOfferOptions + ): Promise> + abstract negotiateOffer( + agentContext: AgentContext, + options: NegotiateOfferOptions + ): Promise> // methods for request - abstract createRequest(options: CreateRequestOptions): Promise> + abstract createRequest( + agentContext: AgentContext, + options: CreateRequestOptions + ): Promise> abstract processRequest(messageContext: InboundMessageContext): Promise - abstract acceptRequest(options: AcceptRequestOptions): Promise> + abstract acceptRequest( + agentContext: AgentContext, + options: AcceptRequestOptions + ): Promise> // methods for issue abstract processCredential(messageContext: InboundMessageContext): Promise - abstract acceptCredential(options: AcceptCredentialOptions): Promise> + abstract acceptCredential( + agentContext: AgentContext, + options: AcceptCredentialOptions + ): Promise> // methods for ack abstract processAck(messageContext: InboundMessageContext): Promise // methods for problem-report - abstract createProblemReport(options: CreateProblemReportOptions): ProblemReportMessage + abstract createProblemReport(agentContext: AgentContext, options: CreateProblemReportOptions): ProblemReportMessage - abstract findProposalMessage(credentialExchangeId: string): Promise - abstract findOfferMessage(credentialExchangeId: string): Promise - abstract findRequestMessage(credentialExchangeId: string): Promise - abstract findCredentialMessage(credentialExchangeId: string): Promise - abstract getFormatData(credentialExchangeId: string): Promise> + abstract findProposalMessage(agentContext: AgentContext, credentialExchangeId: string): Promise + abstract findOfferMessage(agentContext: AgentContext, credentialExchangeId: string): Promise + abstract findRequestMessage(agentContext: AgentContext, credentialExchangeId: string): Promise + abstract findCredentialMessage(agentContext: AgentContext, credentialExchangeId: string): Promise + abstract getFormatData(agentContext: AgentContext, credentialExchangeId: string): Promise> /** * Decline a credential offer * @param credentialRecord The credential to be declined */ - public async declineOffer(credentialRecord: CredentialExchangeRecord): Promise { + public async declineOffer( + agentContext: AgentContext, + credentialRecord: CredentialExchangeRecord + ): Promise { credentialRecord.assertState(CredentialState.OfferReceived) - await this.updateState(credentialRecord, CredentialState.Declined) + await this.updateState(agentContext, credentialRecord, CredentialState.Declined) return credentialRecord } @@ -122,13 +148,14 @@ export abstract class CredentialService({ + this.eventEmitter.emit(agentContext, { type: CredentialEventTypes.CredentialStateChanged, payload: { credentialRecord: clonedCredential, @@ -172,8 +207,8 @@ export abstract class CredentialService { - return this.credentialRepository.getById(credentialRecordId) + public getById(agentContext: AgentContext, credentialRecordId: string): Promise { + return this.credentialRepository.getById(agentContext, credentialRecordId) } /** @@ -181,8 +216,8 @@ export abstract class CredentialService { - return this.credentialRepository.getAll() + public getAll(agentContext: AgentContext): Promise { + return this.credentialRepository.getAll(agentContext) } /** @@ -191,12 +226,16 @@ export abstract class CredentialService { - return this.credentialRepository.findById(connectionId) + public findById(agentContext: AgentContext, connectionId: string): Promise { + return this.credentialRepository.findById(agentContext, connectionId) } - public async delete(credentialRecord: CredentialExchangeRecord, options?: DeleteCredentialOptions): Promise { - await this.credentialRepository.delete(credentialRecord) + public async delete( + agentContext: AgentContext, + credentialRecord: CredentialExchangeRecord, + options?: DeleteCredentialOptions + ): Promise { + await this.credentialRepository.delete(agentContext, credentialRecord) const deleteAssociatedCredentials = options?.deleteAssociatedCredentials ?? true const deleteAssociatedDidCommMessages = options?.deleteAssociatedDidCommMessages ?? true @@ -204,16 +243,16 @@ export abstract class CredentialService { - return this.credentialRepository.getSingleByQuery({ + public getByThreadAndConnectionId( + agentContext: AgentContext, + threadId: string, + connectionId?: string + ): Promise { + return this.credentialRepository.getSingleByQuery(agentContext, { connectionId, threadId, }) @@ -242,16 +285,17 @@ export abstract class CredentialService { - return this.credentialRepository.findSingleByQuery({ + return this.credentialRepository.findSingleByQuery(agentContext, { connectionId, threadId, }) } - public async update(credentialRecord: CredentialExchangeRecord) { - return await this.credentialRepository.update(credentialRecord) + public async update(agentContext: AgentContext, credentialRecord: CredentialExchangeRecord) { + return await this.credentialRepository.update(agentContext, credentialRecord) } } diff --git a/packages/core/src/modules/credentials/services/index.ts b/packages/core/src/modules/credentials/services/index.ts index 05da1a90b5..3ef45ad8eb 100644 --- a/packages/core/src/modules/credentials/services/index.ts +++ b/packages/core/src/modules/credentials/services/index.ts @@ -1,2 +1 @@ export * from './CredentialService' -export * from '../protocol/revocation-notification/services/RevocationNotificationService' diff --git a/packages/core/src/modules/dids/DidsModule.ts b/packages/core/src/modules/dids/DidsModule.ts index 7fe57d25d6..d10ff463f1 100644 --- a/packages/core/src/modules/dids/DidsModule.ts +++ b/packages/core/src/modules/dids/DidsModule.ts @@ -2,6 +2,7 @@ import type { Key } from '../../crypto' import type { DependencyManager } from '../../plugins' import type { DidResolutionOptions } from './types' +import { AgentContext } from '../../agent' import { injectable, module } from '../../plugins' import { DidRepository } from './repository' @@ -12,26 +13,28 @@ import { DidResolverService } from './services/DidResolverService' export class DidsModule { private resolverService: DidResolverService private didRepository: DidRepository + private agentContext: AgentContext - public constructor(resolverService: DidResolverService, didRepository: DidRepository) { + public constructor(resolverService: DidResolverService, didRepository: DidRepository, agentContext: AgentContext) { this.resolverService = resolverService this.didRepository = didRepository + this.agentContext = agentContext } public resolve(didUrl: string, options?: DidResolutionOptions) { - return this.resolverService.resolve(didUrl, options) + return this.resolverService.resolve(this.agentContext, didUrl, options) } public resolveDidDocument(didUrl: string) { - return this.resolverService.resolveDidDocument(didUrl) + return this.resolverService.resolveDidDocument(this.agentContext, didUrl) } public findByRecipientKey(recipientKey: Key) { - return this.didRepository.findByRecipientKey(recipientKey) + return this.didRepository.findByRecipientKey(this.agentContext, recipientKey) } public findAllByRecipientKey(recipientKey: Key) { - return this.didRepository.findAllByRecipientKey(recipientKey) + return this.didRepository.findAllByRecipientKey(this.agentContext, recipientKey) } /** diff --git a/packages/core/src/modules/dids/__tests__/DidResolverService.test.ts b/packages/core/src/modules/dids/__tests__/DidResolverService.test.ts index 785c30d00c..7dff728532 100644 --- a/packages/core/src/modules/dids/__tests__/DidResolverService.test.ts +++ b/packages/core/src/modules/dids/__tests__/DidResolverService.test.ts @@ -1,7 +1,7 @@ import type { IndyLedgerService } from '../../ledger' import type { DidRepository } from '../repository' -import { getAgentConfig, mockProperty } from '../../../../tests/helpers' +import { getAgentConfig, getAgentContext, mockProperty } from '../../../../tests/helpers' import { JsonTransformer } from '../../../utils/JsonTransformer' import { DidDocument } from '../domain' import { parseDid } from '../domain/parse' @@ -13,11 +13,16 @@ import didKeyEd25519Fixture from './__fixtures__/didKeyEd25519.json' jest.mock('../methods/key/KeyDidResolver') const agentConfig = getAgentConfig('DidResolverService') +const agentContext = getAgentContext() describe('DidResolverService', () => { const indyLedgerServiceMock = jest.fn() as unknown as IndyLedgerService const didDocumentRepositoryMock = jest.fn() as unknown as DidRepository - const didResolverService = new DidResolverService(agentConfig, indyLedgerServiceMock, didDocumentRepositoryMock) + const didResolverService = new DidResolverService( + indyLedgerServiceMock, + didDocumentRepositoryMock, + agentConfig.logger + ) it('should correctly find and call the correct resolver for a specified did', async () => { const didKeyResolveSpy = jest.spyOn(KeyDidResolver.prototype, 'resolve') @@ -32,17 +37,19 @@ describe('DidResolverService', () => { } didKeyResolveSpy.mockResolvedValue(returnValue) - const result = await didResolverService.resolve('did:key:xxxx', { someKey: 'string' }) + const result = await didResolverService.resolve(agentContext, 'did:key:xxxx', { someKey: 'string' }) expect(result).toEqual(returnValue) expect(didKeyResolveSpy).toHaveBeenCalledTimes(1) - expect(didKeyResolveSpy).toHaveBeenCalledWith('did:key:xxxx', parseDid('did:key:xxxx'), { someKey: 'string' }) + expect(didKeyResolveSpy).toHaveBeenCalledWith(agentContext, 'did:key:xxxx', parseDid('did:key:xxxx'), { + someKey: 'string', + }) }) it("should return an error with 'invalidDid' if the did string couldn't be parsed", async () => { const did = 'did:__Asd:asdfa' - const result = await didResolverService.resolve(did) + const result = await didResolverService.resolve(agentContext, did) expect(result).toEqual({ didDocument: null, @@ -56,7 +63,7 @@ describe('DidResolverService', () => { it("should return an error with 'unsupportedDidMethod' if the did has no resolver", async () => { const did = 'did:example:asdfa' - const result = await didResolverService.resolve(did) + const result = await didResolverService.resolve(agentContext, did) expect(result).toEqual({ didDocument: null, diff --git a/packages/core/src/modules/dids/__tests__/peer-did.test.ts b/packages/core/src/modules/dids/__tests__/peer-did.test.ts index c5205f1e60..98a772fb07 100644 --- a/packages/core/src/modules/dids/__tests__/peer-did.test.ts +++ b/packages/core/src/modules/dids/__tests__/peer-did.test.ts @@ -1,6 +1,9 @@ +import type { AgentContext } from '../../../agent' import type { IndyLedgerService } from '../../ledger' -import { getAgentConfig } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext } from '../../../../tests/helpers' import { EventEmitter } from '../../../agent/EventEmitter' import { Key, KeyType } from '../../../crypto' import { IndyStorageService } from '../../../storage/IndyStorageService' @@ -24,19 +27,21 @@ describe('peer dids', () => { let didRepository: DidRepository let didResolverService: DidResolverService let wallet: IndyWallet + let agentContext: AgentContext let eventEmitter: EventEmitter beforeEach(async () => { - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) + agentContext = getAgentContext({ wallet }) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) - const storageService = new IndyStorageService(wallet, config) - eventEmitter = new EventEmitter(config) + const storageService = new IndyStorageService(config.agentDependencies) + eventEmitter = new EventEmitter(config.agentDependencies, new Subject()) didRepository = new DidRepository(storageService, eventEmitter) // Mocking IndyLedgerService as we're only interested in the did:peer resolver - didResolverService = new DidResolverService(config, {} as unknown as IndyLedgerService, didRepository) + didResolverService = new DidResolverService({} as unknown as IndyLedgerService, didRepository, config.logger) }) afterEach(async () => { @@ -124,7 +129,7 @@ describe('peer dids', () => { }, }) - await didRepository.save(didDocumentRecord) + await didRepository.save(agentContext, didDocumentRecord) }) test('receive a did and did document', async () => { @@ -161,13 +166,13 @@ describe('peer dids', () => { }, }) - await didRepository.save(didDocumentRecord) + await didRepository.save(agentContext, didDocumentRecord) // Then we save the did (not the did document) in the connection record // connectionRecord.theirDid = didPeer.did // Then when we want to send a message we can resolve the did document - const { didDocument: resolvedDidDocument } = await didResolverService.resolve(did) + const { didDocument: resolvedDidDocument } = await didResolverService.resolve(agentContext, did) expect(resolvedDidDocument).toBeInstanceOf(DidDocument) expect(resolvedDidDocument?.toJSON()).toMatchObject(didPeer1zQmY) }) diff --git a/packages/core/src/modules/dids/domain/DidResolver.ts b/packages/core/src/modules/dids/domain/DidResolver.ts index 6e0a98537f..050ea2cd97 100644 --- a/packages/core/src/modules/dids/domain/DidResolver.ts +++ b/packages/core/src/modules/dids/domain/DidResolver.ts @@ -1,6 +1,12 @@ +import type { AgentContext } from '../../../agent' import type { ParsedDid, DidResolutionResult, DidResolutionOptions } from '../types' export interface DidResolver { readonly supportedMethods: string[] - resolve(did: string, parsed: ParsedDid, didResolutionOptions: DidResolutionOptions): Promise + resolve( + agentContext: AgentContext, + did: string, + parsed: ParsedDid, + didResolutionOptions: DidResolutionOptions + ): Promise } diff --git a/packages/core/src/modules/dids/methods/key/KeyDidResolver.ts b/packages/core/src/modules/dids/methods/key/KeyDidResolver.ts index eb7d4ee5ae..41f4a0e221 100644 --- a/packages/core/src/modules/dids/methods/key/KeyDidResolver.ts +++ b/packages/core/src/modules/dids/methods/key/KeyDidResolver.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../agent' import type { DidResolver } from '../../domain/DidResolver' import type { DidResolutionResult } from '../../types' @@ -6,7 +7,7 @@ import { DidKey } from './DidKey' export class KeyDidResolver implements DidResolver { public readonly supportedMethods = ['key'] - public async resolve(did: string): Promise { + public async resolve(agentContext: AgentContext, did: string): Promise { const didDocumentMetadata = {} try { diff --git a/packages/core/src/modules/dids/methods/key/__tests__/KeyDidResolver.test.ts b/packages/core/src/modules/dids/methods/key/__tests__/KeyDidResolver.test.ts index 7c12e9f110..08157cbdcb 100644 --- a/packages/core/src/modules/dids/methods/key/__tests__/KeyDidResolver.test.ts +++ b/packages/core/src/modules/dids/methods/key/__tests__/KeyDidResolver.test.ts @@ -1,3 +1,6 @@ +import type { AgentContext } from '../../../../../agent' + +import { getAgentContext } from '../../../../../../tests/helpers' import { JsonTransformer } from '../../../../../utils/JsonTransformer' import didKeyEd25519Fixture from '../../../__tests__/__fixtures__/didKeyEd25519.json' import { DidKey } from '../DidKey' @@ -6,14 +9,19 @@ import { KeyDidResolver } from '../KeyDidResolver' describe('DidResolver', () => { describe('KeyDidResolver', () => { let keyDidResolver: KeyDidResolver + let agentContext: AgentContext beforeEach(() => { keyDidResolver = new KeyDidResolver() + agentContext = getAgentContext() }) it('should correctly resolve a did:key document', async () => { const fromDidSpy = jest.spyOn(DidKey, 'fromDid') - const result = await keyDidResolver.resolve('did:key:z6MkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th') + const result = await keyDidResolver.resolve( + agentContext, + 'did:key:z6MkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th' + ) expect(JsonTransformer.toJSON(result)).toMatchObject({ didDocument: didKeyEd25519Fixture, @@ -26,7 +34,10 @@ describe('DidResolver', () => { }) it('should return did resolution metadata with error if the did contains an unsupported multibase', async () => { - const result = await keyDidResolver.resolve('did:key:asdfkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th') + const result = await keyDidResolver.resolve( + agentContext, + 'did:key:asdfkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th' + ) expect(result).toEqual({ didDocument: null, @@ -39,7 +50,10 @@ describe('DidResolver', () => { }) it('should return did resolution metadata with error if the did contains an unsupported multibase', async () => { - const result = await keyDidResolver.resolve('did:key:z6MkmjYasdfasfd8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th') + const result = await keyDidResolver.resolve( + agentContext, + 'did:key:z6MkmjYasdfasfd8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th' + ) expect(result).toEqual({ didDocument: null, diff --git a/packages/core/src/modules/dids/methods/peer/PeerDidResolver.ts b/packages/core/src/modules/dids/methods/peer/PeerDidResolver.ts index 6aebfda5f2..85fad84c54 100644 --- a/packages/core/src/modules/dids/methods/peer/PeerDidResolver.ts +++ b/packages/core/src/modules/dids/methods/peer/PeerDidResolver.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../agent' import type { DidDocument } from '../../domain' import type { DidResolver } from '../../domain/DidResolver' import type { DidRepository } from '../../repository' @@ -18,7 +19,7 @@ export class PeerDidResolver implements DidResolver { this.didRepository = didRepository } - public async resolve(did: string): Promise { + public async resolve(agentContext: AgentContext, did: string): Promise { const didDocumentMetadata = {} try { @@ -36,7 +37,7 @@ export class PeerDidResolver implements DidResolver { } // For Method 1, retrieve from storage else if (numAlgo === PeerDidNumAlgo.GenesisDoc) { - const didDocumentRecord = await this.didRepository.getById(did) + const didDocumentRecord = await this.didRepository.getById(agentContext, did) if (!didDocumentRecord.didDocument) { throw new AriesFrameworkError(`Found did record for method 1 peer did (${did}), but no did document.`) diff --git a/packages/core/src/modules/dids/methods/sov/SovDidResolver.ts b/packages/core/src/modules/dids/methods/sov/SovDidResolver.ts index 5f02c8dd4c..325b5cf185 100644 --- a/packages/core/src/modules/dids/methods/sov/SovDidResolver.ts +++ b/packages/core/src/modules/dids/methods/sov/SovDidResolver.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../agent' import type { IndyEndpointAttrib, IndyLedgerService } from '../../../ledger' import type { DidResolver } from '../../domain/DidResolver' import type { ParsedDid, DidResolutionResult } from '../../types' @@ -22,12 +23,12 @@ export class SovDidResolver implements DidResolver { public readonly supportedMethods = ['sov'] - public async resolve(did: string, parsed: ParsedDid): Promise { + public async resolve(agentContext: AgentContext, did: string, parsed: ParsedDid): Promise { const didDocumentMetadata = {} try { - const nym = await this.indyLedgerService.getPublicDid(parsed.id) - const endpoints = await this.indyLedgerService.getEndpointsForDid(did) + const nym = await this.indyLedgerService.getPublicDid(agentContext, parsed.id) + const endpoints = await this.indyLedgerService.getEndpointsForDid(agentContext, did) const verificationMethodId = `${parsed.did}#key-1` const keyAgreementId = `${parsed.did}#key-agreement-1` diff --git a/packages/core/src/modules/dids/methods/sov/__tests__/SovDidResolver.test.ts b/packages/core/src/modules/dids/methods/sov/__tests__/SovDidResolver.test.ts index ec20ac80be..b1dd46280f 100644 --- a/packages/core/src/modules/dids/methods/sov/__tests__/SovDidResolver.test.ts +++ b/packages/core/src/modules/dids/methods/sov/__tests__/SovDidResolver.test.ts @@ -1,7 +1,8 @@ +import type { AgentContext } from '../../../../../agent' import type { IndyEndpointAttrib } from '../../../../ledger/services/IndyLedgerService' import type { GetNymResponse } from 'indy-sdk' -import { mockFunction } from '../../../../../../tests/helpers' +import { getAgentContext, mockFunction } from '../../../../../../tests/helpers' import { JsonTransformer } from '../../../../../utils/JsonTransformer' import { IndyLedgerService } from '../../../../ledger/services/IndyLedgerService' import didSovR1xKJw17sUoXhejEpugMYJFixture from '../../../__tests__/__fixtures__/didSovR1xKJw17sUoXhejEpugMYJ.json' @@ -16,10 +17,12 @@ describe('DidResolver', () => { describe('SovDidResolver', () => { let ledgerService: IndyLedgerService let sovDidResolver: SovDidResolver + let agentContext: AgentContext beforeEach(() => { ledgerService = new IndyLedgerServiceMock() sovDidResolver = new SovDidResolver(ledgerService) + agentContext = getAgentContext() }) it('should correctly resolve a did:sov document', async () => { @@ -40,7 +43,7 @@ describe('DidResolver', () => { mockFunction(ledgerService.getPublicDid).mockResolvedValue(nymResponse) mockFunction(ledgerService.getEndpointsForDid).mockResolvedValue(endpoints) - const result = await sovDidResolver.resolve(did, parseDid(did)) + const result = await sovDidResolver.resolve(agentContext, did, parseDid(did)) expect(JsonTransformer.toJSON(result)).toMatchObject({ didDocument: didSovR1xKJw17sUoXhejEpugMYJFixture, @@ -69,7 +72,7 @@ describe('DidResolver', () => { mockFunction(ledgerService.getPublicDid).mockReturnValue(Promise.resolve(nymResponse)) mockFunction(ledgerService.getEndpointsForDid).mockReturnValue(Promise.resolve(endpoints)) - const result = await sovDidResolver.resolve(did, parseDid(did)) + const result = await sovDidResolver.resolve(agentContext, did, parseDid(did)) expect(JsonTransformer.toJSON(result)).toMatchObject({ didDocument: didSovWJz9mHyW9BZksioQnRsrAoFixture, @@ -85,7 +88,7 @@ describe('DidResolver', () => { mockFunction(ledgerService.getPublicDid).mockRejectedValue(new Error('Error retrieving did')) - const result = await sovDidResolver.resolve(did, parseDid(did)) + const result = await sovDidResolver.resolve(agentContext, did, parseDid(did)) expect(result).toMatchObject({ didDocument: null, diff --git a/packages/core/src/modules/dids/methods/web/WebDidResolver.ts b/packages/core/src/modules/dids/methods/web/WebDidResolver.ts index 628b2eb177..77d9b1e295 100644 --- a/packages/core/src/modules/dids/methods/web/WebDidResolver.ts +++ b/packages/core/src/modules/dids/methods/web/WebDidResolver.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../../agent' import type { DidResolver } from '../../domain/DidResolver' import type { ParsedDid, DidResolutionResult, DidResolutionOptions } from '../../types' @@ -19,6 +20,7 @@ export class WebDidResolver implements DidResolver { } public async resolve( + agentContext: AgentContext, did: string, parsed: ParsedDid, didResolutionOptions: DidResolutionOptions diff --git a/packages/core/src/modules/dids/repository/DidRepository.ts b/packages/core/src/modules/dids/repository/DidRepository.ts index cb397cd1fe..3384558c7a 100644 --- a/packages/core/src/modules/dids/repository/DidRepository.ts +++ b/packages/core/src/modules/dids/repository/DidRepository.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../agent' import type { Key } from '../../../crypto' import { EventEmitter } from '../../../agent/EventEmitter' @@ -17,11 +18,11 @@ export class DidRepository extends Repository { super(DidRecord, storageService, eventEmitter) } - public findByRecipientKey(recipientKey: Key) { - return this.findSingleByQuery({ recipientKeyFingerprints: [recipientKey.fingerprint] }) + public findByRecipientKey(agentContext: AgentContext, recipientKey: Key) { + return this.findSingleByQuery(agentContext, { recipientKeyFingerprints: [recipientKey.fingerprint] }) } - public findAllByRecipientKey(recipientKey: Key) { - return this.findByQuery({ recipientKeyFingerprints: [recipientKey.fingerprint] }) + public findAllByRecipientKey(agentContext: AgentContext, recipientKey: Key) { + return this.findByQuery(agentContext, { recipientKeyFingerprints: [recipientKey.fingerprint] }) } } diff --git a/packages/core/src/modules/dids/services/DidResolverService.ts b/packages/core/src/modules/dids/services/DidResolverService.ts index 83ab576e50..3a0020a8b5 100644 --- a/packages/core/src/modules/dids/services/DidResolverService.ts +++ b/packages/core/src/modules/dids/services/DidResolverService.ts @@ -1,10 +1,11 @@ -import type { Logger } from '../../../logger' +import type { AgentContext } from '../../../agent' import type { DidResolver } from '../domain/DidResolver' import type { DidResolutionOptions, DidResolutionResult, ParsedDid } from '../types' -import { AgentConfig } from '../../../agent/AgentConfig' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { IndyLedgerService } from '../../ledger' import { parseDid } from '../domain/parse' import { KeyDidResolver } from '../methods/key/KeyDidResolver' @@ -18,8 +19,12 @@ export class DidResolverService { private logger: Logger private resolvers: DidResolver[] - public constructor(agentConfig: AgentConfig, indyLedgerService: IndyLedgerService, didRepository: DidRepository) { - this.logger = agentConfig.logger + public constructor( + indyLedgerService: IndyLedgerService, + didRepository: DidRepository, + @inject(InjectionSymbols.Logger) logger: Logger + ) { + this.logger = logger this.resolvers = [ new SovDidResolver(indyLedgerService), @@ -29,7 +34,11 @@ export class DidResolverService { ] } - public async resolve(didUrl: string, options: DidResolutionOptions = {}): Promise { + public async resolve( + agentContext: AgentContext, + didUrl: string, + options: DidResolutionOptions = {} + ): Promise { this.logger.debug(`resolving didUrl ${didUrl}`) const result = { @@ -56,14 +65,14 @@ export class DidResolverService { } } - return resolver.resolve(parsed.did, parsed, options) + return resolver.resolve(agentContext, parsed.did, parsed, options) } - public async resolveDidDocument(did: string) { + public async resolveDidDocument(agentContext: AgentContext, did: string) { const { didDocument, didResolutionMetadata: { error, message }, - } = await this.resolve(did) + } = await this.resolve(agentContext, did) if (!didDocument) { throw new AriesFrameworkError(`Unable to resolve did document for did '${did}': ${error} ${message}`) diff --git a/packages/core/src/modules/discover-features/DiscoverFeaturesModule.ts b/packages/core/src/modules/discover-features/DiscoverFeaturesModule.ts index b722ab8501..f557d186dd 100644 --- a/packages/core/src/modules/discover-features/DiscoverFeaturesModule.ts +++ b/packages/core/src/modules/discover-features/DiscoverFeaturesModule.ts @@ -2,16 +2,17 @@ import type { AgentMessageProcessedEvent } from '../../agent/Events' import type { DependencyManager } from '../../plugins' import type { ParsedMessageType } from '../../utils/messageType' -import { firstValueFrom, of, ReplaySubject } from 'rxjs' -import { filter, takeUntil, timeout, catchError, map } from 'rxjs/operators' +import { firstValueFrom, of, ReplaySubject, Subject } from 'rxjs' +import { catchError, filter, map, takeUntil, timeout } from 'rxjs/operators' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { EventEmitter } from '../../agent/EventEmitter' import { AgentEventTypes } from '../../agent/Events' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' -import { injectable, module } from '../../plugins' +import { InjectionSymbols } from '../../constants' +import { inject, injectable, module } from '../../plugins' import { canHandleMessageType, parseMessageType } from '../../utils/messageType' import { ConnectionService } from '../connections/services' @@ -26,7 +27,8 @@ export class DiscoverFeaturesModule { private messageSender: MessageSender private discoverFeaturesService: DiscoverFeaturesService private eventEmitter: EventEmitter - private agentConfig: AgentConfig + private stop$: Subject + private agentContext: AgentContext public constructor( dispatcher: Dispatcher, @@ -34,14 +36,16 @@ export class DiscoverFeaturesModule { messageSender: MessageSender, discoverFeaturesService: DiscoverFeaturesService, eventEmitter: EventEmitter, - agentConfig: AgentConfig + @inject(InjectionSymbols.Stop$) stop$: Subject, + agentContext: AgentContext ) { this.connectionService = connectionService this.messageSender = messageSender this.discoverFeaturesService = discoverFeaturesService this.registerHandlers(dispatcher) this.eventEmitter = eventEmitter - this.agentConfig = agentConfig + this.stop$ = stop$ + this.agentContext = agentContext } public async isProtocolSupported(connectionId: string, message: { type: ParsedMessageType }) { @@ -53,7 +57,7 @@ export class DiscoverFeaturesModule { .observable(AgentEventTypes.AgentMessageProcessed) .pipe( // Stop when the agent shuts down - takeUntil(this.agentConfig.stop$), + takeUntil(this.stop$), // filter by connection id and query disclose message type filter( (e) => @@ -83,12 +87,12 @@ export class DiscoverFeaturesModule { } public async queryFeatures(connectionId: string, options: { query: string; comment?: string }) { - const connection = await this.connectionService.getById(connectionId) + const connection = await this.connectionService.getById(this.agentContext, connectionId) const queryMessage = await this.discoverFeaturesService.createQuery(options) const outbound = createOutboundMessage(connection, queryMessage) - await this.messageSender.sendMessage(outbound) + await this.messageSender.sendMessage(this.agentContext, outbound) } private registerHandlers(dispatcher: Dispatcher) { diff --git a/packages/core/src/modules/generic-records/GenericRecordsModule.ts b/packages/core/src/modules/generic-records/GenericRecordsModule.ts index 75dd500025..9da14b016a 100644 --- a/packages/core/src/modules/generic-records/GenericRecordsModule.ts +++ b/packages/core/src/modules/generic-records/GenericRecordsModule.ts @@ -1,9 +1,10 @@ -import type { Logger } from '../../logger' import type { DependencyManager } from '../../plugins' import type { GenericRecord, GenericRecordTags, SaveGenericRecordOption } from './repository/GenericRecord' -import { AgentConfig } from '../../agent/AgentConfig' -import { injectable, module } from '../../plugins' +import { AgentContext } from '../../agent' +import { InjectionSymbols } from '../../constants' +import { Logger } from '../../logger' +import { inject, injectable, module } from '../../plugins' import { GenericRecordsRepository } from './repository/GenericRecordsRepository' import { GenericRecordService } from './service/GenericRecordService' @@ -17,14 +18,21 @@ export type ContentType = { export class GenericRecordsModule { private genericRecordsService: GenericRecordService private logger: Logger - public constructor(agentConfig: AgentConfig, genericRecordsService: GenericRecordService) { + private agentContext: AgentContext + + public constructor( + genericRecordsService: GenericRecordService, + @inject(InjectionSymbols.Logger) logger: Logger, + agentContext: AgentContext + ) { this.genericRecordsService = genericRecordsService - this.logger = agentConfig.logger + this.logger = logger + this.agentContext = agentContext } public async save({ content, tags }: SaveGenericRecordOption) { try { - const record = await this.genericRecordsService.save({ + const record = await this.genericRecordsService.save(this.agentContext, { content: content, tags: tags, }) @@ -41,7 +49,7 @@ export class GenericRecordsModule { public async delete(record: GenericRecord): Promise { try { - await this.genericRecordsService.delete(record) + await this.genericRecordsService.delete(this.agentContext, record) } catch (error) { this.logger.error('Error while saving generic-record', { error, @@ -54,7 +62,7 @@ export class GenericRecordsModule { public async update(record: GenericRecord): Promise { try { - await this.genericRecordsService.update(record) + await this.genericRecordsService.update(this.agentContext, record) } catch (error) { this.logger.error('Error while update generic-record', { error, @@ -66,15 +74,15 @@ export class GenericRecordsModule { } public async findById(id: string) { - return this.genericRecordsService.findById(id) + return this.genericRecordsService.findById(this.agentContext, id) } public async findAllByQuery(query: Partial): Promise { - return this.genericRecordsService.findAllByQuery(query) + return this.genericRecordsService.findAllByQuery(this.agentContext, query) } public async getAll(): Promise { - return this.genericRecordsService.getAll() + return this.genericRecordsService.getAll(this.agentContext) } /** diff --git a/packages/core/src/modules/generic-records/service/GenericRecordService.ts b/packages/core/src/modules/generic-records/service/GenericRecordService.ts index 654f7e3323..e19115ca37 100644 --- a/packages/core/src/modules/generic-records/service/GenericRecordService.ts +++ b/packages/core/src/modules/generic-records/service/GenericRecordService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../agent' import type { GenericRecordTags, SaveGenericRecordOption } from '../repository/GenericRecord' import { AriesFrameworkError } from '../../../error' @@ -13,14 +14,14 @@ export class GenericRecordService { this.genericRecordsRepository = genericRecordsRepository } - public async save({ content, tags }: SaveGenericRecordOption) { + public async save(agentContext: AgentContext, { content, tags }: SaveGenericRecordOption) { const genericRecord = new GenericRecord({ content: content, tags: tags, }) try { - await this.genericRecordsRepository.save(genericRecord) + await this.genericRecordsRepository.save(agentContext, genericRecord) return genericRecord } catch (error) { throw new AriesFrameworkError( @@ -29,31 +30,31 @@ export class GenericRecordService { } } - public async delete(record: GenericRecord): Promise { + public async delete(agentContext: AgentContext, record: GenericRecord): Promise { try { - await this.genericRecordsRepository.delete(record) + await this.genericRecordsRepository.delete(agentContext, record) } catch (error) { throw new AriesFrameworkError(`Unable to delete the genericRecord record with id ${record.id}. Message: ${error}`) } } - public async update(record: GenericRecord): Promise { + public async update(agentContext: AgentContext, record: GenericRecord): Promise { try { - await this.genericRecordsRepository.update(record) + await this.genericRecordsRepository.update(agentContext, record) } catch (error) { throw new AriesFrameworkError(`Unable to update the genericRecord record with id ${record.id}. Message: ${error}`) } } - public async findAllByQuery(query: Partial) { - return this.genericRecordsRepository.findByQuery(query) + public async findAllByQuery(agentContext: AgentContext, query: Partial) { + return this.genericRecordsRepository.findByQuery(agentContext, query) } - public async findById(id: string): Promise { - return this.genericRecordsRepository.findById(id) + public async findById(agentContext: AgentContext, id: string): Promise { + return this.genericRecordsRepository.findById(agentContext, id) } - public async getAll() { - return this.genericRecordsRepository.getAll() + public async getAll(agentContext: AgentContext) { + return this.genericRecordsRepository.getAll(agentContext) } } diff --git a/packages/core/src/modules/indy/services/IndyHolderService.ts b/packages/core/src/modules/indy/services/IndyHolderService.ts index 763841c4ef..e92b20896f 100644 --- a/packages/core/src/modules/indy/services/IndyHolderService.ts +++ b/packages/core/src/modules/indy/services/IndyHolderService.ts @@ -1,12 +1,14 @@ -import type { Logger } from '../../../logger' +import type { AgentContext } from '../../../agent' import type { RequestedCredentials } from '../../proofs' import type * as Indy from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { isIndyError } from '../../../utils/indyError' -import { IndyWallet } from '../../../wallet/IndyWallet' +import { assertIndyWallet } from '../../../wallet/util/assertIndyWallet' import { IndyRevocationService } from './IndyRevocationService' @@ -14,14 +16,16 @@ import { IndyRevocationService } from './IndyRevocationService' export class IndyHolderService { private indy: typeof Indy private logger: Logger - private wallet: IndyWallet private indyRevocationService: IndyRevocationService - public constructor(agentConfig: AgentConfig, indyRevocationService: IndyRevocationService, wallet: IndyWallet) { - this.indy = agentConfig.agentDependencies.indy - this.wallet = wallet + public constructor( + indyRevocationService: IndyRevocationService, + @inject(InjectionSymbols.Logger) logger: Logger, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies + ) { + this.indy = agentDependencies.indy this.indyRevocationService = indyRevocationService - this.logger = agentConfig.logger + this.logger = logger } /** @@ -36,24 +40,24 @@ export class IndyHolderService { * * @todo support attribute non_revoked fields */ - public async createProof({ - proofRequest, - requestedCredentials, - schemas, - credentialDefinitions, - }: CreateProofOptions): Promise { + public async createProof( + agentContext: AgentContext, + { proofRequest, requestedCredentials, schemas, credentialDefinitions }: CreateProofOptions + ): Promise { + assertIndyWallet(agentContext.wallet) try { this.logger.debug('Creating Indy Proof') const revocationStates: Indy.RevStates = await this.indyRevocationService.createRevocationState( + agentContext, proofRequest, requestedCredentials ) const indyProof: Indy.IndyProof = await this.indy.proverCreateProof( - this.wallet.handle, + agentContext.wallet.handle, proofRequest, requestedCredentials.toJSON(), - this.wallet.masterSecretId, + agentContext.wallet.masterSecretId, schemas, credentialDefinitions, revocationStates @@ -80,16 +84,20 @@ export class IndyHolderService { * * @returns The credential id */ - public async storeCredential({ - credentialRequestMetadata, - credential, - credentialDefinition, - credentialId, - revocationRegistryDefinition, - }: StoreCredentialOptions): Promise { + public async storeCredential( + agentContext: AgentContext, + { + credentialRequestMetadata, + credential, + credentialDefinition, + credentialId, + revocationRegistryDefinition, + }: StoreCredentialOptions + ): Promise { + assertIndyWallet(agentContext.wallet) try { return await this.indy.proverStoreCredential( - this.wallet.handle, + agentContext.wallet.handle, credentialId ?? null, credentialRequestMetadata, credential, @@ -114,9 +122,13 @@ export class IndyHolderService { * * @todo handle record not found */ - public async getCredential(credentialId: Indy.CredentialId): Promise { + public async getCredential( + agentContext: AgentContext, + credentialId: Indy.CredentialId + ): Promise { + assertIndyWallet(agentContext.wallet) try { - return await this.indy.proverGetCredential(this.wallet.handle, credentialId) + return await this.indy.proverGetCredential(agentContext.wallet.handle, credentialId) } catch (error) { this.logger.error(`Error getting Indy Credential '${credentialId}'`, { error, @@ -131,18 +143,18 @@ export class IndyHolderService { * * @returns The credential request and the credential request metadata */ - public async createCredentialRequest({ - holderDid, - credentialOffer, - credentialDefinition, - }: CreateCredentialRequestOptions): Promise<[Indy.CredReq, Indy.CredReqMetadata]> { + public async createCredentialRequest( + agentContext: AgentContext, + { holderDid, credentialOffer, credentialDefinition }: CreateCredentialRequestOptions + ): Promise<[Indy.CredReq, Indy.CredReqMetadata]> { + assertIndyWallet(agentContext.wallet) try { return await this.indy.proverCreateCredentialReq( - this.wallet.handle, + agentContext.wallet.handle, holderDid, credentialOffer, credentialDefinition, - this.wallet.masterSecretId + agentContext.wallet.masterSecretId ) } catch (error) { this.logger.error(`Error creating Indy Credential Request`, { @@ -165,17 +177,15 @@ export class IndyHolderService { * @returns List of credentials that are available for building a proof for the given proof request * */ - public async getCredentialsForProofRequest({ - proofRequest, - attributeReferent, - start = 0, - limit = 256, - extraQuery, - }: GetCredentialForProofRequestOptions): Promise { + public async getCredentialsForProofRequest( + agentContext: AgentContext, + { proofRequest, attributeReferent, start = 0, limit = 256, extraQuery }: GetCredentialForProofRequestOptions + ): Promise { + assertIndyWallet(agentContext.wallet) try { // Open indy credential search const searchHandle = await this.indy.proverSearchCredentialsForProofReq( - this.wallet.handle, + agentContext.wallet.handle, proofRequest, extraQuery ?? null ) @@ -210,9 +220,10 @@ export class IndyHolderService { * @param credentialId the id (referent) of the credential * */ - public async deleteCredential(credentialId: Indy.CredentialId): Promise { + public async deleteCredential(agentContext: AgentContext, credentialId: Indy.CredentialId): Promise { + assertIndyWallet(agentContext.wallet) try { - return await this.indy.proverDeleteCredential(this.wallet.handle, credentialId) + return await this.indy.proverDeleteCredential(agentContext.wallet.handle, credentialId) } catch (error) { this.logger.error(`Error deleting Indy Credential from Wallet`, { error, diff --git a/packages/core/src/modules/indy/services/IndyIssuerService.ts b/packages/core/src/modules/indy/services/IndyIssuerService.ts index 9c0ed580a6..58e9917cf0 100644 --- a/packages/core/src/modules/indy/services/IndyIssuerService.ts +++ b/packages/core/src/modules/indy/services/IndyIssuerService.ts @@ -1,37 +1,37 @@ -import type { FileSystem } from '../../../storage/FileSystem' +import type { AgentContext } from '../../../agent' import type { - default as Indy, - CredDef, - Schema, Cred, + CredDef, CredDefId, CredOffer, CredReq, CredRevocId, CredValues, + default as Indy, + Schema, } from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error/AriesFrameworkError' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { injectable, inject } from '../../../plugins' import { isIndyError } from '../../../utils/indyError' -import { IndyWallet } from '../../../wallet/IndyWallet' +import { assertIndyWallet } from '../../../wallet/util/assertIndyWallet' import { IndyUtilitiesService } from './IndyUtilitiesService' @injectable() export class IndyIssuerService { private indy: typeof Indy - private wallet: IndyWallet private indyUtilitiesService: IndyUtilitiesService - private fileSystem: FileSystem - public constructor(agentConfig: AgentConfig, wallet: IndyWallet, indyUtilitiesService: IndyUtilitiesService) { - this.indy = agentConfig.agentDependencies.indy - this.wallet = wallet + public constructor( + indyUtilitiesService: IndyUtilitiesService, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies + ) { + this.indy = agentDependencies.indy this.indyUtilitiesService = indyUtilitiesService - this.fileSystem = agentConfig.fileSystem } /** @@ -39,7 +39,11 @@ export class IndyIssuerService { * * @returns the schema. */ - public async createSchema({ originDid, name, version, attributes }: CreateSchemaOptions): Promise { + public async createSchema( + agentContext: AgentContext, + { originDid, name, version, attributes }: CreateSchemaOptions + ): Promise { + assertIndyWallet(agentContext.wallet) try { const [, schema] = await this.indy.issuerCreateSchema(originDid, name, version, attributes) @@ -54,16 +58,20 @@ export class IndyIssuerService { * * @returns the credential definition. */ - public async createCredentialDefinition({ - issuerDid, - schema, - tag = 'default', - signatureType = 'CL', - supportRevocation = false, - }: CreateCredentialDefinitionOptions): Promise { + public async createCredentialDefinition( + agentContext: AgentContext, + { + issuerDid, + schema, + tag = 'default', + signatureType = 'CL', + supportRevocation = false, + }: CreateCredentialDefinitionOptions + ): Promise { + assertIndyWallet(agentContext.wallet) try { const [, credentialDefinition] = await this.indy.issuerCreateAndStoreCredentialDef( - this.wallet.handle, + agentContext.wallet.handle, issuerDid, schema, tag, @@ -85,9 +93,10 @@ export class IndyIssuerService { * @param credentialDefinitionId The credential definition to create an offer for * @returns The created credential offer */ - public async createCredentialOffer(credentialDefinitionId: CredDefId) { + public async createCredentialOffer(agentContext: AgentContext, credentialDefinitionId: CredDefId) { + assertIndyWallet(agentContext.wallet) try { - return await this.indy.issuerCreateCredentialOffer(this.wallet.handle, credentialDefinitionId) + return await this.indy.issuerCreateCredentialOffer(agentContext.wallet.handle, credentialDefinitionId) } catch (error) { throw isIndyError(error) ? new IndySdkError(error) : error } @@ -98,13 +107,17 @@ export class IndyIssuerService { * * @returns Credential and revocation id */ - public async createCredential({ - credentialOffer, - credentialRequest, - credentialValues, - revocationRegistryId, - tailsFilePath, - }: CreateCredentialOptions): Promise<[Cred, CredRevocId]> { + public async createCredential( + agentContext: AgentContext, + { + credentialOffer, + credentialRequest, + credentialValues, + revocationRegistryId, + tailsFilePath, + }: CreateCredentialOptions + ): Promise<[Cred, CredRevocId]> { + assertIndyWallet(agentContext.wallet) try { // Indy SDK requires tailsReaderHandle. Use null if no tailsFilePath is present const tailsReaderHandle = tailsFilePath ? await this.indyUtilitiesService.createTailsReader(tailsFilePath) : 0 @@ -114,7 +127,7 @@ export class IndyIssuerService { } const [credential, credentialRevocationId] = await this.indy.issuerCreateCredential( - this.wallet.handle, + agentContext.wallet.handle, credentialOffer, credentialRequest, credentialValues, diff --git a/packages/core/src/modules/indy/services/IndyRevocationService.ts b/packages/core/src/modules/indy/services/IndyRevocationService.ts index 4431df8dde..fa84997876 100644 --- a/packages/core/src/modules/indy/services/IndyRevocationService.ts +++ b/packages/core/src/modules/indy/services/IndyRevocationService.ts @@ -1,15 +1,15 @@ -import type { Logger } from '../../../logger' -import type { FileSystem } from '../../../storage/FileSystem' +import type { AgentContext } from '../../../agent' import type { IndyRevocationInterval } from '../../credentials' import type { RequestedCredentials } from '../../proofs' import type { default as Indy } from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error/AriesFrameworkError' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { isIndyError } from '../../../utils/indyError' -import { IndyWallet } from '../../../wallet/IndyWallet' import { IndyLedgerService } from '../../ledger' import { IndyUtilitiesService } from './IndyUtilitiesService' @@ -23,26 +23,23 @@ enum RequestReferentType { export class IndyRevocationService { private indy: typeof Indy private indyUtilitiesService: IndyUtilitiesService - private fileSystem: FileSystem private ledgerService: IndyLedgerService private logger: Logger - private wallet: IndyWallet public constructor( - agentConfig: AgentConfig, indyUtilitiesService: IndyUtilitiesService, ledgerService: IndyLedgerService, - wallet: IndyWallet + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies, + @inject(InjectionSymbols.Logger) logger: Logger ) { - this.fileSystem = agentConfig.fileSystem - this.indy = agentConfig.agentDependencies.indy + this.indy = agentDependencies.indy this.indyUtilitiesService = indyUtilitiesService - this.logger = agentConfig.logger + this.logger = logger this.ledgerService = ledgerService - this.wallet = wallet } public async createRevocationState( + agentContext: AgentContext, proofRequest: Indy.IndyProofRequest, requestedCredentials: RequestedCredentials ): Promise { @@ -100,10 +97,12 @@ export class IndyRevocationService { this.assertRevocationInterval(requestRevocationInterval) const { revocationRegistryDefinition } = await this.ledgerService.getRevocationRegistryDefinition( + agentContext, revocationRegistryId ) const { revocationRegistryDelta, deltaTimestamp } = await this.ledgerService.getRevocationRegistryDelta( + agentContext, revocationRegistryId, requestRevocationInterval?.to, 0 @@ -147,6 +146,7 @@ export class IndyRevocationService { // Get revocation status for credential (given a from-to) // Note from-to interval details: https://github.com/hyperledger/indy-hipe/blob/master/text/0011-cred-revocation/README.md#indy-node-revocation-registry-intervals public async getRevocationStatus( + agentContext: AgentContext, credentialRevocationId: string, revocationRegistryDefinitionId: string, requestRevocationInterval: IndyRevocationInterval @@ -158,6 +158,7 @@ export class IndyRevocationService { this.assertRevocationInterval(requestRevocationInterval) const { revocationRegistryDelta, deltaTimestamp } = await this.ledgerService.getRevocationRegistryDelta( + agentContext, revocationRegistryDefinitionId, requestRevocationInterval.to, 0 diff --git a/packages/core/src/modules/indy/services/IndyUtilitiesService.ts b/packages/core/src/modules/indy/services/IndyUtilitiesService.ts index 74784f2182..eef01ccfd2 100644 --- a/packages/core/src/modules/indy/services/IndyUtilitiesService.ts +++ b/packages/core/src/modules/indy/services/IndyUtilitiesService.ts @@ -1,11 +1,12 @@ -import type { Logger } from '../../../logger' -import type { FileSystem } from '../../../storage/FileSystem' -import type { default as Indy, BlobReaderHandle } from 'indy-sdk' +import type { BlobReaderHandle, default as Indy } from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' +import { FileSystem } from '../../../storage/FileSystem' import { isIndyError } from '../../../utils/indyError' import { getDirFromFilePath } from '../../../utils/path' @@ -15,10 +16,14 @@ export class IndyUtilitiesService { private logger: Logger private fileSystem: FileSystem - public constructor(agentConfig: AgentConfig) { - this.indy = agentConfig.agentDependencies.indy - this.logger = agentConfig.logger - this.fileSystem = agentConfig.fileSystem + public constructor( + @inject(InjectionSymbols.Logger) logger: Logger, + @inject(InjectionSymbols.FileSystem) fileSystem: FileSystem, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies + ) { + this.indy = agentDependencies.indy + this.logger = logger + this.fileSystem = fileSystem } /** diff --git a/packages/core/src/modules/indy/services/IndyVerifierService.ts b/packages/core/src/modules/indy/services/IndyVerifierService.ts index b480288acc..b6cc387c31 100644 --- a/packages/core/src/modules/indy/services/IndyVerifierService.ts +++ b/packages/core/src/modules/indy/services/IndyVerifierService.ts @@ -1,8 +1,10 @@ +import type { AgentContext } from '../../../agent' import type * as Indy from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { IndySdkError } from '../../../error' -import { injectable } from '../../../plugins' +import { injectable, inject } from '../../../plugins' import { isIndyError } from '../../../utils/indyError' import { IndyLedgerService } from '../../ledger/services/IndyLedgerService' @@ -11,19 +13,23 @@ export class IndyVerifierService { private indy: typeof Indy private ledgerService: IndyLedgerService - public constructor(agentConfig: AgentConfig, ledgerService: IndyLedgerService) { - this.indy = agentConfig.agentDependencies.indy + public constructor( + ledgerService: IndyLedgerService, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies + ) { + this.indy = agentDependencies.indy this.ledgerService = ledgerService } - public async verifyProof({ - proofRequest, - proof, - schemas, - credentialDefinitions, - }: VerifyProofOptions): Promise { + public async verifyProof( + agentContext: AgentContext, + { proofRequest, proof, schemas, credentialDefinitions }: VerifyProofOptions + ): Promise { try { - const { revocationRegistryDefinitions, revocationRegistryStates } = await this.getRevocationRegistries(proof) + const { revocationRegistryDefinitions, revocationRegistryStates } = await this.getRevocationRegistries( + agentContext, + proof + ) return await this.indy.verifierVerifyProof( proofRequest, @@ -38,7 +44,7 @@ export class IndyVerifierService { } } - private async getRevocationRegistries(proof: Indy.IndyProof) { + private async getRevocationRegistries(agentContext: AgentContext, proof: Indy.IndyProof) { const revocationRegistryDefinitions: Indy.RevocRegDefs = {} const revocationRegistryStates: Indy.RevStates = Object.create(null) for (const identifier of proof.identifiers) { @@ -48,6 +54,7 @@ export class IndyVerifierService { //Fetch Revocation Registry Definition if not already fetched if (revocationRegistryId && !revocationRegistryDefinitions[revocationRegistryId]) { const { revocationRegistryDefinition } = await this.ledgerService.getRevocationRegistryDefinition( + agentContext, revocationRegistryId ) revocationRegistryDefinitions[revocationRegistryId] = revocationRegistryDefinition @@ -58,7 +65,11 @@ export class IndyVerifierService { if (!revocationRegistryStates[revocationRegistryId]) { revocationRegistryStates[revocationRegistryId] = Object.create(null) } - const { revocationRegistry } = await this.ledgerService.getRevocationRegistry(revocationRegistryId, timestamp) + const { revocationRegistry } = await this.ledgerService.getRevocationRegistry( + agentContext, + revocationRegistryId, + timestamp + ) revocationRegistryStates[revocationRegistryId][timestamp] = revocationRegistry } } diff --git a/packages/core/src/modules/indy/services/__mocks__/IndyHolderService.ts b/packages/core/src/modules/indy/services/__mocks__/IndyHolderService.ts index 1d6ed433b6..35afdc14ab 100644 --- a/packages/core/src/modules/indy/services/__mocks__/IndyHolderService.ts +++ b/packages/core/src/modules/indy/services/__mocks__/IndyHolderService.ts @@ -1,11 +1,11 @@ import type { CreateCredentialRequestOptions, StoreCredentialOptions } from '../IndyHolderService' export const IndyHolderService = jest.fn(() => ({ - storeCredential: jest.fn(({ credentialId }: StoreCredentialOptions) => + storeCredential: jest.fn((_, { credentialId }: StoreCredentialOptions) => Promise.resolve(credentialId ?? 'some-random-uuid') ), deleteCredential: jest.fn(() => Promise.resolve()), - createCredentialRequest: jest.fn(({ holderDid, credentialDefinition }: CreateCredentialRequestOptions) => + createCredentialRequest: jest.fn((_, { holderDid, credentialDefinition }: CreateCredentialRequestOptions) => Promise.resolve([ { prover_did: holderDid, diff --git a/packages/core/src/modules/indy/services/__mocks__/IndyIssuerService.ts b/packages/core/src/modules/indy/services/__mocks__/IndyIssuerService.ts index b9b23337ba..823e961a15 100644 --- a/packages/core/src/modules/indy/services/__mocks__/IndyIssuerService.ts +++ b/packages/core/src/modules/indy/services/__mocks__/IndyIssuerService.ts @@ -13,7 +13,7 @@ export const IndyIssuerService = jest.fn(() => ({ ]) ), - createCredentialOffer: jest.fn((credentialDefinitionId: string) => + createCredentialOffer: jest.fn((_, credentialDefinitionId: string) => Promise.resolve({ schema_id: 'aaa', cred_def_id: credentialDefinitionId, diff --git a/packages/core/src/modules/ledger/IndyPool.ts b/packages/core/src/modules/ledger/IndyPool.ts index 860e2f4f51..d58f98449f 100644 --- a/packages/core/src/modules/ledger/IndyPool.ts +++ b/packages/core/src/modules/ledger/IndyPool.ts @@ -1,7 +1,8 @@ -import type { AgentConfig } from '../../agent/AgentConfig' +import type { AgentDependencies } from '../../agent/AgentDependencies' import type { Logger } from '../../logger' import type { FileSystem } from '../../storage/FileSystem' import type * as Indy from 'indy-sdk' +import type { Subject } from 'rxjs' import { AriesFrameworkError, IndySdkError } from '../../error' import { isIndyError } from '../../utils/indyError' @@ -31,14 +32,20 @@ export class IndyPool { private poolConnected?: Promise public authorAgreement?: AuthorAgreement | null - public constructor(agentConfig: AgentConfig, poolConfig: IndyPoolConfig) { - this.indy = agentConfig.agentDependencies.indy + public constructor( + poolConfig: IndyPoolConfig, + agentDependencies: AgentDependencies, + logger: Logger, + stop$: Subject, + fileSystem: FileSystem + ) { + this.indy = agentDependencies.indy + this.fileSystem = fileSystem this.poolConfig = poolConfig - this.fileSystem = agentConfig.fileSystem - this.logger = agentConfig.logger + this.logger = logger // Listen to stop$ (shutdown) and close pool - agentConfig.stop$.subscribe(async () => { + stop$.subscribe(async () => { if (this._poolHandle) { await this.close() } diff --git a/packages/core/src/modules/ledger/LedgerModule.ts b/packages/core/src/modules/ledger/LedgerModule.ts index cceaf5210e..9262511c8c 100644 --- a/packages/core/src/modules/ledger/LedgerModule.ts +++ b/packages/core/src/modules/ledger/LedgerModule.ts @@ -1,23 +1,27 @@ import type { DependencyManager } from '../../plugins' -import type { SchemaTemplate, CredentialDefinitionTemplate } from './services' +import type { IndyPoolConfig } from './IndyPool' +import type { CredentialDefinitionTemplate, SchemaTemplate } from './services' import type { NymRole } from 'indy-sdk' -import { InjectionSymbols } from '../../constants' +import { AgentContext } from '../../agent' import { AriesFrameworkError } from '../../error' -import { injectable, module, inject } from '../../plugins' -import { Wallet } from '../../wallet/Wallet' +import { injectable, module } from '../../plugins' -import { IndyPoolService, IndyLedgerService } from './services' +import { IndyLedgerService, IndyPoolService } from './services' @module() @injectable() export class LedgerModule { private ledgerService: IndyLedgerService - private wallet: Wallet + private agentContext: AgentContext - public constructor(@inject(InjectionSymbols.Wallet) wallet: Wallet, ledgerService: IndyLedgerService) { + public constructor(ledgerService: IndyLedgerService, agentContext: AgentContext) { this.ledgerService = ledgerService - this.wallet = wallet + this.agentContext = agentContext + } + + public setPools(poolConfigs: IndyPoolConfig[]) { + return this.ledgerService.setPools(poolConfigs) } /** @@ -28,54 +32,54 @@ export class LedgerModule { } public async registerPublicDid(did: string, verkey: string, alias: string, role?: NymRole) { - const myPublicDid = this.wallet.publicDid?.did + const myPublicDid = this.agentContext.wallet.publicDid?.did if (!myPublicDid) { throw new AriesFrameworkError('Agent has no public DID.') } - return this.ledgerService.registerPublicDid(myPublicDid, did, verkey, alias, role) + return this.ledgerService.registerPublicDid(this.agentContext, myPublicDid, did, verkey, alias, role) } public async getPublicDid(did: string) { - return this.ledgerService.getPublicDid(did) + return this.ledgerService.getPublicDid(this.agentContext, did) } public async registerSchema(schema: SchemaTemplate) { - const did = this.wallet.publicDid?.did + const did = this.agentContext.wallet.publicDid?.did if (!did) { throw new AriesFrameworkError('Agent has no public DID.') } - return this.ledgerService.registerSchema(did, schema) + return this.ledgerService.registerSchema(this.agentContext, did, schema) } public async getSchema(id: string) { - return this.ledgerService.getSchema(id) + return this.ledgerService.getSchema(this.agentContext, id) } public async registerCredentialDefinition( credentialDefinitionTemplate: Omit ) { - const did = this.wallet.publicDid?.did + const did = this.agentContext.wallet.publicDid?.did if (!did) { throw new AriesFrameworkError('Agent has no public DID.') } - return this.ledgerService.registerCredentialDefinition(did, { + return this.ledgerService.registerCredentialDefinition(this.agentContext, did, { ...credentialDefinitionTemplate, signatureType: 'CL', }) } public async getCredentialDefinition(id: string) { - return this.ledgerService.getCredentialDefinition(id) + return this.ledgerService.getCredentialDefinition(this.agentContext, id) } public async getRevocationRegistryDefinition(revocationRegistryDefinitionId: string) { - return this.ledgerService.getRevocationRegistryDefinition(revocationRegistryDefinitionId) + return this.ledgerService.getRevocationRegistryDefinition(this.agentContext, revocationRegistryDefinitionId) } public async getRevocationRegistryDelta( @@ -83,7 +87,12 @@ export class LedgerModule { fromSeconds = 0, toSeconds = new Date().getTime() ) { - return this.ledgerService.getRevocationRegistryDelta(revocationRegistryDefinitionId, fromSeconds, toSeconds) + return this.ledgerService.getRevocationRegistryDelta( + this.agentContext, + revocationRegistryDefinitionId, + fromSeconds, + toSeconds + ) } /** diff --git a/packages/core/src/modules/ledger/__tests__/IndyLedgerService.test.ts b/packages/core/src/modules/ledger/__tests__/IndyLedgerService.test.ts index 0929cad3dd..85d29438a3 100644 --- a/packages/core/src/modules/ledger/__tests__/IndyLedgerService.test.ts +++ b/packages/core/src/modules/ledger/__tests__/IndyLedgerService.test.ts @@ -1,7 +1,11 @@ +import type { AgentContext } from '../../../agent' import type { IndyPoolConfig } from '../IndyPool' import type { LedgerReadReplyResponse, LedgerWriteReplyResponse } from 'indy-sdk' -import { getAgentConfig, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { NodeFileSystem } from '../../../../../node/src/NodeFileSystem' +import { getAgentConfig, getAgentContext, mockFunction } from '../../../../tests/helpers' import { CacheRepository } from '../../../cache/CacheRepository' import { IndyWallet } from '../../../wallet/IndyWallet' import { IndyIssuerService } from '../../indy/services/IndyIssuerService' @@ -30,13 +34,15 @@ describe('IndyLedgerService', () => { indyLedgers: pools, }) let wallet: IndyWallet + let agentContext: AgentContext let poolService: IndyPoolService let cacheRepository: CacheRepository let indyIssuerService: IndyIssuerService let ledgerService: IndyLedgerService beforeAll(async () => { - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) + agentContext = getAgentContext() // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) }) @@ -50,7 +56,7 @@ describe('IndyLedgerService', () => { mockFunction(cacheRepository.findById).mockResolvedValue(null) indyIssuerService = new IndyIssuerServiceMock() poolService = new IndyPoolServiceMock() - const pool = new IndyPool(config, pools[0]) + const pool = new IndyPool(pools[0], config.agentDependencies, config.logger, new Subject(), new NodeFileSystem()) jest.spyOn(pool, 'submitWriteRequest').mockResolvedValue({} as LedgerWriteReplyResponse) jest.spyOn(pool, 'submitReadRequest').mockResolvedValue({} as LedgerReadReplyResponse) jest.spyOn(pool, 'connect').mockResolvedValue(0) @@ -58,7 +64,7 @@ describe('IndyLedgerService', () => { // @ts-ignore poolService.ledgerWritePool = pool - ledgerService = new IndyLedgerService(wallet, config, indyIssuerService, poolService) + ledgerService = new IndyLedgerService(config.agentDependencies, config.logger, indyIssuerService, poolService) }) describe('LedgerServiceWrite', () => { @@ -78,6 +84,7 @@ describe('IndyLedgerService', () => { } as never) await expect( ledgerService.registerPublicDid( + agentContext, 'BBPoJqRKatdcfLEAFL7exC', 'N8NQHLtCKfPmWMgCSdfa7h', 'GAb4NUvpBcHVCvtP45vTVa5Bp74vFg3iXzdp1Gbd68Wf', @@ -104,6 +111,7 @@ describe('IndyLedgerService', () => { } as never) await expect( ledgerService.registerPublicDid( + agentContext, 'BBPoJqRKatdcfLEAFL7exC', 'N8NQHLtCKfPmWMgCSdfa7h', 'GAb4NUvpBcHVCvtP45vTVa5Bp74vFg3iXzdp1Gbd68Wf', @@ -118,7 +126,7 @@ describe('IndyLedgerService', () => { poolService.ledgerWritePool.authorAgreement = undefined poolService.ledgerWritePool.config.transactionAuthorAgreement = undefined - ledgerService = new IndyLedgerService(wallet, config, indyIssuerService, poolService) + ledgerService = new IndyLedgerService(config.agentDependencies, config.logger, indyIssuerService, poolService) // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore jest.spyOn(ledgerService, 'getTransactionAuthorAgreement').mockResolvedValue({ @@ -134,6 +142,7 @@ describe('IndyLedgerService', () => { } as never) await expect( ledgerService.registerPublicDid( + agentContext, 'BBPoJqRKatdcfLEAFL7exC', 'N8NQHLtCKfPmWMgCSdfa7h', 'GAb4NUvpBcHVCvtP45vTVa5Bp74vFg3iXzdp1Gbd68Wf', diff --git a/packages/core/src/modules/ledger/__tests__/IndyPoolService.test.ts b/packages/core/src/modules/ledger/__tests__/IndyPoolService.test.ts index cf72f71cf7..eebbe4332c 100644 --- a/packages/core/src/modules/ledger/__tests__/IndyPoolService.test.ts +++ b/packages/core/src/modules/ledger/__tests__/IndyPoolService.test.ts @@ -1,7 +1,11 @@ +import type { AgentContext } from '../../../agent' import type { IndyPoolConfig } from '../IndyPool' import type { CachedDidResponse } from '../services/IndyPoolService' -import { getAgentConfig, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { NodeFileSystem } from '../../../../../node/src/NodeFileSystem' +import { agentDependencies, getAgentConfig, getAgentContext, mockFunction } from '../../../../tests/helpers' import { CacheRecord } from '../../../cache' import { CacheRepository } from '../../../cache/CacheRepository' import { AriesFrameworkError } from '../../../error/AriesFrameworkError' @@ -53,12 +57,14 @@ describe('IndyPoolService', () => { const config = getAgentConfig('IndyPoolServiceTest', { indyLedgers: pools, }) + let agentContext: AgentContext let wallet: IndyWallet let poolService: IndyPoolService let cacheRepository: CacheRepository beforeAll(async () => { - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) + agentContext = getAgentContext() // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) }) @@ -71,7 +77,15 @@ describe('IndyPoolService', () => { cacheRepository = new CacheRepositoryMock() mockFunction(cacheRepository.findById).mockResolvedValue(null) - poolService = new IndyPoolService(config, cacheRepository) + poolService = new IndyPoolService( + cacheRepository, + agentDependencies, + config.logger, + new Subject(), + new NodeFileSystem() + ) + + poolService.setPools(pools) }) describe('ledgerWritePool', () => { @@ -79,20 +93,18 @@ describe('IndyPoolService', () => { expect(poolService.ledgerWritePool).toBe(poolService.pools[0]) }) - it('should throw a LedgerNotConfiguredError error if no pools are configured on the agent', async () => { - const config = getAgentConfig('IndyPoolServiceTest', { indyLedgers: [] }) - poolService = new IndyPoolService(config, cacheRepository) + it('should throw a LedgerNotConfiguredError error if no pools are configured on the pool service', async () => { + poolService.setPools([]) expect(() => poolService.ledgerWritePool).toThrow(LedgerNotConfiguredError) }) }) describe('getPoolForDid', () => { - it('should throw a LedgerNotConfiguredError error if no pools are configured on the agent', async () => { - const config = getAgentConfig('IndyPoolServiceTest', { indyLedgers: [] }) - poolService = new IndyPoolService(config, cacheRepository) + it('should throw a LedgerNotConfiguredError error if no pools are configured on the pool service', async () => { + poolService.setPools([]) - expect(poolService.getPoolForDid('some-did')).rejects.toThrow(LedgerNotConfiguredError) + expect(poolService.getPoolForDid(agentContext, 'some-did')).rejects.toThrow(LedgerNotConfiguredError) }) it('should throw a LedgerError if all ledger requests throw an error other than NotFoundError', async () => { @@ -103,7 +115,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(() => Promise.reject(new AriesFrameworkError('Something went wrong'))) }) - expect(poolService.getPoolForDid(did)).rejects.toThrowError(LedgerError) + expect(poolService.getPoolForDid(agentContext, did)).rejects.toThrowError(LedgerError) }) it('should throw a LedgerNotFoundError if all pools did not find the did on the ledger', async () => { @@ -116,7 +128,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - expect(poolService.getPoolForDid(did)).rejects.toThrowError(LedgerNotFoundError) + expect(poolService.getPoolForDid(agentContext, did)).rejects.toThrowError(LedgerNotFoundError) }) it('should return the pool if the did was only found on one ledger', async () => { @@ -131,7 +143,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('sovrinMain') }) @@ -150,7 +162,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('sovrinBuilder') }) @@ -168,7 +180,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('indicioMain') }) @@ -186,7 +198,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('sovrinMain') }) @@ -205,7 +217,7 @@ describe('IndyPoolService', () => { spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('sovrinBuilder') }) @@ -238,9 +250,7 @@ describe('IndyPoolService', () => { }) ) - poolService = new IndyPoolService(config, cacheRepository) - - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe(pool.id) }) @@ -261,17 +271,16 @@ describe('IndyPoolService', () => { const spy = mockFunction(cacheRepository.update).mockResolvedValue() - poolService = new IndyPoolService(config, cacheRepository) poolService.pools.forEach((pool, index) => { const spy = jest.spyOn(pool, 'submitReadRequest') spy.mockImplementationOnce(responses[index]) }) - const { pool } = await poolService.getPoolForDid(did) + const { pool } = await poolService.getPoolForDid(agentContext, did) expect(pool.config.id).toBe('sovrinBuilder') - const cacheRecord = spy.mock.calls[0][0] + const cacheRecord = spy.mock.calls[0][1] expect(cacheRecord.entries.length).toBe(1) expect(cacheRecord.entries[0].key).toBe(did) expect(cacheRecord.entries[0].value).toEqual({ diff --git a/packages/core/src/modules/ledger/services/IndyLedgerService.ts b/packages/core/src/modules/ledger/services/IndyLedgerService.ts index 60363f628e..99142004ef 100644 --- a/packages/core/src/modules/ledger/services/IndyLedgerService.ts +++ b/packages/core/src/modules/ledger/services/IndyLedgerService.ts @@ -1,8 +1,8 @@ -import type { Logger } from '../../../logger' -import type { AcceptanceMechanisms, AuthorAgreement, IndyPool } from '../IndyPool' +import type { AgentContext } from '../../../agent' +import type { AcceptanceMechanisms, AuthorAgreement, IndyPool, IndyPoolConfig } from '../IndyPool' import type { - default as Indy, CredDef, + default as Indy, LedgerReadReplyResponse, LedgerRequest, LedgerWriteReplyResponse, @@ -10,16 +10,18 @@ import type { Schema, } from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { AgentDependencies } from '../../../agent/AgentDependencies' +import { InjectionSymbols } from '../../../constants' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { - didFromSchemaId, didFromCredentialDefinitionId, didFromRevocationRegistryDefinitionId, + didFromSchemaId, } from '../../../utils/did' import { isIndyError } from '../../../utils/indyError' -import { IndyWallet } from '../../../wallet/IndyWallet' +import { assertIndyWallet } from '../../../wallet/util/assertIndyWallet' import { IndyIssuerService } from '../../indy/services/IndyIssuerService' import { LedgerError } from '../error/LedgerError' @@ -27,7 +29,6 @@ import { IndyPoolService } from './IndyPoolService' @injectable() export class IndyLedgerService { - private wallet: IndyWallet private indy: typeof Indy private logger: Logger @@ -35,23 +36,27 @@ export class IndyLedgerService { private indyPoolService: IndyPoolService public constructor( - wallet: IndyWallet, - agentConfig: AgentConfig, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies, + @inject(InjectionSymbols.Logger) logger: Logger, indyIssuer: IndyIssuerService, indyPoolService: IndyPoolService ) { - this.wallet = wallet - this.indy = agentConfig.agentDependencies.indy - this.logger = agentConfig.logger + this.indy = agentDependencies.indy + this.logger = logger this.indyIssuer = indyIssuer this.indyPoolService = indyPoolService } + public setPools(poolConfigs: IndyPoolConfig[]) { + return this.indyPoolService.setPools(poolConfigs) + } + public async connectToPools() { return this.indyPoolService.connectToPools() } public async registerPublicDid( + agentContext: AgentContext, submitterDid: string, targetDid: string, verkey: string, @@ -65,7 +70,7 @@ export class IndyLedgerService { const request = await this.indy.buildNymRequest(submitterDid, targetDid, verkey, alias, role || null) - const response = await this.submitWriteRequest(pool, request, submitterDid) + const response = await this.submitWriteRequest(agentContext, pool, request, submitterDid) this.logger.debug(`Registered public did '${targetDid}' on ledger '${pool.id}'`, { response, @@ -87,15 +92,15 @@ export class IndyLedgerService { } } - public async getPublicDid(did: string) { + public async getPublicDid(agentContext: AgentContext, did: string) { // Getting the pool for a did also retrieves the DID. We can just use that - const { did: didResponse } = await this.indyPoolService.getPoolForDid(did) + const { did: didResponse } = await this.indyPoolService.getPoolForDid(agentContext, did) return didResponse } - public async getEndpointsForDid(did: string) { - const { pool } = await this.indyPoolService.getPoolForDid(did) + public async getEndpointsForDid(agentContext: AgentContext, did: string) { + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) try { this.logger.debug(`Get endpoints for did '${did}' from ledger '${pool.id}'`) @@ -123,17 +128,21 @@ export class IndyLedgerService { } } - public async registerSchema(did: string, schemaTemplate: SchemaTemplate): Promise { + public async registerSchema( + agentContext: AgentContext, + did: string, + schemaTemplate: SchemaTemplate + ): Promise { const pool = this.indyPoolService.ledgerWritePool try { this.logger.debug(`Register schema on ledger '${pool.id}' with did '${did}'`, schemaTemplate) const { name, attributes, version } = schemaTemplate - const schema = await this.indyIssuer.createSchema({ originDid: did, name, version, attributes }) + const schema = await this.indyIssuer.createSchema(agentContext, { originDid: did, name, version, attributes }) const request = await this.indy.buildSchemaRequest(did, schema) - const response = await this.submitWriteRequest(pool, request, did) + const response = await this.submitWriteRequest(agentContext, pool, request, did) this.logger.debug(`Registered schema '${schema.id}' on ledger '${pool.id}'`, { response, schema, @@ -153,9 +162,9 @@ export class IndyLedgerService { } } - public async getSchema(schemaId: string) { + public async getSchema(agentContext: AgentContext, schemaId: string) { const did = didFromSchemaId(schemaId) - const { pool } = await this.indyPoolService.getPoolForDid(did) + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) try { this.logger.debug(`Getting schema '${schemaId}' from ledger '${pool.id}'`) @@ -186,6 +195,7 @@ export class IndyLedgerService { } public async registerCredentialDefinition( + agentContext: AgentContext, did: string, credentialDefinitionTemplate: CredentialDefinitionTemplate ): Promise { @@ -198,7 +208,7 @@ export class IndyLedgerService { ) const { schema, tag, signatureType, supportRevocation } = credentialDefinitionTemplate - const credentialDefinition = await this.indyIssuer.createCredentialDefinition({ + const credentialDefinition = await this.indyIssuer.createCredentialDefinition(agentContext, { issuerDid: did, schema, tag, @@ -208,7 +218,7 @@ export class IndyLedgerService { const request = await this.indy.buildCredDefRequest(did, credentialDefinition) - const response = await this.submitWriteRequest(pool, request, did) + const response = await this.submitWriteRequest(agentContext, pool, request, did) this.logger.debug(`Registered credential definition '${credentialDefinition.id}' on ledger '${pool.id}'`, { response, @@ -230,9 +240,9 @@ export class IndyLedgerService { } } - public async getCredentialDefinition(credentialDefinitionId: string) { + public async getCredentialDefinition(agentContext: AgentContext, credentialDefinitionId: string) { const did = didFromCredentialDefinitionId(credentialDefinitionId) - const { pool } = await this.indyPoolService.getPoolForDid(did) + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) this.logger.debug(`Using ledger '${pool.id}' to retrieve credential definition '${credentialDefinitionId}'`) @@ -266,10 +276,11 @@ export class IndyLedgerService { } public async getRevocationRegistryDefinition( + agentContext: AgentContext, revocationRegistryDefinitionId: string ): Promise { const did = didFromRevocationRegistryDefinitionId(revocationRegistryDefinitionId) - const { pool } = await this.indyPoolService.getPoolForDid(did) + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) this.logger.debug( `Using ledger '${pool.id}' to retrieve revocation registry definition '${revocationRegistryDefinitionId}'` @@ -313,15 +324,16 @@ export class IndyLedgerService { } } - //Retrieves the accumulated state of a revocation registry by id given a revocation interval from & to (used primarily for proof creation) + // Retrieves the accumulated state of a revocation registry by id given a revocation interval from & to (used primarily for proof creation) public async getRevocationRegistryDelta( + agentContext: AgentContext, revocationRegistryDefinitionId: string, to: number = new Date().getTime(), from = 0 ): Promise { //TODO - implement a cache const did = didFromRevocationRegistryDefinitionId(revocationRegistryDefinitionId) - const { pool } = await this.indyPoolService.getPoolForDid(did) + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) this.logger.debug( `Using ledger '${pool.id}' to retrieve revocation registry delta with revocation registry definition id: '${revocationRegistryDefinitionId}'`, @@ -369,14 +381,15 @@ export class IndyLedgerService { } } - //Retrieves the accumulated state of a revocation registry by id given a timestamp (used primarily for verification) + // Retrieves the accumulated state of a revocation registry by id given a timestamp (used primarily for verification) public async getRevocationRegistry( + agentContext: AgentContext, revocationRegistryDefinitionId: string, timestamp: number ): Promise { //TODO - implement a cache const did = didFromRevocationRegistryDefinitionId(revocationRegistryDefinitionId) - const { pool } = await this.indyPoolService.getPoolForDid(did) + const { pool } = await this.indyPoolService.getPoolForDid(agentContext, did) this.logger.debug( `Using ledger '${pool.id}' to retrieve revocation registry accumulated state with revocation registry definition id: '${revocationRegistryDefinitionId}'`, @@ -417,13 +430,14 @@ export class IndyLedgerService { } private async submitWriteRequest( + agentContext: AgentContext, pool: IndyPool, request: LedgerRequest, signDid: string ): Promise { try { const requestWithTaa = await this.appendTaa(pool, request) - const signedRequestWithTaa = await this.signRequest(signDid, requestWithTaa) + const signedRequestWithTaa = await this.signRequest(agentContext, signDid, requestWithTaa) const response = await pool.submitWriteRequest(signedRequestWithTaa) @@ -443,9 +457,11 @@ export class IndyLedgerService { } } - private async signRequest(did: string, request: LedgerRequest): Promise { + private async signRequest(agentContext: AgentContext, did: string, request: LedgerRequest): Promise { + assertIndyWallet(agentContext.wallet) + try { - return this.indy.signRequest(this.wallet.handle, did, request) + return this.indy.signRequest(agentContext.wallet.handle, did, request) } catch (error) { throw isIndyError(error) ? new IndySdkError(error) : error } diff --git a/packages/core/src/modules/ledger/services/IndyPoolService.ts b/packages/core/src/modules/ledger/services/IndyPoolService.ts index bf0b461176..f45c338f2b 100644 --- a/packages/core/src/modules/ledger/services/IndyPoolService.ts +++ b/packages/core/src/modules/ledger/services/IndyPoolService.ts @@ -1,10 +1,16 @@ -import type { Logger } from '../../../logger/Logger' +import type { AgentContext } from '../../../agent' +import type { IndyPoolConfig } from '../IndyPool' import type * as Indy from 'indy-sdk' -import { AgentConfig } from '../../../agent/AgentConfig' +import { Subject } from 'rxjs' + +import { AgentDependencies } from '../../../agent/AgentDependencies' import { CacheRepository, PersistedLruCache } from '../../../cache' +import { InjectionSymbols } from '../../../constants' import { IndySdkError } from '../../../error/IndySdkError' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger/Logger' +import { injectable, inject } from '../../../plugins' +import { FileSystem } from '../../../storage/FileSystem' import { isSelfCertifiedDid } from '../../../utils/did' import { isIndyError } from '../../../utils/indyError' import { allSettled, onlyFulfilled, onlyRejected } from '../../../utils/promises' @@ -21,19 +27,36 @@ export interface CachedDidResponse { } @injectable() export class IndyPoolService { - public readonly pools: IndyPool[] + public pools: IndyPool[] = [] private logger: Logger private indy: typeof Indy + private agentDependencies: AgentDependencies + private stop$: Subject + private fileSystem: FileSystem private didCache: PersistedLruCache - public constructor(agentConfig: AgentConfig, cacheRepository: CacheRepository) { - this.pools = agentConfig.indyLedgers.map((poolConfig) => new IndyPool(agentConfig, poolConfig)) - this.logger = agentConfig.logger - this.indy = agentConfig.agentDependencies.indy + public constructor( + cacheRepository: CacheRepository, + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies, + @inject(InjectionSymbols.Logger) logger: Logger, + @inject(InjectionSymbols.Stop$) stop$: Subject, + @inject(InjectionSymbols.FileSystem) fileSystem: FileSystem + ) { + this.logger = logger + this.indy = agentDependencies.indy + this.agentDependencies = agentDependencies + this.fileSystem = fileSystem + this.stop$ = stop$ this.didCache = new PersistedLruCache(DID_POOL_CACHE_ID, DID_POOL_CACHE_LIMIT, cacheRepository) } + public setPools(poolConfigs: IndyPoolConfig[]) { + this.pools = poolConfigs.map( + (poolConfig) => new IndyPool(poolConfig, this.agentDependencies, this.logger, this.stop$, this.fileSystem) + ) + } + /** * Create connections to all ledger pools */ @@ -67,7 +90,10 @@ export class IndyPoolService { * Get the most appropriate pool for the given did. The algorithm is based on the approach as described in this document: * https://docs.google.com/document/d/109C_eMsuZnTnYe2OAd02jAts1vC4axwEKIq7_4dnNVA/edit */ - public async getPoolForDid(did: string): Promise<{ pool: IndyPool; did: Indy.GetNymResponse }> { + public async getPoolForDid( + agentContext: AgentContext, + did: string + ): Promise<{ pool: IndyPool; did: Indy.GetNymResponse }> { const pools = this.pools if (pools.length === 0) { @@ -76,7 +102,7 @@ export class IndyPoolService { ) } - const cachedNymResponse = await this.didCache.get(did) + const cachedNymResponse = await this.didCache.get(agentContext, did) const pool = this.pools.find((pool) => pool.id === cachedNymResponse?.poolId) // If we have the nym response with associated pool in the cache, we'll use that @@ -123,7 +149,7 @@ export class IndyPoolService { value = productionOrNonProduction[0].value } - await this.didCache.set(did, { + await this.didCache.set(agentContext, did, { nymResponse: value.did, poolId: value.pool.id, }) diff --git a/packages/core/src/modules/oob/OutOfBandModule.ts b/packages/core/src/modules/oob/OutOfBandModule.ts index b5875af5ce..8f0a28b784 100644 --- a/packages/core/src/modules/oob/OutOfBandModule.ts +++ b/packages/core/src/modules/oob/OutOfBandModule.ts @@ -2,7 +2,6 @@ import type { AgentMessage } from '../../agent/AgentMessage' import type { AgentMessageReceivedEvent } from '../../agent/Events' import type { Key } from '../../crypto' import type { Attachment } from '../../decorators/attachment/Attachment' -import type { Logger } from '../../logger' import type { ConnectionInvitationMessage, ConnectionRecord, Routing } from '../../modules/connections' import type { DependencyManager } from '../../plugins' import type { PlaintextMessage } from '../../types' @@ -10,16 +9,18 @@ import type { HandshakeReusedEvent } from './domain/OutOfBandEvents' import { catchError, EmptyError, first, firstValueFrom, map, of, timeout } from 'rxjs' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { EventEmitter } from '../../agent/EventEmitter' import { AgentEventTypes } from '../../agent/Events' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' +import { InjectionSymbols } from '../../constants' import { ServiceDecorator } from '../../decorators/service/ServiceDecorator' import { AriesFrameworkError } from '../../error' +import { Logger } from '../../logger' import { ConnectionsModule, DidExchangeState, HandshakeProtocol } from '../../modules/connections' -import { injectable, module } from '../../plugins' +import { inject, injectable, module } from '../../plugins' import { DidCommMessageRepository, DidCommMessageRole } from '../../storage' import { JsonEncoder, JsonTransformer } from '../../utils' import { parseMessageType, supportsIncomingMessageType } from '../../utils/messageType' @@ -87,22 +88,23 @@ export class OutOfBandModule { private dispatcher: Dispatcher private messageSender: MessageSender private eventEmitter: EventEmitter - private agentConfig: AgentConfig + private agentContext: AgentContext private logger: Logger public constructor( dispatcher: Dispatcher, - agentConfig: AgentConfig, outOfBandService: OutOfBandService, routingService: RoutingService, connectionsModule: ConnectionsModule, didCommMessageRepository: DidCommMessageRepository, messageSender: MessageSender, - eventEmitter: EventEmitter + eventEmitter: EventEmitter, + @inject(InjectionSymbols.Logger) logger: Logger, + agentContext: AgentContext ) { this.dispatcher = dispatcher - this.agentConfig = agentConfig - this.logger = agentConfig.logger + this.agentContext = agentContext + this.logger = logger this.outOfBandService = outOfBandService this.routingService = routingService this.connectionsModule = connectionsModule @@ -131,11 +133,11 @@ export class OutOfBandModule { const multiUseInvitation = config.multiUseInvitation ?? false const handshake = config.handshake ?? true const customHandshakeProtocols = config.handshakeProtocols - const autoAcceptConnection = config.autoAcceptConnection ?? this.agentConfig.autoAcceptConnections + const autoAcceptConnection = config.autoAcceptConnection ?? this.agentContext.config.autoAcceptConnections // We don't want to treat an empty array as messages being provided const messages = config.messages && config.messages.length > 0 ? config.messages : undefined - const label = config.label ?? this.agentConfig.label - const imageUrl = config.imageUrl ?? this.agentConfig.connectionImageUrl + const label = config.label ?? this.agentContext.config.label + const imageUrl = config.imageUrl ?? this.agentContext.config.connectionImageUrl const appendedAttachments = config.appendedAttachments && config.appendedAttachments.length > 0 ? config.appendedAttachments : undefined @@ -167,7 +169,7 @@ export class OutOfBandModule { } } - const routing = config.routing ?? (await this.routingService.getRouting({})) + const routing = config.routing ?? (await this.routingService.getRouting(this.agentContext, {})) const services = routing.endpoints.map((endpoint, index) => { return new OutOfBandDidCommService({ @@ -209,8 +211,8 @@ export class OutOfBandModule { autoAcceptConnection, }) - await this.outOfBandService.save(outOfBandRecord) - this.outOfBandService.emitStateChangedEvent(outOfBandRecord, null) + await this.outOfBandService.save(this.agentContext, outOfBandRecord) + this.outOfBandService.emitStateChangedEvent(this.agentContext, outOfBandRecord, null) return outOfBandRecord } @@ -239,7 +241,7 @@ export class OutOfBandModule { domain: string }): Promise<{ message: Message; invitationUrl: string }> { // Create keys (and optionally register them at the mediator) - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(this.agentContext) // Set the service on the message config.message.service = new ServiceDecorator({ @@ -250,7 +252,7 @@ export class OutOfBandModule { // We need to update the message with the new service, so we can // retrieve it from storage later on. - await this.didCommMessageRepository.saveOrUpdateAgentMessage({ + await this.didCommMessageRepository.saveOrUpdateAgentMessage(this.agentContext, { agentMessage: config.message, associatedRecordId: config.recordId, role: DidCommMessageRole.Sender, @@ -318,9 +320,9 @@ export class OutOfBandModule { const autoAcceptInvitation = config.autoAcceptInvitation ?? true const autoAcceptConnection = config.autoAcceptConnection ?? true const reuseConnection = config.reuseConnection ?? false - const label = config.label ?? this.agentConfig.label + const label = config.label ?? this.agentContext.config.label const alias = config.alias - const imageUrl = config.imageUrl ?? this.agentConfig.connectionImageUrl + const imageUrl = config.imageUrl ?? this.agentContext.config.connectionImageUrl const messages = outOfBandInvitation.getRequests() @@ -344,8 +346,8 @@ export class OutOfBandModule { outOfBandInvitation: outOfBandInvitation, autoAcceptConnection, }) - await this.outOfBandService.save(outOfBandRecord) - this.outOfBandService.emitStateChangedEvent(outOfBandRecord, null) + await this.outOfBandService.save(this.agentContext, outOfBandRecord) + this.outOfBandService.emitStateChangedEvent(this.agentContext, outOfBandRecord, null) if (autoAcceptInvitation) { return await this.acceptInvitation(outOfBandRecord.id, { @@ -387,7 +389,7 @@ export class OutOfBandModule { routing?: Routing } ) { - const outOfBandRecord = await this.outOfBandService.getById(outOfBandId) + const outOfBandRecord = await this.outOfBandService.getById(this.agentContext, outOfBandId) const { outOfBandInvitation } = outOfBandRecord const { label, alias, imageUrl, autoAcceptConnection, reuseConnection, routing } = config @@ -396,7 +398,7 @@ export class OutOfBandModule { const existingConnection = await this.findExistingConnection(services) - await this.outOfBandService.updateState(outOfBandRecord, OutOfBandState.PrepareResponse) + await this.outOfBandService.updateState(this.agentContext, outOfBandRecord, OutOfBandState.PrepareResponse) if (handshakeProtocols) { this.logger.debug('Out of band message contains handshake protocols.') @@ -478,11 +480,11 @@ export class OutOfBandModule { } public async findByRecipientKey(recipientKey: Key) { - return this.outOfBandService.findByRecipientKey(recipientKey) + return this.outOfBandService.findByRecipientKey(this.agentContext, recipientKey) } public async findByInvitationId(invitationId: string) { - return this.outOfBandService.findByInvitationId(invitationId) + return this.outOfBandService.findByInvitationId(this.agentContext, invitationId) } /** @@ -491,7 +493,7 @@ export class OutOfBandModule { * @returns List containing all out of band records */ public getAll() { - return this.outOfBandService.getAll() + return this.outOfBandService.getAll(this.agentContext) } /** @@ -503,7 +505,7 @@ export class OutOfBandModule { * */ public getById(outOfBandId: string): Promise { - return this.outOfBandService.getById(outOfBandId) + return this.outOfBandService.getById(this.agentContext, outOfBandId) } /** @@ -513,7 +515,7 @@ export class OutOfBandModule { * @returns The out of band record or null if not found */ public findById(outOfBandId: string): Promise { - return this.outOfBandService.findById(outOfBandId) + return this.outOfBandService.findById(this.agentContext, outOfBandId) } /** @@ -522,7 +524,7 @@ export class OutOfBandModule { * @param outOfBandId the out of band record id */ public async deleteById(outOfBandId: string) { - return this.outOfBandService.deleteById(outOfBandId) + return this.outOfBandService.deleteById(this.agentContext, outOfBandId) } private assertHandshakeProtocols(handshakeProtocols: HandshakeProtocol[]) { @@ -605,7 +607,7 @@ export class OutOfBandModule { this.logger.debug(`Message with type ${plaintextMessage['@type']} can be processed.`) - this.eventEmitter.emit({ + this.eventEmitter.emit(this.agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: plaintextMessage, @@ -646,7 +648,7 @@ export class OutOfBandModule { }) plaintextMessage['~service'] = JsonTransformer.toJSON(serviceDecorator) - this.eventEmitter.emit({ + this.eventEmitter.emit(this.agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: plaintextMessage, @@ -655,7 +657,11 @@ export class OutOfBandModule { } private async handleHandshakeReuse(outOfBandRecord: OutOfBandRecord, connectionRecord: ConnectionRecord) { - const reuseMessage = await this.outOfBandService.createHandShakeReuse(outOfBandRecord, connectionRecord) + const reuseMessage = await this.outOfBandService.createHandShakeReuse( + this.agentContext, + outOfBandRecord, + connectionRecord + ) const reuseAcceptedEventPromise = firstValueFrom( this.eventEmitter.observable(OutOfBandEventTypes.HandshakeReused).pipe( @@ -676,7 +682,7 @@ export class OutOfBandModule { ) const outbound = createOutboundMessage(connectionRecord, reuseMessage) - await this.messageSender.sendMessage(outbound) + await this.messageSender.sendMessage(this.agentContext, outbound) return reuseAcceptedEventPromise } diff --git a/packages/core/src/modules/oob/OutOfBandService.ts b/packages/core/src/modules/oob/OutOfBandService.ts index ce64b5513d..0c78517f99 100644 --- a/packages/core/src/modules/oob/OutOfBandService.ts +++ b/packages/core/src/modules/oob/OutOfBandService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../agent' import type { InboundMessageContext } from '../../agent/models/InboundMessageContext' import type { Key } from '../../crypto' import type { ConnectionRecord } from '../connections' @@ -34,7 +35,7 @@ export class OutOfBandService { throw new AriesFrameworkError('handshake-reuse message must have a parent thread id') } - const outOfBandRecord = await this.findByInvitationId(parentThreadId) + const outOfBandRecord = await this.findByInvitationId(messageContext.agentContext, parentThreadId) if (!outOfBandRecord) { throw new AriesFrameworkError('No out of band record found for handshake-reuse message') } @@ -49,7 +50,7 @@ export class OutOfBandService { } const reusedConnection = messageContext.assertReadyConnection() - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: OutOfBandEventTypes.HandshakeReused, payload: { reuseThreadId: reuseMessage.threadId, @@ -60,7 +61,7 @@ export class OutOfBandService { // If the out of band record is not reusable we can set the state to done if (!outOfBandRecord.reusable) { - await this.updateState(outOfBandRecord, OutOfBandState.Done) + await this.updateState(messageContext.agentContext, outOfBandRecord, OutOfBandState.Done) } const reuseAcceptedMessage = new HandshakeReuseAcceptedMessage({ @@ -79,7 +80,7 @@ export class OutOfBandService { throw new AriesFrameworkError('handshake-reuse-accepted message must have a parent thread id') } - const outOfBandRecord = await this.findByInvitationId(parentThreadId) + const outOfBandRecord = await this.findByInvitationId(messageContext.agentContext, parentThreadId) if (!outOfBandRecord) { throw new AriesFrameworkError('No out of band record found for handshake-reuse-accepted message') } @@ -100,7 +101,7 @@ export class OutOfBandService { throw new AriesFrameworkError('handshake-reuse-accepted is not in response to a handshake-reuse message.') } - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: OutOfBandEventTypes.HandshakeReused, payload: { reuseThreadId: reuseAcceptedMessage.threadId, @@ -110,35 +111,43 @@ export class OutOfBandService { }) // receiver role is never reusable, so we can set the state to done - await this.updateState(outOfBandRecord, OutOfBandState.Done) + await this.updateState(messageContext.agentContext, outOfBandRecord, OutOfBandState.Done) } - public async createHandShakeReuse(outOfBandRecord: OutOfBandRecord, connectionRecord: ConnectionRecord) { + public async createHandShakeReuse( + agentContext: AgentContext, + outOfBandRecord: OutOfBandRecord, + connectionRecord: ConnectionRecord + ) { const reuseMessage = new HandshakeReuseMessage({ parentThreadId: outOfBandRecord.outOfBandInvitation.id }) // Store the reuse connection id outOfBandRecord.reuseConnectionId = connectionRecord.id - await this.outOfBandRepository.update(outOfBandRecord) + await this.outOfBandRepository.update(agentContext, outOfBandRecord) return reuseMessage } - public async save(outOfBandRecord: OutOfBandRecord) { - return this.outOfBandRepository.save(outOfBandRecord) + public async save(agentContext: AgentContext, outOfBandRecord: OutOfBandRecord) { + return this.outOfBandRepository.save(agentContext, outOfBandRecord) } - public async updateState(outOfBandRecord: OutOfBandRecord, newState: OutOfBandState) { + public async updateState(agentContext: AgentContext, outOfBandRecord: OutOfBandRecord, newState: OutOfBandState) { const previousState = outOfBandRecord.state outOfBandRecord.state = newState - await this.outOfBandRepository.update(outOfBandRecord) + await this.outOfBandRepository.update(agentContext, outOfBandRecord) - this.emitStateChangedEvent(outOfBandRecord, previousState) + this.emitStateChangedEvent(agentContext, outOfBandRecord, previousState) } - public emitStateChangedEvent(outOfBandRecord: OutOfBandRecord, previousState: OutOfBandState | null) { + public emitStateChangedEvent( + agentContext: AgentContext, + outOfBandRecord: OutOfBandRecord, + previousState: OutOfBandState | null + ) { const clonedOutOfBandRecord = JsonTransformer.clone(outOfBandRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: OutOfBandEventTypes.OutOfBandStateChanged, payload: { outOfBandRecord: clonedOutOfBandRecord, @@ -147,28 +156,30 @@ export class OutOfBandService { }) } - public async findById(outOfBandRecordId: string) { - return this.outOfBandRepository.findById(outOfBandRecordId) + public async findById(agentContext: AgentContext, outOfBandRecordId: string) { + return this.outOfBandRepository.findById(agentContext, outOfBandRecordId) } - public async getById(outOfBandRecordId: string) { - return this.outOfBandRepository.getById(outOfBandRecordId) + public async getById(agentContext: AgentContext, outOfBandRecordId: string) { + return this.outOfBandRepository.getById(agentContext, outOfBandRecordId) } - public async findByInvitationId(invitationId: string) { - return this.outOfBandRepository.findSingleByQuery({ invitationId }) + public async findByInvitationId(agentContext: AgentContext, invitationId: string) { + return this.outOfBandRepository.findSingleByQuery(agentContext, { invitationId }) } - public async findByRecipientKey(recipientKey: Key) { - return this.outOfBandRepository.findSingleByQuery({ recipientKeyFingerprints: [recipientKey.fingerprint] }) + public async findByRecipientKey(agentContext: AgentContext, recipientKey: Key) { + return this.outOfBandRepository.findSingleByQuery(agentContext, { + recipientKeyFingerprints: [recipientKey.fingerprint], + }) } - public async getAll() { - return this.outOfBandRepository.getAll() + public async getAll(agentContext: AgentContext) { + return this.outOfBandRepository.getAll(agentContext) } - public async deleteById(outOfBandId: string) { - const outOfBandRecord = await this.getById(outOfBandId) - return this.outOfBandRepository.delete(outOfBandRecord) + public async deleteById(agentContext: AgentContext, outOfBandId: string) { + const outOfBandRecord = await this.getById(agentContext, outOfBandId) + return this.outOfBandRepository.delete(agentContext, outOfBandRecord) } } diff --git a/packages/core/src/modules/oob/__tests__/OutOfBandService.test.ts b/packages/core/src/modules/oob/__tests__/OutOfBandService.test.ts index dd1c98098b..d5c1e358ad 100644 --- a/packages/core/src/modules/oob/__tests__/OutOfBandService.test.ts +++ b/packages/core/src/modules/oob/__tests__/OutOfBandService.test.ts @@ -1,6 +1,16 @@ +import type { AgentContext } from '../../../agent' import type { Wallet } from '../../../wallet/Wallet' -import { getAgentConfig, getMockConnection, getMockOutOfBand, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { + agentDependencies, + getAgentConfig, + getAgentContext, + getMockConnection, + getMockOutOfBand, + mockFunction, +} from '../../../../tests/helpers' import { EventEmitter } from '../../../agent/EventEmitter' import { InboundMessageContext } from '../../../agent/models/InboundMessageContext' import { KeyType, Key } from '../../../crypto' @@ -26,9 +36,11 @@ describe('OutOfBandService', () => { let outOfBandRepository: OutOfBandRepository let outOfBandService: OutOfBandService let eventEmitter: EventEmitter + let agentContext: AgentContext beforeAll(async () => { - wallet = new IndyWallet(agentConfig) + wallet = new IndyWallet(agentConfig.agentDependencies, agentConfig.logger) + agentContext = getAgentContext() // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(agentConfig.walletConfig!) }) @@ -38,7 +50,7 @@ describe('OutOfBandService', () => { }) beforeEach(async () => { - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentDependencies, new Subject()) outOfBandRepository = new OutOfBandRepositoryMock() outOfBandService = new OutOfBandService(outOfBandRepository, eventEmitter) }) @@ -54,6 +66,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -69,6 +82,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -84,6 +98,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -112,6 +127,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -134,6 +150,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -158,6 +175,7 @@ describe('OutOfBandService', () => { const connection = getMockConnection({ state: DidExchangeState.Completed }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, connection, @@ -192,6 +210,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, connection: getMockConnection({ state: DidExchangeState.Completed }), @@ -213,7 +232,7 @@ describe('OutOfBandService', () => { // Non-reusable should update state mockOob.reusable = false await outOfBandService.processHandshakeReuse(messageContext) - expect(updateStateSpy).toHaveBeenCalledWith(mockOob, OutOfBandState.Done) + expect(updateStateSpy).toHaveBeenCalledWith(agentContext, mockOob, OutOfBandState.Done) }) it('returns a handshake-reuse-accepted message', async () => { @@ -222,6 +241,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseMessage, { + agentContext, senderKey: key, recipientKey: key, connection: getMockConnection({ state: DidExchangeState.Completed }), @@ -255,6 +275,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -271,6 +292,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -287,6 +309,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -316,6 +339,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, }) @@ -338,6 +362,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, connection: getMockConnection({ state: DidExchangeState.Completed, id: 'connectionId' }), @@ -365,6 +390,7 @@ describe('OutOfBandService', () => { const connection = getMockConnection({ state: DidExchangeState.Completed, id: 'connectionId' }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, connection, @@ -401,6 +427,7 @@ describe('OutOfBandService', () => { }) const messageContext = new InboundMessageContext(reuseAcceptedMessage, { + agentContext, senderKey: key, recipientKey: key, connection: getMockConnection({ state: DidExchangeState.Completed, id: 'connectionId' }), @@ -417,7 +444,7 @@ describe('OutOfBandService', () => { const updateStateSpy = jest.spyOn(outOfBandService, 'updateState') await outOfBandService.processHandshakeReuseAccepted(messageContext) - expect(updateStateSpy).toHaveBeenCalledWith(mockOob, OutOfBandState.Done) + expect(updateStateSpy).toHaveBeenCalledWith(agentContext, mockOob, OutOfBandState.Done) }) }) @@ -427,7 +454,7 @@ describe('OutOfBandService', () => { state: OutOfBandState.Initial, }) - await outOfBandService.updateState(mockOob, OutOfBandState.Done) + await outOfBandService.updateState(agentContext, mockOob, OutOfBandState.Done) expect(mockOob.state).toEqual(OutOfBandState.Done) }) @@ -437,9 +464,9 @@ describe('OutOfBandService', () => { state: OutOfBandState.Initial, }) - await outOfBandService.updateState(mockOob, OutOfBandState.Done) + await outOfBandService.updateState(agentContext, mockOob, OutOfBandState.Done) - expect(outOfBandRepository.update).toHaveBeenCalledWith(mockOob) + expect(outOfBandRepository.update).toHaveBeenCalledWith(agentContext, mockOob) }) test('emits an OutOfBandStateChangedEvent', async () => { @@ -450,7 +477,7 @@ describe('OutOfBandService', () => { }) eventEmitter.on(OutOfBandEventTypes.OutOfBandStateChanged, stateChangedListener) - await outOfBandService.updateState(mockOob, OutOfBandState.Done) + await outOfBandService.updateState(agentContext, mockOob, OutOfBandState.Done) eventEmitter.off(OutOfBandEventTypes.OutOfBandStateChanged, stateChangedListener) expect(stateChangedListener).toHaveBeenCalledTimes(1) @@ -470,8 +497,8 @@ describe('OutOfBandService', () => { it('getById should return value from outOfBandRepository.getById', async () => { const expected = getMockOutOfBand() mockFunction(outOfBandRepository.getById).mockReturnValue(Promise.resolve(expected)) - const result = await outOfBandService.getById(expected.id) - expect(outOfBandRepository.getById).toBeCalledWith(expected.id) + const result = await outOfBandService.getById(agentContext, expected.id) + expect(outOfBandRepository.getById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -479,8 +506,8 @@ describe('OutOfBandService', () => { it('findById should return value from outOfBandRepository.findById', async () => { const expected = getMockOutOfBand() mockFunction(outOfBandRepository.findById).mockReturnValue(Promise.resolve(expected)) - const result = await outOfBandService.findById(expected.id) - expect(outOfBandRepository.findById).toBeCalledWith(expected.id) + const result = await outOfBandService.findById(agentContext, expected.id) + expect(outOfBandRepository.findById).toBeCalledWith(agentContext, expected.id) expect(result).toBe(expected) }) @@ -489,8 +516,8 @@ describe('OutOfBandService', () => { const expected = [getMockOutOfBand(), getMockOutOfBand()] mockFunction(outOfBandRepository.getAll).mockReturnValue(Promise.resolve(expected)) - const result = await outOfBandService.getAll() - expect(outOfBandRepository.getAll).toBeCalledWith() + const result = await outOfBandService.getAll(agentContext) + expect(outOfBandRepository.getAll).toBeCalledWith(agentContext) expect(result).toEqual(expect.arrayContaining(expected)) }) diff --git a/packages/core/src/modules/proofs/ProofResponseCoordinator.ts b/packages/core/src/modules/proofs/ProofResponseCoordinator.ts index d839edb646..7e95e73682 100644 --- a/packages/core/src/modules/proofs/ProofResponseCoordinator.ts +++ b/packages/core/src/modules/proofs/ProofResponseCoordinator.ts @@ -1,6 +1,6 @@ +import type { AgentContext } from '../../agent/AgentContext' import type { ProofRecord } from './repository' -import { AgentConfig } from '../../agent/AgentConfig' import { injectable } from '../../plugins' import { AutoAcceptProof } from './ProofAutoAcceptType' @@ -11,12 +11,6 @@ import { AutoAcceptProof } from './ProofAutoAcceptType' */ @injectable() export class ProofResponseCoordinator { - private agentConfig: AgentConfig - - public constructor(agentConfig: AgentConfig) { - this.agentConfig = agentConfig - } - /** * Returns the proof auto accept config based on priority: * - The record config takes first priority @@ -33,10 +27,10 @@ export class ProofResponseCoordinator { /** * Checks whether it should automatically respond to a proposal */ - public shouldAutoRespondToProposal(proofRecord: ProofRecord) { + public shouldAutoRespondToProposal(agentContext: AgentContext, proofRecord: ProofRecord) { const autoAccept = ProofResponseCoordinator.composeAutoAccept( proofRecord.autoAcceptProof, - this.agentConfig.autoAcceptProofs + agentContext.config.autoAcceptProofs ) if (autoAccept === AutoAcceptProof.Always) { @@ -48,10 +42,10 @@ export class ProofResponseCoordinator { /** * Checks whether it should automatically respond to a request */ - public shouldAutoRespondToRequest(proofRecord: ProofRecord) { + public shouldAutoRespondToRequest(agentContext: AgentContext, proofRecord: ProofRecord) { const autoAccept = ProofResponseCoordinator.composeAutoAccept( proofRecord.autoAcceptProof, - this.agentConfig.autoAcceptProofs + agentContext.config.autoAcceptProofs ) if ( @@ -67,10 +61,10 @@ export class ProofResponseCoordinator { /** * Checks whether it should automatically respond to a presentation of proof */ - public shouldAutoRespondToPresentation(proofRecord: ProofRecord) { + public shouldAutoRespondToPresentation(agentContext: AgentContext, proofRecord: ProofRecord) { const autoAccept = ProofResponseCoordinator.composeAutoAccept( proofRecord.autoAcceptProof, - this.agentConfig.autoAcceptProofs + agentContext.config.autoAcceptProofs ) if ( diff --git a/packages/core/src/modules/proofs/ProofsModule.ts b/packages/core/src/modules/proofs/ProofsModule.ts index c9fa80b8ef..e1d3a8d2f1 100644 --- a/packages/core/src/modules/proofs/ProofsModule.ts +++ b/packages/core/src/modules/proofs/ProofsModule.ts @@ -5,24 +5,26 @@ import type { RequestedCredentials, RetrievedCredentials } from './models' import type { ProofRequestOptions } from './models/ProofRequest' import type { ProofRecord } from './repository/ProofRecord' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' +import { InjectionSymbols } from '../../constants' import { ServiceDecorator } from '../../decorators/service/ServiceDecorator' import { AriesFrameworkError } from '../../error' -import { injectable, module } from '../../plugins' +import { Logger } from '../../logger' +import { inject, injectable, module } from '../../plugins' import { ConnectionService } from '../connections/services/ConnectionService' import { RoutingService } from '../routing/services/RoutingService' import { ProofResponseCoordinator } from './ProofResponseCoordinator' import { PresentationProblemReportReason } from './errors' import { - ProposePresentationHandler, - RequestPresentationHandler, PresentationAckHandler, PresentationHandler, PresentationProblemReportHandler, + ProposePresentationHandler, + RequestPresentationHandler, } from './handlers' import { PresentationProblemReportMessage } from './messages/PresentationProblemReportMessage' import { ProofRequest } from './models/ProofRequest' @@ -36,24 +38,27 @@ export class ProofsModule { private connectionService: ConnectionService private messageSender: MessageSender private routingService: RoutingService - private agentConfig: AgentConfig + private agentContext: AgentContext private proofResponseCoordinator: ProofResponseCoordinator + private logger: Logger public constructor( dispatcher: Dispatcher, proofService: ProofService, connectionService: ConnectionService, routingService: RoutingService, - agentConfig: AgentConfig, + agentContext: AgentContext, messageSender: MessageSender, - proofResponseCoordinator: ProofResponseCoordinator + proofResponseCoordinator: ProofResponseCoordinator, + @inject(InjectionSymbols.Logger) logger: Logger ) { this.proofService = proofService this.connectionService = connectionService this.messageSender = messageSender this.routingService = routingService - this.agentConfig = agentConfig + this.agentContext = agentContext this.proofResponseCoordinator = proofResponseCoordinator + this.logger = logger this.registerHandlers(dispatcher) } @@ -75,12 +80,17 @@ export class ProofsModule { autoAcceptProof?: AutoAcceptProof } ): Promise { - const connection = await this.connectionService.getById(connectionId) + const connection = await this.connectionService.getById(this.agentContext, connectionId) - const { message, proofRecord } = await this.proofService.createProposal(connection, presentationProposal, config) + const { message, proofRecord } = await this.proofService.createProposal( + this.agentContext, + connection, + presentationProposal, + config + ) const outbound = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outbound) + await this.messageSender.sendMessage(this.agentContext, outbound) return proofRecord } @@ -105,7 +115,7 @@ export class ProofsModule { comment?: string } ): Promise { - const proofRecord = await this.proofService.getById(proofRecordId) + const proofRecord = await this.proofService.getById(this.agentContext, proofRecordId) if (!proofRecord.connectionId) { throw new AriesFrameworkError( @@ -113,25 +123,29 @@ export class ProofsModule { ) } - const connection = await this.connectionService.getById(proofRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, proofRecord.connectionId) const presentationProposal = proofRecord.proposalMessage?.presentationProposal if (!presentationProposal) { throw new AriesFrameworkError(`Proof record with id ${proofRecordId} is missing required presentation proposal`) } - const proofRequest = await this.proofService.createProofRequestFromProposal(presentationProposal, { - name: config?.request?.name ?? 'proof-request', - version: config?.request?.version ?? '1.0', - nonce: config?.request?.nonce, - }) + const proofRequest = await this.proofService.createProofRequestFromProposal( + this.agentContext, + presentationProposal, + { + name: config?.request?.name ?? 'proof-request', + version: config?.request?.version ?? '1.0', + nonce: config?.request?.nonce, + } + ) - const { message } = await this.proofService.createRequestAsResponse(proofRecord, proofRequest, { + const { message } = await this.proofService.createRequestAsResponse(this.agentContext, proofRecord, proofRequest, { comment: config?.comment, }) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return proofRecord } @@ -150,9 +164,9 @@ export class ProofsModule { proofRequestOptions: CreateProofRequestOptions, config?: ProofRequestConfig ): Promise { - const connection = await this.connectionService.getById(connectionId) + const connection = await this.connectionService.getById(this.agentContext, connectionId) - const nonce = proofRequestOptions.nonce ?? (await this.proofService.generateProofRequestNonce()) + const nonce = proofRequestOptions.nonce ?? (await this.proofService.generateProofRequestNonce(this.agentContext)) const proofRequest = new ProofRequest({ name: proofRequestOptions.name ?? 'proof-request', @@ -162,10 +176,15 @@ export class ProofsModule { requestedPredicates: proofRequestOptions.requestedPredicates, }) - const { message, proofRecord } = await this.proofService.createRequest(proofRequest, connection, config) + const { message, proofRecord } = await this.proofService.createRequest( + this.agentContext, + proofRequest, + connection, + config + ) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return proofRecord } @@ -185,7 +204,7 @@ export class ProofsModule { requestMessage: RequestPresentationMessage proofRecord: ProofRecord }> { - const nonce = proofRequestOptions.nonce ?? (await this.proofService.generateProofRequestNonce()) + const nonce = proofRequestOptions.nonce ?? (await this.proofService.generateProofRequestNonce(this.agentContext)) const proofRequest = new ProofRequest({ name: proofRequestOptions.name ?? 'proof-request', @@ -195,10 +214,15 @@ export class ProofsModule { requestedPredicates: proofRequestOptions.requestedPredicates, }) - const { message, proofRecord } = await this.proofService.createRequest(proofRequest, undefined, config) + const { message, proofRecord } = await this.proofService.createRequest( + this.agentContext, + proofRequest, + undefined, + config + ) // Create and set ~service decorator - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(this.agentContext) message.service = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -207,7 +231,7 @@ export class ProofsModule { // Save ~service decorator to record (to remember our verkey) proofRecord.requestMessage = message - await this.proofService.update(proofRecord) + await this.proofService.update(this.agentContext, proofRecord) return { proofRecord, requestMessage: message } } @@ -229,22 +253,27 @@ export class ProofsModule { comment?: string } ): Promise { - const record = await this.proofService.getById(proofRecordId) - const { message, proofRecord } = await this.proofService.createPresentation(record, requestedCredentials, config) + const record = await this.proofService.getById(this.agentContext, proofRecordId) + const { message, proofRecord } = await this.proofService.createPresentation( + this.agentContext, + record, + requestedCredentials, + config + ) // Use connection if present if (proofRecord.connectionId) { - const connection = await this.connectionService.getById(proofRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, proofRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return proofRecord } // Use ~service decorator otherwise else if (proofRecord.requestMessage?.service) { // Create ~service decorator - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(this.agentContext) const ourService = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -256,9 +285,9 @@ export class ProofsModule { // Set and save ~service decorator to record (to remember our verkey) message.service = ourService proofRecord.presentationMessage = message - await this.proofService.update(proofRecord) + await this.proofService.update(this.agentContext, proofRecord) - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(this.agentContext, { message, service: recipientService.resolvedDidCommService, senderKey: ourService.resolvedDidCommService.recipientKeys[0], @@ -281,8 +310,8 @@ export class ProofsModule { * @returns proof record that was declined */ public async declineRequest(proofRecordId: string) { - const proofRecord = await this.proofService.getById(proofRecordId) - await this.proofService.declineRequest(proofRecord) + const proofRecord = await this.proofService.getById(this.agentContext, proofRecordId) + await this.proofService.declineRequest(this.agentContext, proofRecord) return proofRecord } @@ -295,21 +324,21 @@ export class ProofsModule { * */ public async acceptPresentation(proofRecordId: string): Promise { - const record = await this.proofService.getById(proofRecordId) - const { message, proofRecord } = await this.proofService.createAck(record) + const record = await this.proofService.getById(this.agentContext, proofRecordId) + const { message, proofRecord } = await this.proofService.createAck(this.agentContext, record) // Use connection if present if (proofRecord.connectionId) { - const connection = await this.connectionService.getById(proofRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, proofRecord.connectionId) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) } // Use ~service decorator otherwise else if (proofRecord.requestMessage?.service && proofRecord.presentationMessage?.service) { const recipientService = proofRecord.presentationMessage?.service const ourService = proofRecord.requestMessage.service - await this.messageSender.sendMessageToService({ + await this.messageSender.sendMessageToService(this.agentContext, { message, service: recipientService.resolvedDidCommService, senderKey: ourService.resolvedDidCommService.recipientKeys[0], @@ -343,7 +372,7 @@ export class ProofsModule { proofRecordId: string, config?: GetRequestedCredentialsConfig ): Promise { - const proofRecord = await this.proofService.getById(proofRecordId) + const proofRecord = await this.proofService.getById(this.agentContext, proofRecordId) const indyProofRequest = proofRecord.requestMessage?.indyProofRequest const presentationPreview = config?.filterByPresentationPreview @@ -356,7 +385,7 @@ export class ProofsModule { ) } - return this.proofService.getRequestedCredentialsForProofRequest(indyProofRequest, { + return this.proofService.getRequestedCredentialsForProofRequest(this.agentContext, indyProofRequest, { presentationProposal: presentationPreview, filterByNonRevocationRequirements: config?.filterByNonRevocationRequirements ?? true, }) @@ -383,11 +412,11 @@ export class ProofsModule { * @returns proof record associated with the proof problem report message */ public async sendProblemReport(proofRecordId: string, message: string) { - const record = await this.proofService.getById(proofRecordId) + const record = await this.proofService.getById(this.agentContext, proofRecordId) if (!record.connectionId) { throw new AriesFrameworkError(`No connectionId found for proof record '${record.id}'.`) } - const connection = await this.connectionService.getById(record.connectionId) + const connection = await this.connectionService.getById(this.agentContext, record.connectionId) const presentationProblemReportMessage = new PresentationProblemReportMessage({ description: { en: message, @@ -398,7 +427,7 @@ export class ProofsModule { threadId: record.threadId, }) const outboundMessage = createOutboundMessage(connection, presentationProblemReportMessage) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return record } @@ -409,7 +438,7 @@ export class ProofsModule { * @returns List containing all proof records */ public getAll(): Promise { - return this.proofService.getAll() + return this.proofService.getAll(this.agentContext) } /** @@ -422,7 +451,7 @@ export class ProofsModule { * */ public async getById(proofRecordId: string): Promise { - return this.proofService.getById(proofRecordId) + return this.proofService.getById(this.agentContext, proofRecordId) } /** @@ -433,7 +462,7 @@ export class ProofsModule { * */ public async findById(proofRecordId: string): Promise { - return this.proofService.findById(proofRecordId) + return this.proofService.findById(this.agentContext, proofRecordId) } /** @@ -442,24 +471,17 @@ export class ProofsModule { * @param proofId the proof record id */ public async deleteById(proofId: string) { - return this.proofService.deleteById(proofId) + return this.proofService.deleteById(this.agentContext, proofId) } private registerHandlers(dispatcher: Dispatcher) { dispatcher.registerHandler( - new ProposePresentationHandler(this.proofService, this.agentConfig, this.proofResponseCoordinator) - ) - dispatcher.registerHandler( - new RequestPresentationHandler( - this.proofService, - this.agentConfig, - this.proofResponseCoordinator, - this.routingService - ) + new ProposePresentationHandler(this.proofService, this.proofResponseCoordinator, this.logger) ) dispatcher.registerHandler( - new PresentationHandler(this.proofService, this.agentConfig, this.proofResponseCoordinator) + new RequestPresentationHandler(this.proofService, this.proofResponseCoordinator, this.routingService, this.logger) ) + dispatcher.registerHandler(new PresentationHandler(this.proofService, this.proofResponseCoordinator, this.logger)) dispatcher.registerHandler(new PresentationAckHandler(this.proofService)) dispatcher.registerHandler(new PresentationProblemReportHandler(this.proofService)) } diff --git a/packages/core/src/modules/proofs/__tests__/ProofService.test.ts b/packages/core/src/modules/proofs/__tests__/ProofService.test.ts index d654dd924a..554856172e 100644 --- a/packages/core/src/modules/proofs/__tests__/ProofService.test.ts +++ b/packages/core/src/modules/proofs/__tests__/ProofService.test.ts @@ -1,9 +1,11 @@ -import type { Wallet } from '../../../wallet/Wallet' +import type { AgentContext } from '../../../agent' import type { CredentialRepository } from '../../credentials/repository' import type { ProofStateChangedEvent } from '../ProofEvents' import type { CustomProofTags } from './../repository/ProofRecord' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../tests/helpers' import { EventEmitter } from '../../../agent/EventEmitter' import { InboundMessageContext } from '../../../agent/models/InboundMessageContext' import { Attachment, AttachmentData } from '../../../decorators/attachment/Attachment' @@ -93,13 +95,13 @@ describe('ProofService', () => { let proofRepository: ProofRepository let proofService: ProofService let ledgerService: IndyLedgerService - let wallet: Wallet let indyVerifierService: IndyVerifierService let indyHolderService: IndyHolderService let indyRevocationService: IndyRevocationService let eventEmitter: EventEmitter let credentialRepository: CredentialRepository let connectionService: ConnectionService + let agentContext: AgentContext beforeEach(() => { const agentConfig = getAgentConfig('ProofServiceTest') @@ -108,20 +110,20 @@ describe('ProofService', () => { indyHolderService = new IndyHolderServiceMock() indyRevocationService = new IndyRevocationServiceMock() ledgerService = new IndyLedgerServiceMock() - eventEmitter = new EventEmitter(agentConfig) + eventEmitter = new EventEmitter(agentConfig.agentDependencies, new Subject()) connectionService = new connectionServiceMock() + agentContext = getAgentContext() proofService = new ProofService( proofRepository, ledgerService, - wallet, - agentConfig, indyHolderService, indyVerifierService, indyRevocationService, connectionService, eventEmitter, - credentialRepository + credentialRepository, + agentConfig.logger ) mockFunction(ledgerService.getCredentialDefinition).mockReturnValue(Promise.resolve(credDef)) @@ -138,6 +140,7 @@ describe('ProofService', () => { }) messageContext = new InboundMessageContext(presentationRequest, { connection, + agentContext, }) }) @@ -157,7 +160,7 @@ describe('ProofService', () => { connectionId: connection.id, } expect(repositorySaveSpy).toHaveBeenCalledTimes(1) - const [[createdProofRecord]] = repositorySaveSpy.mock.calls + const [[, createdProofRecord]] = repositorySaveSpy.mock.calls expect(createdProofRecord).toMatchObject(expectedProofRecord) expect(returnedProofRecord).toMatchObject(expectedProofRecord) }) @@ -236,6 +239,7 @@ describe('ProofService', () => { presentationProblemReportMessage.setThread({ threadId: 'somethreadid' }) messageContext = new InboundMessageContext(presentationProblemReportMessage, { connection, + agentContext, }) }) @@ -252,12 +256,12 @@ describe('ProofService', () => { const expectedCredentialRecord = { errorMessage: 'abandoned: Indy error', } - expect(proofRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, { + expect(proofRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, { threadId: 'somethreadid', connectionId: connection.id, }) expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1) - const [[updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls + const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls expect(updatedCredentialRecord).toMatchObject(expectedCredentialRecord) expect(returnedCredentialRecord).toMatchObject(expectedCredentialRecord) }) diff --git a/packages/core/src/modules/proofs/handlers/PresentationHandler.ts b/packages/core/src/modules/proofs/handlers/PresentationHandler.ts index c00fa139c7..991a3a550d 100644 --- a/packages/core/src/modules/proofs/handlers/PresentationHandler.ts +++ b/packages/core/src/modules/proofs/handlers/PresentationHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' +import type { Logger } from '../../../logger' import type { ProofResponseCoordinator } from '../ProofResponseCoordinator' import type { ProofRecord } from '../repository' import type { ProofService } from '../services' @@ -9,34 +9,30 @@ import { PresentationMessage } from '../messages' export class PresentationHandler implements Handler { private proofService: ProofService - private agentConfig: AgentConfig private proofResponseCoordinator: ProofResponseCoordinator + private logger: Logger public supportedMessages = [PresentationMessage] - public constructor( - proofService: ProofService, - agentConfig: AgentConfig, - proofResponseCoordinator: ProofResponseCoordinator - ) { + public constructor(proofService: ProofService, proofResponseCoordinator: ProofResponseCoordinator, logger: Logger) { this.proofService = proofService - this.agentConfig = agentConfig this.proofResponseCoordinator = proofResponseCoordinator + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const proofRecord = await this.proofService.processPresentation(messageContext) - if (this.proofResponseCoordinator.shouldAutoRespondToPresentation(proofRecord)) { + if (this.proofResponseCoordinator.shouldAutoRespondToPresentation(messageContext.agentContext, proofRecord)) { return await this.createAck(proofRecord, messageContext) } } private async createAck(record: ProofRecord, messageContext: HandlerInboundMessage) { - this.agentConfig.logger.info( - `Automatically sending acknowledgement with autoAccept on ${this.agentConfig.autoAcceptProofs}` + this.logger.info( + `Automatically sending acknowledgement with autoAccept on ${messageContext.agentContext.config.autoAcceptProofs}` ) - const { message, proofRecord } = await this.proofService.createAck(record) + const { message, proofRecord } = await this.proofService.createAck(messageContext.agentContext, record) if (messageContext.connection) { return createOutboundMessage(messageContext.connection, message) @@ -51,6 +47,6 @@ export class PresentationHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create presentation ack`) + this.logger.error(`Could not automatically create presentation ack`) } } diff --git a/packages/core/src/modules/proofs/handlers/ProposePresentationHandler.ts b/packages/core/src/modules/proofs/handlers/ProposePresentationHandler.ts index de29cc2e1d..6ab1879fdb 100644 --- a/packages/core/src/modules/proofs/handlers/ProposePresentationHandler.ts +++ b/packages/core/src/modules/proofs/handlers/ProposePresentationHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' +import type { Logger } from '../../../logger' import type { ProofResponseCoordinator } from '../ProofResponseCoordinator' import type { ProofRecord } from '../repository' import type { ProofService } from '../services' @@ -9,24 +9,20 @@ import { ProposePresentationMessage } from '../messages' export class ProposePresentationHandler implements Handler { private proofService: ProofService - private agentConfig: AgentConfig private proofResponseCoordinator: ProofResponseCoordinator + private logger: Logger public supportedMessages = [ProposePresentationMessage] - public constructor( - proofService: ProofService, - agentConfig: AgentConfig, - proofResponseCoordinator: ProofResponseCoordinator - ) { + public constructor(proofService: ProofService, proofResponseCoordinator: ProofResponseCoordinator, logger: Logger) { this.proofService = proofService - this.agentConfig = agentConfig this.proofResponseCoordinator = proofResponseCoordinator + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const proofRecord = await this.proofService.processProposal(messageContext) - if (this.proofResponseCoordinator.shouldAutoRespondToProposal(proofRecord)) { + if (this.proofResponseCoordinator.shouldAutoRespondToProposal(messageContext.agentContext, proofRecord)) { return await this.createRequest(proofRecord, messageContext) } } @@ -35,19 +31,20 @@ export class ProposePresentationHandler implements Handler { proofRecord: ProofRecord, messageContext: HandlerInboundMessage ) { - this.agentConfig.logger.info( - `Automatically sending request with autoAccept on ${this.agentConfig.autoAcceptProofs}` + this.logger.info( + `Automatically sending request with autoAccept on ${messageContext.agentContext.config.autoAcceptProofs}` ) if (!messageContext.connection) { - this.agentConfig.logger.error('No connection on the messageContext') + this.logger.error('No connection on the messageContext') return } if (!proofRecord.proposalMessage) { - this.agentConfig.logger.error(`Proof record with id ${proofRecord.id} is missing required credential proposal`) + this.logger.error(`Proof record with id ${proofRecord.id} is missing required credential proposal`) return } const proofRequest = await this.proofService.createProofRequestFromProposal( + messageContext.agentContext, proofRecord.proposalMessage.presentationProposal, { name: 'proof-request', @@ -55,7 +52,11 @@ export class ProposePresentationHandler implements Handler { } ) - const { message } = await this.proofService.createRequestAsResponse(proofRecord, proofRequest) + const { message } = await this.proofService.createRequestAsResponse( + messageContext.agentContext, + proofRecord, + proofRequest + ) return createOutboundMessage(messageContext.connection, message) } diff --git a/packages/core/src/modules/proofs/handlers/RequestPresentationHandler.ts b/packages/core/src/modules/proofs/handlers/RequestPresentationHandler.ts index e5b1ca8280..87b1445d94 100644 --- a/packages/core/src/modules/proofs/handlers/RequestPresentationHandler.ts +++ b/packages/core/src/modules/proofs/handlers/RequestPresentationHandler.ts @@ -1,5 +1,5 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' +import type { Logger } from '../../../logger' import type { RoutingService } from '../../routing/services/RoutingService' import type { ProofResponseCoordinator } from '../ProofResponseCoordinator' import type { ProofRecord } from '../repository' @@ -11,27 +11,27 @@ import { RequestPresentationMessage } from '../messages' export class RequestPresentationHandler implements Handler { private proofService: ProofService - private agentConfig: AgentConfig private proofResponseCoordinator: ProofResponseCoordinator private routingService: RoutingService + private logger: Logger public supportedMessages = [RequestPresentationMessage] public constructor( proofService: ProofService, - agentConfig: AgentConfig, proofResponseCoordinator: ProofResponseCoordinator, - routingService: RoutingService + routingService: RoutingService, + logger: Logger ) { this.proofService = proofService - this.agentConfig = agentConfig this.proofResponseCoordinator = proofResponseCoordinator this.routingService = routingService + this.logger = logger } public async handle(messageContext: HandlerInboundMessage) { const proofRecord = await this.proofService.processRequest(messageContext) - if (this.proofResponseCoordinator.shouldAutoRespondToRequest(proofRecord)) { + if (this.proofResponseCoordinator.shouldAutoRespondToRequest(messageContext.agentContext, proofRecord)) { return await this.createPresentation(proofRecord, messageContext) } } @@ -43,28 +43,36 @@ export class RequestPresentationHandler implements Handler { const indyProofRequest = record.requestMessage?.indyProofRequest const presentationProposal = record.proposalMessage?.presentationProposal - this.agentConfig.logger.info( - `Automatically sending presentation with autoAccept on ${this.agentConfig.autoAcceptProofs}` + this.logger.info( + `Automatically sending presentation with autoAccept on ${messageContext.agentContext.config.autoAcceptProofs}` ) if (!indyProofRequest) { - this.agentConfig.logger.error('Proof request is undefined.') + this.logger.error('Proof request is undefined.') return } - const retrievedCredentials = await this.proofService.getRequestedCredentialsForProofRequest(indyProofRequest, { - presentationProposal, - }) + const retrievedCredentials = await this.proofService.getRequestedCredentialsForProofRequest( + messageContext.agentContext, + indyProofRequest, + { + presentationProposal, + } + ) const requestedCredentials = this.proofService.autoSelectCredentialsForProofRequest(retrievedCredentials) - const { message, proofRecord } = await this.proofService.createPresentation(record, requestedCredentials) + const { message, proofRecord } = await this.proofService.createPresentation( + messageContext.agentContext, + record, + requestedCredentials + ) if (messageContext.connection) { return createOutboundMessage(messageContext.connection, message) } else if (proofRecord.requestMessage?.service) { // Create ~service decorator - const routing = await this.routingService.getRouting() + const routing = await this.routingService.getRouting(messageContext.agentContext) const ourService = new ServiceDecorator({ serviceEndpoint: routing.endpoints[0], recipientKeys: [routing.recipientKey.publicKeyBase58], @@ -76,7 +84,7 @@ export class RequestPresentationHandler implements Handler { // Set and save ~service decorator to record (to remember our verkey) message.service = ourService proofRecord.presentationMessage = message - await this.proofService.update(proofRecord) + await this.proofService.update(messageContext.agentContext, proofRecord) return createOutboundServiceMessage({ payload: message, @@ -85,6 +93,6 @@ export class RequestPresentationHandler implements Handler { }) } - this.agentConfig.logger.error(`Could not automatically create presentation`) + this.logger.error(`Could not automatically create presentation`) } } diff --git a/packages/core/src/modules/proofs/services/ProofService.ts b/packages/core/src/modules/proofs/services/ProofService.ts index 58611575d6..809035067a 100644 --- a/packages/core/src/modules/proofs/services/ProofService.ts +++ b/packages/core/src/modules/proofs/services/ProofService.ts @@ -1,6 +1,6 @@ +import type { AgentContext } from '../../../agent' import type { AgentMessage } from '../../../agent/AgentMessage' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' -import type { Logger } from '../../../logger' import type { ConnectionRecord } from '../../connections' import type { AutoAcceptProof } from '../ProofAutoAcceptType' import type { ProofStateChangedEvent } from '../ProofEvents' @@ -10,22 +10,21 @@ import type { CredDef, IndyProof, Schema } from 'indy-sdk' import { validateOrReject } from 'class-validator' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' import { InjectionSymbols } from '../../../constants' import { Attachment, AttachmentData } from '../../../decorators/attachment/Attachment' import { AriesFrameworkError } from '../../../error' +import { Logger } from '../../../logger' import { inject, injectable } from '../../../plugins' import { JsonEncoder } from '../../../utils/JsonEncoder' import { JsonTransformer } from '../../../utils/JsonTransformer' import { checkProofRequestForDuplicates } from '../../../utils/indyProofRequest' import { uuid } from '../../../utils/uuid' -import { Wallet } from '../../../wallet/Wallet' import { AckStatus } from '../../common' import { ConnectionService } from '../../connections' -import { IndyCredential, CredentialRepository, IndyCredentialInfo } from '../../credentials' +import { CredentialRepository, IndyCredential, IndyCredentialInfo } from '../../credentials' import { IndyCredentialUtils } from '../../credentials/formats/indy/IndyCredentialUtils' -import { IndyHolderService, IndyVerifierService, IndyRevocationService } from '../../indy' +import { IndyHolderService, IndyRevocationService, IndyVerifierService } from '../../indy' import { IndyLedgerService } from '../../ledger/services/IndyLedgerService' import { ProofEventTypes } from '../ProofEvents' import { ProofState } from '../ProofState' @@ -62,7 +61,6 @@ export class ProofService { private proofRepository: ProofRepository private credentialRepository: CredentialRepository private ledgerService: IndyLedgerService - private wallet: Wallet private logger: Logger private indyHolderService: IndyHolderService private indyVerifierService: IndyVerifierService @@ -73,20 +71,18 @@ export class ProofService { public constructor( proofRepository: ProofRepository, ledgerService: IndyLedgerService, - @inject(InjectionSymbols.Wallet) wallet: Wallet, - agentConfig: AgentConfig, indyHolderService: IndyHolderService, indyVerifierService: IndyVerifierService, indyRevocationService: IndyRevocationService, connectionService: ConnectionService, eventEmitter: EventEmitter, - credentialRepository: CredentialRepository + credentialRepository: CredentialRepository, + @inject(InjectionSymbols.Logger) logger: Logger ) { this.proofRepository = proofRepository this.credentialRepository = credentialRepository this.ledgerService = ledgerService - this.wallet = wallet - this.logger = agentConfig.logger + this.logger = logger this.indyHolderService = indyHolderService this.indyVerifierService = indyVerifierService this.indyRevocationService = indyRevocationService @@ -105,6 +101,7 @@ export class ProofService { * */ public async createProposal( + agentContext: AgentContext, connectionRecord: ConnectionRecord, presentationProposal: PresentationPreview, config?: { @@ -129,8 +126,8 @@ export class ProofService { proposalMessage, autoAcceptProof: config?.autoAcceptProof, }) - await this.proofRepository.save(proofRecord) - this.emitStateChangedEvent(proofRecord, null) + await this.proofRepository.save(agentContext, proofRecord) + this.emitStateChangedEvent(agentContext, proofRecord, null) return { message: proposalMessage, proofRecord } } @@ -146,6 +143,7 @@ export class ProofService { * */ public async createProposalAsResponse( + agentContext: AgentContext, proofRecord: ProofRecord, presentationProposal: PresentationPreview, config?: { @@ -164,7 +162,7 @@ export class ProofService { // Update record proofRecord.proposalMessage = proposalMessage - await this.updateState(proofRecord, ProofState.ProposalSent) + await this.updateState(agentContext, proofRecord, ProofState.ProposalSent) return { message: proposalMessage, proofRecord } } @@ -173,10 +171,10 @@ export class ProofService { * Decline a proof request * @param proofRecord The proof request to be declined */ - public async declineRequest(proofRecord: ProofRecord): Promise { + public async declineRequest(agentContext: AgentContext, proofRecord: ProofRecord): Promise { proofRecord.assertState(ProofState.RequestReceived) - await this.updateState(proofRecord, ProofState.Declined) + await this.updateState(agentContext, proofRecord, ProofState.Declined) return proofRecord } @@ -201,7 +199,11 @@ export class ProofService { try { // Proof record already exists - proofRecord = await this.getByThreadAndConnectionId(proposalMessage.threadId, connection?.id) + proofRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + proposalMessage.threadId, + connection?.id + ) // Assert proofRecord.assertState(ProofState.RequestSent) @@ -212,7 +214,7 @@ export class ProofService { // Update record proofRecord.proposalMessage = proposalMessage - await this.updateState(proofRecord, ProofState.ProposalReceived) + await this.updateState(messageContext.agentContext, proofRecord, ProofState.ProposalReceived) } catch { // No proof record exists with thread id proofRecord = new ProofRecord({ @@ -226,8 +228,8 @@ export class ProofService { this.connectionService.assertConnectionOrServiceDecorator(messageContext) // Save record - await this.proofRepository.save(proofRecord) - this.emitStateChangedEvent(proofRecord, null) + await this.proofRepository.save(messageContext.agentContext, proofRecord) + this.emitStateChangedEvent(messageContext.agentContext, proofRecord, null) } return proofRecord @@ -244,6 +246,7 @@ export class ProofService { * */ public async createRequestAsResponse( + agentContext: AgentContext, proofRecord: ProofRecord, proofRequest: ProofRequest, config?: { @@ -274,7 +277,7 @@ export class ProofService { // Update record proofRecord.requestMessage = requestPresentationMessage - await this.updateState(proofRecord, ProofState.RequestSent) + await this.updateState(agentContext, proofRecord, ProofState.RequestSent) return { message: requestPresentationMessage, proofRecord } } @@ -289,6 +292,7 @@ export class ProofService { * */ public async createRequest( + agentContext: AgentContext, proofRequest: ProofRequest, connectionRecord?: ConnectionRecord, config?: { @@ -326,8 +330,8 @@ export class ProofService { autoAcceptProof: config?.autoAcceptProof, }) - await this.proofRepository.save(proofRecord) - this.emitStateChangedEvent(proofRecord, null) + await this.proofRepository.save(agentContext, proofRecord) + this.emitStateChangedEvent(agentContext, proofRecord, null) return { message: requestPresentationMessage, proofRecord } } @@ -366,7 +370,11 @@ export class ProofService { try { // Proof record already exists - proofRecord = await this.getByThreadAndConnectionId(proofRequestMessage.threadId, connection?.id) + proofRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + proofRequestMessage.threadId, + connection?.id + ) // Assert proofRecord.assertState(ProofState.ProposalSent) @@ -377,7 +385,7 @@ export class ProofService { // Update record proofRecord.requestMessage = proofRequestMessage - await this.updateState(proofRecord, ProofState.RequestReceived) + await this.updateState(messageContext.agentContext, proofRecord, ProofState.RequestReceived) } catch { // No proof record exists with thread id proofRecord = new ProofRecord({ @@ -391,8 +399,8 @@ export class ProofService { this.connectionService.assertConnectionOrServiceDecorator(messageContext) // Save in repository - await this.proofRepository.save(proofRecord) - this.emitStateChangedEvent(proofRecord, null) + await this.proofRepository.save(messageContext.agentContext, proofRecord) + this.emitStateChangedEvent(messageContext.agentContext, proofRecord, null) } return proofRecord @@ -408,6 +416,7 @@ export class ProofService { * */ public async createPresentation( + agentContext: AgentContext, proofRecord: ProofRecord, requestedCredentials: RequestedCredentials, config?: { @@ -429,12 +438,13 @@ export class ProofService { // Get the matching attachments to the requested credentials const attachments = await this.getRequestedAttachmentsForRequestedCredentials( + agentContext, indyProofRequest, requestedCredentials ) // Create proof - const proof = await this.createProof(indyProofRequest, requestedCredentials) + const proof = await this.createProof(agentContext, indyProofRequest, requestedCredentials) // Create message const attachment = new Attachment({ @@ -454,7 +464,7 @@ export class ProofService { // Update record proofRecord.presentationMessage = presentationMessage - await this.updateState(proofRecord, ProofState.PresentationSent) + await this.updateState(agentContext, proofRecord, ProofState.PresentationSent) return { message: presentationMessage, proofRecord } } @@ -474,7 +484,11 @@ export class ProofService { this.logger.debug(`Processing presentation with id ${presentationMessage.id}`) - const proofRecord = await this.getByThreadAndConnectionId(presentationMessage.threadId, connection?.id) + const proofRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + presentationMessage.threadId, + connection?.id + ) // Assert proofRecord.assertState(ProofState.RequestSent) @@ -501,12 +515,12 @@ export class ProofService { ) } - const isValid = await this.verifyProof(indyProofJson, indyProofRequest) + const isValid = await this.verifyProof(messageContext.agentContext, indyProofJson, indyProofRequest) // Update record proofRecord.isVerified = isValid proofRecord.presentationMessage = presentationMessage - await this.updateState(proofRecord, ProofState.PresentationReceived) + await this.updateState(messageContext.agentContext, proofRecord, ProofState.PresentationReceived) return proofRecord } @@ -518,7 +532,10 @@ export class ProofService { * @returns Object containing presentation acknowledgement message and associated proof record * */ - public async createAck(proofRecord: ProofRecord): Promise> { + public async createAck( + agentContext: AgentContext, + proofRecord: ProofRecord + ): Promise> { this.logger.debug(`Creating presentation ack for proof record with id ${proofRecord.id}`) // Assert @@ -531,7 +548,7 @@ export class ProofService { }) // Update record - await this.updateState(proofRecord, ProofState.Done) + await this.updateState(agentContext, proofRecord, ProofState.Done) return { message: ackMessage, proofRecord } } @@ -548,7 +565,11 @@ export class ProofService { this.logger.debug(`Processing presentation ack with id ${presentationAckMessage.id}`) - const proofRecord = await this.getByThreadAndConnectionId(presentationAckMessage.threadId, connection?.id) + const proofRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + presentationAckMessage.threadId, + connection?.id + ) // Assert proofRecord.assertState(ProofState.PresentationSent) @@ -558,7 +579,7 @@ export class ProofService { }) // Update record - await this.updateState(proofRecord, ProofState.Done) + await this.updateState(messageContext.agentContext, proofRecord, ProofState.Done) return proofRecord } @@ -579,15 +600,19 @@ export class ProofService { this.logger.debug(`Processing problem report with id ${presentationProblemReportMessage.id}`) - const proofRecord = await this.getByThreadAndConnectionId(presentationProblemReportMessage.threadId, connection?.id) + const proofRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, + presentationProblemReportMessage.threadId, + connection?.id + ) proofRecord.errorMessage = `${presentationProblemReportMessage.description.code}: ${presentationProblemReportMessage.description.en}` - await this.update(proofRecord) + await this.update(messageContext.agentContext, proofRecord) return proofRecord } - public async generateProofRequestNonce() { - return this.wallet.generateNonce() + public async generateProofRequestNonce(agentContext: AgentContext) { + return agentContext.wallet.generateNonce() } /** @@ -600,10 +625,11 @@ export class ProofService { * */ public async createProofRequestFromProposal( + agentContext: AgentContext, presentationProposal: PresentationPreview, config: { name: string; version: string; nonce?: string } ): Promise { - const nonce = config.nonce ?? (await this.generateProofRequestNonce()) + const nonce = config.nonce ?? (await this.generateProofRequestNonce(agentContext)) const proofRequest = new ProofRequest({ name: config.name, @@ -683,6 +709,7 @@ export class ProofService { * @returns a list of attachments that are linked to the requested credentials */ public async getRequestedAttachmentsForRequestedCredentials( + agentContext: AgentContext, indyProofRequest: ProofRequest, requestedCredentials: RequestedCredentials ): Promise { @@ -700,7 +727,10 @@ export class ProofService { //Get credentialInfo if (!requestedAttribute.credentialInfo) { - const indyCredentialInfo = await this.indyHolderService.getCredential(requestedAttribute.credentialId) + const indyCredentialInfo = await this.indyHolderService.getCredential( + agentContext, + requestedAttribute.credentialId + ) requestedAttribute.credentialInfo = JsonTransformer.fromJSON(indyCredentialInfo, IndyCredentialInfo) } @@ -716,7 +746,9 @@ export class ProofService { for (const credentialId of credentialIds) { // Get the credentialRecord that matches the ID - const credentialRecord = await this.credentialRepository.getSingleByQuery({ credentialIds: [credentialId] }) + const credentialRecord = await this.credentialRepository.getSingleByQuery(agentContext, { + credentialIds: [credentialId], + }) if (credentialRecord.linkedAttachments) { // Get the credentials that have a hashlink as value and are requested @@ -750,6 +782,7 @@ export class ProofService { * @returns RetrievedCredentials object */ public async getRequestedCredentialsForProofRequest( + agentContext: AgentContext, proofRequest: ProofRequest, config: { presentationProposal?: PresentationPreview @@ -760,7 +793,7 @@ export class ProofService { for (const [referent, requestedAttribute] of proofRequest.requestedAttributes.entries()) { let credentialMatch: IndyCredential[] = [] - const credentials = await this.getCredentialsForProofRequest(proofRequest, referent) + const credentials = await this.getCredentialsForProofRequest(agentContext, proofRequest, referent) // If we have exactly one credential, or no proposal to pick preferences // on the credentials to use, we will use the first one @@ -789,7 +822,7 @@ export class ProofService { retrievedCredentials.requestedAttributes[referent] = await Promise.all( credentialMatch.map(async (credential: IndyCredential) => { - const { revoked, deltaTimestamp } = await this.getRevocationStatusForRequestedItem({ + const { revoked, deltaTimestamp } = await this.getRevocationStatusForRequestedItem(agentContext, { proofRequest, requestedItem: requestedAttribute, credential, @@ -815,11 +848,11 @@ export class ProofService { } for (const [referent, requestedPredicate] of proofRequest.requestedPredicates.entries()) { - const credentials = await this.getCredentialsForProofRequest(proofRequest, referent) + const credentials = await this.getCredentialsForProofRequest(agentContext, proofRequest, referent) retrievedCredentials.requestedPredicates[referent] = await Promise.all( credentials.map(async (credential) => { - const { revoked, deltaTimestamp } = await this.getRevocationStatusForRequestedItem({ + const { revoked, deltaTimestamp } = await this.getRevocationStatusForRequestedItem(agentContext, { proofRequest, requestedItem: requestedPredicate, credential, @@ -890,7 +923,11 @@ export class ProofService { * @returns Boolean whether the proof is valid * */ - public async verifyProof(proofJson: IndyProof, proofRequest: ProofRequest): Promise { + public async verifyProof( + agentContext: AgentContext, + proofJson: IndyProof, + proofRequest: ProofRequest + ): Promise { const proof = JsonTransformer.fromJSON(proofJson, PartialProof) for (const [referent, attribute] of proof.requestedProof.revealedAttributes.entries()) { @@ -908,12 +945,13 @@ export class ProofService { // I'm not 100% sure how much indy does. Also if it checks whether the proof requests matches the proof // @see https://github.com/hyperledger/aries-cloudagent-python/blob/master/aries_cloudagent/indy/sdk/verifier.py#L79-L164 - const schemas = await this.getSchemas(new Set(proof.identifiers.map((i) => i.schemaId))) + const schemas = await this.getSchemas(agentContext, new Set(proof.identifiers.map((i) => i.schemaId))) const credentialDefinitions = await this.getCredentialDefinitions( + agentContext, new Set(proof.identifiers.map((i) => i.credentialDefinitionId)) ) - return await this.indyVerifierService.verifyProof({ + return await this.indyVerifierService.verifyProof(agentContext, { proofRequest: proofRequest.toJSON(), proof: proofJson, schemas, @@ -926,8 +964,8 @@ export class ProofService { * * @returns List containing all proof records */ - public async getAll(): Promise { - return this.proofRepository.getAll() + public async getAll(agentContext: AgentContext): Promise { + return this.proofRepository.getAll(agentContext) } /** @@ -938,8 +976,8 @@ export class ProofService { * @return The proof record * */ - public async getById(proofRecordId: string): Promise { - return this.proofRepository.getById(proofRecordId) + public async getById(agentContext: AgentContext, proofRecordId: string): Promise { + return this.proofRepository.getById(agentContext, proofRecordId) } /** @@ -949,8 +987,8 @@ export class ProofService { * @return The proof record or null if not found * */ - public async findById(proofRecordId: string): Promise { - return this.proofRepository.findById(proofRecordId) + public async findById(agentContext: AgentContext, proofRecordId: string): Promise { + return this.proofRepository.findById(agentContext, proofRecordId) } /** @@ -958,9 +996,9 @@ export class ProofService { * * @param proofId the proof record id */ - public async deleteById(proofId: string) { - const proofRecord = await this.getById(proofId) - return this.proofRepository.delete(proofRecord) + public async deleteById(agentContext: AgentContext, proofId: string) { + const proofRecord = await this.getById(agentContext, proofId) + return this.proofRepository.delete(agentContext, proofRecord) } /** @@ -972,12 +1010,16 @@ export class ProofService { * @throws {RecordDuplicateError} If multiple records are found * @returns The proof record */ - public async getByThreadAndConnectionId(threadId: string, connectionId?: string): Promise { - return this.proofRepository.getSingleByQuery({ threadId, connectionId }) + public async getByThreadAndConnectionId( + agentContext: AgentContext, + threadId: string, + connectionId?: string + ): Promise { + return this.proofRepository.getSingleByQuery(agentContext, { threadId, connectionId }) } - public update(proofRecord: ProofRecord) { - return this.proofRepository.update(proofRecord) + public update(agentContext: AgentContext, proofRecord: ProofRecord) { + return this.proofRepository.update(agentContext, proofRecord) } /** @@ -988,6 +1030,7 @@ export class ProofService { * @returns indy proof object */ private async createProof( + agentContext: AgentContext, proofRequest: ProofRequest, requestedCredentials: RequestedCredentials ): Promise { @@ -999,17 +1042,18 @@ export class ProofService { if (c.credentialInfo) { return c.credentialInfo } - const credentialInfo = await this.indyHolderService.getCredential(c.credentialId) + const credentialInfo = await this.indyHolderService.getCredential(agentContext, c.credentialId) return JsonTransformer.fromJSON(credentialInfo, IndyCredentialInfo) }) ) - const schemas = await this.getSchemas(new Set(credentialObjects.map((c) => c.schemaId))) + const schemas = await this.getSchemas(agentContext, new Set(credentialObjects.map((c) => c.schemaId))) const credentialDefinitions = await this.getCredentialDefinitions( + agentContext, new Set(credentialObjects.map((c) => c.credentialDefinitionId)) ) - return this.indyHolderService.createProof({ + return this.indyHolderService.createProof(agentContext, { proofRequest: proofRequest.toJSON(), requestedCredentials: requestedCredentials, schemas, @@ -1018,10 +1062,11 @@ export class ProofService { } private async getCredentialsForProofRequest( + agentContext: AgentContext, proofRequest: ProofRequest, attributeReferent: string ): Promise { - const credentialsJson = await this.indyHolderService.getCredentialsForProofRequest({ + const credentialsJson = await this.indyHolderService.getCredentialsForProofRequest(agentContext, { proofRequest: proofRequest.toJSON(), attributeReferent, }) @@ -1029,15 +1074,18 @@ export class ProofService { return JsonTransformer.fromJSON(credentialsJson, IndyCredential) as unknown as IndyCredential[] } - private async getRevocationStatusForRequestedItem({ - proofRequest, - requestedItem, - credential, - }: { - proofRequest: ProofRequest - requestedItem: ProofAttributeInfo | ProofPredicateInfo - credential: IndyCredential - }) { + private async getRevocationStatusForRequestedItem( + agentContext: AgentContext, + { + proofRequest, + requestedItem, + credential, + }: { + proofRequest: ProofRequest + requestedItem: ProofAttributeInfo | ProofPredicateInfo + credential: IndyCredential + } + ) { const requestNonRevoked = requestedItem.nonRevoked ?? proofRequest.nonRevoked const credentialRevocationId = credential.credentialInfo.credentialRevocationId const revocationRegistryId = credential.credentialInfo.revocationRegistryId @@ -1055,6 +1103,7 @@ export class ProofService { // Note presentation from-to's vs ledger from-to's: https://github.com/hyperledger/indy-hipe/blob/master/text/0011-cred-revocation/README.md#indy-node-revocation-registry-intervals const status = await this.indyRevocationService.getRevocationStatus( + agentContext, credentialRevocationId, revocationRegistryId, requestNonRevoked @@ -1074,18 +1123,22 @@ export class ProofService { * @param newState The state to update to * */ - private async updateState(proofRecord: ProofRecord, newState: ProofState) { + private async updateState(agentContext: AgentContext, proofRecord: ProofRecord, newState: ProofState) { const previousState = proofRecord.state proofRecord.state = newState - await this.proofRepository.update(proofRecord) + await this.proofRepository.update(agentContext, proofRecord) - this.emitStateChangedEvent(proofRecord, previousState) + this.emitStateChangedEvent(agentContext, proofRecord, previousState) } - private emitStateChangedEvent(proofRecord: ProofRecord, previousState: ProofState | null) { + private emitStateChangedEvent( + agentContext: AgentContext, + proofRecord: ProofRecord, + previousState: ProofState | null + ) { const clonedProof = JsonTransformer.clone(proofRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: ProofEventTypes.ProofStateChanged, payload: { proofRecord: clonedProof, @@ -1103,11 +1156,11 @@ export class ProofService { * @returns Object containing schemas for specified schema ids * */ - private async getSchemas(schemaIds: Set) { + private async getSchemas(agentContext: AgentContext, schemaIds: Set) { const schemas: { [key: string]: Schema } = {} for (const schemaId of schemaIds) { - const schema = await this.ledgerService.getSchema(schemaId) + const schema = await this.ledgerService.getSchema(agentContext, schemaId) schemas[schemaId] = schema } @@ -1123,11 +1176,11 @@ export class ProofService { * @returns Object containing credential definitions for specified credential definition ids * */ - private async getCredentialDefinitions(credentialDefinitionIds: Set) { + private async getCredentialDefinitions(agentContext: AgentContext, credentialDefinitionIds: Set) { const credentialDefinitions: { [key: string]: CredDef } = {} for (const credDefId of credentialDefinitionIds) { - const credDef = await this.ledgerService.getCredentialDefinition(credDefId) + const credDef = await this.ledgerService.getCredentialDefinition(agentContext, credDefId) credentialDefinitions[credDefId] = credDef } diff --git a/packages/core/src/modules/question-answer/QuestionAnswerModule.ts b/packages/core/src/modules/question-answer/QuestionAnswerModule.ts index 6eac5b48d4..bc17fe8ada 100644 --- a/packages/core/src/modules/question-answer/QuestionAnswerModule.ts +++ b/packages/core/src/modules/question-answer/QuestionAnswerModule.ts @@ -1,6 +1,7 @@ import type { DependencyManager } from '../../plugins' import type { ValidResponse } from './models' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' @@ -17,16 +18,19 @@ export class QuestionAnswerModule { private questionAnswerService: QuestionAnswerService private messageSender: MessageSender private connectionService: ConnectionService + private agentContext: AgentContext public constructor( dispatcher: Dispatcher, questionAnswerService: QuestionAnswerService, messageSender: MessageSender, - connectionService: ConnectionService + connectionService: ConnectionService, + agentContext: AgentContext ) { this.questionAnswerService = questionAnswerService this.messageSender = messageSender this.connectionService = connectionService + this.agentContext = agentContext this.registerHandlers(dispatcher) } @@ -46,16 +50,20 @@ export class QuestionAnswerModule { detail?: string } ) { - const connection = await this.connectionService.getById(connectionId) + const connection = await this.connectionService.getById(this.agentContext, connectionId) connection.assertReady() - const { questionMessage, questionAnswerRecord } = await this.questionAnswerService.createQuestion(connectionId, { - question: config.question, - validResponses: config.validResponses, - detail: config?.detail, - }) + const { questionMessage, questionAnswerRecord } = await this.questionAnswerService.createQuestion( + this.agentContext, + connectionId, + { + question: config.question, + validResponses: config.validResponses, + detail: config?.detail, + } + ) const outboundMessage = createOutboundMessage(connection, questionMessage) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return questionAnswerRecord } @@ -68,17 +76,18 @@ export class QuestionAnswerModule { * @returns QuestionAnswer record */ public async sendAnswer(questionRecordId: string, response: string) { - const questionRecord = await this.questionAnswerService.getById(questionRecordId) + const questionRecord = await this.questionAnswerService.getById(this.agentContext, questionRecordId) const { answerMessage, questionAnswerRecord } = await this.questionAnswerService.createAnswer( + this.agentContext, questionRecord, response ) - const connection = await this.connectionService.getById(questionRecord.connectionId) + const connection = await this.connectionService.getById(this.agentContext, questionRecord.connectionId) const outboundMessage = createOutboundMessage(connection, answerMessage) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return questionAnswerRecord } @@ -89,7 +98,7 @@ export class QuestionAnswerModule { * @returns list containing all QuestionAnswer records */ public getAll() { - return this.questionAnswerService.getAll() + return this.questionAnswerService.getAll(this.agentContext) } private registerHandlers(dispatcher: Dispatcher) { diff --git a/packages/core/src/modules/question-answer/__tests__/QuestionAnswerService.test.ts b/packages/core/src/modules/question-answer/__tests__/QuestionAnswerService.test.ts index 3b7f3982a1..c940c1e30c 100644 --- a/packages/core/src/modules/question-answer/__tests__/QuestionAnswerService.test.ts +++ b/packages/core/src/modules/question-answer/__tests__/QuestionAnswerService.test.ts @@ -1,9 +1,18 @@ +import type { AgentContext } from '../../../agent' import type { AgentConfig } from '../../../agent/AgentConfig' import type { Repository } from '../../../storage/Repository' import type { QuestionAnswerStateChangedEvent } from '../QuestionAnswerEvents' import type { ValidResponse } from '../models' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { + agentDependencies, + getAgentConfig, + getAgentContext, + getMockConnection, + mockFunction, +} from '../../../../tests/helpers' import { EventEmitter } from '../../../agent/EventEmitter' import { IndyWallet } from '../../../wallet/IndyWallet' import { QuestionAnswerEventTypes } from '../QuestionAnswerEvents' @@ -27,6 +36,7 @@ describe('QuestionAnswerService', () => { let questionAnswerRepository: Repository let questionAnswerService: QuestionAnswerService let eventEmitter: EventEmitter + let agentContext: AgentContext const mockQuestionAnswerRecord = (options: { questionText: string @@ -52,15 +62,16 @@ describe('QuestionAnswerService', () => { beforeAll(async () => { agentConfig = getAgentConfig('QuestionAnswerServiceTest') - wallet = new IndyWallet(agentConfig) + wallet = new IndyWallet(agentConfig.agentDependencies, agentConfig.logger) + agentContext = getAgentContext() // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(agentConfig.walletConfig!) }) beforeEach(async () => { questionAnswerRepository = new QuestionAnswerRepositoryMock() - eventEmitter = new EventEmitter(agentConfig) - questionAnswerService = new QuestionAnswerService(questionAnswerRepository, eventEmitter, agentConfig) + eventEmitter = new EventEmitter(agentDependencies, new Subject()) + questionAnswerService = new QuestionAnswerService(questionAnswerRepository, eventEmitter, agentConfig.logger) }) afterAll(async () => { @@ -81,7 +92,7 @@ describe('QuestionAnswerService', () => { validResponses: [{ text: 'Yes' }, { text: 'No' }], }) - await questionAnswerService.createQuestion(mockConnectionRecord.id, { + await questionAnswerService.createQuestion(agentContext, mockConnectionRecord.id, { question: questionMessage.questionText, validResponses: questionMessage.validResponses, }) @@ -117,7 +128,7 @@ describe('QuestionAnswerService', () => { }) it(`throws an error when invalid response is provided`, async () => { - expect(questionAnswerService.createAnswer(mockRecord, 'Maybe')).rejects.toThrowError( + expect(questionAnswerService.createAnswer(agentContext, mockRecord, 'Maybe')).rejects.toThrowError( `Response does not match valid responses` ) }) @@ -131,7 +142,7 @@ describe('QuestionAnswerService', () => { mockFunction(questionAnswerRepository.getSingleByQuery).mockReturnValue(Promise.resolve(mockRecord)) - await questionAnswerService.createAnswer(mockRecord, 'Yes') + await questionAnswerService.createAnswer(agentContext, mockRecord, 'Yes') expect(eventListenerMock).toHaveBeenCalledWith({ type: 'QuestionAnswerStateChanged', diff --git a/packages/core/src/modules/question-answer/services/QuestionAnswerService.ts b/packages/core/src/modules/question-answer/services/QuestionAnswerService.ts index e1ef6093e1..01539adf57 100644 --- a/packages/core/src/modules/question-answer/services/QuestionAnswerService.ts +++ b/packages/core/src/modules/question-answer/services/QuestionAnswerService.ts @@ -1,16 +1,17 @@ +import type { AgentContext } from '../../../agent' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' -import type { Logger } from '../../../logger' import type { QuestionAnswerStateChangedEvent } from '../QuestionAnswerEvents' import type { ValidResponse } from '../models' import type { QuestionAnswerTags } from '../repository' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error' -import { injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { QuestionAnswerEventTypes } from '../QuestionAnswerEvents' import { QuestionAnswerRole } from '../QuestionAnswerRole' -import { QuestionMessage, AnswerMessage } from '../messages' +import { AnswerMessage, QuestionMessage } from '../messages' import { QuestionAnswerState } from '../models' import { QuestionAnswerRecord, QuestionAnswerRepository } from '../repository' @@ -23,11 +24,11 @@ export class QuestionAnswerService { public constructor( questionAnswerRepository: QuestionAnswerRepository, eventEmitter: EventEmitter, - agentConfig: AgentConfig + @inject(InjectionSymbols.Logger) logger: Logger ) { this.questionAnswerRepository = questionAnswerRepository this.eventEmitter = eventEmitter - this.logger = agentConfig.logger + this.logger = logger } /** * Create a question message and a new QuestionAnswer record for the questioner role @@ -39,6 +40,7 @@ export class QuestionAnswerService { * @returns question message and QuestionAnswer record */ public async createQuestion( + agentContext: AgentContext, connectionId: string, config: { question: string @@ -64,9 +66,9 @@ export class QuestionAnswerService { validResponses: questionMessage.validResponses, }) - await this.questionAnswerRepository.save(questionAnswerRecord) + await this.questionAnswerRepository.save(agentContext, questionAnswerRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: QuestionAnswerEventTypes.QuestionAnswerStateChanged, payload: { previousState: null, questionAnswerRecord }, }) @@ -88,7 +90,7 @@ export class QuestionAnswerService { this.logger.debug(`Receiving question message with id ${questionMessage.id}`) const connection = messageContext.assertReadyConnection() - const questionRecord = await this.getById(questionMessage.id) + const questionRecord = await this.getById(messageContext.agentContext, questionMessage.id) questionRecord.assertState(QuestionAnswerState.QuestionSent) const questionAnswerRecord = await this.createRecord({ @@ -102,9 +104,9 @@ export class QuestionAnswerService { validResponses: questionMessage.validResponses, }) - await this.questionAnswerRepository.save(questionAnswerRecord) + await this.questionAnswerRepository.save(messageContext.agentContext, questionAnswerRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: QuestionAnswerEventTypes.QuestionAnswerStateChanged, payload: { previousState: null, questionAnswerRecord }, }) @@ -119,7 +121,7 @@ export class QuestionAnswerService { * @param response response used in answer message * @returns answer message and QuestionAnswer record */ - public async createAnswer(questionAnswerRecord: QuestionAnswerRecord, response: string) { + public async createAnswer(agentContext: AgentContext, questionAnswerRecord: QuestionAnswerRecord, response: string) { const answerMessage = new AnswerMessage({ response: response, threadId: questionAnswerRecord.threadId }) questionAnswerRecord.assertState(QuestionAnswerState.QuestionReceived) @@ -127,7 +129,7 @@ export class QuestionAnswerService { questionAnswerRecord.response = response if (questionAnswerRecord.validResponses.some((e) => e.text === response)) { - await this.updateState(questionAnswerRecord, QuestionAnswerState.AnswerSent) + await this.updateState(agentContext, questionAnswerRecord, QuestionAnswerState.AnswerSent) } else { throw new AriesFrameworkError(`Response does not match valid responses`) } @@ -146,17 +148,18 @@ export class QuestionAnswerService { this.logger.debug(`Receiving answer message with id ${answerMessage.id}`) const connection = messageContext.assertReadyConnection() - const answerRecord = await this.getById(answerMessage.id) + const answerRecord = await this.getById(messageContext.agentContext, answerMessage.id) answerRecord.assertState(QuestionAnswerState.AnswerSent) const questionAnswerRecord: QuestionAnswerRecord = await this.getByThreadAndConnectionId( + messageContext.agentContext, answerMessage.threadId, connection?.id ) questionAnswerRecord.response = answerMessage.response - await this.updateState(questionAnswerRecord, QuestionAnswerState.AnswerReceived) + await this.updateState(messageContext.agentContext, questionAnswerRecord, QuestionAnswerState.AnswerReceived) return questionAnswerRecord } @@ -169,12 +172,16 @@ export class QuestionAnswerService { * @param newState The state to update to * */ - private async updateState(questionAnswerRecord: QuestionAnswerRecord, newState: QuestionAnswerState) { + private async updateState( + agentContext: AgentContext, + questionAnswerRecord: QuestionAnswerRecord, + newState: QuestionAnswerState + ) { const previousState = questionAnswerRecord.state questionAnswerRecord.state = newState - await this.questionAnswerRepository.update(questionAnswerRecord) + await this.questionAnswerRepository.update(agentContext, questionAnswerRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: QuestionAnswerEventTypes.QuestionAnswerStateChanged, payload: { previousState, @@ -216,8 +223,12 @@ export class QuestionAnswerService { * @throws {RecordDuplicateError} If multiple records are found * @returns The credential record */ - public getByThreadAndConnectionId(connectionId: string, threadId: string): Promise { - return this.questionAnswerRepository.getSingleByQuery({ + public getByThreadAndConnectionId( + agentContext: AgentContext, + connectionId: string, + threadId: string + ): Promise { + return this.questionAnswerRepository.getSingleByQuery(agentContext, { connectionId, threadId, }) @@ -231,8 +242,8 @@ export class QuestionAnswerService { * @return The connection record * */ - public getById(questionAnswerId: string): Promise { - return this.questionAnswerRepository.getById(questionAnswerId) + public getById(agentContext: AgentContext, questionAnswerId: string): Promise { + return this.questionAnswerRepository.getById(agentContext, questionAnswerId) } /** @@ -240,11 +251,11 @@ export class QuestionAnswerService { * * @returns List containing all QuestionAnswer records */ - public getAll() { - return this.questionAnswerRepository.getAll() + public getAll(agentContext: AgentContext) { + return this.questionAnswerRepository.getAll(agentContext) } - public async findAllByQuery(query: Partial) { - return this.questionAnswerRepository.findByQuery(query) + public async findAllByQuery(agentContext: AgentContext, query: Partial) { + return this.questionAnswerRepository.findByQuery(agentContext, query) } } diff --git a/packages/core/src/modules/routing/MediatorModule.ts b/packages/core/src/modules/routing/MediatorModule.ts index 9d3c614a84..d7a8625e17 100644 --- a/packages/core/src/modules/routing/MediatorModule.ts +++ b/packages/core/src/modules/routing/MediatorModule.ts @@ -2,16 +2,15 @@ import type { DependencyManager } from '../../plugins' import type { EncryptedMessage } from '../../types' import type { MediationRecord } from './repository' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { EventEmitter } from '../../agent/EventEmitter' -import { MessageReceiver } from '../../agent/MessageReceiver' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' import { injectable, module } from '../../plugins' import { ConnectionService } from '../connections/services' -import { KeylistUpdateHandler, ForwardHandler, BatchPickupHandler, BatchHandler } from './handlers' +import { BatchHandler, BatchPickupHandler, ForwardHandler, KeylistUpdateHandler } from './handlers' import { MediationRequestHandler } from './handlers/MediationRequestHandler' import { MediatorService } from './services/MediatorService' import { MessagePickupService } from './services/MessagePickupService' @@ -23,7 +22,7 @@ export class MediatorModule { private messagePickupService: MessagePickupService private messageSender: MessageSender public eventEmitter: EventEmitter - public agentConfig: AgentConfig + public agentContext: AgentContext public connectionService: ConnectionService public constructor( @@ -31,28 +30,30 @@ export class MediatorModule { mediationService: MediatorService, messagePickupService: MessagePickupService, messageSender: MessageSender, - messageReceiver: MessageReceiver, eventEmitter: EventEmitter, - agentConfig: AgentConfig, + agentContext: AgentContext, connectionService: ConnectionService ) { this.mediatorService = mediationService this.messagePickupService = messagePickupService this.messageSender = messageSender this.eventEmitter = eventEmitter - this.agentConfig = agentConfig this.connectionService = connectionService + this.agentContext = agentContext this.registerHandlers(dispatcher) } public async grantRequestedMediation(mediatorId: string): Promise { - const record = await this.mediatorService.getById(mediatorId) - const connectionRecord = await this.connectionService.getById(record.connectionId) + const record = await this.mediatorService.getById(this.agentContext, mediatorId) + const connectionRecord = await this.connectionService.getById(this.agentContext, record.connectionId) - const { message, mediationRecord } = await this.mediatorService.createGrantMediationMessage(record) + const { message, mediationRecord } = await this.mediatorService.createGrantMediationMessage( + this.agentContext, + record + ) const outboundMessage = createOutboundMessage(connectionRecord, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(this.agentContext, outboundMessage) return mediationRecord } @@ -66,7 +67,7 @@ export class MediatorModule { dispatcher.registerHandler(new ForwardHandler(this.mediatorService, this.connectionService, this.messageSender)) dispatcher.registerHandler(new BatchPickupHandler(this.messagePickupService)) dispatcher.registerHandler(new BatchHandler(this.eventEmitter)) - dispatcher.registerHandler(new MediationRequestHandler(this.mediatorService, this.agentConfig)) + dispatcher.registerHandler(new MediationRequestHandler(this.mediatorService)) } /** diff --git a/packages/core/src/modules/routing/RecipientModule.ts b/packages/core/src/modules/routing/RecipientModule.ts index 575c7ede5b..1e892278d6 100644 --- a/packages/core/src/modules/routing/RecipientModule.ts +++ b/packages/core/src/modules/routing/RecipientModule.ts @@ -1,4 +1,3 @@ -import type { Logger } from '../../logger' import type { DependencyManager } from '../../plugins' import type { OutboundWebSocketClosedEvent } from '../../transport' import type { OutboundMessage } from '../../types' @@ -7,16 +6,18 @@ import type { MediationStateChangedEvent } from './RoutingEvents' import type { MediationRecord } from './index' import type { GetRoutingOptions } from './services/RoutingService' -import { firstValueFrom, interval, ReplaySubject, timer } from 'rxjs' -import { filter, first, takeUntil, throttleTime, timeout, tap, delayWhen } from 'rxjs/operators' +import { firstValueFrom, interval, ReplaySubject, Subject, timer } from 'rxjs' +import { delayWhen, filter, first, takeUntil, tap, throttleTime, timeout } from 'rxjs/operators' -import { AgentConfig } from '../../agent/AgentConfig' +import { AgentContext } from '../../agent' import { Dispatcher } from '../../agent/Dispatcher' import { EventEmitter } from '../../agent/EventEmitter' import { MessageSender } from '../../agent/MessageSender' import { createOutboundMessage } from '../../agent/helpers' +import { InjectionSymbols } from '../../constants' import { AriesFrameworkError } from '../../error' -import { injectable, module } from '../../plugins' +import { Logger } from '../../logger' +import { inject, injectable, module } from '../../plugins' import { TransportEventTypes } from '../../transport' import { ConnectionService } from '../connections/services' import { DidsModule } from '../dids' @@ -38,7 +39,6 @@ import { RoutingService } from './services/RoutingService' @module() @injectable() export class RecipientModule { - private agentConfig: AgentConfig private mediationRecipientService: MediationRecipientService private connectionService: ConnectionService private dids: DidsModule @@ -48,10 +48,11 @@ export class RecipientModule { private discoverFeaturesModule: DiscoverFeaturesModule private mediationRepository: MediationRepository private routingService: RoutingService + private agentContext: AgentContext + private stop$: Subject public constructor( dispatcher: Dispatcher, - agentConfig: AgentConfig, mediationRecipientService: MediationRecipientService, connectionService: ConnectionService, dids: DidsModule, @@ -59,32 +60,36 @@ export class RecipientModule { eventEmitter: EventEmitter, discoverFeaturesModule: DiscoverFeaturesModule, mediationRepository: MediationRepository, - routingService: RoutingService + routingService: RoutingService, + @inject(InjectionSymbols.Logger) logger: Logger, + agentContext: AgentContext, + @inject(InjectionSymbols.Stop$) stop$: Subject ) { - this.agentConfig = agentConfig this.connectionService = connectionService this.dids = dids this.mediationRecipientService = mediationRecipientService this.messageSender = messageSender this.eventEmitter = eventEmitter - this.logger = agentConfig.logger + this.logger = logger this.discoverFeaturesModule = discoverFeaturesModule this.mediationRepository = mediationRepository this.routingService = routingService + this.agentContext = agentContext + this.stop$ = stop$ this.registerHandlers(dispatcher) } public async initialize() { - const { defaultMediatorId, clearDefaultMediator } = this.agentConfig + const { defaultMediatorId, clearDefaultMediator } = this.agentContext.config // Set default mediator by id if (defaultMediatorId) { - const mediatorRecord = await this.mediationRecipientService.getById(defaultMediatorId) - await this.mediationRecipientService.setDefaultMediator(mediatorRecord) + const mediatorRecord = await this.mediationRecipientService.getById(this.agentContext, defaultMediatorId) + await this.mediationRecipientService.setDefaultMediator(this.agentContext, mediatorRecord) } // Clear the stored default mediator else if (clearDefaultMediator) { - await this.mediationRecipientService.clearDefaultMediator() + await this.mediationRecipientService.clearDefaultMediator(this.agentContext) } // Poll for messages from mediator @@ -95,13 +100,13 @@ export class RecipientModule { } private async sendMessage(outboundMessage: OutboundMessage) { - const { mediatorPickupStrategy } = this.agentConfig + const { mediatorPickupStrategy } = this.agentContext.config const transportPriority = mediatorPickupStrategy === MediatorPickupStrategy.Implicit ? { schemes: ['wss', 'ws'], restrictive: true } : undefined - await this.messageSender.sendMessage(outboundMessage, { + await this.messageSender.sendMessage(this.agentContext, outboundMessage, { transportPriority, // TODO: add keepAlive: true to enforce through the public api // we need to keep the socket alive. It already works this way, but would @@ -112,8 +117,8 @@ export class RecipientModule { } private async openMediationWebSocket(mediator: MediationRecord) { - const connection = await this.connectionService.getById(mediator.connectionId) - const { message, connectionRecord } = await this.connectionService.createTrustPing(connection, { + const connection = await this.connectionService.getById(this.agentContext, mediator.connectionId) + const { message, connectionRecord } = await this.connectionService.createTrustPing(this.agentContext, connection, { responseRequested: false, }) @@ -126,7 +131,7 @@ export class RecipientModule { throw new AriesFrameworkError('Cannot open websocket to connection without websocket service endpoint') } - await this.messageSender.sendMessage(createOutboundMessage(connectionRecord, message), { + await this.messageSender.sendMessage(this.agentContext, createOutboundMessage(connectionRecord, message), { transportPriority: { schemes: websocketSchemes, restrictive: true, @@ -150,7 +155,7 @@ export class RecipientModule { .observable(TransportEventTypes.OutboundWebSocketClosedEvent) .pipe( // Stop when the agent shuts down - takeUntil(this.agentConfig.stop$), + takeUntil(this.stop$), filter((e) => e.payload.connectionId === mediator.connectionId), // Make sure we're not reconnecting multiple times throttleTime(interval), @@ -184,21 +189,21 @@ export class RecipientModule { } public async initiateMessagePickup(mediator: MediationRecord) { - const { mediatorPollingInterval } = this.agentConfig + const { mediatorPollingInterval } = this.agentContext.config const mediatorPickupStrategy = await this.getPickupStrategyForMediator(mediator) - const mediatorConnection = await this.connectionService.getById(mediator.connectionId) + const mediatorConnection = await this.connectionService.getById(this.agentContext, mediator.connectionId) switch (mediatorPickupStrategy) { case MediatorPickupStrategy.PickUpV2: - this.agentConfig.logger.info(`Starting pickup of messages from mediator '${mediator.id}'`) + this.logger.info(`Starting pickup of messages from mediator '${mediator.id}'`) await this.openWebSocketAndPickUp(mediator, mediatorPickupStrategy) await this.sendStatusRequest({ mediatorId: mediator.id }) break case MediatorPickupStrategy.PickUpV1: { // Explicit means polling every X seconds with batch message - this.agentConfig.logger.info(`Starting explicit (batch) pickup of messages from mediator '${mediator.id}'`) + this.logger.info(`Starting explicit (batch) pickup of messages from mediator '${mediator.id}'`) const subscription = interval(mediatorPollingInterval) - .pipe(takeUntil(this.agentConfig.stop$)) + .pipe(takeUntil(this.stop$)) .subscribe(async () => { await this.pickupMessages(mediatorConnection) }) @@ -207,29 +212,30 @@ export class RecipientModule { case MediatorPickupStrategy.Implicit: // Implicit means sending ping once and keeping connection open. This requires a long-lived transport // such as WebSockets to work - this.agentConfig.logger.info(`Starting implicit pickup of messages from mediator '${mediator.id}'`) + this.logger.info(`Starting implicit pickup of messages from mediator '${mediator.id}'`) await this.openWebSocketAndPickUp(mediator, mediatorPickupStrategy) break default: - this.agentConfig.logger.info( - `Skipping pickup of messages from mediator '${mediator.id}' due to pickup strategy none` - ) + this.logger.info(`Skipping pickup of messages from mediator '${mediator.id}' due to pickup strategy none`) } } private async sendStatusRequest(config: { mediatorId: string; recipientKey?: string }) { - const mediationRecord = await this.mediationRecipientService.getById(config.mediatorId) + const mediationRecord = await this.mediationRecipientService.getById(this.agentContext, config.mediatorId) const statusRequestMessage = await this.mediationRecipientService.createStatusRequest(mediationRecord, { recipientKey: config.recipientKey, }) - const mediatorConnection = await this.connectionService.getById(mediationRecord.connectionId) - return this.messageSender.sendMessage(createOutboundMessage(mediatorConnection, statusRequestMessage)) + const mediatorConnection = await this.connectionService.getById(this.agentContext, mediationRecord.connectionId) + return this.messageSender.sendMessage( + this.agentContext, + createOutboundMessage(mediatorConnection, statusRequestMessage) + ) } private async getPickupStrategyForMediator(mediator: MediationRecord) { - let mediatorPickupStrategy = mediator.pickupStrategy ?? this.agentConfig.mediatorPickupStrategy + let mediatorPickupStrategy = mediator.pickupStrategy ?? this.agentContext.config.mediatorPickupStrategy // If mediator pickup strategy is not configured we try to query if batch pickup // is supported through the discover features protocol @@ -254,14 +260,14 @@ export class RecipientModule { // Store the result so it can be reused next time mediator.pickupStrategy = mediatorPickupStrategy - await this.mediationRepository.update(mediator) + await this.mediationRepository.update(this.agentContext, mediator) } return mediatorPickupStrategy } public async discoverMediation() { - return this.mediationRecipientService.discoverMediation() + return this.mediationRecipientService.discoverMediation(this.agentContext) } public async pickupMessages(mediatorConnection: ConnectionRecord) { @@ -273,11 +279,14 @@ export class RecipientModule { } public async setDefaultMediator(mediatorRecord: MediationRecord) { - return this.mediationRecipientService.setDefaultMediator(mediatorRecord) + return this.mediationRecipientService.setDefaultMediator(this.agentContext, mediatorRecord) } public async requestMediation(connection: ConnectionRecord): Promise { - const { mediationRecord, message } = await this.mediationRecipientService.createRequest(connection) + const { mediationRecord, message } = await this.mediationRecipientService.createRequest( + this.agentContext, + connection + ) const outboundMessage = createOutboundMessage(connection, message) await this.sendMessage(outboundMessage) @@ -291,29 +300,32 @@ export class RecipientModule { } public async findByConnectionId(connectionId: string) { - return await this.mediationRecipientService.findByConnectionId(connectionId) + return await this.mediationRecipientService.findByConnectionId(this.agentContext, connectionId) } public async getMediators() { - return await this.mediationRecipientService.getMediators() + return await this.mediationRecipientService.getMediators(this.agentContext) } public async findDefaultMediator(): Promise { - return this.mediationRecipientService.findDefaultMediator() + return this.mediationRecipientService.findDefaultMediator(this.agentContext) } public async findDefaultMediatorConnection(): Promise { const mediatorRecord = await this.findDefaultMediator() if (mediatorRecord) { - return this.connectionService.getById(mediatorRecord.connectionId) + return this.connectionService.getById(this.agentContext, mediatorRecord.connectionId) } return null } public async requestAndAwaitGrant(connection: ConnectionRecord, timeoutMs = 10000): Promise { - const { mediationRecord, message } = await this.mediationRecipientService.createRequest(connection) + const { mediationRecord, message } = await this.mediationRecipientService.createRequest( + this.agentContext, + connection + ) // Create observable for event const observable = this.eventEmitter.observable(RoutingEventTypes.MediationStateChanged) @@ -353,22 +365,20 @@ export class RecipientModule { let mediation = await this.findByConnectionId(connection.id) if (!mediation) { - this.agentConfig.logger.info(`Requesting mediation for connection ${connection.id}`) + this.logger.info(`Requesting mediation for connection ${connection.id}`) mediation = await this.requestAndAwaitGrant(connection, 60000) // TODO: put timeout as a config parameter this.logger.debug('Mediation granted, setting as default mediator') await this.setDefaultMediator(mediation) this.logger.debug('Default mediator set') } else { - this.agentConfig.logger.warn( - `Mediator invitation has already been ${mediation.isReady ? 'granted' : 'requested'}` - ) + this.logger.warn(`Mediator invitation has already been ${mediation.isReady ? 'granted' : 'requested'}`) } return mediation } public async getRouting(options: GetRoutingOptions) { - return this.routingService.getRouting(options) + return this.routingService.getRouting(this.agentContext, options) } // Register handlers for the several messages for the mediator. diff --git a/packages/core/src/modules/routing/__tests__/mediation.test.ts b/packages/core/src/modules/routing/__tests__/mediation.test.ts index 42265474fa..c616bdd522 100644 --- a/packages/core/src/modules/routing/__tests__/mediation.test.ts +++ b/packages/core/src/modules/routing/__tests__/mediation.test.ts @@ -7,6 +7,7 @@ import { SubjectInboundTransport } from '../../../../../../tests/transport/Subje import { SubjectOutboundTransport } from '../../../../../../tests/transport/SubjectOutboundTransport' import { getBaseConfig, waitForBasicMessage } from '../../../../tests/helpers' import { Agent } from '../../../agent/Agent' +import { InjectionSymbols } from '../../../constants' import { sleep } from '../../../utils/sleep' import { ConnectionRecord, HandshakeProtocol } from '../../connections' import { MediatorPickupStrategy } from '../MediatorPickupStrategy' @@ -35,7 +36,8 @@ describe('mediator establishment', () => { // We want to stop the mediator polling before the agent is shutdown. // FIXME: add a way to stop mediator polling from the public api, and make sure this is // being handled in the agent shutdown so we don't get any errors with wallets being closed. - recipientAgent.config.stop$.next(true) + const stop$ = recipientAgent.injectionContainer.resolve>(InjectionSymbols.Stop$) + stop$.next(true) await sleep(1000) await recipientAgent?.shutdown() diff --git a/packages/core/src/modules/routing/handlers/BatchHandler.ts b/packages/core/src/modules/routing/handlers/BatchHandler.ts index c18861a673..e5bcb31b1e 100644 --- a/packages/core/src/modules/routing/handlers/BatchHandler.ts +++ b/packages/core/src/modules/routing/handlers/BatchHandler.ts @@ -3,7 +3,6 @@ import type { AgentMessageReceivedEvent } from '../../../agent/Events' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import { AgentEventTypes } from '../../../agent/Events' -import { AriesFrameworkError } from '../../../error' import { BatchMessage } from '../messages' export class BatchHandler implements Handler { @@ -15,15 +14,13 @@ export class BatchHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - const { message, connection } = messageContext + const { message } = messageContext - if (!connection) { - throw new AriesFrameworkError(`No connection associated with incoming message with id ${message.id}`) - } + messageContext.assertReadyConnection() const forwardedMessages = message.messages forwardedMessages.forEach((message) => { - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: message.message, diff --git a/packages/core/src/modules/routing/handlers/BatchPickupHandler.ts b/packages/core/src/modules/routing/handlers/BatchPickupHandler.ts index 841ba039e8..8cd5f8dc62 100644 --- a/packages/core/src/modules/routing/handlers/BatchPickupHandler.ts +++ b/packages/core/src/modules/routing/handlers/BatchPickupHandler.ts @@ -1,7 +1,6 @@ import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { MessagePickupService } from '../services' -import { AriesFrameworkError } from '../../../error' import { BatchPickupMessage } from '../messages' export class BatchPickupHandler implements Handler { @@ -13,11 +12,7 @@ export class BatchPickupHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - const { message, connection } = messageContext - - if (!connection) { - throw new AriesFrameworkError(`No connection associated with incoming message with id ${message.id}`) - } + messageContext.assertReadyConnection() return this.messagePickupService.batch(messageContext) } diff --git a/packages/core/src/modules/routing/handlers/ForwardHandler.ts b/packages/core/src/modules/routing/handlers/ForwardHandler.ts index 8755f8c1f1..4407480175 100644 --- a/packages/core/src/modules/routing/handlers/ForwardHandler.ts +++ b/packages/core/src/modules/routing/handlers/ForwardHandler.ts @@ -25,10 +25,16 @@ export class ForwardHandler implements Handler { public async handle(messageContext: HandlerInboundMessage) { const { encryptedMessage, mediationRecord } = await this.mediatorService.processForwardMessage(messageContext) - const connectionRecord = await this.connectionService.getById(mediationRecord.connectionId) + const connectionRecord = await this.connectionService.getById( + messageContext.agentContext, + mediationRecord.connectionId + ) // The message inside the forward message is packed so we just send the packed // message to the connection associated with it - await this.messageSender.sendPackage({ connection: connectionRecord, encryptedMessage }) + await this.messageSender.sendPackage(messageContext.agentContext, { + connection: connectionRecord, + encryptedMessage, + }) } } diff --git a/packages/core/src/modules/routing/handlers/KeylistUpdateHandler.ts b/packages/core/src/modules/routing/handlers/KeylistUpdateHandler.ts index 9d68d453ca..09dc1398bc 100644 --- a/packages/core/src/modules/routing/handlers/KeylistUpdateHandler.ts +++ b/packages/core/src/modules/routing/handlers/KeylistUpdateHandler.ts @@ -2,7 +2,6 @@ import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { MediatorService } from '../services/MediatorService' import { createOutboundMessage } from '../../../agent/helpers' -import { AriesFrameworkError } from '../../../error' import { KeylistUpdateMessage } from '../messages' export class KeylistUpdateHandler implements Handler { @@ -14,11 +13,7 @@ export class KeylistUpdateHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - const { message, connection } = messageContext - - if (!connection) { - throw new AriesFrameworkError(`No connection associated with incoming message with id ${message.id}`) - } + const connection = messageContext.assertReadyConnection() const response = await this.mediatorService.processKeylistUpdateRequest(messageContext) return createOutboundMessage(connection, response) diff --git a/packages/core/src/modules/routing/handlers/KeylistUpdateResponseHandler.ts b/packages/core/src/modules/routing/handlers/KeylistUpdateResponseHandler.ts index 23a0c4a96f..1b637dbd6f 100644 --- a/packages/core/src/modules/routing/handlers/KeylistUpdateResponseHandler.ts +++ b/packages/core/src/modules/routing/handlers/KeylistUpdateResponseHandler.ts @@ -12,9 +12,8 @@ export class KeylistUpdateResponseHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - if (!messageContext.connection) { - throw new Error(`Connection for verkey ${messageContext.recipientKey} not found!`) - } + messageContext.assertReadyConnection() + return await this.mediationRecipientService.processKeylistUpdateResults(messageContext) } } diff --git a/packages/core/src/modules/routing/handlers/MediationDenyHandler.ts b/packages/core/src/modules/routing/handlers/MediationDenyHandler.ts index fa32169a7b..ca6e163c11 100644 --- a/packages/core/src/modules/routing/handlers/MediationDenyHandler.ts +++ b/packages/core/src/modules/routing/handlers/MediationDenyHandler.ts @@ -12,9 +12,8 @@ export class MediationDenyHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - if (!messageContext.connection) { - throw new Error(`Connection for verkey ${messageContext.recipientKey} not found!`) - } + messageContext.assertReadyConnection() + await this.mediationRecipientService.processMediationDeny(messageContext) } } diff --git a/packages/core/src/modules/routing/handlers/MediationGrantHandler.ts b/packages/core/src/modules/routing/handlers/MediationGrantHandler.ts index 5706216fbb..5ac69e7c3f 100644 --- a/packages/core/src/modules/routing/handlers/MediationGrantHandler.ts +++ b/packages/core/src/modules/routing/handlers/MediationGrantHandler.ts @@ -12,9 +12,8 @@ export class MediationGrantHandler implements Handler { } public async handle(messageContext: HandlerInboundMessage) { - if (!messageContext.connection) { - throw new Error(`Connection for key ${messageContext.recipientKey} not found!`) - } + messageContext.assertReadyConnection() + await this.mediationRecipientService.processMediationGrant(messageContext) } } diff --git a/packages/core/src/modules/routing/handlers/MediationRequestHandler.ts b/packages/core/src/modules/routing/handlers/MediationRequestHandler.ts index 9a4b90ca7c..2cc2944668 100644 --- a/packages/core/src/modules/routing/handlers/MediationRequestHandler.ts +++ b/packages/core/src/modules/routing/handlers/MediationRequestHandler.ts @@ -1,31 +1,28 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' import type { Handler, HandlerInboundMessage } from '../../../agent/Handler' import type { MediatorService } from '../services/MediatorService' import { createOutboundMessage } from '../../../agent/helpers' -import { AriesFrameworkError } from '../../../error' import { MediationRequestMessage } from '../messages/MediationRequestMessage' export class MediationRequestHandler implements Handler { private mediatorService: MediatorService - private agentConfig: AgentConfig public supportedMessages = [MediationRequestMessage] - public constructor(mediatorService: MediatorService, agentConfig: AgentConfig) { + public constructor(mediatorService: MediatorService) { this.mediatorService = mediatorService - this.agentConfig = agentConfig } public async handle(messageContext: HandlerInboundMessage) { - if (!messageContext.connection) { - throw new AriesFrameworkError(`Connection for verkey ${messageContext.recipientKey} not found!`) - } + const connection = messageContext.assertReadyConnection() const mediationRecord = await this.mediatorService.processMediationRequest(messageContext) - if (this.agentConfig.autoAcceptMediationRequests) { - const { message } = await this.mediatorService.createGrantMediationMessage(mediationRecord) - return createOutboundMessage(messageContext.connection, message) + if (messageContext.agentContext.config.autoAcceptMediationRequests) { + const { message } = await this.mediatorService.createGrantMediationMessage( + messageContext.agentContext, + mediationRecord + ) + return createOutboundMessage(connection, message) } } } diff --git a/packages/core/src/modules/routing/repository/MediationRepository.ts b/packages/core/src/modules/routing/repository/MediationRepository.ts index 9f149f46e0..e89c04aa11 100644 --- a/packages/core/src/modules/routing/repository/MediationRepository.ts +++ b/packages/core/src/modules/routing/repository/MediationRepository.ts @@ -1,3 +1,5 @@ +import type { AgentContext } from '../../../agent' + import { EventEmitter } from '../../../agent/EventEmitter' import { InjectionSymbols } from '../../../constants' import { inject, injectable } from '../../../plugins' @@ -15,13 +17,13 @@ export class MediationRepository extends Repository { super(MediationRecord, storageService, eventEmitter) } - public getSingleByRecipientKey(recipientKey: string) { - return this.getSingleByQuery({ + public getSingleByRecipientKey(agentContext: AgentContext, recipientKey: string) { + return this.getSingleByQuery(agentContext, { recipientKeys: [recipientKey], }) } - public async getByConnectionId(connectionId: string): Promise { - return this.getSingleByQuery({ connectionId }) + public async getByConnectionId(agentContext: AgentContext, connectionId: string): Promise { + return this.getSingleByQuery(agentContext, { connectionId }) } } diff --git a/packages/core/src/modules/routing/services/MediationRecipientService.ts b/packages/core/src/modules/routing/services/MediationRecipientService.ts index 06b8c0f15b..da8bae5ca7 100644 --- a/packages/core/src/modules/routing/services/MediationRecipientService.ts +++ b/packages/core/src/modules/routing/services/MediationRecipientService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../../agent' import type { AgentMessage } from '../../../agent/AgentMessage' import type { AgentMessageReceivedEvent } from '../../../agent/Events' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' @@ -17,7 +18,6 @@ import type { GetRoutingOptions } from './RoutingService' import { firstValueFrom, ReplaySubject } from 'rxjs' import { filter, first, timeout } from 'rxjs/operators' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' import { AgentEventTypes } from '../../../agent/Events' import { MessageSender } from '../../../agent/MessageSender' @@ -48,16 +48,13 @@ export class MediationRecipientService { private eventEmitter: EventEmitter private connectionService: ConnectionService private messageSender: MessageSender - private config: AgentConfig public constructor( connectionService: ConnectionService, messageSender: MessageSender, - config: AgentConfig, mediatorRepository: MediationRepository, eventEmitter: EventEmitter ) { - this.config = config this.mediationRepository = mediatorRepository this.eventEmitter = eventEmitter this.connectionService = connectionService @@ -82,6 +79,7 @@ export class MediationRecipientService { } public async createRequest( + agentContext: AgentContext, connection: ConnectionRecord ): Promise> { const message = new MediationRequestMessage({}) @@ -92,8 +90,8 @@ export class MediationRecipientService { role: MediationRole.Recipient, connectionId: connection.id, }) - await this.mediationRepository.save(mediationRecord) - this.emitStateChangedEvent(mediationRecord, null) + await this.mediationRepository.save(agentContext, mediationRecord) + this.emitStateChangedEvent(agentContext, mediationRecord, null) return { mediationRecord, message } } @@ -103,7 +101,7 @@ export class MediationRecipientService { const connection = messageContext.assertReadyConnection() // Mediation record must already exists to be updated to granted status - const mediationRecord = await this.mediationRepository.getByConnectionId(connection.id) + const mediationRecord = await this.mediationRepository.getByConnectionId(messageContext.agentContext, connection.id) // Assert mediationRecord.assertState(MediationState.Requested) @@ -112,14 +110,14 @@ export class MediationRecipientService { // Update record mediationRecord.endpoint = messageContext.message.endpoint mediationRecord.routingKeys = messageContext.message.routingKeys - return await this.updateState(mediationRecord, MediationState.Granted) + return await this.updateState(messageContext.agentContext, mediationRecord, MediationState.Granted) } public async processKeylistUpdateResults(messageContext: InboundMessageContext) { // Assert ready connection const connection = messageContext.assertReadyConnection() - const mediationRecord = await this.mediationRepository.getByConnectionId(connection.id) + const mediationRecord = await this.mediationRepository.getByConnectionId(messageContext.agentContext, connection.id) // Assert mediationRecord.assertReady() @@ -136,8 +134,8 @@ export class MediationRecipientService { } } - await this.mediationRepository.update(mediationRecord) - this.eventEmitter.emit({ + await this.mediationRepository.update(messageContext.agentContext, mediationRecord) + this.eventEmitter.emit(messageContext.agentContext, { type: RoutingEventTypes.RecipientKeylistUpdated, payload: { mediationRecord, @@ -147,12 +145,13 @@ export class MediationRecipientService { } public async keylistUpdateAndAwait( + agentContext: AgentContext, mediationRecord: MediationRecord, verKey: string, timeoutMs = 15000 // TODO: this should be a configurable value in agent config ): Promise { const message = this.createKeylistUpdateMessage(verKey) - const connection = await this.connectionService.getById(mediationRecord.connectionId) + const connection = await this.connectionService.getById(agentContext, mediationRecord.connectionId) mediationRecord.assertReady() mediationRecord.assertRole(MediationRole.Recipient) @@ -174,7 +173,7 @@ export class MediationRecipientService { .subscribe(subject) const outboundMessage = createOutboundMessage(connection, message) - await this.messageSender.sendMessage(outboundMessage) + await this.messageSender.sendMessage(agentContext, outboundMessage) const keylistUpdate = await firstValueFrom(subject) return keylistUpdate.payload.mediationRecord @@ -193,24 +192,29 @@ export class MediationRecipientService { } public async addMediationRouting( + agentContext: AgentContext, routing: Routing, { mediatorId, useDefaultMediator = true }: GetRoutingOptions = {} ): Promise { let mediationRecord: MediationRecord | null = null if (mediatorId) { - mediationRecord = await this.getById(mediatorId) + mediationRecord = await this.getById(agentContext, mediatorId) } else if (useDefaultMediator) { // If no mediatorId is provided, and useDefaultMediator is true (default) // We use the default mediator if available - mediationRecord = await this.findDefaultMediator() + mediationRecord = await this.findDefaultMediator(agentContext) } // Return early if no mediation record if (!mediationRecord) return routing // new did has been created and mediator needs to be updated with the public key. - mediationRecord = await this.keylistUpdateAndAwait(mediationRecord, routing.recipientKey.publicKeyBase58) + mediationRecord = await this.keylistUpdateAndAwait( + agentContext, + mediationRecord, + routing.recipientKey.publicKeyBase58 + ) return { ...routing, @@ -223,7 +227,7 @@ export class MediationRecipientService { const connection = messageContext.assertReadyConnection() // Mediation record already exists - const mediationRecord = await this.findByConnectionId(connection.id) + const mediationRecord = await this.findByConnectionId(messageContext.agentContext, connection.id) if (!mediationRecord) { throw new Error(`No mediation has been requested for this connection id: ${connection.id}`) @@ -234,7 +238,7 @@ export class MediationRecipientService { mediationRecord.assertState(MediationState.Requested) // Update record - await this.updateState(mediationRecord, MediationState.Denied) + await this.updateState(messageContext.agentContext, mediationRecord, MediationState.Denied) return mediationRecord } @@ -244,33 +248,41 @@ export class MediationRecipientService { const { message: statusMessage } = messageContext const { messageCount, recipientKey } = statusMessage - const mediationRecord = await this.mediationRepository.getByConnectionId(connection.id) + const mediationRecord = await this.mediationRepository.getByConnectionId(messageContext.agentContext, connection.id) mediationRecord.assertReady() mediationRecord.assertRole(MediationRole.Recipient) //No messages to be sent if (messageCount === 0) { - const { message, connectionRecord } = await this.connectionService.createTrustPing(connection, { - responseRequested: false, - }) + const { message, connectionRecord } = await this.connectionService.createTrustPing( + messageContext.agentContext, + connection, + { + responseRequested: false, + } + ) const websocketSchemes = ['ws', 'wss'] - await this.messageSender.sendMessage(createOutboundMessage(connectionRecord, message), { - transportPriority: { - schemes: websocketSchemes, - restrictive: true, - // TODO: add keepAlive: true to enforce through the public api - // we need to keep the socket alive. It already works this way, but would - // be good to make more explicit from the public facing API. - // This would also make it easier to change the internal API later on. - // keepAlive: true, - }, - }) + await this.messageSender.sendMessage( + messageContext.agentContext, + createOutboundMessage(connectionRecord, message), + { + transportPriority: { + schemes: websocketSchemes, + restrictive: true, + // TODO: add keepAlive: true to enforce through the public api + // we need to keep the socket alive. It already works this way, but would + // be good to make more explicit from the public facing API. + // This would also make it easier to change the internal API later on. + // keepAlive: true, + }, + } + ) return null } - const { maximumMessagePickup } = this.config + const { maximumMessagePickup } = messageContext.agentContext.config const limit = messageCount < maximumMessagePickup ? messageCount : maximumMessagePickup const deliveryRequestMessage = new DeliveryRequestMessage({ @@ -286,7 +298,7 @@ export class MediationRecipientService { const { appendedAttachments } = messageContext.message - const mediationRecord = await this.mediationRepository.getByConnectionId(connection.id) + const mediationRecord = await this.mediationRepository.getByConnectionId(messageContext.agentContext, connection.id) mediationRecord.assertReady() mediationRecord.assertRole(MediationRole.Recipient) @@ -300,7 +312,7 @@ export class MediationRecipientService { for (const attachment of appendedAttachments) { ids.push(attachment.id) - this.eventEmitter.emit({ + this.eventEmitter.emit(messageContext.agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: attachment.getDataAsJson(), @@ -321,18 +333,22 @@ export class MediationRecipientService { * @param newState The state to update to * */ - private async updateState(mediationRecord: MediationRecord, newState: MediationState) { + private async updateState(agentContext: AgentContext, mediationRecord: MediationRecord, newState: MediationState) { const previousState = mediationRecord.state mediationRecord.state = newState - await this.mediationRepository.update(mediationRecord) + await this.mediationRepository.update(agentContext, mediationRecord) - this.emitStateChangedEvent(mediationRecord, previousState) + this.emitStateChangedEvent(agentContext, mediationRecord, previousState) return mediationRecord } - private emitStateChangedEvent(mediationRecord: MediationRecord, previousState: MediationState | null) { + private emitStateChangedEvent( + agentContext: AgentContext, + mediationRecord: MediationRecord, + previousState: MediationState | null + ) { const clonedMediationRecord = JsonTransformer.clone(mediationRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: RoutingEventTypes.MediationStateChanged, payload: { mediationRecord: clonedMediationRecord, @@ -341,29 +357,32 @@ export class MediationRecipientService { }) } - public async getById(id: string): Promise { - return this.mediationRepository.getById(id) + public async getById(agentContext: AgentContext, id: string): Promise { + return this.mediationRepository.getById(agentContext, id) } - public async findByConnectionId(connectionId: string): Promise { - return this.mediationRepository.findSingleByQuery({ connectionId }) + public async findByConnectionId(agentContext: AgentContext, connectionId: string): Promise { + return this.mediationRepository.findSingleByQuery(agentContext, { connectionId }) } - public async getMediators(): Promise { - return this.mediationRepository.getAll() + public async getMediators(agentContext: AgentContext): Promise { + return this.mediationRepository.getAll(agentContext) } - public async findDefaultMediator(): Promise { - return this.mediationRepository.findSingleByQuery({ default: true }) + public async findDefaultMediator(agentContext: AgentContext): Promise { + return this.mediationRepository.findSingleByQuery(agentContext, { default: true }) } - public async discoverMediation(mediatorId?: string): Promise { + public async discoverMediation( + agentContext: AgentContext, + mediatorId?: string + ): Promise { // If mediatorId is passed, always use it (and error if it is not found) if (mediatorId) { - return this.mediationRepository.getById(mediatorId) + return this.mediationRepository.getById(agentContext, mediatorId) } - const defaultMediator = await this.findDefaultMediator() + const defaultMediator = await this.findDefaultMediator(agentContext) if (defaultMediator) { if (defaultMediator.state !== MediationState.Granted) { throw new AriesFrameworkError( @@ -375,25 +394,25 @@ export class MediationRecipientService { } } - public async setDefaultMediator(mediator: MediationRecord) { - const mediationRecords = await this.mediationRepository.findByQuery({ default: true }) + public async setDefaultMediator(agentContext: AgentContext, mediator: MediationRecord) { + const mediationRecords = await this.mediationRepository.findByQuery(agentContext, { default: true }) for (const record of mediationRecords) { record.setTag('default', false) - await this.mediationRepository.update(record) + await this.mediationRepository.update(agentContext, record) } // Set record coming in tag to true and then update. mediator.setTag('default', true) - await this.mediationRepository.update(mediator) + await this.mediationRepository.update(agentContext, mediator) } - public async clearDefaultMediator() { - const mediationRecord = await this.findDefaultMediator() + public async clearDefaultMediator(agentContext: AgentContext) { + const mediationRecord = await this.findDefaultMediator(agentContext) if (mediationRecord) { mediationRecord.setTag('default', false) - await this.mediationRepository.update(mediationRecord) + await this.mediationRepository.update(agentContext, mediationRecord) } } } diff --git a/packages/core/src/modules/routing/services/MediatorService.ts b/packages/core/src/modules/routing/services/MediatorService.ts index 403d3424d8..021ccc44d2 100644 --- a/packages/core/src/modules/routing/services/MediatorService.ts +++ b/packages/core/src/modules/routing/services/MediatorService.ts @@ -1,22 +1,22 @@ +import type { AgentContext } from '../../../agent' import type { InboundMessageContext } from '../../../agent/models/InboundMessageContext' import type { EncryptedMessage } from '../../../types' import type { MediationStateChangedEvent } from '../RoutingEvents' import type { ForwardMessage, KeylistUpdateMessage, MediationRequestMessage } from '../messages' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error' -import { inject, injectable } from '../../../plugins' +import { Logger } from '../../../logger' +import { injectable, inject } from '../../../plugins' import { JsonTransformer } from '../../../utils/JsonTransformer' -import { Wallet } from '../../../wallet/Wallet' import { RoutingEventTypes } from '../RoutingEvents' import { KeylistUpdateAction, - KeylistUpdateResult, KeylistUpdated, - MediationGrantMessage, KeylistUpdateResponseMessage, + KeylistUpdateResult, + MediationGrantMessage, } from '../messages' import { MediationRole } from '../models/MediationRole' import { MediationState } from '../models/MediationState' @@ -27,54 +27,52 @@ import { MediatorRoutingRepository } from '../repository/MediatorRoutingReposito @injectable() export class MediatorService { - private agentConfig: AgentConfig + private logger: Logger private mediationRepository: MediationRepository private mediatorRoutingRepository: MediatorRoutingRepository - private wallet: Wallet private eventEmitter: EventEmitter private _mediatorRoutingRecord?: MediatorRoutingRecord public constructor( mediationRepository: MediationRepository, mediatorRoutingRepository: MediatorRoutingRepository, - agentConfig: AgentConfig, - @inject(InjectionSymbols.Wallet) wallet: Wallet, - eventEmitter: EventEmitter + eventEmitter: EventEmitter, + @inject(InjectionSymbols.Logger) logger: Logger ) { this.mediationRepository = mediationRepository this.mediatorRoutingRepository = mediatorRoutingRepository - this.agentConfig = agentConfig - this.wallet = wallet this.eventEmitter = eventEmitter + this.logger = logger } - private async getRoutingKeys() { - this.agentConfig.logger.debug('Retrieving mediator routing keys') + private async getRoutingKeys(agentContext: AgentContext) { + this.logger.debug('Retrieving mediator routing keys') // If the routing record is not loaded yet, retrieve it from storage if (!this._mediatorRoutingRecord) { - this.agentConfig.logger.debug('Mediator routing record not loaded yet, retrieving from storage') + this.logger.debug('Mediator routing record not loaded yet, retrieving from storage') let routingRecord = await this.mediatorRoutingRepository.findById( + agentContext, this.mediatorRoutingRepository.MEDIATOR_ROUTING_RECORD_ID ) // If we don't have a routing record yet, create it if (!routingRecord) { - this.agentConfig.logger.debug('Mediator routing record does not exist yet, creating routing keys and record') - const { verkey } = await this.wallet.createDid() + this.logger.debug('Mediator routing record does not exist yet, creating routing keys and record') + const { verkey } = await agentContext.wallet.createDid() routingRecord = new MediatorRoutingRecord({ id: this.mediatorRoutingRepository.MEDIATOR_ROUTING_RECORD_ID, routingKeys: [verkey], }) - await this.mediatorRoutingRepository.save(routingRecord) + await this.mediatorRoutingRepository.save(agentContext, routingRecord) } this._mediatorRoutingRecord = routingRecord } // Return the routing keys - this.agentConfig.logger.debug(`Returning mediator routing keys ${this._mediatorRoutingRecord.routingKeys}`) + this.logger.debug(`Returning mediator routing keys ${this._mediatorRoutingRecord.routingKeys}`) return this._mediatorRoutingRecord.routingKeys } @@ -88,7 +86,10 @@ export class MediatorService { throw new AriesFrameworkError('Invalid Message: Missing required attribute "to"') } - const mediationRecord = await this.mediationRepository.getSingleByRecipientKey(message.to) + const mediationRecord = await this.mediationRepository.getSingleByRecipientKey( + messageContext.agentContext, + message.to + ) // Assert mediation record is ready to be used mediationRecord.assertReady() @@ -107,7 +108,7 @@ export class MediatorService { const { message } = messageContext const keylist: KeylistUpdated[] = [] - const mediationRecord = await this.mediationRepository.getByConnectionId(connection.id) + const mediationRecord = await this.mediationRepository.getByConnectionId(messageContext.agentContext, connection.id) mediationRecord.assertReady() mediationRecord.assertRole(MediationRole.Mediator) @@ -130,21 +131,21 @@ export class MediatorService { } } - await this.mediationRepository.update(mediationRecord) + await this.mediationRepository.update(messageContext.agentContext, mediationRecord) return new KeylistUpdateResponseMessage({ keylist, threadId: message.threadId }) } - public async createGrantMediationMessage(mediationRecord: MediationRecord) { + public async createGrantMediationMessage(agentContext: AgentContext, mediationRecord: MediationRecord) { // Assert mediationRecord.assertState(MediationState.Requested) mediationRecord.assertRole(MediationRole.Mediator) - await this.updateState(mediationRecord, MediationState.Granted) + await this.updateState(agentContext, mediationRecord, MediationState.Granted) const message = new MediationGrantMessage({ - endpoint: this.agentConfig.endpoints[0], - routingKeys: await this.getRoutingKeys(), + endpoint: agentContext.config.endpoints[0], + routingKeys: await this.getRoutingKeys(agentContext), threadId: mediationRecord.threadId, }) @@ -162,37 +163,41 @@ export class MediatorService { threadId: messageContext.message.threadId, }) - await this.mediationRepository.save(mediationRecord) - this.emitStateChangedEvent(mediationRecord, null) + await this.mediationRepository.save(messageContext.agentContext, mediationRecord) + this.emitStateChangedEvent(messageContext.agentContext, mediationRecord, null) return mediationRecord } - public async findById(mediatorRecordId: string): Promise { - return this.mediationRepository.findById(mediatorRecordId) + public async findById(agentContext: AgentContext, mediatorRecordId: string): Promise { + return this.mediationRepository.findById(agentContext, mediatorRecordId) } - public async getById(mediatorRecordId: string): Promise { - return this.mediationRepository.getById(mediatorRecordId) + public async getById(agentContext: AgentContext, mediatorRecordId: string): Promise { + return this.mediationRepository.getById(agentContext, mediatorRecordId) } - public async getAll(): Promise { - return await this.mediationRepository.getAll() + public async getAll(agentContext: AgentContext): Promise { + return await this.mediationRepository.getAll(agentContext) } - private async updateState(mediationRecord: MediationRecord, newState: MediationState) { + private async updateState(agentContext: AgentContext, mediationRecord: MediationRecord, newState: MediationState) { const previousState = mediationRecord.state mediationRecord.state = newState - await this.mediationRepository.update(mediationRecord) + await this.mediationRepository.update(agentContext, mediationRecord) - this.emitStateChangedEvent(mediationRecord, previousState) + this.emitStateChangedEvent(agentContext, mediationRecord, previousState) } - private emitStateChangedEvent(mediationRecord: MediationRecord, previousState: MediationState | null) { + private emitStateChangedEvent( + agentContext: AgentContext, + mediationRecord: MediationRecord, + previousState: MediationState | null + ) { const clonedMediationRecord = JsonTransformer.clone(mediationRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: RoutingEventTypes.MediationStateChanged, payload: { mediationRecord: clonedMediationRecord, diff --git a/packages/core/src/modules/routing/services/RoutingService.ts b/packages/core/src/modules/routing/services/RoutingService.ts index 134507b528..357cb05d3d 100644 --- a/packages/core/src/modules/routing/services/RoutingService.ts +++ b/packages/core/src/modules/routing/services/RoutingService.ts @@ -1,12 +1,10 @@ +import type { AgentContext } from '../../../agent' import type { Routing } from '../../connections' import type { RoutingCreatedEvent } from '../RoutingEvents' -import { AgentConfig } from '../../../agent/AgentConfig' import { EventEmitter } from '../../../agent/EventEmitter' -import { InjectionSymbols } from '../../../constants' import { Key, KeyType } from '../../../crypto' -import { inject, injectable } from '../../../plugins' -import { Wallet } from '../../../wallet' +import { injectable } from '../../../plugins' import { RoutingEventTypes } from '../RoutingEvents' import { MediationRecipientService } from './MediationRecipientService' @@ -14,42 +12,38 @@ import { MediationRecipientService } from './MediationRecipientService' @injectable() export class RoutingService { private mediationRecipientService: MediationRecipientService - private agentConfig: AgentConfig - private wallet: Wallet + private eventEmitter: EventEmitter - public constructor( - mediationRecipientService: MediationRecipientService, - agentConfig: AgentConfig, - @inject(InjectionSymbols.Wallet) wallet: Wallet, - eventEmitter: EventEmitter - ) { + public constructor(mediationRecipientService: MediationRecipientService, eventEmitter: EventEmitter) { this.mediationRecipientService = mediationRecipientService - this.agentConfig = agentConfig - this.wallet = wallet + this.eventEmitter = eventEmitter } - public async getRouting({ mediatorId, useDefaultMediator = true }: GetRoutingOptions = {}): Promise { + public async getRouting( + agentContext: AgentContext, + { mediatorId, useDefaultMediator = true }: GetRoutingOptions = {} + ): Promise { // Create and store new key - const { verkey: publicKeyBase58 } = await this.wallet.createDid() + const { verkey: publicKeyBase58 } = await agentContext.wallet.createDid() const recipientKey = Key.fromPublicKeyBase58(publicKeyBase58, KeyType.Ed25519) let routing: Routing = { - endpoints: this.agentConfig.endpoints, + endpoints: agentContext.config.endpoints, routingKeys: [], recipientKey, } // Extend routing with mediator keys (if applicable) - routing = await this.mediationRecipientService.addMediationRouting(routing, { + routing = await this.mediationRecipientService.addMediationRouting(agentContext, routing, { mediatorId, useDefaultMediator, }) // Emit event so other parts of the framework can react on keys created - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: RoutingEventTypes.RoutingCreatedEvent, payload: { routing, diff --git a/packages/core/src/modules/routing/services/__tests__/MediationRecipientService.test.ts b/packages/core/src/modules/routing/services/__tests__/MediationRecipientService.test.ts index 5c08a1885c..9d5c715237 100644 --- a/packages/core/src/modules/routing/services/__tests__/MediationRecipientService.test.ts +++ b/packages/core/src/modules/routing/services/__tests__/MediationRecipientService.test.ts @@ -1,7 +1,8 @@ +import type { AgentContext } from '../../../../agent' import type { Wallet } from '../../../../wallet/Wallet' import type { Routing } from '../../../connections/services/ConnectionService' -import { getAgentConfig, getMockConnection, mockFunction } from '../../../../../tests/helpers' +import { getAgentConfig, getAgentContext, getMockConnection, mockFunction } from '../../../../../tests/helpers' import { EventEmitter } from '../../../../agent/EventEmitter' import { AgentEventTypes } from '../../../../agent/Events' import { MessageSender } from '../../../../agent/MessageSender' @@ -56,9 +57,13 @@ describe('MediationRecipientService', () => { let messageSender: MessageSender let mediationRecipientService: MediationRecipientService let mediationRecord: MediationRecord + let agentContext: AgentContext beforeAll(async () => { - wallet = new IndyWallet(config) + wallet = new IndyWallet(config.agentDependencies, config.logger) + agentContext = getAgentContext({ + agentConfig: config, + }) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(config.walletConfig!) }) @@ -71,7 +76,7 @@ describe('MediationRecipientService', () => { eventEmitter = new EventEmitterMock() connectionRepository = new ConnectionRepositoryMock() didRepository = new DidRepositoryMock() - connectionService = new ConnectionService(wallet, config, connectionRepository, didRepository, eventEmitter) + connectionService = new ConnectionService(config.logger, connectionRepository, didRepository, eventEmitter) mediationRepository = new MediationRepositoryMock() messageSender = new MessageSenderMock() @@ -87,7 +92,6 @@ describe('MediationRecipientService', () => { mediationRecipientService = new MediationRecipientService( connectionService, messageSender, - config, mediationRepository, eventEmitter ) @@ -126,7 +130,7 @@ describe('MediationRecipientService', () => { messageCount: 0, }) - const messageContext = new InboundMessageContext(status, { connection: mockConnection }) + const messageContext = new InboundMessageContext(status, { connection: mockConnection, agentContext }) const deliveryRequestMessage = await mediationRecipientService.processStatus(messageContext) expect(deliveryRequestMessage).toBeNull() }) @@ -135,7 +139,7 @@ describe('MediationRecipientService', () => { const status = new StatusMessage({ messageCount: 1, }) - const messageContext = new InboundMessageContext(status, { connection: mockConnection }) + const messageContext = new InboundMessageContext(status, { connection: mockConnection, agentContext }) const deliveryRequestMessage = await mediationRecipientService.processStatus(messageContext) expect(deliveryRequestMessage) @@ -146,7 +150,7 @@ describe('MediationRecipientService', () => { const status = new StatusMessage({ messageCount: 1, }) - const messageContext = new InboundMessageContext(status, { connection: mockConnection }) + const messageContext = new InboundMessageContext(status, { connection: mockConnection, agentContext }) mediationRecord.role = MediationRole.Mediator await expect(mediationRecipientService.processStatus(messageContext)).rejects.toThrowError( @@ -164,7 +168,10 @@ describe('MediationRecipientService', () => { describe('processDelivery', () => { it('if the delivery has no attachments expect an error', async () => { - const messageContext = new InboundMessageContext({} as MessageDeliveryMessage, { connection: mockConnection }) + const messageContext = new InboundMessageContext({} as MessageDeliveryMessage, { + connection: mockConnection, + agentContext, + }) await expect(mediationRecipientService.processDelivery(messageContext)).rejects.toThrowError( new AriesFrameworkError('Error processing attachments') @@ -184,7 +191,10 @@ describe('MediationRecipientService', () => { }), ], }) - const messageContext = new InboundMessageContext(messageDeliveryMessage, { connection: mockConnection }) + const messageContext = new InboundMessageContext(messageDeliveryMessage, { + connection: mockConnection, + agentContext, + }) const messagesReceivedMessage = await mediationRecipientService.processDelivery(messageContext) @@ -217,18 +227,21 @@ describe('MediationRecipientService', () => { }), ], }) - const messageContext = new InboundMessageContext(messageDeliveryMessage, { connection: mockConnection }) + const messageContext = new InboundMessageContext(messageDeliveryMessage, { + connection: mockConnection, + agentContext, + }) await mediationRecipientService.processDelivery(messageContext) expect(eventEmitter.emit).toHaveBeenCalledTimes(2) - expect(eventEmitter.emit).toHaveBeenNthCalledWith(1, { + expect(eventEmitter.emit).toHaveBeenNthCalledWith(1, agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: { first: 'value' }, }, }) - expect(eventEmitter.emit).toHaveBeenNthCalledWith(2, { + expect(eventEmitter.emit).toHaveBeenNthCalledWith(2, agentContext, { type: AgentEventTypes.AgentMessageReceived, payload: { message: { second: 'value' }, @@ -249,7 +262,10 @@ describe('MediationRecipientService', () => { }), ], }) - const messageContext = new InboundMessageContext(messageDeliveryMessage, { connection: mockConnection }) + const messageContext = new InboundMessageContext(messageDeliveryMessage, { + connection: mockConnection, + agentContext, + }) mediationRecord.role = MediationRole.Mediator await expect(mediationRecipientService.processDelivery(messageContext)).rejects.toThrowError( @@ -290,7 +306,7 @@ describe('MediationRecipientService', () => { test('adds mediation routing id mediator id is passed', async () => { mockFunction(mediationRepository.getById).mockResolvedValue(mediationRecord) - const extendedRouting = await mediationRecipientService.addMediationRouting(routing, { + const extendedRouting = await mediationRecipientService.addMediationRouting(agentContext, routing, { mediatorId: 'mediator-id', }) @@ -298,14 +314,14 @@ describe('MediationRecipientService', () => { endpoints: ['https://a-mediator-endpoint.com'], routingKeys: [routingKey], }) - expect(mediationRepository.getById).toHaveBeenCalledWith('mediator-id') + expect(mediationRepository.getById).toHaveBeenCalledWith(agentContext, 'mediator-id') }) test('adds mediation routing if useDefaultMediator is true and default mediation is found', async () => { mockFunction(mediationRepository.findSingleByQuery).mockResolvedValue(mediationRecord) jest.spyOn(mediationRecipientService, 'keylistUpdateAndAwait').mockResolvedValue(mediationRecord) - const extendedRouting = await mediationRecipientService.addMediationRouting(routing, { + const extendedRouting = await mediationRecipientService.addMediationRouting(agentContext, routing, { useDefaultMediator: true, }) @@ -313,14 +329,14 @@ describe('MediationRecipientService', () => { endpoints: ['https://a-mediator-endpoint.com'], routingKeys: [routingKey], }) - expect(mediationRepository.findSingleByQuery).toHaveBeenCalledWith({ default: true }) + expect(mediationRepository.findSingleByQuery).toHaveBeenCalledWith(agentContext, { default: true }) }) test('does not add mediation routing if no mediation is found', async () => { mockFunction(mediationRepository.findSingleByQuery).mockResolvedValue(mediationRecord) jest.spyOn(mediationRecipientService, 'keylistUpdateAndAwait').mockResolvedValue(mediationRecord) - const extendedRouting = await mediationRecipientService.addMediationRouting(routing, { + const extendedRouting = await mediationRecipientService.addMediationRouting(agentContext, routing, { useDefaultMediator: false, }) diff --git a/packages/core/src/modules/routing/services/__tests__/RoutingService.test.ts b/packages/core/src/modules/routing/services/__tests__/RoutingService.test.ts index 4a674a7f6d..c1360ed1df 100644 --- a/packages/core/src/modules/routing/services/__tests__/RoutingService.test.ts +++ b/packages/core/src/modules/routing/services/__tests__/RoutingService.test.ts @@ -1,4 +1,6 @@ -import { getAgentConfig, mockFunction } from '../../../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, mockFunction } from '../../../../../tests/helpers' import { EventEmitter } from '../../../../agent/EventEmitter' import { Key } from '../../../../crypto' import { IndyWallet } from '../../../../wallet/IndyWallet' @@ -15,10 +17,14 @@ const MediationRecipientServiceMock = MediationRecipientService as jest.Mock { describe('getRouting', () => { test('calls mediation recipient service', async () => { - const routing = await routingService.getRouting({ + const routing = await routingService.getRouting(agentContext, { mediatorId: 'mediator-id', useDefaultMediator: true, }) - expect(mediationRecipientService.addMediationRouting).toHaveBeenCalledWith(routing, { + expect(mediationRecipientService.addMediationRouting).toHaveBeenCalledWith(agentContext, routing, { mediatorId: 'mediator-id', useDefaultMediator: true, }) @@ -55,7 +61,7 @@ describe('RoutingService', () => { const routingListener = jest.fn() eventEmitter.on(RoutingEventTypes.RoutingCreatedEvent, routingListener) - const routing = await routingService.getRouting() + const routing = await routingService.getRouting(agentContext) expect(routing).toEqual(routing) expect(routingListener).toHaveBeenCalledWith({ diff --git a/packages/core/src/modules/vc/W3cCredentialService.ts b/packages/core/src/modules/vc/W3cCredentialService.ts index c1bf89d930..5c724388a3 100644 --- a/packages/core/src/modules/vc/W3cCredentialService.ts +++ b/packages/core/src/modules/vc/W3cCredentialService.ts @@ -1,5 +1,6 @@ +import type { AgentContext } from '../..' import type { Key } from '../../crypto/Key' -import type { DocumentLoaderResult } from './jsonldUtil' +import type { DocumentLoader } from './jsonldUtil' import type { W3cVerifyCredentialResult } from './models' import type { CreatePresentationOptions, @@ -12,15 +13,11 @@ import type { } from './models/W3cCredentialServiceOptions' import type { VerifyPresentationResult } from './models/presentation/VerifyPresentationResult' -import { inject } from 'tsyringe' - -import { InjectionSymbols } from '../../constants' import { createWalletKeyPairClass } from '../../crypto/WalletKeyPair' import { AriesFrameworkError } from '../../error' import { injectable } from '../../plugins' import { JsonTransformer } from '../../utils' import { isNodeJS, isReactNative } from '../../utils/environment' -import { Wallet } from '../../wallet' import { DidResolverService, VerificationMethod } from '../dids' import { getKeyDidMappingByVerificationMethod } from '../dids/domain/key-type' @@ -36,17 +33,11 @@ import { deriveProof } from './signature-suites/bbs' @injectable() export class W3cCredentialService { - private wallet: Wallet private w3cCredentialRepository: W3cCredentialRepository private didResolver: DidResolverService private suiteRegistry: SignatureSuiteRegistry - public constructor( - @inject(InjectionSymbols.Wallet) wallet: Wallet, - w3cCredentialRepository: W3cCredentialRepository, - didResolver: DidResolverService - ) { - this.wallet = wallet + public constructor(w3cCredentialRepository: W3cCredentialRepository, didResolver: DidResolverService) { this.w3cCredentialRepository = w3cCredentialRepository this.didResolver = didResolver this.suiteRegistry = new SignatureSuiteRegistry() @@ -58,10 +49,13 @@ export class W3cCredentialService { * @param credential the credential to be signed * @returns the signed credential */ - public async signCredential(options: SignCredentialOptions): Promise { - const WalletKeyPair = createWalletKeyPairClass(this.wallet) + public async signCredential( + agentContext: AgentContext, + options: SignCredentialOptions + ): Promise { + const WalletKeyPair = createWalletKeyPairClass(agentContext.wallet) - const signingKey = await this.getPublicKeyFromVerificationMethod(options.verificationMethod) + const signingKey = await this.getPublicKeyFromVerificationMethod(agentContext, options.verificationMethod) const suiteInfo = this.suiteRegistry.getByProofType(options.proofType) if (signingKey.keyType !== suiteInfo.keyType) { @@ -72,7 +66,7 @@ export class W3cCredentialService { controller: options.credential.issuerId, // should we check this against the verificationMethod.controller? id: options.verificationMethod, key: signingKey, - wallet: this.wallet, + wallet: agentContext.wallet, }) const SuiteClass = suiteInfo.suiteClass @@ -91,7 +85,7 @@ export class W3cCredentialService { credential: JsonTransformer.toJSON(options.credential), suite: suite, purpose: options.proofPurpose, - documentLoader: this.documentLoader, + documentLoader: this.documentLoaderWithContext(agentContext), }) return JsonTransformer.fromJSON(result, W3cVerifiableCredential) @@ -103,13 +97,16 @@ export class W3cCredentialService { * @param credential the credential to be verified * @returns the verification result */ - public async verifyCredential(options: VerifyCredentialOptions): Promise { - const suites = this.getSignatureSuitesForCredential(options.credential) + public async verifyCredential( + agentContext: AgentContext, + options: VerifyCredentialOptions + ): Promise { + const suites = this.getSignatureSuitesForCredential(agentContext, options.credential) const verifyOptions: Record = { credential: JsonTransformer.toJSON(options.credential), suite: suites, - documentLoader: this.documentLoader, + documentLoader: this.documentLoaderWithContext(agentContext), } // this is a hack because vcjs throws if purpose is passed as undefined or null @@ -152,9 +149,12 @@ export class W3cCredentialService { * @param presentation the presentation to be signed * @returns the signed presentation */ - public async signPresentation(options: SignPresentationOptions): Promise { + public async signPresentation( + agentContext: AgentContext, + options: SignPresentationOptions + ): Promise { // create keyPair - const WalletKeyPair = createWalletKeyPairClass(this.wallet) + const WalletKeyPair = createWalletKeyPairClass(agentContext.wallet) const suiteInfo = this.suiteRegistry.getByProofType(options.signatureType) @@ -162,13 +162,14 @@ export class W3cCredentialService { throw new AriesFrameworkError(`The requested proofType ${options.signatureType} is not supported`) } - const signingKey = await this.getPublicKeyFromVerificationMethod(options.verificationMethod) + const signingKey = await this.getPublicKeyFromVerificationMethod(agentContext, options.verificationMethod) if (signingKey.keyType !== suiteInfo.keyType) { throw new AriesFrameworkError('The key type of the verification method does not match the suite') } - const verificationMethodObject = (await this.documentLoader(options.verificationMethod)).document as Record< + const documentLoader = this.documentLoaderWithContext(agentContext) + const verificationMethodObject = (await documentLoader(options.verificationMethod)).document as Record< string, unknown > @@ -177,7 +178,7 @@ export class W3cCredentialService { controller: verificationMethodObject['controller'] as string, id: options.verificationMethod, key: signingKey, - wallet: this.wallet, + wallet: agentContext.wallet, }) const suite = new suiteInfo.suiteClass({ @@ -194,7 +195,7 @@ export class W3cCredentialService { presentation: JsonTransformer.toJSON(options.presentation), suite: suite, challenge: options.challenge, - documentLoader: this.documentLoader, + documentLoader: this.documentLoaderWithContext(agentContext), }) return JsonTransformer.fromJSON(result, W3cVerifiablePresentation) @@ -206,9 +207,12 @@ export class W3cCredentialService { * @param presentation the presentation to be verified * @returns the verification result */ - public async verifyPresentation(options: VerifyPresentationOptions): Promise { + public async verifyPresentation( + agentContext: AgentContext, + options: VerifyPresentationOptions + ): Promise { // create keyPair - const WalletKeyPair = createWalletKeyPairClass(this.wallet) + const WalletKeyPair = createWalletKeyPairClass(agentContext.wallet) let proofs = options.presentation.proof @@ -235,14 +239,16 @@ export class W3cCredentialService { ? options.presentation.verifiableCredential : [options.presentation.verifiableCredential] - const credentialSuites = credentials.map((credential) => this.getSignatureSuitesForCredential(credential)) + const credentialSuites = credentials.map((credential) => + this.getSignatureSuitesForCredential(agentContext, credential) + ) const allSuites = presentationSuites.concat(...credentialSuites) const verifyOptions: Record = { presentation: JsonTransformer.toJSON(options.presentation), suite: allSuites, challenge: options.challenge, - documentLoader: this.documentLoader, + documentLoader: this.documentLoaderWithContext(agentContext), } // this is a hack because vcjs throws if purpose is passed as undefined or null @@ -255,7 +261,7 @@ export class W3cCredentialService { return result as unknown as VerifyPresentationResult } - public async deriveProof(options: DeriveProofOptions): Promise { + public async deriveProof(agentContext: AgentContext, options: DeriveProofOptions): Promise { const suiteInfo = this.suiteRegistry.getByProofType('BbsBlsSignatureProof2020') const SuiteClass = suiteInfo.suiteClass @@ -263,48 +269,54 @@ export class W3cCredentialService { const proof = await deriveProof(JsonTransformer.toJSON(options.credential), options.revealDocument, { suite: suite, - documentLoader: this.documentLoader, + documentLoader: this.documentLoaderWithContext(agentContext), }) return proof } - public documentLoader = async (url: string): Promise => { - if (url.startsWith('did:')) { - const result = await this.didResolver.resolve(url) + public documentLoaderWithContext = (agentContext: AgentContext): DocumentLoader => { + return async (url: string) => { + if (url.startsWith('did:')) { + const result = await this.didResolver.resolve(agentContext, url) - if (result.didResolutionMetadata.error || !result.didDocument) { - throw new AriesFrameworkError(`Unable to resolve DID: ${url}`) - } + if (result.didResolutionMetadata.error || !result.didDocument) { + throw new AriesFrameworkError(`Unable to resolve DID: ${url}`) + } - const framed = await jsonld.frame(result.didDocument.toJSON(), { - '@context': result.didDocument.context, - '@embed': '@never', - id: url, - }) + const framed = await jsonld.frame(result.didDocument.toJSON(), { + '@context': result.didDocument.context, + '@embed': '@never', + id: url, + }) - return { - contextUrl: null, - documentUrl: url, - document: framed, + return { + contextUrl: null, + documentUrl: url, + document: framed, + } } - } - let loader + let loader - if (isNodeJS()) { - loader = documentLoaderNode.apply(jsonld, []) - } else if (isReactNative()) { - loader = documentLoaderXhr.apply(jsonld, []) - } else { - throw new AriesFrameworkError('Unsupported environment') - } + if (isNodeJS()) { + loader = documentLoaderNode.apply(jsonld, []) + } else if (isReactNative()) { + loader = documentLoaderXhr.apply(jsonld, []) + } else { + throw new AriesFrameworkError('Unsupported environment') + } - return await loader(url) + return await loader(url) + } } - private async getPublicKeyFromVerificationMethod(verificationMethod: string): Promise { - const verificationMethodObject = await this.documentLoader(verificationMethod) + private async getPublicKeyFromVerificationMethod( + agentContext: AgentContext, + verificationMethod: string + ): Promise { + const documentLoader = this.documentLoaderWithContext(agentContext) + const verificationMethodObject = await documentLoader(verificationMethod) const verificationMethodClass = JsonTransformer.fromJSON(verificationMethodObject.document, VerificationMethod) const key = getKeyDidMappingByVerificationMethod(verificationMethodClass) @@ -318,10 +330,15 @@ export class W3cCredentialService { * @param record the credential to be stored * @returns the credential record that was written to storage */ - public async storeCredential(options: StoreCredentialOptions): Promise { + public async storeCredential( + agentContext: AgentContext, + options: StoreCredentialOptions + ): Promise { // Get the expanded types const expandedTypes = ( - await jsonld.expand(JsonTransformer.toJSON(options.record), { documentLoader: this.documentLoader }) + await jsonld.expand(JsonTransformer.toJSON(options.record), { + documentLoader: this.documentLoaderWithContext(agentContext), + }) )[0]['@type'] // Create an instance of the w3cCredentialRecord @@ -331,36 +348,38 @@ export class W3cCredentialService { }) // Store the w3c credential record - await this.w3cCredentialRepository.save(w3cCredentialRecord) + await this.w3cCredentialRepository.save(agentContext, w3cCredentialRecord) return w3cCredentialRecord } - public async getAllCredentials(): Promise { - const allRecords = await this.w3cCredentialRepository.getAll() + public async getAllCredentials(agentContext: AgentContext): Promise { + const allRecords = await this.w3cCredentialRepository.getAll(agentContext) return allRecords.map((record) => record.credential) } - public async getCredentialById(id: string): Promise { - return (await this.w3cCredentialRepository.getById(id)).credential + public async getCredentialById(agentContext: AgentContext, id: string): Promise { + return (await this.w3cCredentialRepository.getById(agentContext, id)).credential } public async findCredentialsByQuery( - query: Parameters[0] + agentContext: AgentContext, + query: Parameters[1] ): Promise { - const result = await this.w3cCredentialRepository.findByQuery(query) + const result = await this.w3cCredentialRepository.findByQuery(agentContext, query) return result.map((record) => record.credential) } public async findSingleCredentialByQuery( - query: Parameters[0] + agentContext: AgentContext, + query: Parameters[1] ): Promise { - const result = await this.w3cCredentialRepository.findSingleByQuery(query) + const result = await this.w3cCredentialRepository.findSingleByQuery(agentContext, query) return result?.credential } - private getSignatureSuitesForCredential(credential: W3cVerifiableCredential) { - const WalletKeyPair = createWalletKeyPairClass(this.wallet) + private getSignatureSuitesForCredential(agentContext: AgentContext, credential: W3cVerifiableCredential) { + const WalletKeyPair = createWalletKeyPairClass(agentContext.wallet) let proofs = credential.proof diff --git a/packages/core/src/modules/vc/__tests__/W3cCredentialService.test.ts b/packages/core/src/modules/vc/__tests__/W3cCredentialService.test.ts index 20a28c5e19..b9a8e9d2a8 100644 --- a/packages/core/src/modules/vc/__tests__/W3cCredentialService.test.ts +++ b/packages/core/src/modules/vc/__tests__/W3cCredentialService.test.ts @@ -1,6 +1,6 @@ -import type { AgentConfig } from '../../../agent/AgentConfig' +import type { AgentContext } from '../../../agent' -import { getAgentConfig } from '../../../../tests/helpers' +import { getAgentConfig, getAgentContext } from '../../../../tests/helpers' import { KeyType } from '../../../crypto' import { Key } from '../../../crypto/Key' import { JsonTransformer } from '../../../utils/JsonTransformer' @@ -30,23 +30,32 @@ const DidRepositoryMock = DidRepository as unknown as jest.Mock jest.mock('../repository/W3cCredentialRepository') const W3cCredentialRepositoryMock = W3cCredentialRepository as jest.Mock +const agentConfig = getAgentConfig('W3cCredentialServiceTest') + describe('W3cCredentialService', () => { let wallet: IndyWallet - let agentConfig: AgentConfig + let agentContext: AgentContext let didResolverService: DidResolverService let w3cCredentialService: W3cCredentialService let w3cCredentialRepository: W3cCredentialRepository const seed = 'testseed000000000000000000000001' beforeAll(async () => { - agentConfig = getAgentConfig('W3cCredentialServiceTest') - wallet = new IndyWallet(agentConfig) + wallet = new IndyWallet(agentConfig.agentDependencies, agentConfig.logger) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await wallet.createAndOpen(agentConfig.walletConfig!) - didResolverService = new DidResolverService(agentConfig, new IndyLedgerServiceMock(), new DidRepositoryMock()) + agentContext = getAgentContext({ + agentConfig, + wallet, + }) + didResolverService = new DidResolverService( + new IndyLedgerServiceMock(), + new DidRepositoryMock(), + agentConfig.logger + ) w3cCredentialRepository = new W3cCredentialRepositoryMock() - w3cCredentialService = new W3cCredentialService(wallet, w3cCredentialRepository, didResolverService) - w3cCredentialService.documentLoader = customDocumentLoader + w3cCredentialService = new W3cCredentialService(w3cCredentialRepository, didResolverService) + w3cCredentialService.documentLoaderWithContext = () => customDocumentLoader }) afterAll(async () => { @@ -69,7 +78,7 @@ describe('W3cCredentialService', () => { const credential = JsonTransformer.fromJSON(credentialJson, W3cCredential) - const vc = await w3cCredentialService.signCredential({ + const vc = await w3cCredentialService.signCredential(agentContext, { credential, proofType: 'Ed25519Signature2018', verificationMethod: verificationMethod, @@ -91,7 +100,7 @@ describe('W3cCredentialService', () => { const credential = JsonTransformer.fromJSON(credentialJson, W3cCredential) expect(async () => { - await w3cCredentialService.signCredential({ + await w3cCredentialService.signCredential(agentContext, { credential, proofType: 'Ed25519Signature2018', verificationMethod: @@ -106,7 +115,7 @@ describe('W3cCredentialService', () => { Ed25519Signature2018Fixtures.TEST_LD_DOCUMENT_SIGNED, W3cVerifiableCredential ) - const result = await w3cCredentialService.verifyCredential({ credential: vc }) + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: vc }) expect(result.verified).toBe(true) expect(result.error).toBeUndefined() @@ -121,7 +130,7 @@ describe('W3cCredentialService', () => { Ed25519Signature2018Fixtures.TEST_LD_DOCUMENT_BAD_SIGNED, W3cVerifiableCredential ) - const result = await w3cCredentialService.verifyCredential({ credential: vc }) + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: vc }) expect(result.verified).toBe(false) expect(result.error).toBeDefined() @@ -141,7 +150,7 @@ describe('W3cCredentialService', () => { } const vc = JsonTransformer.fromJSON(vcJson, W3cVerifiableCredential) - const result = await w3cCredentialService.verifyCredential({ credential: vc }) + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: vc }) expect(result.verified).toBe(false) @@ -163,7 +172,7 @@ describe('W3cCredentialService', () => { } const vc = JsonTransformer.fromJSON(vcJson, W3cVerifiableCredential) - const result = await w3cCredentialService.verifyCredential({ credential: vc }) + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: vc }) expect(result.verified).toBe(false) @@ -220,7 +229,7 @@ describe('W3cCredentialService', () => { date: new Date().toISOString(), }) - const verifiablePresentation = await w3cCredentialService.signPresentation({ + const verifiablePresentation = await w3cCredentialService.signPresentation(agentContext, { presentation: presentation, purpose: purpose, signatureType: 'Ed25519Signature2018', @@ -238,7 +247,7 @@ describe('W3cCredentialService', () => { W3cVerifiablePresentation ) - const result = await w3cCredentialService.verifyPresentation({ + const result = await w3cCredentialService.verifyPresentation(agentContext, { presentation: vp, proofType: 'Ed25519Signature2018', challenge: '7bf32d0b-39d4-41f3-96b6-45de52988e4c', @@ -256,7 +265,7 @@ describe('W3cCredentialService', () => { W3cVerifiableCredential ) - const w3cCredentialRecord = await w3cCredentialService.storeCredential({ record: credential }) + const w3cCredentialRecord = await w3cCredentialService.storeCredential(agentContext, { record: credential }) expect(w3cCredentialRecord).toMatchObject({ type: 'W3cCredentialRecord', @@ -290,7 +299,7 @@ describe('W3cCredentialService', () => { const credential = JsonTransformer.fromJSON(credentialJson, W3cCredential) - const vc = await w3cCredentialService.signCredential({ + const vc = await w3cCredentialService.signCredential(agentContext, { credential, proofType: 'BbsBlsSignature2020', verificationMethod: verificationMethod, @@ -307,7 +316,7 @@ describe('W3cCredentialService', () => { }) describe('verifyCredential', () => { it('should verify the credential successfully', async () => { - const result = await w3cCredentialService.verifyCredential({ + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: JsonTransformer.fromJSON( BbsBlsSignature2020Fixtures.TEST_LD_DOCUMENT_SIGNED, W3cVerifiableCredential @@ -340,7 +349,7 @@ describe('W3cCredentialService', () => { }, } - const result = await w3cCredentialService.deriveProof({ + const result = await w3cCredentialService.deriveProof(agentContext, { credential: vc, revealDocument: revealDocument, verificationMethod: verificationMethod, @@ -354,7 +363,7 @@ describe('W3cCredentialService', () => { }) describe('verifyDerived', () => { it('should verify the derived proof successfully', async () => { - const result = await w3cCredentialService.verifyCredential({ + const result = await w3cCredentialService.verifyCredential(agentContext, { credential: JsonTransformer.fromJSON(BbsBlsSignature2020Fixtures.TEST_VALID_DERIVED, W3cVerifiableCredential), proofPurpose: new purposes.AssertionProofPurpose(), }) @@ -388,7 +397,7 @@ describe('W3cCredentialService', () => { date: new Date().toISOString(), }) - const verifiablePresentation = await w3cCredentialService.signPresentation({ + const verifiablePresentation = await w3cCredentialService.signPresentation(agentContext, { presentation: presentation, purpose: purpose, signatureType: 'Ed25519Signature2018', @@ -406,7 +415,7 @@ describe('W3cCredentialService', () => { W3cVerifiablePresentation ) - const result = await w3cCredentialService.verifyPresentation({ + const result = await w3cCredentialService.verifyPresentation(agentContext, { presentation: vp, proofType: 'Ed25519Signature2018', challenge: 'e950bfe5-d7ec-4303-ad61-6983fb976ac9', diff --git a/packages/core/src/storage/InMemoryMessageRepository.ts b/packages/core/src/storage/InMemoryMessageRepository.ts index d2e404d40b..5cec5f5102 100644 --- a/packages/core/src/storage/InMemoryMessageRepository.ts +++ b/packages/core/src/storage/InMemoryMessageRepository.ts @@ -1,17 +1,17 @@ -import type { Logger } from '../logger' import type { EncryptedMessage } from '../types' import type { MessageRepository } from './MessageRepository' -import { AgentConfig } from '../agent/AgentConfig' -import { injectable } from '../plugins' +import { InjectionSymbols } from '../constants' +import { Logger } from '../logger' +import { injectable, inject } from '../plugins' @injectable() export class InMemoryMessageRepository implements MessageRepository { private logger: Logger private messages: { [key: string]: EncryptedMessage[] } = {} - public constructor(agentConfig: AgentConfig) { - this.logger = agentConfig.logger + public constructor(@inject(InjectionSymbols.Logger) logger: Logger) { + this.logger = logger } public takeFromQueue(connectionId: string, limit?: number) { diff --git a/packages/core/src/storage/IndyStorageService.ts b/packages/core/src/storage/IndyStorageService.ts index 65dfccd953..3ceac062d3 100644 --- a/packages/core/src/storage/IndyStorageService.ts +++ b/packages/core/src/storage/IndyStorageService.ts @@ -1,18 +1,20 @@ +import type { AgentContext } from '../agent' +import type { IndyWallet } from '../wallet/IndyWallet' import type { BaseRecord, TagsBase } from './BaseRecord' -import type { StorageService, BaseRecordConstructor, Query } from './StorageService' +import type { BaseRecordConstructor, Query, StorageService } from './StorageService' import type { default as Indy, WalletQuery, WalletRecord, WalletSearchOptions } from 'indy-sdk' -import { AgentConfig } from '../agent/AgentConfig' -import { RecordNotFoundError, RecordDuplicateError, IndySdkError } from '../error' -import { injectable } from '../plugins' +import { AgentDependencies } from '../agent/AgentDependencies' +import { InjectionSymbols } from '../constants' +import { IndySdkError, RecordDuplicateError, RecordNotFoundError } from '../error' +import { injectable, inject } from '../plugins' import { JsonTransformer } from '../utils/JsonTransformer' import { isIndyError } from '../utils/indyError' import { isBoolean } from '../utils/type' -import { IndyWallet } from '../wallet/IndyWallet' +import { assertIndyWallet } from '../wallet/util/assertIndyWallet' @injectable() export class IndyStorageService implements StorageService { - private wallet: IndyWallet private indy: typeof Indy private static DEFAULT_QUERY_OPTIONS = { @@ -20,9 +22,8 @@ export class IndyStorageService implements StorageService< retrieveTags: true, } - public constructor(wallet: IndyWallet, agentConfig: AgentConfig) { - this.wallet = wallet - this.indy = agentConfig.agentDependencies.indy + public constructor(@inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies) { + this.indy = agentDependencies.indy } private transformToRecordTagValues(tags: { [key: number]: string | undefined }): TagsBase { @@ -133,12 +134,14 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async save(record: T) { + public async save(agentContext: AgentContext, record: T) { + assertIndyWallet(agentContext.wallet) + const value = JsonTransformer.serialize(record) const tags = this.transformFromRecordTagValues(record.getTags()) as Record try { - await this.indy.addWalletRecord(this.wallet.handle, record.type, record.id, value, tags) + await this.indy.addWalletRecord(agentContext.wallet.handle, record.type, record.id, value, tags) } catch (error) { // Record already exists if (isIndyError(error, 'WalletItemAlreadyExists')) { @@ -150,13 +153,15 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async update(record: T): Promise { + public async update(agentContext: AgentContext, record: T): Promise { + assertIndyWallet(agentContext.wallet) + const value = JsonTransformer.serialize(record) const tags = this.transformFromRecordTagValues(record.getTags()) as Record try { - await this.indy.updateWalletRecordValue(this.wallet.handle, record.type, record.id, value) - await this.indy.updateWalletRecordTags(this.wallet.handle, record.type, record.id, tags) + await this.indy.updateWalletRecordValue(agentContext.wallet.handle, record.type, record.id, value) + await this.indy.updateWalletRecordTags(agentContext.wallet.handle, record.type, record.id, tags) } catch (error) { // Record does not exist if (isIndyError(error, 'WalletItemNotFound')) { @@ -171,9 +176,11 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async delete(record: T) { + public async delete(agentContext: AgentContext, record: T) { + assertIndyWallet(agentContext.wallet) + try { - await this.indy.deleteWalletRecord(this.wallet.handle, record.type, record.id) + await this.indy.deleteWalletRecord(agentContext.wallet.handle, record.type, record.id) } catch (error) { // Record does not exist if (isIndyError(error, 'WalletItemNotFound')) { @@ -188,10 +195,12 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async getById(recordClass: BaseRecordConstructor, id: string): Promise { + public async getById(agentContext: AgentContext, recordClass: BaseRecordConstructor, id: string): Promise { + assertIndyWallet(agentContext.wallet) + try { const record = await this.indy.getWalletRecord( - this.wallet.handle, + agentContext.wallet.handle, recordClass.type, id, IndyStorageService.DEFAULT_QUERY_OPTIONS @@ -210,8 +219,15 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async getAll(recordClass: BaseRecordConstructor): Promise { - const recordIterator = this.search(recordClass.type, {}, IndyStorageService.DEFAULT_QUERY_OPTIONS) + public async getAll(agentContext: AgentContext, recordClass: BaseRecordConstructor): Promise { + assertIndyWallet(agentContext.wallet) + + const recordIterator = this.search( + agentContext.wallet, + recordClass.type, + {}, + IndyStorageService.DEFAULT_QUERY_OPTIONS + ) const records = [] for await (const record of recordIterator) { records.push(this.recordToInstance(record, recordClass)) @@ -220,10 +236,21 @@ export class IndyStorageService implements StorageService< } /** @inheritDoc */ - public async findByQuery(recordClass: BaseRecordConstructor, query: Query): Promise { + public async findByQuery( + agentContext: AgentContext, + recordClass: BaseRecordConstructor, + query: Query + ): Promise { + assertIndyWallet(agentContext.wallet) + const indyQuery = this.indyQueryFromSearchQuery(query) - const recordIterator = this.search(recordClass.type, indyQuery, IndyStorageService.DEFAULT_QUERY_OPTIONS) + const recordIterator = this.search( + agentContext.wallet, + recordClass.type, + indyQuery, + IndyStorageService.DEFAULT_QUERY_OPTIONS + ) const records = [] for await (const record of recordIterator) { records.push(this.recordToInstance(record, recordClass)) @@ -232,12 +259,13 @@ export class IndyStorageService implements StorageService< } private async *search( + wallet: IndyWallet, type: string, query: WalletQuery, { limit = Infinity, ...options }: WalletSearchOptions & { limit?: number } ) { try { - const searchHandle = await this.indy.openWalletSearch(this.wallet.handle, type, query, options) + const searchHandle = await this.indy.openWalletSearch(wallet.handle, type, query, options) let records: Indy.WalletRecord[] = [] @@ -247,7 +275,7 @@ export class IndyStorageService implements StorageService< // Loop while limit not reached (or no limit specified) while (!limit || records.length < limit) { // Retrieve records - const recordsJson = await this.indy.fetchWalletSearchNextRecords(this.wallet.handle, searchHandle, chunk) + const recordsJson = await this.indy.fetchWalletSearchNextRecords(wallet.handle, searchHandle, chunk) if (recordsJson.records) { records = [...records, ...recordsJson.records] diff --git a/packages/core/src/storage/Repository.ts b/packages/core/src/storage/Repository.ts index 49b4898753..cfbcc4193d 100644 --- a/packages/core/src/storage/Repository.ts +++ b/packages/core/src/storage/Repository.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../agent' import type { EventEmitter } from '../agent/EventEmitter' import type { BaseRecord } from './BaseRecord' import type { RecordSavedEvent, RecordUpdatedEvent, RecordDeletedEvent } from './RepositoryEvents' @@ -24,9 +25,9 @@ export class Repository> { } /** @inheritDoc {StorageService#save} */ - public async save(record: T): Promise { - await this.storageService.save(record) - this.eventEmitter.emit>({ + public async save(agentContext: AgentContext, record: T): Promise { + await this.storageService.save(agentContext, record) + this.eventEmitter.emit>(agentContext, { type: RepositoryEventTypes.RecordSaved, payload: { record, @@ -35,9 +36,9 @@ export class Repository> { } /** @inheritDoc {StorageService#update} */ - public async update(record: T): Promise { - await this.storageService.update(record) - this.eventEmitter.emit>({ + public async update(agentContext: AgentContext, record: T): Promise { + await this.storageService.update(agentContext, record) + this.eventEmitter.emit>(agentContext, { type: RepositoryEventTypes.RecordUpdated, payload: { record, @@ -46,9 +47,9 @@ export class Repository> { } /** @inheritDoc {StorageService#delete} */ - public async delete(record: T): Promise { - await this.storageService.delete(record) - this.eventEmitter.emit>({ + public async delete(agentContext: AgentContext, record: T): Promise { + await this.storageService.delete(agentContext, record) + this.eventEmitter.emit>(agentContext, { type: RepositoryEventTypes.RecordDeleted, payload: { record, @@ -57,8 +58,8 @@ export class Repository> { } /** @inheritDoc {StorageService#getById} */ - public async getById(id: string): Promise { - return this.storageService.getById(this.recordClass, id) + public async getById(agentContext: AgentContext, id: string): Promise { + return this.storageService.getById(agentContext, this.recordClass, id) } /** @@ -66,9 +67,9 @@ export class Repository> { * @param id the id of the record to retrieve * @returns */ - public async findById(id: string): Promise { + public async findById(agentContext: AgentContext, id: string): Promise { try { - return await this.storageService.getById(this.recordClass, id) + return await this.storageService.getById(agentContext, this.recordClass, id) } catch (error) { if (error instanceof RecordNotFoundError) return null @@ -77,13 +78,13 @@ export class Repository> { } /** @inheritDoc {StorageService#getAll} */ - public async getAll(): Promise { - return this.storageService.getAll(this.recordClass) + public async getAll(agentContext: AgentContext): Promise { + return this.storageService.getAll(agentContext, this.recordClass) } /** @inheritDoc {StorageService#findByQuery} */ - public async findByQuery(query: Query): Promise { - return this.storageService.findByQuery(this.recordClass, query) + public async findByQuery(agentContext: AgentContext, query: Query): Promise { + return this.storageService.findByQuery(agentContext, this.recordClass, query) } /** @@ -92,8 +93,8 @@ export class Repository> { * @returns the record, or null if not found * @throws {RecordDuplicateError} if multiple records are found for the given query */ - public async findSingleByQuery(query: Query): Promise { - const records = await this.findByQuery(query) + public async findSingleByQuery(agentContext: AgentContext, query: Query): Promise { + const records = await this.findByQuery(agentContext, query) if (records.length > 1) { throw new RecordDuplicateError(`Multiple records found for given query '${JSON.stringify(query)}'`, { @@ -115,8 +116,8 @@ export class Repository> { * @throws {RecordDuplicateError} if multiple records are found for the given query * @throws {RecordNotFoundError} if no record is found for the given query */ - public async getSingleByQuery(query: Query): Promise { - const record = await this.findSingleByQuery(query) + public async getSingleByQuery(agentContext: AgentContext, query: Query): Promise { + const record = await this.findSingleByQuery(agentContext, query) if (!record) { throw new RecordNotFoundError(`No record found for given query '${JSON.stringify(query)}'`, { diff --git a/packages/core/src/storage/StorageService.ts b/packages/core/src/storage/StorageService.ts index 1491c94764..2925269c22 100644 --- a/packages/core/src/storage/StorageService.ts +++ b/packages/core/src/storage/StorageService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../agent' import type { Constructor } from '../utils/mixins' import type { BaseRecord, TagsBase } from './BaseRecord' @@ -24,7 +25,7 @@ export interface StorageService> { * @param record the record to store * @throws {RecordDuplicateError} if a record with this id already exists */ - save(record: T): Promise + save(agentContext: AgentContext, record: T): Promise /** * Update record in storage @@ -32,7 +33,7 @@ export interface StorageService> { * @param record the record to update * @throws {RecordNotFoundError} if a record with this id and type does not exist */ - update(record: T): Promise + update(agentContext: AgentContext, record: T): Promise /** * Delete record from storage @@ -40,7 +41,7 @@ export interface StorageService> { * @param record the record to delete * @throws {RecordNotFoundError} if a record with this id and type does not exist */ - delete(record: T): Promise + delete(agentContext: AgentContext, record: T): Promise /** * Get record by id. @@ -49,14 +50,14 @@ export interface StorageService> { * @param id the id of the record to retrieve from storage * @throws {RecordNotFoundError} if a record with this id and type does not exist */ - getById(recordClass: BaseRecordConstructor, id: string): Promise + getById(agentContext: AgentContext, recordClass: BaseRecordConstructor, id: string): Promise /** * Get all records by specified record class. * * @param recordClass the record class to get records for */ - getAll(recordClass: BaseRecordConstructor): Promise + getAll(agentContext: AgentContext, recordClass: BaseRecordConstructor): Promise /** * Find all records by specified record class and query. @@ -64,5 +65,5 @@ export interface StorageService> { * @param recordClass the record class to find records for * @param query the query to use for finding records */ - findByQuery(recordClass: BaseRecordConstructor, query: Query): Promise + findByQuery(agentContext: AgentContext, recordClass: BaseRecordConstructor, query: Query): Promise } diff --git a/packages/core/src/storage/__tests__/DidCommMessageRepository.test.ts b/packages/core/src/storage/__tests__/DidCommMessageRepository.test.ts index 3b6816e546..067a290dcb 100644 --- a/packages/core/src/storage/__tests__/DidCommMessageRepository.test.ts +++ b/packages/core/src/storage/__tests__/DidCommMessageRepository.test.ts @@ -1,4 +1,6 @@ -import { getAgentConfig, mockFunction } from '../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, mockFunction } from '../../../tests/helpers' import { EventEmitter } from '../../agent/EventEmitter' import { ConnectionInvitationMessage } from '../../modules/connections' import { JsonTransformer } from '../../utils/JsonTransformer' @@ -17,14 +19,17 @@ const invitationJson = { label: 'test', } -describe('Repository', () => { +const config = getAgentConfig('DidCommMessageRepository') +const agentContext = getAgentContext() + +describe('DidCommMessageRepository', () => { let repository: DidCommMessageRepository let storageMock: IndyStorageService let eventEmitter: EventEmitter beforeEach(async () => { storageMock = new StorageMock() - eventEmitter = new EventEmitter(getAgentConfig('DidCommMessageRepositoryTest')) + eventEmitter = new EventEmitter(config.agentDependencies, new Subject()) repository = new DidCommMessageRepository(storageMock, eventEmitter) }) @@ -42,12 +47,12 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record])) - const invitation = await repository.findAgentMessage({ + const invitation = await repository.findAgentMessage(agentContext, { messageClass: ConnectionInvitationMessage, associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) - expect(storageMock.findByQuery).toBeCalledWith(DidCommMessageRecord, { + expect(storageMock.findByQuery).toBeCalledWith(agentContext, DidCommMessageRecord, { associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', messageName: 'invitation', protocolName: 'connections', @@ -61,12 +66,12 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record])) - const invitation = await repository.findAgentMessage({ + const invitation = await repository.findAgentMessage(agentContext, { messageClass: ConnectionInvitationMessage, associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) - expect(storageMock.findByQuery).toBeCalledWith(DidCommMessageRecord, { + expect(storageMock.findByQuery).toBeCalledWith(agentContext, DidCommMessageRecord, { associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', messageName: 'invitation', protocolName: 'connections', @@ -78,12 +83,12 @@ describe('Repository', () => { it("should return null because the record doesn't exist", async () => { mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([])) - const invitation = await repository.findAgentMessage({ + const invitation = await repository.findAgentMessage(agentContext, { messageClass: ConnectionInvitationMessage, associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) - expect(storageMock.findByQuery).toBeCalledWith(DidCommMessageRecord, { + expect(storageMock.findByQuery).toBeCalledWith(agentContext, DidCommMessageRecord, { associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', messageName: 'invitation', protocolName: 'connections', @@ -95,13 +100,14 @@ describe('Repository', () => { describe('saveAgentMessage()', () => { it('should transform and save the agent message', async () => { - await repository.saveAgentMessage({ + await repository.saveAgentMessage(agentContext, { role: DidCommMessageRole.Receiver, agentMessage: JsonTransformer.fromJSON(invitationJson, ConnectionInvitationMessage), associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) expect(storageMock.save).toBeCalledWith( + agentContext, expect.objectContaining({ role: DidCommMessageRole.Receiver, message: invitationJson, @@ -114,13 +120,14 @@ describe('Repository', () => { describe('saveOrUpdateAgentMessage()', () => { it('should transform and save the agent message', async () => { mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([])) - await repository.saveOrUpdateAgentMessage({ + await repository.saveOrUpdateAgentMessage(agentContext, { role: DidCommMessageRole.Receiver, agentMessage: JsonTransformer.fromJSON(invitationJson, ConnectionInvitationMessage), associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) expect(storageMock.save).toBeCalledWith( + agentContext, expect.objectContaining({ role: DidCommMessageRole.Receiver, message: invitationJson, @@ -132,19 +139,19 @@ describe('Repository', () => { it('should transform and update the agent message', async () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record])) - await repository.saveOrUpdateAgentMessage({ + await repository.saveOrUpdateAgentMessage(agentContext, { role: DidCommMessageRole.Receiver, agentMessage: JsonTransformer.fromJSON(invitationJson, ConnectionInvitationMessage), associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', }) - expect(storageMock.findByQuery).toBeCalledWith(DidCommMessageRecord, { + expect(storageMock.findByQuery).toBeCalledWith(agentContext, DidCommMessageRecord, { associatedRecordId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', messageName: 'invitation', protocolName: 'connections', protocolMajorVersion: '1', }) - expect(storageMock.update).toBeCalledWith(record) + expect(storageMock.update).toBeCalledWith(agentContext, record) }) }) }) diff --git a/packages/core/src/storage/__tests__/IndyStorageService.test.ts b/packages/core/src/storage/__tests__/IndyStorageService.test.ts index dde5754cc4..08f2df3fa7 100644 --- a/packages/core/src/storage/__tests__/IndyStorageService.test.ts +++ b/packages/core/src/storage/__tests__/IndyStorageService.test.ts @@ -1,8 +1,8 @@ +import type { AgentContext } from '../../agent' import type { TagsBase } from '../BaseRecord' import type * as Indy from 'indy-sdk' -import { agentDependencies, getAgentConfig } from '../../../tests/helpers' -import { AgentConfig } from '../../agent/AgentConfig' +import { agentDependencies, getAgentConfig, getAgentContext } from '../../../tests/helpers' import { RecordDuplicateError, RecordNotFoundError } from '../../error' import { IndyWallet } from '../../wallet/IndyWallet' import { IndyStorageService } from '../IndyStorageService' @@ -13,14 +13,19 @@ describe('IndyStorageService', () => { let wallet: IndyWallet let indy: typeof Indy let storageService: IndyStorageService + let agentContext: AgentContext beforeEach(async () => { - const config = getAgentConfig('IndyStorageServiceTest') - indy = config.agentDependencies.indy - wallet = new IndyWallet(config) + const agentConfig = getAgentConfig('IndyStorageServiceTest') + indy = agentConfig.agentDependencies.indy + wallet = new IndyWallet(agentConfig.agentDependencies, agentConfig.logger) + agentContext = getAgentContext({ + wallet, + agentConfig, + }) // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await wallet.createAndOpen(config.walletConfig!) - storageService = new IndyStorageService(wallet, config) + await wallet.createAndOpen(agentConfig.walletConfig!) + storageService = new IndyStorageService(agentConfig.agentDependencies) }) afterEach(async () => { @@ -34,7 +39,7 @@ describe('IndyStorageService', () => { tags: tags ?? { myTag: 'foobar' }, } const record = new TestRecord(props) - await storageService.save(record) + await storageService.save(agentContext, record) return record } @@ -81,7 +86,7 @@ describe('IndyStorageService', () => { anotherStringNumberValue: 'n__0', }) - const record = await storageService.getById(TestRecord, 'some-id') + const record = await storageService.getById(agentContext, TestRecord, 'some-id') expect(record.getTags()).toEqual({ someBoolean: true, @@ -98,12 +103,12 @@ describe('IndyStorageService', () => { it('should throw RecordDuplicateError if a record with the id already exists', async () => { const record = await insertRecord({ id: 'test-id' }) - return expect(() => storageService.save(record)).rejects.toThrowError(RecordDuplicateError) + return expect(() => storageService.save(agentContext, record)).rejects.toThrowError(RecordDuplicateError) }) it('should save the record', async () => { const record = await insertRecord({ id: 'test-id' }) - const found = await storageService.getById(TestRecord, 'test-id') + const found = await storageService.getById(agentContext, TestRecord, 'test-id') expect(record).toEqual(found) }) @@ -111,14 +116,14 @@ describe('IndyStorageService', () => { describe('getById()', () => { it('should throw RecordNotFoundError if the record does not exist', async () => { - return expect(() => storageService.getById(TestRecord, 'does-not-exist')).rejects.toThrowError( + return expect(() => storageService.getById(agentContext, TestRecord, 'does-not-exist')).rejects.toThrowError( RecordNotFoundError ) }) it('should return the record by id', async () => { const record = await insertRecord({ id: 'test-id' }) - const found = await storageService.getById(TestRecord, 'test-id') + const found = await storageService.getById(agentContext, TestRecord, 'test-id') expect(found).toEqual(record) }) @@ -132,7 +137,7 @@ describe('IndyStorageService', () => { tags: { some: 'tag' }, }) - return expect(() => storageService.update(record)).rejects.toThrowError(RecordNotFoundError) + return expect(() => storageService.update(agentContext, record)).rejects.toThrowError(RecordNotFoundError) }) it('should update the record', async () => { @@ -140,9 +145,9 @@ describe('IndyStorageService', () => { record.replaceTags({ ...record.getTags(), foo: 'bar' }) record.foo = 'foobaz' - await storageService.update(record) + await storageService.update(agentContext, record) - const retrievedRecord = await storageService.getById(TestRecord, record.id) + const retrievedRecord = await storageService.getById(agentContext, TestRecord, record.id) expect(retrievedRecord).toEqual(record) }) }) @@ -155,14 +160,16 @@ describe('IndyStorageService', () => { tags: { some: 'tag' }, }) - return expect(() => storageService.delete(record)).rejects.toThrowError(RecordNotFoundError) + return expect(() => storageService.delete(agentContext, record)).rejects.toThrowError(RecordNotFoundError) }) it('should delete the record', async () => { const record = await insertRecord({ id: 'test-id' }) - await storageService.delete(record) + await storageService.delete(agentContext, record) - return expect(() => storageService.getById(TestRecord, record.id)).rejects.toThrowError(RecordNotFoundError) + return expect(() => storageService.getById(agentContext, TestRecord, record.id)).rejects.toThrowError( + RecordNotFoundError + ) }) }) @@ -174,7 +181,7 @@ describe('IndyStorageService', () => { .map((_, index) => insertRecord({ id: `record-${index}` })) ) - const records = await storageService.getAll(TestRecord) + const records = await storageService.getAll(agentContext, TestRecord) expect(records).toEqual(expect.arrayContaining(createdRecords)) }) @@ -185,7 +192,7 @@ describe('IndyStorageService', () => { const expectedRecord = await insertRecord({ tags: { myTag: 'foobar' } }) await insertRecord({ tags: { myTag: 'notfoobar' } }) - const records = await storageService.findByQuery(TestRecord, { myTag: 'foobar' }) + const records = await storageService.findByQuery(agentContext, TestRecord, { myTag: 'foobar' }) expect(records.length).toBe(1) expect(records[0]).toEqual(expectedRecord) @@ -195,7 +202,7 @@ describe('IndyStorageService', () => { const expectedRecord = await insertRecord({ tags: { myTag: 'foo', anotherTag: 'bar' } }) await insertRecord({ tags: { myTag: 'notfoobar' } }) - const records = await storageService.findByQuery(TestRecord, { + const records = await storageService.findByQuery(agentContext, TestRecord, { $and: [{ myTag: 'foo' }, { anotherTag: 'bar' }], }) @@ -208,7 +215,7 @@ describe('IndyStorageService', () => { const expectedRecord2 = await insertRecord({ tags: { anotherTag: 'bar' } }) await insertRecord({ tags: { myTag: 'notfoobar' } }) - const records = await storageService.findByQuery(TestRecord, { + const records = await storageService.findByQuery(agentContext, TestRecord, { $or: [{ myTag: 'foo' }, { anotherTag: 'bar' }], }) @@ -221,7 +228,7 @@ describe('IndyStorageService', () => { const expectedRecord2 = await insertRecord({ tags: { anotherTag: 'bar' } }) await insertRecord({ tags: { myTag: 'notfoobar' } }) - const records = await storageService.findByQuery(TestRecord, { + const records = await storageService.findByQuery(agentContext, TestRecord, { $not: { myTag: 'notfoobar' }, }) @@ -231,22 +238,16 @@ describe('IndyStorageService', () => { it('correctly transforms an advanced query into a valid WQL query', async () => { const indySpy = jest.fn() - const storageServiceWithoutIndy = new IndyStorageService( - wallet, - new AgentConfig( - { label: 'hello' }, - { - ...agentDependencies, - indy: { - openWalletSearch: indySpy, - fetchWalletSearchNextRecords: jest.fn(() => ({ records: undefined })), - closeWalletSearch: jest.fn(), - } as unknown as typeof Indy, - } - ) - ) + const storageServiceWithoutIndy = new IndyStorageService({ + ...agentDependencies, + indy: { + openWalletSearch: indySpy, + fetchWalletSearchNextRecords: jest.fn(() => ({ records: undefined })), + closeWalletSearch: jest.fn(), + } as unknown as typeof Indy, + }) - await storageServiceWithoutIndy.findByQuery(TestRecord, { + await storageServiceWithoutIndy.findByQuery(agentContext, TestRecord, { $and: [ { $or: [{ myTag: true }, { myTag: false }], diff --git a/packages/core/src/storage/__tests__/Repository.test.ts b/packages/core/src/storage/__tests__/Repository.test.ts index 9952646aa0..6ff81c3c64 100644 --- a/packages/core/src/storage/__tests__/Repository.test.ts +++ b/packages/core/src/storage/__tests__/Repository.test.ts @@ -1,7 +1,10 @@ +import type { AgentContext } from '../../agent' import type { TagsBase } from '../BaseRecord' import type { RecordDeletedEvent, RecordSavedEvent, RecordUpdatedEvent } from '../RepositoryEvents' -import { getAgentConfig, mockFunction } from '../../../tests/helpers' +import { Subject } from 'rxjs' + +import { getAgentConfig, getAgentContext, mockFunction } from '../../../tests/helpers' import { EventEmitter } from '../../agent/EventEmitter' import { AriesFrameworkError, RecordDuplicateError, RecordNotFoundError } from '../../error' import { IndyStorageService } from '../IndyStorageService' @@ -14,15 +17,19 @@ jest.mock('../IndyStorageService') const StorageMock = IndyStorageService as unknown as jest.Mock> +const config = getAgentConfig('Repository') + describe('Repository', () => { let repository: Repository let storageMock: IndyStorageService + let agentContext: AgentContext let eventEmitter: EventEmitter beforeEach(async () => { storageMock = new StorageMock() - eventEmitter = new EventEmitter(getAgentConfig('RepositoryTest')) + eventEmitter = new EventEmitter(config.agentDependencies, new Subject()) repository = new Repository(TestRecord, storageMock, eventEmitter) + agentContext = getAgentContext() }) const getRecord = ({ id, tags }: { id?: string; tags?: TagsBase } = {}) => { @@ -36,9 +43,9 @@ describe('Repository', () => { describe('save()', () => { it('should save the record using the storage service', async () => { const record = getRecord({ id: 'test-id' }) - await repository.save(record) + await repository.save(agentContext, record) - expect(storageMock.save).toBeCalledWith(record) + expect(storageMock.save).toBeCalledWith(agentContext, record) }) it(`should emit saved event`, async () => { @@ -49,7 +56,7 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) // when - await repository.save(record) + await repository.save(agentContext, record) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -66,9 +73,9 @@ describe('Repository', () => { describe('update()', () => { it('should update the record using the storage service', async () => { const record = getRecord({ id: 'test-id' }) - await repository.update(record) + await repository.update(agentContext, record) - expect(storageMock.update).toBeCalledWith(record) + expect(storageMock.update).toBeCalledWith(agentContext, record) }) it(`should emit updated event`, async () => { @@ -79,7 +86,7 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) // when - await repository.update(record) + await repository.update(agentContext, record) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -96,9 +103,9 @@ describe('Repository', () => { describe('delete()', () => { it('should delete the record using the storage service', async () => { const record = getRecord({ id: 'test-id' }) - await repository.delete(record) + await repository.delete(agentContext, record) - expect(storageMock.delete).toBeCalledWith(record) + expect(storageMock.delete).toBeCalledWith(agentContext, record) }) it(`should emit deleted event`, async () => { @@ -109,7 +116,7 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) // when - await repository.delete(record) + await repository.delete(agentContext, record) // then expect(eventListenerMock).toHaveBeenCalledWith({ @@ -128,9 +135,9 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.getById).mockReturnValue(Promise.resolve(record)) - const returnValue = await repository.getById('test-id') + const returnValue = await repository.getById(agentContext, 'test-id') - expect(storageMock.getById).toBeCalledWith(TestRecord, 'test-id') + expect(storageMock.getById).toBeCalledWith(agentContext, TestRecord, 'test-id') expect(returnValue).toBe(record) }) }) @@ -140,9 +147,9 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.getById).mockReturnValue(Promise.resolve(record)) - const returnValue = await repository.findById('test-id') + const returnValue = await repository.findById(agentContext, 'test-id') - expect(storageMock.getById).toBeCalledWith(TestRecord, 'test-id') + expect(storageMock.getById).toBeCalledWith(agentContext, TestRecord, 'test-id') expect(returnValue).toBe(record) }) @@ -151,17 +158,17 @@ describe('Repository', () => { Promise.reject(new RecordNotFoundError('Not found', { recordType: TestRecord.type })) ) - const returnValue = await repository.findById('test-id') + const returnValue = await repository.findById(agentContext, 'test-id') - expect(storageMock.getById).toBeCalledWith(TestRecord, 'test-id') + expect(storageMock.getById).toBeCalledWith(agentContext, TestRecord, 'test-id') expect(returnValue).toBeNull() }) it('should return null if the storage service throws an error that is not RecordNotFoundError', async () => { mockFunction(storageMock.getById).mockReturnValue(Promise.reject(new AriesFrameworkError('Not found'))) - expect(repository.findById('test-id')).rejects.toThrowError(AriesFrameworkError) - expect(storageMock.getById).toBeCalledWith(TestRecord, 'test-id') + expect(repository.findById(agentContext, 'test-id')).rejects.toThrowError(AriesFrameworkError) + expect(storageMock.getById).toBeCalledWith(agentContext, TestRecord, 'test-id') }) }) @@ -171,9 +178,9 @@ describe('Repository', () => { const record2 = getRecord({ id: 'test-id2' }) mockFunction(storageMock.getAll).mockReturnValue(Promise.resolve([record, record2])) - const returnValue = await repository.getAll() + const returnValue = await repository.getAll(agentContext) - expect(storageMock.getAll).toBeCalledWith(TestRecord) + expect(storageMock.getAll).toBeCalledWith(agentContext, TestRecord) expect(returnValue).toEqual(expect.arrayContaining([record, record2])) }) }) @@ -184,9 +191,9 @@ describe('Repository', () => { const record2 = getRecord({ id: 'test-id2' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record, record2])) - const returnValue = await repository.findByQuery({ something: 'interesting' }) + const returnValue = await repository.findByQuery(agentContext, { something: 'interesting' }) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) expect(returnValue).toEqual(expect.arrayContaining([record, record2])) }) }) @@ -196,18 +203,18 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record])) - const returnValue = await repository.findSingleByQuery({ something: 'interesting' }) + const returnValue = await repository.findSingleByQuery(agentContext, { something: 'interesting' }) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) expect(returnValue).toBe(record) }) it('should return null if the no records are returned by the storage service', async () => { mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([])) - const returnValue = await repository.findSingleByQuery({ something: 'interesting' }) + const returnValue = await repository.findSingleByQuery(agentContext, { something: 'interesting' }) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) expect(returnValue).toBeNull() }) @@ -216,8 +223,10 @@ describe('Repository', () => { const record2 = getRecord({ id: 'test-id2' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record, record2])) - expect(repository.findSingleByQuery({ something: 'interesting' })).rejects.toThrowError(RecordDuplicateError) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(repository.findSingleByQuery(agentContext, { something: 'interesting' })).rejects.toThrowError( + RecordDuplicateError + ) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) }) }) @@ -226,17 +235,19 @@ describe('Repository', () => { const record = getRecord({ id: 'test-id' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record])) - const returnValue = await repository.getSingleByQuery({ something: 'interesting' }) + const returnValue = await repository.getSingleByQuery(agentContext, { something: 'interesting' }) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) expect(returnValue).toBe(record) }) it('should throw RecordNotFoundError if no records are returned by the storage service', async () => { mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([])) - expect(repository.getSingleByQuery({ something: 'interesting' })).rejects.toThrowError(RecordNotFoundError) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(repository.getSingleByQuery(agentContext, { something: 'interesting' })).rejects.toThrowError( + RecordNotFoundError + ) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) }) it('should throw RecordDuplicateError if more than one record is returned by the storage service', async () => { @@ -244,8 +255,10 @@ describe('Repository', () => { const record2 = getRecord({ id: 'test-id2' }) mockFunction(storageMock.findByQuery).mockReturnValue(Promise.resolve([record, record2])) - expect(repository.getSingleByQuery({ something: 'interesting' })).rejects.toThrowError(RecordDuplicateError) - expect(storageMock.findByQuery).toBeCalledWith(TestRecord, { something: 'interesting' }) + expect(repository.getSingleByQuery(agentContext, { something: 'interesting' })).rejects.toThrowError( + RecordDuplicateError + ) + expect(storageMock.findByQuery).toBeCalledWith(agentContext, TestRecord, { something: 'interesting' }) }) }) }) diff --git a/packages/core/src/storage/didcomm/DidCommMessageRepository.ts b/packages/core/src/storage/didcomm/DidCommMessageRepository.ts index f964c62554..db2db2d04f 100644 --- a/packages/core/src/storage/didcomm/DidCommMessageRepository.ts +++ b/packages/core/src/storage/didcomm/DidCommMessageRepository.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../../agent' import type { AgentMessage, ConstructableAgentMessage } from '../../agent/AgentMessage' import type { JsonObject } from '../../types' import type { DidCommMessageRole } from './DidCommMessageRole' @@ -20,20 +21,23 @@ export class DidCommMessageRepository extends Repository { super(DidCommMessageRecord, storageService, eventEmitter) } - public async saveAgentMessage({ role, agentMessage, associatedRecordId }: SaveAgentMessageOptions) { + public async saveAgentMessage( + agentContext: AgentContext, + { role, agentMessage, associatedRecordId }: SaveAgentMessageOptions + ) { const didCommMessageRecord = new DidCommMessageRecord({ message: agentMessage.toJSON() as JsonObject, role, associatedRecordId, }) - await this.save(didCommMessageRecord) + await this.save(agentContext, didCommMessageRecord) } - public async saveOrUpdateAgentMessage(options: SaveAgentMessageOptions) { + public async saveOrUpdateAgentMessage(agentContext: AgentContext, options: SaveAgentMessageOptions) { const { messageName, protocolName, protocolMajorVersion } = parseMessageType(options.agentMessage.type) - const record = await this.findSingleByQuery({ + const record = await this.findSingleByQuery(agentContext, { associatedRecordId: options.associatedRecordId, messageName: messageName, protocolName: protocolName, @@ -43,18 +47,18 @@ export class DidCommMessageRepository extends Repository { if (record) { record.message = options.agentMessage.toJSON() as JsonObject record.role = options.role - await this.update(record) + await this.update(agentContext, record) return } - await this.saveAgentMessage(options) + await this.saveAgentMessage(agentContext, options) } - public async getAgentMessage({ - associatedRecordId, - messageClass, - }: GetAgentMessageOptions): Promise> { - const record = await this.getSingleByQuery({ + public async getAgentMessage( + agentContext: AgentContext, + { associatedRecordId, messageClass }: GetAgentMessageOptions + ): Promise> { + const record = await this.getSingleByQuery(agentContext, { associatedRecordId, messageName: messageClass.type.messageName, protocolName: messageClass.type.protocolName, @@ -63,11 +67,11 @@ export class DidCommMessageRepository extends Repository { return record.getMessageInstance(messageClass) } - public async findAgentMessage({ - associatedRecordId, - messageClass, - }: GetAgentMessageOptions): Promise | null> { - const record = await this.findSingleByQuery({ + public async findAgentMessage( + agentContext: AgentContext, + { associatedRecordId, messageClass }: GetAgentMessageOptions + ): Promise | null> { + const record = await this.findSingleByQuery(agentContext, { associatedRecordId, messageName: messageClass.type.messageName, protocolName: messageClass.type.protocolName, diff --git a/packages/core/src/storage/migration/StorageUpdateService.ts b/packages/core/src/storage/migration/StorageUpdateService.ts index 8d755dd133..2c8991d319 100644 --- a/packages/core/src/storage/migration/StorageUpdateService.ts +++ b/packages/core/src/storage/migration/StorageUpdateService.ts @@ -1,8 +1,9 @@ -import type { Logger } from '../../logger' +import type { AgentContext } from '../../agent' import type { VersionString } from '../../utils/version' -import { AgentConfig } from '../../agent/AgentConfig' -import { injectable } from '../../plugins' +import { InjectionSymbols } from '../../constants' +import { Logger } from '../../logger' +import { injectable, inject } from '../../plugins' import { StorageVersionRecord } from './repository/StorageVersionRecord' import { StorageVersionRepository } from './repository/StorageVersionRepository' @@ -15,34 +16,39 @@ export class StorageUpdateService { private logger: Logger private storageVersionRepository: StorageVersionRepository - public constructor(agentConfig: AgentConfig, storageVersionRepository: StorageVersionRepository) { + public constructor( + @inject(InjectionSymbols.Logger) logger: Logger, + storageVersionRepository: StorageVersionRepository + ) { + this.logger = logger this.storageVersionRepository = storageVersionRepository - this.logger = agentConfig.logger } - public async isUpToDate() { - const currentStorageVersion = await this.getCurrentStorageVersion() + public async isUpToDate(agentContext: AgentContext) { + const currentStorageVersion = await this.getCurrentStorageVersion(agentContext) const isUpToDate = CURRENT_FRAMEWORK_STORAGE_VERSION === currentStorageVersion return isUpToDate } - public async getCurrentStorageVersion(): Promise { - const storageVersionRecord = await this.getStorageVersionRecord() + public async getCurrentStorageVersion(agentContext: AgentContext): Promise { + const storageVersionRecord = await this.getStorageVersionRecord(agentContext) return storageVersionRecord.storageVersion } - public async setCurrentStorageVersion(storageVersion: VersionString) { + public async setCurrentStorageVersion(agentContext: AgentContext, storageVersion: VersionString) { this.logger.debug(`Setting current agent storage version to ${storageVersion}`) const storageVersionRecord = await this.storageVersionRepository.findById( + agentContext, StorageUpdateService.STORAGE_VERSION_RECORD_ID ) if (!storageVersionRecord) { this.logger.trace('Storage upgrade record does not exist yet. Creating.') await this.storageVersionRepository.save( + agentContext, new StorageVersionRecord({ id: StorageUpdateService.STORAGE_VERSION_RECORD_ID, storageVersion, @@ -51,7 +57,7 @@ export class StorageUpdateService { } else { this.logger.trace('Storage upgrade record already exists. Updating.') storageVersionRecord.storageVersion = storageVersion - await this.storageVersionRepository.update(storageVersionRecord) + await this.storageVersionRepository.update(agentContext, storageVersionRecord) } } @@ -61,8 +67,9 @@ export class StorageUpdateService { * The storageVersion will be set to the INITIAL_STORAGE_VERSION if it doesn't exist yet, * as we can assume the wallet was created before the udpate record existed */ - public async getStorageVersionRecord() { + public async getStorageVersionRecord(agentContext: AgentContext) { let storageVersionRecord = await this.storageVersionRepository.findById( + agentContext, StorageUpdateService.STORAGE_VERSION_RECORD_ID ) @@ -71,7 +78,7 @@ export class StorageUpdateService { id: StorageUpdateService.STORAGE_VERSION_RECORD_ID, storageVersion: INITIAL_STORAGE_VERSION, }) - await this.storageVersionRepository.save(storageVersionRecord) + await this.storageVersionRepository.save(agentContext, storageVersionRecord) } return storageVersionRecord diff --git a/packages/core/src/storage/migration/UpdateAssistant.ts b/packages/core/src/storage/migration/UpdateAssistant.ts index 084a72b584..cb07798529 100644 --- a/packages/core/src/storage/migration/UpdateAssistant.ts +++ b/packages/core/src/storage/migration/UpdateAssistant.ts @@ -1,6 +1,9 @@ import type { Agent } from '../../agent/Agent' +import type { FileSystem } from '../FileSystem' import type { UpdateConfig } from './updates' +import { AgentContext } from '../../agent' +import { InjectionSymbols } from '../../constants' import { AriesFrameworkError } from '../../error' import { isFirstVersionHigherThanSecond, parseVersionString } from '../../utils/version' import { WalletError } from '../../wallet/error/WalletError' @@ -13,12 +16,16 @@ export class UpdateAssistant { private agent: Agent private storageUpdateService: StorageUpdateService private updateConfig: UpdateConfig + private agentContext: AgentContext + private fileSystem: FileSystem public constructor(agent: Agent, updateConfig: UpdateConfig) { this.agent = agent this.updateConfig = updateConfig this.storageUpdateService = this.agent.dependencyManager.resolve(StorageUpdateService) + this.agentContext = this.agent.dependencyManager.resolve(AgentContext) + this.fileSystem = this.agent.dependencyManager.resolve(InjectionSymbols.FileSystem) } public async initialize() { @@ -39,11 +46,11 @@ export class UpdateAssistant { } public async isUpToDate() { - return this.storageUpdateService.isUpToDate() + return this.storageUpdateService.isUpToDate(this.agentContext) } public async getCurrentAgentStorageVersion() { - return this.storageUpdateService.getCurrentStorageVersion() + return this.storageUpdateService.getCurrentStorageVersion(this.agentContext) } public static get frameworkStorageVersion() { @@ -51,7 +58,9 @@ export class UpdateAssistant { } public async getNeededUpdates() { - const currentStorageVersion = parseVersionString(await this.storageUpdateService.getCurrentStorageVersion()) + const currentStorageVersion = parseVersionString( + await this.storageUpdateService.getCurrentStorageVersion(this.agentContext) + ) // Filter updates. We don't want older updates we already applied // or aren't needed because the wallet was created after the update script was made @@ -104,7 +113,7 @@ export class UpdateAssistant { await update.doUpdate(this.agent, this.updateConfig) // Update the framework version in storage - await this.storageUpdateService.setCurrentStorageVersion(update.toVersion) + await this.storageUpdateService.setCurrentStorageVersion(this.agentContext, update.toVersion) this.agent.config.logger.info( `Successfully updated agent storage from version ${update.fromVersion} to version ${update.toVersion}` ) @@ -132,8 +141,7 @@ export class UpdateAssistant { } private getBackupPath(backupIdentifier: string) { - const fileSystem = this.agent.config.fileSystem - return `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + return `${this.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` } private async createBackup(backupIdentifier: string) { diff --git a/packages/core/src/storage/migration/__tests__/0.1.test.ts b/packages/core/src/storage/migration/__tests__/0.1.test.ts index 71231222bb..87a6d030c0 100644 --- a/packages/core/src/storage/migration/__tests__/0.1.test.ts +++ b/packages/core/src/storage/migration/__tests__/0.1.test.ts @@ -1,3 +1,4 @@ +import type { FileSystem } from '../../../../src' import type { V0_1ToV0_2UpdateConfig } from '../updates/0.1-0.2' import { unlinkSync, readFileSync } from 'fs' @@ -48,6 +49,8 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { container ) + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) + const updateAssistant = new UpdateAssistant(agent, { v0_1ToV0_2: { mediationRoleUpdateStrategy, @@ -75,7 +78,7 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { expect(storageService.records).toMatchSnapshot(mediationRoleUpdateStrategy) // Need to remove backupFiles after each run so we don't get IOErrors - const backupPath = `${agent.config.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + const backupPath = `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` unlinkSync(backupPath) await agent.shutdown() @@ -107,6 +110,8 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { container ) + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) + const updateAssistant = new UpdateAssistant(agent, { v0_1ToV0_2: { mediationRoleUpdateStrategy: 'doNotChange', @@ -135,7 +140,7 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { expect(storageService.records).toMatchSnapshot() // Need to remove backupFiles after each run so we don't get IOErrors - const backupPath = `${agent.config.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + const backupPath = `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` unlinkSync(backupPath) await agent.shutdown() @@ -169,6 +174,8 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { container ) + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) + // We need to manually initialize the wallet as we're using the in memory wallet service // When we call agent.initialize() it will create the wallet and store the current framework // version in the in memory storage service. We need to manually set the records between initializing @@ -184,7 +191,7 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { expect(storageService.records).toMatchSnapshot() // Need to remove backupFiles after each run so we don't get IOErrors - const backupPath = `${agent.config.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + const backupPath = `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` unlinkSync(backupPath) await agent.shutdown() @@ -218,6 +225,8 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { container ) + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) + // We need to manually initialize the wallet as we're using the in memory wallet service // When we call agent.initialize() it will create the wallet and store the current framework // version in the in memory storage service. We need to manually set the records between initializing @@ -233,7 +242,7 @@ describe('UpdateAssistant | v0.1 - v0.2', () => { expect(storageService.records).toMatchSnapshot() // Need to remove backupFiles after each run so we don't get IOErrors - const backupPath = `${agent.config.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + const backupPath = `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` unlinkSync(backupPath) await agent.shutdown() diff --git a/packages/core/src/storage/migration/__tests__/backup.test.ts b/packages/core/src/storage/migration/__tests__/backup.test.ts index 02033a656d..557e521549 100644 --- a/packages/core/src/storage/migration/__tests__/backup.test.ts +++ b/packages/core/src/storage/migration/__tests__/backup.test.ts @@ -1,3 +1,4 @@ +import type { FileSystem } from '../../FileSystem' import type { StorageUpdateError } from '../error/StorageUpdateError' import { readFileSync, unlinkSync } from 'fs' @@ -5,6 +6,7 @@ import path from 'path' import { getBaseConfig } from '../../../../tests/helpers' import { Agent } from '../../../agent/Agent' +import { InjectionSymbols } from '../../../constants' import { AriesFrameworkError } from '../../../error' import { CredentialExchangeRecord, CredentialRepository } from '../../../modules/credentials' import { JsonTransformer } from '../../../utils' @@ -29,13 +31,14 @@ describe('UpdateAssistant | Backup', () => { beforeEach(async () => { agent = new Agent(config, agentDependencies) - backupPath = `${agent.config.fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` + const fileSystem = agent.dependencyManager.resolve(InjectionSymbols.FileSystem) + backupPath = `${fileSystem.basePath}/afj/migration/backup/${backupIdentifier}` // If tests fail it's possible the cleanup has been skipped. So remove before running tests - if (await agent.config.fileSystem.exists(backupPath)) { + if (await fileSystem.exists(backupPath)) { unlinkSync(backupPath) } - if (await agent.config.fileSystem.exists(`${backupPath}-error`)) { + if (await fileSystem.exists(`${backupPath}-error`)) { unlinkSync(`${backupPath}-error`) } @@ -69,14 +72,14 @@ describe('UpdateAssistant | Backup', () => { // Add 0.1 data and set version to 0.1 for (const credentialRecord of aliceCredentialRecords) { - await credentialRepository.save(credentialRecord) + await credentialRepository.save(agent.context, credentialRecord) } - await storageUpdateService.setCurrentStorageVersion('0.1') + await storageUpdateService.setCurrentStorageVersion(agent.context, '0.1') // Expect an update is needed expect(await updateAssistant.isUpToDate()).toBe(false) - const fileSystem = agent.config.fileSystem + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) // Backup should not exist before update expect(await fileSystem.exists(backupPath)).toBe(false) @@ -86,7 +89,9 @@ describe('UpdateAssistant | Backup', () => { // Backup should exist after update expect(await fileSystem.exists(backupPath)).toBe(true) - expect((await credentialRepository.getAll()).sort((a, b) => a.id.localeCompare(b.id))).toMatchSnapshot() + expect( + (await credentialRepository.getAll(agent.context)).sort((a, b) => a.id.localeCompare(b.id)) + ).toMatchSnapshot() }) it('should restore the backup if an error occurs during the update', async () => { @@ -105,9 +110,9 @@ describe('UpdateAssistant | Backup', () => { // Add 0.1 data and set version to 0.1 for (const credentialRecord of aliceCredentialRecords) { - await credentialRepository.save(credentialRecord) + await credentialRepository.save(agent.context, credentialRecord) } - await storageUpdateService.setCurrentStorageVersion('0.1') + await storageUpdateService.setCurrentStorageVersion(agent.context, '0.1') // Expect an update is needed expect(await updateAssistant.isUpToDate()).toBe(false) @@ -121,7 +126,7 @@ describe('UpdateAssistant | Backup', () => { }, ]) - const fileSystem = agent.config.fileSystem + const fileSystem = agent.injectionContainer.resolve(InjectionSymbols.FileSystem) // Backup should not exist before update expect(await fileSystem.exists(backupPath)).toBe(false) @@ -140,7 +145,7 @@ describe('UpdateAssistant | Backup', () => { expect(await fileSystem.exists(`${backupPath}-error`)).toBe(true) // Wallet should be same as when we started because of backup - expect((await credentialRepository.getAll()).sort((a, b) => a.id.localeCompare(b.id))).toEqual( + expect((await credentialRepository.getAll(agent.context)).sort((a, b) => a.id.localeCompare(b.id))).toEqual( aliceCredentialRecords.sort((a, b) => a.id.localeCompare(b.id)) ) }) diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/connection.test.ts b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/connection.test.ts index c68f5e14d1..520bf571aa 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/connection.test.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/connection.test.ts @@ -1,4 +1,4 @@ -import { getAgentConfig, mockFunction } from '../../../../../../tests/helpers' +import { getAgentConfig, getAgentContext, mockFunction } from '../../../../../../tests/helpers' import { Agent } from '../../../../../agent/Agent' import { ConnectionRecord, @@ -24,6 +24,7 @@ import legacyDidPeer4kgVt6CidfKgo1MoWMqsQX from './__fixtures__/legacyDidPeer4kg import legacyDidPeerR1xKJw17sUoXhejEpugMYJ from './__fixtures__/legacyDidPeerR1xKJw17sUoXhejEpugMYJ.json' const agentConfig = getAgentConfig('Migration ConnectionRecord 0.1-0.2') +const agentContext = getAgentContext() jest.mock('../../../../../modules/connections/repository/ConnectionRepository') const ConnectionRepositoryMock = ConnectionRepository as jest.Mock @@ -41,6 +42,7 @@ jest.mock('../../../../../agent/Agent', () => { return { Agent: jest.fn(() => ({ config: agentConfig, + context: agentContext, dependencyManager: { resolve: jest.fn((cls) => { if (cls === ConnectionRepository) { @@ -122,7 +124,7 @@ describe('0.1-0.2 | Connection', () => { expect(connectionRepository.getAll).toHaveBeenCalledTimes(1) expect(connectionRepository.update).toHaveBeenCalledTimes(records.length) - const [[updatedConnectionRecord]] = mockFunction(connectionRepository.update).mock.calls + const [[, updatedConnectionRecord]] = mockFunction(connectionRepository.update).mock.calls // Check first object is transformed correctly. // - removed invitation, theirDidDoc, didDoc @@ -210,7 +212,7 @@ describe('0.1-0.2 | Connection', () => { expect(didRepository.save).toHaveBeenCalledTimes(2) - const [[didRecord], [theirDidRecord]] = mockFunction(didRepository.save).mock.calls + const [[, didRecord], [, theirDidRecord]] = mockFunction(didRepository.save).mock.calls expect(didRecord.toJSON()).toMatchObject({ id: didPeerR1xKJw17sUoXhejEpugMYJ.id, @@ -314,15 +316,15 @@ describe('0.1-0.2 | Connection', () => { ) // Both did records already exist - mockFunction(didRepository.findById).mockImplementation((id) => + mockFunction(didRepository.findById).mockImplementation((_, id) => Promise.resolve(id === didPeerR1xKJw17sUoXhejEpugMYJ.id ? didRecord : theirDidRecord) ) await testModule.extractDidDocument(agent, connectionRecord) expect(didRepository.save).not.toHaveBeenCalled() - expect(didRepository.findById).toHaveBeenNthCalledWith(1, didPeerR1xKJw17sUoXhejEpugMYJ.id) - expect(didRepository.findById).toHaveBeenNthCalledWith(2, didPeer4kgVt6CidfKgo1MoWMqsQX.id) + expect(didRepository.findById).toHaveBeenNthCalledWith(1, agentContext, didPeerR1xKJw17sUoXhejEpugMYJ.id) + expect(didRepository.findById).toHaveBeenNthCalledWith(2, agentContext, didPeer4kgVt6CidfKgo1MoWMqsQX.id) expect(connectionRecord.toJSON()).toEqual({ _tags: {}, @@ -376,7 +378,7 @@ describe('0.1-0.2 | Connection', () => { await testModule.migrateToOobRecord(agent, connectionRecord) - const [[outOfBandRecord]] = mockFunction(outOfBandRepository.save).mock.calls + const [[, outOfBandRecord]] = mockFunction(outOfBandRepository.save).mock.calls expect(outOfBandRepository.save).toHaveBeenCalledTimes(1) expect(connectionRecord.outOfBandId).toEqual(outOfBandRecord.id) @@ -419,7 +421,7 @@ describe('0.1-0.2 | Connection', () => { await testModule.migrateToOobRecord(agent, connectionRecord) expect(outOfBandRepository.findByQuery).toHaveBeenCalledTimes(1) - expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, { + expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, agentContext, { invitationId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', recipientKeyFingerprints: ['z6MksYU4MHtfmNhNm1uGMvANr9j4CBv2FymjiJtRgA36bSVH'], }) @@ -469,7 +471,7 @@ describe('0.1-0.2 | Connection', () => { await testModule.migrateToOobRecord(agent, connectionRecord) expect(outOfBandRepository.findByQuery).toHaveBeenCalledTimes(1) - expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, { + expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, agentContext, { invitationId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', recipientKeyFingerprints: ['z6MksYU4MHtfmNhNm1uGMvANr9j4CBv2FymjiJtRgA36bSVH'], }) @@ -535,13 +537,13 @@ describe('0.1-0.2 | Connection', () => { await testModule.migrateToOobRecord(agent, connectionRecord) expect(outOfBandRepository.findByQuery).toHaveBeenCalledTimes(1) - expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, { + expect(outOfBandRepository.findByQuery).toHaveBeenNthCalledWith(1, agentContext, { invitationId: '04a2c382-999e-4de9-a1d2-9dec0b2fa5e4', recipientKeyFingerprints: ['z6MksYU4MHtfmNhNm1uGMvANr9j4CBv2FymjiJtRgA36bSVH'], }) expect(outOfBandRepository.save).not.toHaveBeenCalled() - expect(outOfBandRepository.update).toHaveBeenCalledWith(outOfBandRecord) - expect(connectionRepository.delete).toHaveBeenCalledWith(connectionRecord) + expect(outOfBandRepository.update).toHaveBeenCalledWith(agentContext, outOfBandRecord) + expect(connectionRepository.delete).toHaveBeenCalledWith(agentContext, connectionRecord) expect(outOfBandRecord.toJSON()).toEqual({ id: '3c52cc26-577d-4200-8753-05f1f425c342', diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/credential.test.ts b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/credential.test.ts index 00df7457da..c4c3434b77 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/credential.test.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/credential.test.ts @@ -1,7 +1,7 @@ import type { CredentialRecordBinding } from '../../../../../../src/modules/credentials' import { CredentialExchangeRecord, CredentialState } from '../../../../../../src/modules/credentials' -import { getAgentConfig, mockFunction } from '../../../../../../tests/helpers' +import { getAgentConfig, getAgentContext, mockFunction } from '../../../../../../tests/helpers' import { Agent } from '../../../../../agent/Agent' import { CredentialRepository } from '../../../../../modules/credentials/repository/CredentialRepository' import { JsonTransformer } from '../../../../../utils' @@ -10,6 +10,7 @@ import { DidCommMessageRepository } from '../../../../didcomm/DidCommMessageRepo import * as testModule from '../credential' const agentConfig = getAgentConfig('Migration CredentialRecord 0.1-0.2') +const agentContext = getAgentContext() jest.mock('../../../../../modules/credentials/repository/CredentialRepository') const CredentialRepositoryMock = CredentialRepository as jest.Mock @@ -23,6 +24,7 @@ jest.mock('../../../../../agent/Agent', () => { return { Agent: jest.fn(() => ({ config: agentConfig, + context: agentContext, dependencyManager: { resolve: jest.fn((token) => token === CredentialRepositoryMock ? credentialRepository : didCommMessageRepository @@ -75,7 +77,7 @@ describe('0.1-0.2 | Credential', () => { expect(credentialRepository.getAll).toHaveBeenCalledTimes(1) expect(credentialRepository.update).toHaveBeenCalledTimes(records.length) - const updatedRecord = mockFunction(credentialRepository.update).mock.calls[0][0] + const updatedRecord = mockFunction(credentialRepository.update).mock.calls[0][1] // Check first object is transformed correctly expect(updatedRecord.toJSON()).toMatchObject({ @@ -277,7 +279,7 @@ describe('0.1-0.2 | Credential', () => { await testModule.moveDidCommMessages(agent, credentialRecord) expect(didCommMessageRepository.save).toHaveBeenCalledTimes(4) - const [[proposalMessageRecord], [offerMessageRecord], [requestMessageRecord], [credentialMessageRecord]] = + const [[, proposalMessageRecord], [, offerMessageRecord], [, requestMessageRecord], [, credentialMessageRecord]] = mockFunction(didCommMessageRepository.save).mock.calls expect(proposalMessageRecord).toMatchObject({ @@ -340,7 +342,7 @@ describe('0.1-0.2 | Credential', () => { await testModule.moveDidCommMessages(agent, credentialRecord) expect(didCommMessageRepository.save).toHaveBeenCalledTimes(2) - const [[proposalMessageRecord], [offerMessageRecord]] = mockFunction(didCommMessageRepository.save).mock.calls + const [[, proposalMessageRecord], [, offerMessageRecord]] = mockFunction(didCommMessageRepository.save).mock.calls expect(proposalMessageRecord).toMatchObject({ role: DidCommMessageRole.Sender, @@ -388,7 +390,7 @@ describe('0.1-0.2 | Credential', () => { await testModule.moveDidCommMessages(agent, credentialRecord) expect(didCommMessageRepository.save).toHaveBeenCalledTimes(4) - const [[proposalMessageRecord], [offerMessageRecord], [requestMessageRecord], [credentialMessageRecord]] = + const [[, proposalMessageRecord], [, offerMessageRecord], [, requestMessageRecord], [, credentialMessageRecord]] = mockFunction(didCommMessageRepository.save).mock.calls expect(proposalMessageRecord).toMatchObject({ diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/mediation.test.ts b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/mediation.test.ts index 9f0ccd49f7..b5616578e2 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/mediation.test.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/__tests__/mediation.test.ts @@ -1,4 +1,4 @@ -import { getAgentConfig, mockFunction } from '../../../../../../tests/helpers' +import { getAgentConfig, getAgentContext, mockFunction } from '../../../../../../tests/helpers' import { Agent } from '../../../../../agent/Agent' import { MediationRole, MediationRecord } from '../../../../../modules/routing' import { MediationRepository } from '../../../../../modules/routing/repository/MediationRepository' @@ -6,6 +6,7 @@ import { JsonTransformer } from '../../../../../utils' import * as testModule from '../mediation' const agentConfig = getAgentConfig('Migration MediationRecord 0.1-0.2') +const agentContext = getAgentContext() jest.mock('../../../../../modules/routing/repository/MediationRepository') const MediationRepositoryMock = MediationRepository as jest.Mock @@ -15,6 +16,7 @@ jest.mock('../../../../../agent/Agent', () => { return { Agent: jest.fn(() => ({ config: agentConfig, + context: agentContext, dependencyManager: { resolve: jest.fn(() => mediationRepository), }, @@ -57,6 +59,7 @@ describe('0.1-0.2 | Mediation', () => { // Check second object is transformed correctly expect(mediationRepository.update).toHaveBeenNthCalledWith( 2, + agentContext, getMediationRecord({ role: MediationRole.Mediator, endpoint: 'secondEndpoint', diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/connection.ts b/packages/core/src/storage/migration/updates/0.1-0.2/connection.ts index 30d5058729..0c66521d5c 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/connection.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/connection.ts @@ -36,7 +36,7 @@ export async function migrateConnectionRecordToV0_2(agent: Agent) { const connectionRepository = agent.dependencyManager.resolve(ConnectionRepository) agent.config.logger.debug(`Fetching all connection records from storage`) - const allConnections = await connectionRepository.getAll() + const allConnections = await connectionRepository.getAll(agent.context) agent.config.logger.debug(`Found a total of ${allConnections.length} connection records to update.`) for (const connectionRecord of allConnections) { @@ -53,7 +53,7 @@ export async function migrateConnectionRecordToV0_2(agent: Agent) { // migrateToOobRecord will return the connection record if it has not been deleted. When using multiUseInvitation the connection record // will be removed after processing, in which case the update method will throw an error. if (_connectionRecord) { - await connectionRepository.update(connectionRecord) + await connectionRepository.update(agent.context, connectionRecord) } agent.config.logger.debug( @@ -161,7 +161,7 @@ export async function extractDidDocument(agent: Agent, connectionRecord: Connect const newDidDocument = convertToNewDidDocument(oldDidDoc) // Maybe we already have a record for this did because the migration failed previously - let didRecord = await didRepository.findById(newDidDocument.id) + let didRecord = await didRepository.findById(agent.context, newDidDocument.id) if (!didRecord) { agent.config.logger.debug(`Creating did record for did ${newDidDocument.id}`) @@ -180,7 +180,7 @@ export async function extractDidDocument(agent: Agent, connectionRecord: Connect didDocumentString: JsonEncoder.toString(oldDidDocJson), }) - await didRepository.save(didRecord) + await didRepository.save(agent.context, didRecord) agent.config.logger.debug(`Successfully saved did record for did ${newDidDocument.id}`) } else { @@ -207,7 +207,7 @@ export async function extractDidDocument(agent: Agent, connectionRecord: Connect const newTheirDidDocument = convertToNewDidDocument(oldTheirDidDoc) // Maybe we already have a record for this did because the migration failed previously - let didRecord = await didRepository.findById(newTheirDidDocument.id) + let didRecord = await didRepository.findById(agent.context, newTheirDidDocument.id) if (!didRecord) { agent.config.logger.debug(`Creating did record for theirDid ${newTheirDidDocument.id}`) @@ -227,7 +227,7 @@ export async function extractDidDocument(agent: Agent, connectionRecord: Connect didDocumentString: JsonEncoder.toString(oldTheirDidDocJson), }) - await didRepository.save(didRecord) + await didRepository.save(agent.context, didRecord) agent.config.logger.debug(`Successfully saved did record for theirDid ${newTheirDidDocument.id}`) } else { @@ -310,7 +310,7 @@ export async function migrateToOobRecord( const outOfBandInvitation = convertToNewInvitation(oldInvitation) // If both the recipientKeys and the @id match we assume the connection was created using the same invitation. - const oobRecords = await oobRepository.findByQuery({ + const oobRecords = await oobRepository.findByQuery(agent.context, { invitationId: oldInvitation.id, recipientKeyFingerprints: outOfBandInvitation.getRecipientKeys().map((key) => key.fingerprint), }) @@ -337,7 +337,7 @@ export async function migrateToOobRecord( createdAt: connectionRecord.createdAt, }) - await oobRepository.save(oobRecord) + await oobRepository.save(agent.context, oobRecord) agent.config.logger.debug(`Successfully saved out of band record for invitation @id ${oldInvitation.id}`) } else { agent.config.logger.debug( @@ -353,8 +353,8 @@ export async function migrateToOobRecord( oobRecord.mediatorId = connectionRecord.mediatorId oobRecord.autoAcceptConnection = connectionRecord.autoAcceptConnection - await oobRepository.update(oobRecord) - await connectionRepository.delete(connectionRecord) + await oobRepository.update(agent.context, oobRecord) + await connectionRepository.delete(agent.context, connectionRecord) agent.config.logger.debug( `Set reusable=true for out of band record with invitation @id ${oobRecord.outOfBandInvitation.id}.` ) diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/credential.ts b/packages/core/src/storage/migration/updates/0.1-0.2/credential.ts index 548f9a6b15..2f59d915ed 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/credential.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/credential.ts @@ -21,7 +21,7 @@ export async function migrateCredentialRecordToV0_2(agent: Agent) { const credentialRepository = agent.dependencyManager.resolve(CredentialRepository) agent.config.logger.debug(`Fetching all credential records from storage`) - const allCredentials = await credentialRepository.getAll() + const allCredentials = await credentialRepository.getAll(agent.context) agent.config.logger.debug(`Found a total of ${allCredentials.length} credential records to update.`) for (const credentialRecord of allCredentials) { @@ -31,7 +31,7 @@ export async function migrateCredentialRecordToV0_2(agent: Agent) { await migrateInternalCredentialRecordProperties(agent, credentialRecord) await moveDidCommMessages(agent, credentialRecord) - await credentialRepository.update(credentialRecord) + await credentialRepository.update(agent.context, credentialRecord) agent.config.logger.debug( `Successfully migrated credential record with id ${credentialRecord.id} to storage version 0.2` @@ -232,7 +232,7 @@ export async function moveDidCommMessages(agent: Agent, credentialRecord: Creden associatedRecordId: credentialRecord.id, message, }) - await didCommMessageRepository.save(didCommMessageRecord) + await didCommMessageRepository.save(agent.context, didCommMessageRecord) agent.config.logger.debug( `Successfully moved ${messageKey} from credential record with id ${credentialRecord.id} to DIDCommMessageRecord` diff --git a/packages/core/src/storage/migration/updates/0.1-0.2/mediation.ts b/packages/core/src/storage/migration/updates/0.1-0.2/mediation.ts index 62f9da238d..e6d3447a1c 100644 --- a/packages/core/src/storage/migration/updates/0.1-0.2/mediation.ts +++ b/packages/core/src/storage/migration/updates/0.1-0.2/mediation.ts @@ -1,6 +1,6 @@ -import type { V0_1ToV0_2UpdateConfig } from '.' import type { Agent } from '../../../../agent/Agent' import type { MediationRecord } from '../../../../modules/routing' +import type { V0_1ToV0_2UpdateConfig } from './index' import { MediationRepository, MediationRole } from '../../../../modules/routing' @@ -17,7 +17,7 @@ export async function migrateMediationRecordToV0_2(agent: Agent, upgradeConfig: const mediationRepository = agent.dependencyManager.resolve(MediationRepository) agent.config.logger.debug(`Fetching all mediation records from storage`) - const allMediationRecords = await mediationRepository.getAll() + const allMediationRecords = await mediationRepository.getAll(agent.context) agent.config.logger.debug(`Found a total of ${allMediationRecords.length} mediation records to update.`) for (const mediationRecord of allMediationRecords) { @@ -25,7 +25,7 @@ export async function migrateMediationRecordToV0_2(agent: Agent, upgradeConfig: await updateMediationRole(agent, mediationRecord, upgradeConfig) - await mediationRepository.update(mediationRecord) + await mediationRepository.update(agent.context, mediationRecord) agent.config.logger.debug( `Successfully migrated mediation record with id ${mediationRecord.id} to storage version 0.2` diff --git a/packages/core/src/transport/HttpOutboundTransport.ts b/packages/core/src/transport/HttpOutboundTransport.ts index 5cdf4e1dbb..3f331d9ded 100644 --- a/packages/core/src/transport/HttpOutboundTransport.ts +++ b/packages/core/src/transport/HttpOutboundTransport.ts @@ -6,23 +6,20 @@ import type fetch from 'node-fetch' import { AbortController } from 'abort-controller' -import { AgentConfig } from '../agent/AgentConfig' import { AriesFrameworkError } from '../error/AriesFrameworkError' import { isValidJweStructure, JsonEncoder } from '../utils' export class HttpOutboundTransport implements OutboundTransport { private agent!: Agent private logger!: Logger - private agentConfig!: AgentConfig private fetch!: typeof fetch public supportedSchemes = ['http', 'https'] public async start(agent: Agent): Promise { this.agent = agent - this.agentConfig = agent.dependencyManager.resolve(AgentConfig) - this.logger = this.agentConfig.logger - this.fetch = this.agentConfig.agentDependencies.fetch + this.logger = this.agent.config.logger + this.fetch = this.agent.config.agentDependencies.fetch this.logger.debug('Starting HTTP outbound transport') } @@ -53,7 +50,7 @@ export class HttpOutboundTransport implements OutboundTransport { response = await this.fetch(endpoint, { method: 'POST', body: JSON.stringify(payload), - headers: { 'Content-Type': this.agentConfig.didCommMimeType }, + headers: { 'Content-Type': this.agent.config.didCommMimeType }, signal: abortController.signal, }) clearTimeout(id) @@ -96,7 +93,7 @@ export class HttpOutboundTransport implements OutboundTransport { error, message: error.message, body: payload, - didCommMimeType: this.agentConfig.didCommMimeType, + didCommMimeType: this.agent.config.didCommMimeType, }) throw new AriesFrameworkError(`Error sending message to ${endpoint}: ${error.message}`, { cause: error }) } diff --git a/packages/core/src/transport/WsOutboundTransport.ts b/packages/core/src/transport/WsOutboundTransport.ts index 9c4a6edbf8..cd5f9ffb6c 100644 --- a/packages/core/src/transport/WsOutboundTransport.ts +++ b/packages/core/src/transport/WsOutboundTransport.ts @@ -6,8 +6,6 @@ import type { OutboundTransport } from './OutboundTransport' import type { OutboundWebSocketClosedEvent } from './TransportEventTypes' import type WebSocket from 'ws' -import { AgentConfig } from '../agent/AgentConfig' -import { EventEmitter } from '../agent/EventEmitter' import { AgentEventTypes } from '../agent/Events' import { AriesFrameworkError } from '../error/AriesFrameworkError' import { isValidJweStructure, JsonEncoder } from '../utils' @@ -19,18 +17,16 @@ export class WsOutboundTransport implements OutboundTransport { private transportTable: Map = new Map() private agent!: Agent private logger!: Logger - private eventEmitter!: EventEmitter private WebSocketClass!: typeof WebSocket public supportedSchemes = ['ws', 'wss'] public async start(agent: Agent): Promise { this.agent = agent - const agentConfig = agent.dependencyManager.resolve(AgentConfig) - this.logger = agentConfig.logger - this.eventEmitter = agent.dependencyManager.resolve(EventEmitter) + this.logger = agent.config.logger + this.logger.debug('Starting WS outbound transport') - this.WebSocketClass = agentConfig.agentDependencies.WebSocketClass + this.WebSocketClass = agent.config.agentDependencies.WebSocketClass } public async stop() { @@ -111,7 +107,8 @@ export class WsOutboundTransport implements OutboundTransport { ) } this.logger.debug('Payload received from mediator:', payload) - this.eventEmitter.emit({ + + this.agent.events.emit(this.agent.context, { type: AgentEventTypes.AgentMessageReceived, payload: { message: payload, @@ -153,7 +150,7 @@ export class WsOutboundTransport implements OutboundTransport { socket.removeEventListener('message', this.handleMessageEvent) this.transportTable.delete(socketId) - this.eventEmitter.emit({ + this.agent.events.emit(this.agent.context, { type: TransportEventTypes.OutboundWebSocketClosedEvent, payload: { socketId, diff --git a/packages/core/src/wallet/IndyWallet.test.ts b/packages/core/src/wallet/IndyWallet.test.ts index a1147a6260..a59cd82f60 100644 --- a/packages/core/src/wallet/IndyWallet.test.ts +++ b/packages/core/src/wallet/IndyWallet.test.ts @@ -1,33 +1,41 @@ +import type { WalletConfig } from '../types' + import { BBS_SIGNATURE_LENGTH } from '@mattrglobal/bbs-signatures' import { SIGNATURE_LENGTH as ED25519_SIGNATURE_LENGTH } from '@stablelib/ed25519' -import { getBaseConfig } from '../../tests/helpers' -import { Agent } from '../agent/Agent' +import { agentDependencies } from '../../tests/helpers' +import testLogger from '../../tests/logger' import { KeyType } from '../crypto' +import { KeyDerivationMethod } from '../types' import { TypedArrayEncoder } from '../utils' import { IndyWallet } from './IndyWallet' import { WalletError } from './error' +// use raw key derivation method to speed up wallet creating / opening / closing between tests +const walletConfig: WalletConfig = { + id: 'Wallet: IndyWalletTest', + // generated using indy.generateWalletKey + key: 'CwNJroKHTSSj3XvE7ZAnuKiTn2C4QkFvxEqfm5rzhNrb', + keyDerivationMethod: KeyDerivationMethod.Raw, +} + describe('IndyWallet', () => { let indyWallet: IndyWallet - let agent: Agent const seed = 'sample-seed' const message = TypedArrayEncoder.fromString('sample-message') beforeEach(async () => { - const { config, agentDependencies } = getBaseConfig('IndyWallettest') - agent = new Agent(config, agentDependencies) - indyWallet = agent.injectionContainer.resolve(IndyWallet) - await agent.initialize() + indyWallet = new IndyWallet(agentDependencies, testLogger) + await indyWallet.createAndOpen(walletConfig) }) afterEach(async () => { - await agent.shutdown() - await agent.wallet.delete() + await indyWallet.delete() }) - test('Get the public DID', () => { + test('Get the public DID', async () => { + await indyWallet.initPublicDid({ seed: '000000000000000000000000Trustee9' }) expect(indyWallet.publicDid).toMatchObject({ did: expect.any(String), verkey: expect.any(String), @@ -35,7 +43,7 @@ describe('IndyWallet', () => { }) test('Get the Master Secret', () => { - expect(indyWallet.masterSecretId).toEqual('Wallet: IndyWallettest') + expect(indyWallet.masterSecretId).toEqual('Wallet: IndyWalletTest') }) test('Get the wallet handle', () => { diff --git a/packages/core/src/wallet/IndyWallet.ts b/packages/core/src/wallet/IndyWallet.ts index c354d5ef5b..e99143dc7b 100644 --- a/packages/core/src/wallet/IndyWallet.ts +++ b/packages/core/src/wallet/IndyWallet.ts @@ -1,5 +1,4 @@ import type { BlsKeyPair } from '../crypto/BbsService' -import type { Logger } from '../logger' import type { EncryptedMessage, KeyDerivationMethod, @@ -19,12 +18,14 @@ import type { } from './Wallet' import type { default as Indy, WalletStorageConfig } from 'indy-sdk' -import { AgentConfig } from '../agent/AgentConfig' +import { AgentDependencies } from '../agent/AgentDependencies' +import { InjectionSymbols } from '../constants' import { BbsService } from '../crypto/BbsService' import { Key } from '../crypto/Key' import { KeyType } from '../crypto/KeyType' import { AriesFrameworkError, IndySdkError, RecordDuplicateError, RecordNotFoundError } from '../error' -import { injectable } from '../plugins' +import { Logger } from '../logger' +import { inject, injectable } from '../plugins' import { JsonEncoder, TypedArrayEncoder } from '../utils' import { isError } from '../utils/error' import { isIndyError } from '../utils/indyError' @@ -41,9 +42,12 @@ export class IndyWallet implements Wallet { private publicDidInfo: DidInfo | undefined private indy: typeof Indy - public constructor(agentConfig: AgentConfig) { - this.logger = agentConfig.logger - this.indy = agentConfig.agentDependencies.indy + public constructor( + @inject(InjectionSymbols.AgentDependencies) agentDependencies: AgentDependencies, + @inject(InjectionSymbols.Logger) logger: Logger + ) { + this.logger = logger + this.indy = agentDependencies.indy } public get isProvisioned() { diff --git a/packages/core/src/wallet/WalletModule.ts b/packages/core/src/wallet/WalletModule.ts index 89c301b26a..7d4bd0f739 100644 --- a/packages/core/src/wallet/WalletModule.ts +++ b/packages/core/src/wallet/WalletModule.ts @@ -1,33 +1,35 @@ -import type { Logger } from '../logger' import type { DependencyManager } from '../plugins' import type { WalletConfig, WalletConfigRekey, WalletExportImportConfig } from '../types' +import type { Wallet } from './Wallet' -import { AgentConfig } from '../agent/AgentConfig' +import { AgentContext } from '../agent' import { InjectionSymbols } from '../constants' +import { Logger } from '../logger' import { inject, injectable, module } from '../plugins' import { StorageUpdateService } from '../storage' import { CURRENT_FRAMEWORK_STORAGE_VERSION } from '../storage/migration/updates' -import { Wallet } from './Wallet' import { WalletError } from './error/WalletError' import { WalletNotFoundError } from './error/WalletNotFoundError' @module() @injectable() export class WalletModule { + private agentContext: AgentContext private wallet: Wallet private storageUpdateService: StorageUpdateService private logger: Logger private _walletConfig?: WalletConfig public constructor( - @inject(InjectionSymbols.Wallet) wallet: Wallet, storageUpdateService: StorageUpdateService, - agentConfig: AgentConfig + agentContext: AgentContext, + @inject(InjectionSymbols.Logger) logger: Logger ) { - this.wallet = wallet this.storageUpdateService = storageUpdateService - this.logger = agentConfig.logger + this.logger = logger + this.wallet = agentContext.wallet + this.agentContext = agentContext } public get isInitialized() { @@ -73,7 +75,7 @@ export class WalletModule { this._walletConfig = walletConfig // Store the storage version in the wallet - await this.storageUpdateService.setCurrentStorageVersion(CURRENT_FRAMEWORK_STORAGE_VERSION) + await this.storageUpdateService.setCurrentStorageVersion(this.agentContext, CURRENT_FRAMEWORK_STORAGE_VERSION) } public async create(walletConfig: WalletConfig): Promise { diff --git a/packages/core/src/wallet/util/assertIndyWallet.ts b/packages/core/src/wallet/util/assertIndyWallet.ts new file mode 100644 index 0000000000..6c6ac4a4eb --- /dev/null +++ b/packages/core/src/wallet/util/assertIndyWallet.ts @@ -0,0 +1,10 @@ +import type { Wallet } from '../Wallet' + +import { AriesFrameworkError } from '../../error' +import { IndyWallet } from '../IndyWallet' + +export function assertIndyWallet(wallet: Wallet): asserts wallet is IndyWallet { + if (!(wallet instanceof IndyWallet)) { + throw new AriesFrameworkError(`Expected wallet to be instance of IndyWallet, found ${wallet}`) + } +} diff --git a/packages/core/tests/connectionless-proofs.test.ts b/packages/core/tests/connectionless-proofs.test.ts index b38a30c36b..ab49b8c838 100644 --- a/packages/core/tests/connectionless-proofs.test.ts +++ b/packages/core/tests/connectionless-proofs.test.ts @@ -5,6 +5,7 @@ import { Subject, ReplaySubject } from 'rxjs' import { SubjectInboundTransport } from '../../../tests/transport/SubjectInboundTransport' import { SubjectOutboundTransport } from '../../../tests/transport/SubjectOutboundTransport' +import { InjectionSymbols } from '../src' import { Agent } from '../src/agent/Agent' import { Attachment, AttachmentData } from '../src/decorators/attachment/Attachment' import { HandshakeProtocol } from '../src/modules/connections' @@ -345,8 +346,10 @@ describe('Present Proof', () => { // We want to stop the mediator polling before the agent is shutdown. // FIXME: add a way to stop mediator polling from the public api, and make sure this is // being handled in the agent shutdown so we don't get any errors with wallets being closed. - faberAgent.config.stop$.next(true) - aliceAgent.config.stop$.next(true) + const faberStop$ = faberAgent.injectionContainer.resolve>(InjectionSymbols.Stop$) + const aliceStop$ = aliceAgent.injectionContainer.resolve>(InjectionSymbols.Stop$) + faberStop$.next(true) + aliceStop$.next(true) await sleep(2000) }) }) diff --git a/packages/core/tests/helpers.ts b/packages/core/tests/helpers.ts index 292b2f426d..cbd9742e9d 100644 --- a/packages/core/tests/helpers.ts +++ b/packages/core/tests/helpers.ts @@ -12,6 +12,7 @@ import type { ProofPredicateInfo, ProofStateChangedEvent, SchemaTemplate, + Wallet, } from '../src' import type { AcceptOfferOptions } from '../src/modules/credentials' import type { IndyOfferCredentialFormat } from '../src/modules/credentials/formats/indy/IndyCredentialFormat' @@ -28,14 +29,17 @@ import { agentDependencies, WalletScheme } from '../../node/src' import { Agent, AgentConfig, + AgentContext, AriesFrameworkError, BasicMessageEventTypes, ConnectionRecord, CredentialEventTypes, CredentialState, + DependencyManager, DidExchangeRole, DidExchangeState, HandshakeProtocol, + InjectionSymbols, LogLevel, PredicateType, PresentationPreview, @@ -134,6 +138,20 @@ export function getAgentConfig(name: string, extraConfig: Partial = return new AgentConfig(config, agentDependencies) } +export function getAgentContext({ + dependencyManager = new DependencyManager(), + wallet, + agentConfig, +}: { + dependencyManager?: DependencyManager + wallet?: Wallet + agentConfig?: AgentConfig +} = {}) { + if (wallet) dependencyManager.registerInstance(InjectionSymbols.Wallet, wallet) + if (agentConfig) dependencyManager.registerInstance(AgentConfig, agentConfig) + return new AgentContext({ dependencyManager }) +} + export async function waitForProofRecord( agent: Agent, options: { diff --git a/packages/core/tests/ledger.test.ts b/packages/core/tests/ledger.test.ts index 992feeab30..ce0802353d 100644 --- a/packages/core/tests/ledger.test.ts +++ b/packages/core/tests/ledger.test.ts @@ -4,7 +4,6 @@ import * as indy from 'indy-sdk' import { Agent } from '../src/agent/Agent' import { DID_IDENTIFIER_REGEX, isAbbreviatedVerkey, isFullVerkey, VERKEY_REGEX } from '../src/utils/did' import { sleep } from '../src/utils/sleep' -import { IndyWallet } from '../src/wallet/IndyWallet' import { genesisPath, getBaseConfig } from './helpers' import testLogger from './logger' @@ -65,7 +64,7 @@ describe('ledger', () => { throw new Error('Agent does not have public did.') } - const faberWallet = faberAgent.dependencyManager.resolve(IndyWallet) + const faberWallet = faberAgent.context.wallet const didInfo = await faberWallet.createDid() const result = await faberAgent.ledger.registerPublicDid(didInfo.did, didInfo.verkey, 'alias', 'TRUST_ANCHOR') diff --git a/packages/core/tests/mocks/MockWallet.ts b/packages/core/tests/mocks/MockWallet.ts new file mode 100644 index 0000000000..83132e1303 --- /dev/null +++ b/packages/core/tests/mocks/MockWallet.ts @@ -0,0 +1,74 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import type { Wallet } from '../../src' +import type { Key } from '../../src/crypto' +import type { EncryptedMessage, WalletConfig, WalletExportImportConfig, WalletConfigRekey } from '../../src/types' +import type { Buffer } from '../../src/utils/buffer' +import type { + DidInfo, + UnpackedMessageContext, + DidConfig, + CreateKeyOptions, + SignOptions, + VerifyOptions, +} from '../../src/wallet' + +export class MockWallet implements Wallet { + public publicDid = undefined + public isInitialized = true + public isProvisioned = true + + public create(walletConfig: WalletConfig): Promise { + throw new Error('Method not implemented.') + } + public createAndOpen(walletConfig: WalletConfig): Promise { + throw new Error('Method not implemented.') + } + public open(walletConfig: WalletConfig): Promise { + throw new Error('Method not implemented.') + } + public rotateKey(walletConfig: WalletConfigRekey): Promise { + throw new Error('Method not implemented.') + } + public close(): Promise { + throw new Error('Method not implemented.') + } + public delete(): Promise { + throw new Error('Method not implemented.') + } + public export(exportConfig: WalletExportImportConfig): Promise { + throw new Error('Method not implemented.') + } + public import(walletConfig: WalletConfig, importConfig: WalletExportImportConfig): Promise { + throw new Error('Method not implemented.') + } + public initPublicDid(didConfig: DidConfig): Promise { + throw new Error('Method not implemented.') + } + public createDid(didConfig?: DidConfig): Promise { + throw new Error('Method not implemented.') + } + public pack( + payload: Record, + recipientKeys: string[], + senderVerkey?: string + ): Promise { + throw new Error('Method not implemented.') + } + public unpack(encryptedMessage: EncryptedMessage): Promise { + throw new Error('Method not implemented.') + } + public sign(options: SignOptions): Promise { + throw new Error('Method not implemented.') + } + public verify(options: VerifyOptions): Promise { + throw new Error('Method not implemented.') + } + + public createKey(options: CreateKeyOptions): Promise { + throw new Error('Method not implemented.') + } + + public generateNonce(): Promise { + throw new Error('Method not implemented.') + } +} diff --git a/packages/core/tests/mocks/index.ts b/packages/core/tests/mocks/index.ts new file mode 100644 index 0000000000..3dbf2226a2 --- /dev/null +++ b/packages/core/tests/mocks/index.ts @@ -0,0 +1 @@ +export * from './MockWallet' diff --git a/packages/core/tests/multi-protocol-version.test.ts b/packages/core/tests/multi-protocol-version.test.ts index 413ec53db7..0a2f86aa93 100644 --- a/packages/core/tests/multi-protocol-version.test.ts +++ b/packages/core/tests/multi-protocol-version.test.ts @@ -84,7 +84,7 @@ describe('multi version protocols', () => { ) ) - await bobMessageSender.sendMessage(createOutboundMessage(bobConnection, new TestMessageV11())) + await bobMessageSender.sendMessage(bobAgent.context, createOutboundMessage(bobConnection, new TestMessageV11())) // Wait for the agent message processed event to be called await agentMessageV11ProcessedPromise @@ -99,7 +99,7 @@ describe('multi version protocols', () => { ) ) - await bobMessageSender.sendMessage(createOutboundMessage(bobConnection, new TestMessageV15())) + await bobMessageSender.sendMessage(bobAgent.context, createOutboundMessage(bobConnection, new TestMessageV15())) await agentMessageV15ProcessedPromise expect(mockHandle).toHaveBeenCalledTimes(2) diff --git a/packages/core/tests/oob.test.ts b/packages/core/tests/oob.test.ts index cf9b3233ad..03e33462ea 100644 --- a/packages/core/tests/oob.test.ts +++ b/packages/core/tests/oob.test.ts @@ -711,7 +711,7 @@ describe('out of band', () => { message, }) - expect(saveOrUpdateSpy).toHaveBeenCalledWith({ + expect(saveOrUpdateSpy).toHaveBeenCalledWith(expect.anything(), { agentMessage: message, associatedRecordId: credentialRecord.id, role: DidCommMessageRole.Sender, diff --git a/packages/core/tests/wallet.test.ts b/packages/core/tests/wallet.test.ts index 9d06608a4a..2d6d718d0c 100644 --- a/packages/core/tests/wallet.test.ts +++ b/packages/core/tests/wallet.test.ts @@ -124,7 +124,7 @@ describe('wallet', () => { }) // Save in wallet - await bobBasicMessageRepository.save(basicMessageRecord) + await bobBasicMessageRepository.save(bobAgent.context, basicMessageRecord) if (!bobAgent.config.walletConfig) { throw new Error('No wallet config on bobAgent') @@ -142,7 +142,7 @@ describe('wallet', () => { // This should create a new wallet // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await bobAgent.wallet.initialize(bobConfig.config.walletConfig!) - expect(await bobBasicMessageRepository.findById(basicMessageRecord.id)).toBeNull() + expect(await bobBasicMessageRepository.findById(bobAgent.context, basicMessageRecord.id)).toBeNull() await bobAgent.wallet.delete() // Import backup with different wallet id and initialize @@ -150,7 +150,9 @@ describe('wallet', () => { await bobAgent.wallet.initialize({ id: backupWalletName, key: backupWalletName }) // Expect same basic message record to exist in new wallet - expect(await bobBasicMessageRepository.getById(basicMessageRecord.id)).toMatchObject(basicMessageRecord) + expect(await bobBasicMessageRepository.getById(bobAgent.context, basicMessageRecord.id)).toMatchObject( + basicMessageRecord + ) }) test('changing wallet key', async () => { diff --git a/packages/node/src/transport/HttpInboundTransport.ts b/packages/node/src/transport/HttpInboundTransport.ts index 2835ce2a84..59144ee392 100644 --- a/packages/node/src/transport/HttpInboundTransport.ts +++ b/packages/node/src/transport/HttpInboundTransport.ts @@ -2,7 +2,7 @@ import type { InboundTransport, Agent, TransportSession, EncryptedMessage } from import type { Express, Request, Response } from 'express' import type { Server } from 'http' -import { DidCommMimeType, AriesFrameworkError, AgentConfig, TransportService, utils } from '@aries-framework/core' +import { DidCommMimeType, AriesFrameworkError, TransportService, utils } from '@aries-framework/core' import express, { text } from 'express' export class HttpInboundTransport implements InboundTransport { @@ -30,9 +30,8 @@ export class HttpInboundTransport implements InboundTransport { public async start(agent: Agent) { const transportService = agent.dependencyManager.resolve(TransportService) - const config = agent.dependencyManager.resolve(AgentConfig) - config.logger.debug(`Starting HTTP inbound transport`, { + agent.config.logger.debug(`Starting HTTP inbound transport`, { port: this.port, }) @@ -48,7 +47,7 @@ export class HttpInboundTransport implements InboundTransport { res.status(200).end() } } catch (error) { - config.logger.error(`Error processing inbound message: ${error.message}`, error) + agent.config.logger.error(`Error processing inbound message: ${error.message}`, error) if (!res.headersSent) { res.status(500).send('Error processing message') diff --git a/packages/node/src/transport/WsInboundTransport.ts b/packages/node/src/transport/WsInboundTransport.ts index a81f48aa72..c87fc24a39 100644 --- a/packages/node/src/transport/WsInboundTransport.ts +++ b/packages/node/src/transport/WsInboundTransport.ts @@ -1,6 +1,6 @@ import type { Agent, InboundTransport, Logger, TransportSession, EncryptedMessage } from '@aries-framework/core' -import { AriesFrameworkError, AgentConfig, TransportService, utils } from '@aries-framework/core' +import { AriesFrameworkError, TransportService, utils } from '@aries-framework/core' import WebSocket, { Server } from 'ws' export class WsInboundTransport implements InboundTransport { @@ -16,11 +16,10 @@ export class WsInboundTransport implements InboundTransport { public async start(agent: Agent) { const transportService = agent.dependencyManager.resolve(TransportService) - const config = agent.dependencyManager.resolve(AgentConfig) - this.logger = config.logger + this.logger = agent.config.logger - const wsEndpoint = config.endpoints.find((e) => e.startsWith('ws')) + const wsEndpoint = agent.config.endpoints.find((e) => e.startsWith('ws')) this.logger.debug(`Starting WS inbound transport`, { endpoint: wsEndpoint, }) diff --git a/samples/extension-module/dummy/DummyApi.ts b/samples/extension-module/dummy/DummyApi.ts index b15735148a..82ac700f7b 100644 --- a/samples/extension-module/dummy/DummyApi.ts +++ b/samples/extension-module/dummy/DummyApi.ts @@ -1,6 +1,6 @@ import type { DummyRecord } from './repository/DummyRecord' -import { injectable, ConnectionService, Dispatcher, MessageSender } from '@aries-framework/core' +import { AgentContext, ConnectionService, Dispatcher, injectable, MessageSender } from '@aries-framework/core' import { DummyRequestHandler, DummyResponseHandler } from './handlers' import { DummyState } from './repository' @@ -11,16 +11,20 @@ export class DummyApi { private messageSender: MessageSender private dummyService: DummyService private connectionService: ConnectionService + private agentContext: AgentContext public constructor( dispatcher: Dispatcher, messageSender: MessageSender, dummyService: DummyService, - connectionService: ConnectionService + connectionService: ConnectionService, + agentContext: AgentContext ) { this.messageSender = messageSender this.dummyService = dummyService this.connectionService = connectionService + this.agentContext = agentContext + this.registerHandlers(dispatcher) } @@ -31,12 +35,12 @@ export class DummyApi { * @returns created Dummy Record */ public async request(connectionId: string) { - const connection = await this.connectionService.getById(connectionId) - const { record, message: payload } = await this.dummyService.createRequest(connection) + const connection = await this.connectionService.getById(this.agentContext, connectionId) + const { record, message: payload } = await this.dummyService.createRequest(this.agentContext, connection) - await this.messageSender.sendMessage({ connection, payload }) + await this.messageSender.sendMessage(this.agentContext, { connection, payload }) - await this.dummyService.updateState(record, DummyState.RequestSent) + await this.dummyService.updateState(this.agentContext, record, DummyState.RequestSent) return record } @@ -48,14 +52,14 @@ export class DummyApi { * @returns Updated dummy record */ public async respond(dummyId: string) { - const record = await this.dummyService.getById(dummyId) - const connection = await this.connectionService.getById(record.connectionId) + const record = await this.dummyService.getById(this.agentContext, dummyId) + const connection = await this.connectionService.getById(this.agentContext, record.connectionId) - const payload = await this.dummyService.createResponse(record) + const payload = await this.dummyService.createResponse(this.agentContext, record) - await this.messageSender.sendMessage({ connection, payload }) + await this.messageSender.sendMessage(this.agentContext, { connection, payload }) - await this.dummyService.updateState(record, DummyState.ResponseSent) + await this.dummyService.updateState(this.agentContext, record, DummyState.ResponseSent) return record } @@ -66,7 +70,7 @@ export class DummyApi { * @returns List containing all records */ public getAll(): Promise { - return this.dummyService.getAll() + return this.dummyService.getAll(this.agentContext) } private registerHandlers(dispatcher: Dispatcher) { diff --git a/samples/extension-module/dummy/services/DummyService.ts b/samples/extension-module/dummy/services/DummyService.ts index 3cc73eba9f..2defd9d393 100644 --- a/samples/extension-module/dummy/services/DummyService.ts +++ b/samples/extension-module/dummy/services/DummyService.ts @@ -1,5 +1,5 @@ import type { DummyStateChangedEvent } from './DummyEvents' -import type { ConnectionRecord, InboundMessageContext } from '@aries-framework/core' +import type { AgentContext, ConnectionRecord, InboundMessageContext } from '@aries-framework/core' import { injectable, JsonTransformer, EventEmitter } from '@aries-framework/core' @@ -27,7 +27,7 @@ export class DummyService { * @returns Object containing dummy request message and associated dummy record * */ - public async createRequest(connectionRecord: ConnectionRecord) { + public async createRequest(agentContext: AgentContext, connectionRecord: ConnectionRecord) { // Create message const message = new DummyRequestMessage({}) @@ -38,9 +38,9 @@ export class DummyService { state: DummyState.Init, }) - await this.dummyRepository.save(record) + await this.dummyRepository.save(agentContext, record) - this.emitStateChangedEvent(record, null) + this.emitStateChangedEvent(agentContext, record, null) return { record, message } } @@ -51,7 +51,7 @@ export class DummyService { * @param record the dummy record for which to create a dummy response * @returns outbound message containing dummy response */ - public async createResponse(record: DummyRecord) { + public async createResponse(agentContext: AgentContext, record: DummyRecord) { const responseMessage = new DummyResponseMessage({ threadId: record.threadId, }) @@ -76,9 +76,9 @@ export class DummyService { state: DummyState.RequestReceived, }) - await this.dummyRepository.save(record) + await this.dummyRepository.save(messageContext.agentContext, record) - this.emitStateChangedEvent(record, null) + this.emitStateChangedEvent(messageContext.agentContext, record, null) return record } @@ -96,13 +96,13 @@ export class DummyService { const connection = messageContext.assertReadyConnection() // Dummy record already exists - const record = await this.findByThreadAndConnectionId(message.threadId, connection.id) + const record = await this.findByThreadAndConnectionId(messageContext.agentContext, message.threadId, connection.id) if (record) { // Check current state record.assertState(DummyState.RequestSent) - await this.updateState(record, DummyState.ResponseReceived) + await this.updateState(messageContext.agentContext, record, DummyState.ResponseReceived) } else { throw new Error(`Dummy record not found with threadId ${message.threadId}`) } @@ -115,8 +115,8 @@ export class DummyService { * * @returns List containing all dummy records */ - public getAll(): Promise { - return this.dummyRepository.getAll() + public getAll(agentContext: AgentContext): Promise { + return this.dummyRepository.getAll(agentContext) } /** @@ -127,8 +127,8 @@ export class DummyService { * @return The dummy record * */ - public getById(dummyRecordId: string): Promise { - return this.dummyRepository.getById(dummyRecordId) + public getById(agentContext: AgentContext, dummyRecordId: string): Promise { + return this.dummyRepository.getById(agentContext, dummyRecordId) } /** @@ -140,8 +140,12 @@ export class DummyService { * @throws {RecordDuplicateError} If multiple records are found * @returns The dummy record */ - public async findByThreadAndConnectionId(threadId: string, connectionId?: string): Promise { - return this.dummyRepository.findSingleByQuery({ threadId, connectionId }) + public async findByThreadAndConnectionId( + agentContext: AgentContext, + threadId: string, + connectionId?: string + ): Promise { + return this.dummyRepository.findSingleByQuery(agentContext, { threadId, connectionId }) } /** @@ -152,19 +156,23 @@ export class DummyService { * @param newState The state to update to * */ - public async updateState(dummyRecord: DummyRecord, newState: DummyState) { + public async updateState(agentContext: AgentContext, dummyRecord: DummyRecord, newState: DummyState) { const previousState = dummyRecord.state dummyRecord.state = newState - await this.dummyRepository.update(dummyRecord) + await this.dummyRepository.update(agentContext, dummyRecord) - this.emitStateChangedEvent(dummyRecord, previousState) + this.emitStateChangedEvent(agentContext, dummyRecord, previousState) } - private emitStateChangedEvent(dummyRecord: DummyRecord, previousState: DummyState | null) { + private emitStateChangedEvent( + agentContext: AgentContext, + dummyRecord: DummyRecord, + previousState: DummyState | null + ) { // we need to clone the dummy record to avoid mutating records after they're emitted in an event const clonedDummyRecord = JsonTransformer.clone(dummyRecord) - this.eventEmitter.emit({ + this.eventEmitter.emit(agentContext, { type: DummyEventTypes.StateChanged, payload: { dummyRecord: clonedDummyRecord, previousState: previousState }, }) diff --git a/samples/mediator.ts b/samples/mediator.ts index ec57dc253b..da4dd15293 100644 --- a/samples/mediator.ts +++ b/samples/mediator.ts @@ -25,7 +25,6 @@ import { Agent, ConnectionInvitationMessage, LogLevel, - AgentConfig, WsOutboundTransport, } from '@aries-framework/core' import { HttpInboundTransport, agentDependencies, WsInboundTransport } from '@aries-framework/node' @@ -55,7 +54,7 @@ const agentConfig: InitConfig = { // Set up agent const agent = new Agent(agentConfig, agentDependencies) -const config = agent.dependencyManager.resolve(AgentConfig) +const config = agent.config // Create all transports const httpInboundTransport = new HttpInboundTransport({ app, port }) diff --git a/tests/InMemoryStorageService.ts b/tests/InMemoryStorageService.ts index 6c2d383eda..fbdccba2cf 100644 --- a/tests/InMemoryStorageService.ts +++ b/tests/InMemoryStorageService.ts @@ -1,3 +1,4 @@ +import type { AgentContext } from '../packages/core/src/agent' import type { BaseRecord, TagsBase } from '../packages/core/src/storage/BaseRecord' import type { StorageService, BaseRecordConstructor, Query } from '../packages/core/src/storage/StorageService' @@ -33,7 +34,7 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async save(record: T) { + public async save(agentContext: AgentContext, record: T) { const value = JsonTransformer.toJSON(record) if (this.records[record.id]) { @@ -49,7 +50,7 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async update(record: T): Promise { + public async update(agentContext: AgentContext, record: T): Promise { const value = JsonTransformer.toJSON(record) delete value._tags @@ -68,7 +69,7 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async delete(record: T) { + public async delete(agentContext: AgentContext, record: T) { if (!this.records[record.id]) { throw new RecordNotFoundError(`record with id ${record.id} not found.`, { recordType: record.type, @@ -79,7 +80,7 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async getById(recordClass: BaseRecordConstructor, id: string): Promise { + public async getById(agentContext: AgentContext, recordClass: BaseRecordConstructor, id: string): Promise { const record = this.records[id] if (!record) { @@ -92,7 +93,7 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async getAll(recordClass: BaseRecordConstructor): Promise { + public async getAll(agentContext: AgentContext, recordClass: BaseRecordConstructor): Promise { const records = Object.values(this.records) .filter((record) => record.type === recordClass.type) .map((record) => this.recordToInstance(record, recordClass)) @@ -101,7 +102,11 @@ export class InMemoryStorageService implement } /** @inheritDoc */ - public async findByQuery(recordClass: BaseRecordConstructor, query: Query): Promise { + public async findByQuery( + agentContext: AgentContext, + recordClass: BaseRecordConstructor, + query: Query + ): Promise { if (query.$and || query.$or || query.$not) { throw new AriesFrameworkError( 'Advanced wallet query features $and, $or or $not not supported in in memory storage' diff --git a/tests/e2e-test.ts b/tests/e2e-test.ts index d074e0aa6c..7c42b6a13d 100644 --- a/tests/e2e-test.ts +++ b/tests/e2e-test.ts @@ -1,9 +1,11 @@ import type { Agent } from '@aries-framework/core' +import type { Subject } from 'rxjs' import { sleep } from '../packages/core/src/utils/sleep' import { issueCredential, makeConnection, prepareForIssuance, presentProof } from '../packages/core/tests/helpers' import { + InjectionSymbols, V1CredentialPreview, AttributeFilter, CredentialState, @@ -95,6 +97,7 @@ export async function e2eTest({ // We want to stop the mediator polling before the agent is shutdown. // FIXME: add a way to stop mediator polling from the public api, and make sure this is // being handled in the agent shutdown so we don't get any errors with wallets being closed. - recipientAgent.config.stop$.next(true) + const recipientStop$ = recipientAgent.injectionContainer.resolve>(InjectionSymbols.Stop$) + recipientStop$.next(true) await sleep(2000) } diff --git a/tests/transport/SubjectInboundTransport.ts b/tests/transport/SubjectInboundTransport.ts index 6611d616ae..cd713f7d3f 100644 --- a/tests/transport/SubjectInboundTransport.ts +++ b/tests/transport/SubjectInboundTransport.ts @@ -3,7 +3,6 @@ import type { TransportSession } from '../../packages/core/src/agent/TransportSe import type { EncryptedMessage } from '../../packages/core/src/types' import type { Subject, Subscription } from 'rxjs' -import { AgentConfig } from '../../packages/core/src/agent/AgentConfig' import { TransportService } from '../../packages/core/src/agent/TransportService' import { uuid } from '../../packages/core/src/utils/uuid' @@ -26,7 +25,7 @@ export class SubjectInboundTransport implements InboundTransport { } private subscribe(agent: Agent) { - const logger = agent.dependencyManager.resolve(AgentConfig).logger + const logger = agent.config.logger const transportService = agent.dependencyManager.resolve(TransportService) this.subscription = this.ourSubject.subscribe({ diff --git a/tests/transport/SubjectOutboundTransport.ts b/tests/transport/SubjectOutboundTransport.ts index 7adc82b10d..1754dbe067 100644 --- a/tests/transport/SubjectOutboundTransport.ts +++ b/tests/transport/SubjectOutboundTransport.ts @@ -9,6 +9,7 @@ export class SubjectOutboundTransport implements OutboundTransport { private logger!: Logger private subjectMap: { [key: string]: Subject | undefined } private agent!: Agent + private stop$!: Subject public supportedSchemes = ['rxjs'] @@ -20,6 +21,7 @@ export class SubjectOutboundTransport implements OutboundTransport { this.agent = agent this.logger = agent.dependencyManager.resolve(InjectionSymbols.Logger) + this.stop$ = agent.dependencyManager.resolve(InjectionSymbols.Stop$) } public async stop(): Promise { @@ -45,9 +47,9 @@ export class SubjectOutboundTransport implements OutboundTransport { // Create a replySubject just for this session. Both ends will be able to close it, // mimicking a transport like http or websocket. Close session automatically when agent stops const replySubject = new Subject() - this.agent.config.stop$.pipe(take(1)).subscribe(() => !replySubject.closed && replySubject.complete()) + this.stop$.pipe(take(1)).subscribe(() => !replySubject.closed && replySubject.complete()) - replySubject.pipe(takeUntil(this.agent.config.stop$)).subscribe({ + replySubject.pipe(takeUntil(this.stop$)).subscribe({ next: async ({ message }: SubjectMessage) => { this.logger.test('Received message')