Skip to content

Commit

Permalink
feat(tenants): tenant lifecycle (#942)
Browse files Browse the repository at this point in the history
* fix: make event listeners tenant aware
* chore(deps): update tsyringe
* feat: add agent context disposal
* feat(tenants): with tenant agent method
* test(tenants): add tests for session mutex
* feat(tenants): use RAW key derivation
* test(tenants): add e2e session tests
* feat(tenants): destroy and end session

Signed-off-by: Timo Glastra <[email protected]>
  • Loading branch information
TimoGlastra committed Aug 26, 2022
1 parent 7cbd08c commit adfa65b
Show file tree
Hide file tree
Showing 42 changed files with 941 additions and 196 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"ts-jest": "^27.0.3",
"ts-node": "^10.0.0",
"tsconfig-paths": "^3.9.0",
"tsyringe": "^4.6.0",
"tsyringe": "^4.7.0",
"typescript": "~4.3.0",
"ws": "^7.4.6"
},
Expand Down
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"query-string": "^7.0.1",
"reflect-metadata": "^0.1.13",
"rxjs": "^7.2.0",
"tsyringe": "^4.5.0",
"tsyringe": "^4.7.0",
"uuid": "^8.3.2",
"varint": "^6.0.0",
"web-did-resolver": "^2.0.8"
Expand Down
7 changes: 4 additions & 3 deletions packages/core/src/agent/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,10 @@ export class Agent extends BaseAgent {
const transportPromises = allTransports.map((transport) => transport.stop())
await Promise.all(transportPromises)

// close wallet if still initialized
if (this.wallet.isInitialized) {
await this.wallet.close()
}

await super.shutdown()
this._isInitialized = false
}

Expand Down Expand Up @@ -205,7 +203,10 @@ export class Agent extends BaseAgent {
// Bind the default agent context to the container for use in modules etc.
dependencyManager.registerInstance(
AgentContext,
new AgentContext({ dependencyManager, contextCorrelationId: 'default' })
new AgentContext({
dependencyManager,
contextCorrelationId: 'default',
})
)

// If no agent context provider has been registered we use the default agent context provider.
Expand Down
4 changes: 0 additions & 4 deletions packages/core/src/agent/BaseAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ export abstract class BaseAgent {
}
}

public async shutdown() {
// No logic required at the moment
}

public get publicDid() {
return this.agentContext.wallet.publicDid
}
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/agent/Events.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import type { ConnectionRecord } from '../modules/connections'
import type { AgentMessage } from './AgentMessage'
import type { Observable } from 'rxjs'

import { filter } from 'rxjs'

export function filterContextCorrelationId(contextCorrelationId: string) {
return <T extends BaseEvent>(source: Observable<T>) => {
return source.pipe(filter((event) => event.metadata.contextCorrelationId === contextCorrelationId))
}
}

export enum AgentEventTypes {
AgentMessageReceived = 'AgentMessageReceived',
Expand Down
17 changes: 11 additions & 6 deletions packages/core/src/agent/MessageReceiver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ export class MessageReceiver {
contextCorrelationId,
})

if (this.isEncryptedMessage(inboundMessage)) {
await this.receiveEncryptedMessage(agentContext, inboundMessage as EncryptedMessage, session)
} else if (this.isPlaintextMessage(inboundMessage)) {
await this.receivePlaintextMessage(agentContext, inboundMessage, connection)
} else {
throw new AriesFrameworkError('Unable to parse incoming message: unrecognized format')
try {
if (this.isEncryptedMessage(inboundMessage)) {
await this.receiveEncryptedMessage(agentContext, inboundMessage as EncryptedMessage, session)
} else if (this.isPlaintextMessage(inboundMessage)) {
await this.receivePlaintextMessage(agentContext, inboundMessage, connection)
} else {
throw new AriesFrameworkError('Unable to parse incoming message: unrecognized format')
}
} finally {
// Always end the session for the agent context after handling the message.
await agentContext.endSession()
}
}

Expand Down
12 changes: 12 additions & 0 deletions packages/core/src/agent/context/AgentContext.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { DependencyManager } from '../../plugins'
import type { Wallet } from '../../wallet'
import type { AgentContextProvider } from './AgentContextProvider'

import { InjectionSymbols } from '../../constants'
import { AgentConfig } from '../AgentConfig'
Expand Down Expand Up @@ -47,6 +48,17 @@ export class AgentContext {
return this.dependencyManager.resolve<Wallet>(InjectionSymbols.Wallet)
}

/**
* End session the current agent context
*/
public async endSession() {
const agentContextProvider = this.dependencyManager.resolve<AgentContextProvider>(
InjectionSymbols.AgentContextProvider
)

await agentContextProvider.endSessionForAgentContext(this)
}

public toJSON() {
return {
contextCorrelationId: this.contextCorrelationId,
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/agent/context/AgentContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ export interface AgentContextProvider {
* for the specified contextCorrelationId.
*/
getAgentContextForContextCorrelationId(contextCorrelationId: string): Promise<AgentContext>

/**
* End sessions for the provided agent context. This does not necessarily mean the wallet will be closed or the dependency manager will
* be disposed, it is to inform the agent context provider this session for the agent context is no longer in use. This should only be
* called once for every session and the agent context MUST not be used after this method is called.
*/
endSessionForAgentContext(agentContext: AgentContext): Promise<void>
}
11 changes: 11 additions & 0 deletions packages/core/src/agent/context/DefaultAgentContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,15 @@ export class DefaultAgentContextProvider implements AgentContextProvider {

return this.agentContext
}

public async endSessionForAgentContext(agentContext: AgentContext) {
// Throw an error if the context correlation id does not match to prevent misuse.
if (agentContext.contextCorrelationId !== this.agentContext.contextCorrelationId) {
throw new AriesFrameworkError(
`Could not end session for agent context with contextCorrelationId '${agentContext.contextCorrelationId}'. Only contextCorrelationId '${this.agentContext.contextCorrelationId}' is provided by this provider.`
)
}

// We won't dispose the agent context as we don't keep track of the total number of sessions for the root agent context.65
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,23 @@ describe('DefaultAgentContextProvider', () => {
)
})
})

describe('endSessionForAgentContext()', () => {
test('resolves when the correct agent context is passed', async () => {
const agentContextProvider: AgentContextProvider = new DefaultAgentContextProvider(agentContext)

await expect(agentContextProvider.endSessionForAgentContext(agentContext)).resolves.toBeUndefined()
})

test('throws an error if the contextCorrelationId does not match with the contextCorrelationId from the constructor agent context', async () => {
const agentContextProvider: AgentContextProvider = new DefaultAgentContextProvider(agentContext)
const agentContext2 = getAgentContext({
contextCorrelationId: 'mock2',
})

await expect(agentContextProvider.endSessionForAgentContext(agentContext2)).rejects.toThrowError(
`Could not end session for agent context with contextCorrelationId 'mock2'. Only contextCorrelationId 'mock' is provided by this provider.`
)
})
})
})
2 changes: 1 addition & 1 deletion packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export { Dispatcher } from './agent/Dispatcher'
export { MessageSender } from './agent/MessageSender'
export type { AgentDependencies } from './agent/AgentDependencies'
export type { InitConfig, OutboundPackage, EncryptedMessage, WalletConfig } from './types'
export { KeyDerivationMethod, DidCommMimeType } from './types'
export { DidCommMimeType, KeyDerivationMethod } from './types'
export type { FileSystem } from './storage/FileSystem'
export * from './storage/BaseRecord'
export { InMemoryMessageRepository } from './storage/InMemoryMessageRepository'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { firstValueFrom, ReplaySubject } from 'rxjs'
import { first, map, timeout } from 'rxjs/operators'

import { EventEmitter } from '../../../agent/EventEmitter'
import { filterContextCorrelationId } from '../../../agent/Events'
import { InjectionSymbols } from '../../../constants'
import { Key } from '../../../crypto'
import { signData, unpackAndVerifySignatureDecorator } from '../../../decorators/signature/SignatureDecoratorUtils'
Expand Down Expand Up @@ -749,6 +750,7 @@ export class ConnectionService {

observable
.pipe(
filterContextCorrelationId(agentContext.contextCorrelationId),
map((e) => e.payload.connectionRecord),
first(isConnected), // Do not wait for longer than specified timeout
timeout(timeoutMs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { catchError, filter, map, takeUntil, timeout } from 'rxjs/operators'
import { AgentContext } from '../../agent'
import { Dispatcher } from '../../agent/Dispatcher'
import { EventEmitter } from '../../agent/EventEmitter'
import { AgentEventTypes } from '../../agent/Events'
import { filterContextCorrelationId, AgentEventTypes } from '../../agent/Events'
import { MessageSender } from '../../agent/MessageSender'
import { createOutboundMessage } from '../../agent/helpers'
import { InjectionSymbols } from '../../constants'
Expand Down Expand Up @@ -58,6 +58,7 @@ export class DiscoverFeaturesModule {
.pipe(
// Stop when the agent shuts down
takeUntil(this.stop$),
filterContextCorrelationId(this.agentContext.contextCorrelationId),
// filter by connection id and query disclose message type
filter(
(e) =>
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/modules/oob/OutOfBandModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { catchError, EmptyError, first, firstValueFrom, map, of, timeout } from
import { AgentContext } from '../../agent'
import { Dispatcher } from '../../agent/Dispatcher'
import { EventEmitter } from '../../agent/EventEmitter'
import { AgentEventTypes } from '../../agent/Events'
import { filterContextCorrelationId, AgentEventTypes } from '../../agent/Events'
import { MessageSender } from '../../agent/MessageSender'
import { createOutboundMessage } from '../../agent/helpers'
import { InjectionSymbols } from '../../constants'
Expand Down Expand Up @@ -681,6 +681,7 @@ export class OutOfBandModule {

const reuseAcceptedEventPromise = firstValueFrom(
this.eventEmitter.observable<HandshakeReusedEvent>(OutOfBandEventTypes.HandshakeReused).pipe(
filterContextCorrelationId(this.agentContext.contextCorrelationId),
// Find the first reuse event where the handshake reuse accepted matches the reuse message thread
// TODO: Should we store the reuse state? Maybe we can keep it in memory for now
first(
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/modules/routing/RecipientModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { delayWhen, filter, first, takeUntil, tap, throttleTime, timeout } from
import { AgentContext } from '../../agent'
import { Dispatcher } from '../../agent/Dispatcher'
import { EventEmitter } from '../../agent/EventEmitter'
import { filterContextCorrelationId } from '../../agent/Events'
import { MessageSender } from '../../agent/MessageSender'
import { createOutboundMessage } from '../../agent/helpers'
import { InjectionSymbols } from '../../constants'
Expand Down Expand Up @@ -145,6 +146,11 @@ export class RecipientModule {
private async openWebSocketAndPickUp(mediator: MediationRecord, pickupStrategy: MediatorPickupStrategy) {
let interval = 50

// FIXME: this won't work for tenant agents created by the tenants module as the agent context session
// could be closed. I'm not sure we want to support this as you probably don't want different tenants opening
// various websocket connections to mediators. However we should look at throwing an error or making sure
// it is not possible to use the mediation module with tenant agents.

// Listens to Outbound websocket closed events and will reopen the websocket connection
// in a recursive back off strategy if it matches the following criteria:
// - Agent is not shutdown
Expand Down Expand Up @@ -335,6 +341,7 @@ export class RecipientModule {
// Apply required filters to observable stream subscribe to replay subject
observable
.pipe(
filterContextCorrelationId(this.agentContext.contextCorrelationId),
// Only take event for current mediation record
filter((event) => event.payload.mediationRecord.id === mediationRecord.id),
// Only take event for previous state requested, current state granted
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { firstValueFrom, ReplaySubject } from 'rxjs'
import { filter, first, timeout } from 'rxjs/operators'

import { EventEmitter } from '../../../agent/EventEmitter'
import { AgentEventTypes } from '../../../agent/Events'
import { filterContextCorrelationId, AgentEventTypes } from '../../../agent/Events'
import { MessageSender } from '../../../agent/MessageSender'
import { createOutboundMessage } from '../../../agent/helpers'
import { Key, KeyType } from '../../../crypto'
Expand Down Expand Up @@ -157,6 +157,7 @@ export class MediationRecipientService {
// Apply required filters to observable stream and create promise to subscribe to observable
observable
.pipe(
filterContextCorrelationId(agentContext.contextCorrelationId),
// Only take event for current mediation record
filter((event) => mediationRecord.id === event.payload.mediationRecord.id),
// Only wait for first event that matches the criteria
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/plugins/DependencyManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ export class DependencyManager {
else this.container.register(token, token, { lifecycle: Lifecycle.ContainerScoped })
}

/**
* Dispose the dependency manager. Calls `.dispose()` on all instances that implement the `Disposable` interface and have
* been constructed by the `DependencyManager`. This means all instances registered using `registerInstance` won't have the
* dispose method called.
*/
public async dispose() {
await this.container.dispose()
}

public createChild() {
return new DependencyManager(this.container.createChildContainer())
}
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/plugins/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export * from './DependencyManager'
export * from './Module'
export { inject, injectable } from 'tsyringe'
export { inject, injectable, Disposable } from 'tsyringe'
9 changes: 9 additions & 0 deletions packages/core/src/wallet/IndyWallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ export class IndyWallet implements Wallet {
return this.walletConfig.id
}

/**
* Dispose method is called when an agent context is disposed.
*/
public async dispose() {
if (this.isInitialized) {
await this.close()
}
}

private walletStorageConfig(walletConfig: WalletConfig): Indy.WalletConfig {
const walletStorageConfig: Indy.WalletConfig = {
id: walletConfig.id,
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/wallet/Wallet.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Key, KeyType } from '../crypto'
import type { Disposable } from '../plugins'
import type {
EncryptedMessage,
WalletConfig,
Expand All @@ -8,7 +9,7 @@ import type {
} from '../types'
import type { Buffer } from '../utils/buffer'

export interface Wallet {
export interface Wallet extends Disposable {
publicDid: DidInfo | undefined
isInitialized: boolean
isProvisioned: boolean
Expand Down
4 changes: 4 additions & 0 deletions packages/core/tests/mocks/MockWallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ export class MockWallet implements Wallet {
public generateWalletKey(): Promise<string> {
throw new Error('Method not implemented.')
}

public dispose() {
// Nothing to do here
}
}
6 changes: 3 additions & 3 deletions packages/module-tenants/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "@aries-framework/module-tenants",
"main": "build/index",
"types": "build/index",
"version": "0.2.0",
"version": "0.2.2",
"files": [
"build"
],
Expand All @@ -24,11 +24,11 @@
"test": "jest"
},
"dependencies": {
"@aries-framework/core": "0.2.0",
"@aries-framework/core": "0.2.2",
"async-mutex": "^0.3.2"
},
"devDependencies": {
"@aries-framework/node": "0.2.0",
"@aries-framework/node": "0.2.2",
"reflect-metadata": "^0.1.13",
"rimraf": "~3.0.2",
"typescript": "~4.3.0"
Expand Down
16 changes: 13 additions & 3 deletions packages/module-tenants/src/TenantAgent.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import type { AgentContext } from '@aries-framework/core'

import { BaseAgent } from '@aries-framework/core'
import { AriesFrameworkError, BaseAgent } from '@aries-framework/core'

export class TenantAgent extends BaseAgent {
private sessionHasEnded = false

public constructor(agentContext: AgentContext) {
super(agentContext.config, agentContext.dependencyManager)
}

public async initialize() {
if (this.sessionHasEnded) {
throw new AriesFrameworkError("Can't initialize agent after tenant sessions has been ended.")
}

await super.initialize()
this._isInitialized = true
}

public async shutdown() {
await super.shutdown()
public async endSession() {
this.logger.trace(
`Ending session for agent context with contextCorrelationId '${this.agentContext.contextCorrelationId}'`
)
await this.agentContext.endSession()
this._isInitialized = false
this.sessionHasEnded = true
}

protected registerDependencies() {
Expand Down
Loading

0 comments on commit adfa65b

Please sign in to comment.