From 7366ca7b6ba2925a28020d5d063272505d53b0d5 Mon Sep 17 00:00:00 2001 From: jakubkoci Date: Wed, 7 Jul 2021 17:11:10 +0200 Subject: [PATCH] feat: Use session to send outbound message (#362) Signed-off-by: Jakub Koci --- src/__tests__/helpers.ts | 23 ++++- src/agent/Agent.ts | 12 +++ src/agent/Dispatcher.ts | 16 +--- src/agent/MessageReceiver.ts | 20 +++- src/agent/MessageSender.ts | 16 +++- src/agent/TransportService.ts | 36 ++++--- src/agent/__tests__/MessageSender.test.ts | 94 +++++++++++++++---- src/agent/__tests__/TransportService.test.ts | 25 ++++- src/agent/__tests__/stubs.ts | 20 ++++ .../transport/TransportDecoratorExtension.ts | 5 + src/transport/WsOutboundTransporter.ts | 26 +---- tests/mediator-ws.ts | 42 ++++++--- tests/mediator.ts | 46 +++++++-- 13 files changed, 283 insertions(+), 98 deletions(-) create mode 100644 src/agent/__tests__/stubs.ts diff --git a/src/__tests__/helpers.ts b/src/__tests__/helpers.ts index a1ae959c5f..2cbfd26b46 100644 --- a/src/__tests__/helpers.ts +++ b/src/__tests__/helpers.ts @@ -1,4 +1,5 @@ import type { Agent } from '../agent/Agent' +import type { TransportSession } from '../agent/TransportService' import type { BasicMessage, BasicMessageReceivedEvent } from '../modules/basic-messages' import type { ConnectionRecordProps } from '../modules/connections' import type { CredentialRecord, CredentialOfferTemplate, CredentialStateChangedEvent } from '../modules/credentials' @@ -133,6 +134,22 @@ export async function waitForBasicMessage(agent: Agent, { content }: { content?: }) } +class SubjectTransportSession implements TransportSession { + public id: string + public readonly type = 'subject' + private theirSubject: Subject + + public constructor(id: string, theirSubject: Subject) { + this.id = id + this.theirSubject = theirSubject + } + + public send(outboundMessage: OutboundPackage): Promise { + this.theirSubject.next(outboundMessage.payload) + return Promise.resolve() + } +} + export class SubjectInboundTransporter implements InboundTransporter { private subject: Subject private theirSubject: Subject @@ -149,10 +166,8 @@ export class SubjectInboundTransporter implements InboundTransporter { private subscribe(agent: Agent) { this.subject.subscribe({ next: async (message: WireMessage) => { - const outboundMessage = await agent.receiveMessage(message) - if (outboundMessage) { - this.theirSubject.next(outboundMessage.payload) - } + const session = new SubjectTransportSession('subject-session-1', this.theirSubject) + await agent.receiveMessage(message, session) }, }) } diff --git a/src/agent/Agent.ts b/src/agent/Agent.ts index fe423eab61..914a8d98ef 100644 --- a/src/agent/Agent.ts +++ b/src/agent/Agent.ts @@ -30,6 +30,7 @@ import { EventEmitter } from './EventEmitter' import { AgentEventTypes } from './Events' import { MessageReceiver } from './MessageReceiver' import { MessageSender } from './MessageSender' +import { TransportService } from './TransportService' export class Agent { protected agentConfig: AgentConfig @@ -38,6 +39,7 @@ export class Agent { protected eventEmitter: EventEmitter protected wallet: Wallet protected messageReceiver: MessageReceiver + protected transportService: TransportService protected messageSender: MessageSender public inboundTransporter?: InboundTransporter private _isInitialized = false @@ -96,6 +98,7 @@ export class Agent { this.eventEmitter = this.container.resolve(EventEmitter) this.messageSender = this.container.resolve(MessageSender) this.messageReceiver = this.container.resolve(MessageReceiver) + this.transportService = this.container.resolve(TransportService) this.wallet = this.container.resolve(InjectionSymbols.Wallet) // We set the modules in the constructor because that allows to set them as read-only @@ -176,6 +179,15 @@ export class Agent { return await this.messageReceiver.receiveMessage(inboundPackedMessage, session) } + public async closeAndDeleteWallet() { + await this.wallet.close() + await this.wallet.delete() + } + + public removeSession(session: TransportSession) { + this.transportService.removeSession(session) + } + public get injectionContainer() { return this.container } diff --git a/src/agent/Dispatcher.ts b/src/agent/Dispatcher.ts index 9d611f46c2..3110f27e2c 100644 --- a/src/agent/Dispatcher.ts +++ b/src/agent/Dispatcher.ts @@ -1,4 +1,3 @@ -import type { OutboundMessage, OutboundPackage } from '../types' import type { AgentMessage } from './AgentMessage' import type { Handler } from './Handler' import type { InboundMessageContext } from './models/InboundMessageContext' @@ -26,7 +25,7 @@ class Dispatcher { this.handlers.push(handler) } - public async dispatch(messageContext: InboundMessageContext): Promise { + public async dispatch(messageContext: InboundMessageContext): Promise { const message = messageContext.message const handler = this.getHandlerForType(message.type) @@ -37,22 +36,9 @@ class Dispatcher { const outboundMessage = await handler.handle(messageContext) if (outboundMessage) { - const threadId = outboundMessage.payload.threadId - if (!this.transportService.hasInboundEndpoint(outboundMessage.connection)) { outboundMessage.payload.setReturnRouting(ReturnRouteTypes.all) } - - // Check for return routing, with thread id - if (message.hasReturnRouting(threadId)) { - const keys = { - recipientKeys: messageContext.senderVerkey ? [messageContext.senderVerkey] : [], - routingKeys: [], - senderKey: messageContext.connection?.verkey || null, - } - return await this.messageSender.packMessage(outboundMessage, keys) - } - await this.messageSender.sendMessage(outboundMessage) } } diff --git a/src/agent/MessageReceiver.ts b/src/agent/MessageReceiver.ts index 115a5d8635..1eaba0a94d 100644 --- a/src/agent/MessageReceiver.ts +++ b/src/agent/MessageReceiver.ts @@ -73,10 +73,6 @@ export class MessageReceiver { } } - if (connection && session) { - this.transportService.saveSession(connection.id, session) - } - this.logger.info(`Received message with type '${unpackedMessage.message['@type']}'`, unpackedMessage.message) const message = await this.transformMessage(unpackedMessage) @@ -86,6 +82,22 @@ export class MessageReceiver { recipientVerkey: unpackedMessage.recipient_verkey, }) + // We want to save a session if there is a chance of returning outbound message via inbound transport. + // That can happen when inbound message has `return_route` set to `all` or `thread`. + // If `return_route` defines just `thread`, we decide later whether to use session according to outbound message `threadId`. + if (connection && message.hasAnyReturnRoute() && session) { + const keys = { + // TODO handle the case when senderKey is missing + recipientKeys: senderKey ? [senderKey] : [], + routingKeys: [], + senderKey: connection?.verkey || null, + } + session.keys = keys + session.inboundMessage = message + session.connection = connection + this.transportService.saveSession(session) + } + return await this.dispatcher.dispatch(messageContext) } diff --git a/src/agent/MessageSender.ts b/src/agent/MessageSender.ts index 809cd97555..8d9ece301b 100644 --- a/src/agent/MessageSender.ts +++ b/src/agent/MessageSender.ts @@ -55,6 +55,21 @@ export class MessageSender { connection: { id, verkey, theirKey }, }) + const session = this.transportService.findSessionByConnectionId(connection.id) + if (session?.inboundMessage?.hasReturnRouting(outboundMessage.payload.threadId)) { + this.logger.debug(`Existing ${session.type} transport session has been found.`) + try { + if (!session.keys) { + throw new AriesFrameworkError(`There are no keys for the given ${session.type} transport session.`) + } + const outboundPackage = await this.packMessage(outboundMessage, session.keys) + await session.send(outboundPackage) + return + } catch (error) { + this.logger.info(`Sending an outbound message via session failed with error: ${error.message}.`, error) + } + } + const services = this.transportService.findDidCommServices(connection) if (services.length === 0) { throw new AriesFrameworkError(`Connection with id ${connection.id} has no service!`) @@ -69,7 +84,6 @@ export class MessageSender { senderKey: connection.verkey, } const outboundPackage = await this.packMessage(outboundMessage, keys) - outboundPackage.session = this.transportService.findSession(connection.id) outboundPackage.endpoint = service.serviceEndpoint outboundPackage.responseRequested = outboundMessage.payload.hasReturnRouting() diff --git a/src/agent/TransportService.ts b/src/agent/TransportService.ts index 980f34b4a2..4b17f81e71 100644 --- a/src/agent/TransportService.ts +++ b/src/agent/TransportService.ts @@ -1,30 +1,35 @@ import type { ConnectionRecord } from '../modules/connections/repository' +import type { OutboundPackage } from '../types' +import type { AgentMessage } from './AgentMessage' +import type { EnvelopeKeys } from './EnvelopeService' -import { Lifecycle, scoped, inject } from 'tsyringe' +import { Lifecycle, scoped } from 'tsyringe' -import { DID_COMM_TRANSPORT_QUEUE, InjectionSymbols } from '../constants' -import { Logger } from '../logger' +import { DID_COMM_TRANSPORT_QUEUE } from '../constants' import { ConnectionRole, DidCommService } from '../modules/connections/models' @scoped(Lifecycle.ContainerScoped) export class TransportService { private transportSessionTable: TransportSessionTable = {} - private logger: Logger - public constructor(@inject(InjectionSymbols.Logger) logger: Logger) { - this.logger = logger + public saveSession(session: TransportSession) { + this.transportSessionTable[session.id] = session } - public saveSession(connectionId: string, transport: TransportSession) { - this.transportSessionTable[connectionId] = transport + public findSessionByConnectionId(connectionId: string) { + return Object.values(this.transportSessionTable).find((session) => session.connection?.id === connectionId) } - public hasInboundEndpoint(connection: ConnectionRecord) { - return connection.didDoc.didCommServices.find((s) => s.serviceEndpoint !== DID_COMM_TRANSPORT_QUEUE) + public findSessionById(sessionId: string) { + return this.transportSessionTable[sessionId] + } + + public removeSession(session: TransportSession) { + delete this.transportSessionTable[session.id] } - public findSession(connectionId: string) { - return this.transportSessionTable[connectionId] + public hasInboundEndpoint(connection: ConnectionRecord) { + return connection.didDoc.didCommServices.find((s) => s.serviceEndpoint !== DID_COMM_TRANSPORT_QUEUE) } public findDidCommServices(connection: ConnectionRecord): DidCommService[] { @@ -49,9 +54,14 @@ export class TransportService { } interface TransportSessionTable { - [connectionRecordId: string]: TransportSession + [sessionId: string]: TransportSession } export interface TransportSession { + id: string type: string + keys?: EnvelopeKeys + inboundMessage?: AgentMessage + connection?: ConnectionRecord + send(outboundMessage: OutboundPackage): Promise } diff --git a/src/agent/__tests__/MessageSender.test.ts b/src/agent/__tests__/MessageSender.test.ts index a48cbcacec..862c2e3422 100644 --- a/src/agent/__tests__/MessageSender.test.ts +++ b/src/agent/__tests__/MessageSender.test.ts @@ -1,7 +1,6 @@ import type { ConnectionRecord } from '../../modules/connections' import type { OutboundTransporter } from '../../transport' import type { OutboundMessage } from '../../types' -import type { TransportSession } from '../TransportService' import { getMockConnection, mockFunction } from '../../__tests__/helpers' import testLogger from '../../__tests__/logger' @@ -13,6 +12,8 @@ import { MessageSender } from '../MessageSender' import { TransportService as TransportServiceImpl } from '../TransportService' import { createOutboundMessage } from '../helpers' +import { DummyTransportSession } from './stubs' + jest.mock('../TransportService') jest.mock('../EnvelopeService') @@ -34,10 +35,6 @@ class DummyOutboundTransporter implements OutboundTransporter { } } -class DummyTransportSession implements TransportSession { - public readonly type = 'dummy' -} - describe('MessageSender', () => { const TransportService = >(TransportServiceImpl) const EnvelopeService = >(EnvelopeServiceImpl) @@ -54,9 +51,24 @@ describe('MessageSender', () => { const enveloperService = new EnvelopeService() const envelopeServicePackMessageMock = mockFunction(enveloperService.packMessage) + const inboundMessage = new AgentMessage() + inboundMessage.setReturnRouting(ReturnRouteTypes.all) + + const session = new DummyTransportSession('session-123') + session.keys = { + recipientKeys: ['verkey'], + routingKeys: [], + senderKey: 'senderKey', + } + session.inboundMessage = inboundMessage + session.send = jest.fn() + + const sessionWithoutKeys = new DummyTransportSession('sessionWithoutKeys-123') + sessionWithoutKeys.inboundMessage = inboundMessage + sessionWithoutKeys.send = jest.fn() + const transportService = new TransportService() - const session = new DummyTransportSession() - const transportServiceFindSessionMock = mockFunction(transportService.findSession) + const transportServiceFindSessionMock = mockFunction(transportService.findSessionByConnectionId) const firstDidCommService = new DidCommService({ id: `;indy`, @@ -85,18 +97,17 @@ describe('MessageSender', () => { envelopeServicePackMessageMock.mockReturnValue(Promise.resolve(wireMessage)) transportServiceFindServicesMock.mockReturnValue([firstDidCommService, secondDidCommService]) - transportServiceFindSessionMock.mockReturnValue(session) }) afterEach(() => { jest.resetAllMocks() }) - test('throws error when there is no outbound transport', async () => { + test('throw error when there is no outbound transport', async () => { await expect(messageSender.sendMessage(outboundMessage)).rejects.toThrow(`Agent has no outbound transporter!`) }) - test('throws error when there is no service', async () => { + test('throw error when there is no service', async () => { messageSender.setOutboundTransporter(outboundTransporter) transportServiceFindServicesMock.mockReturnValue([]) @@ -105,7 +116,59 @@ describe('MessageSender', () => { ) }) - test('calls send message with connection, payload and endpoint from first DidComm service', async () => { + test('call send message when session send method fails', async () => { + messageSender.setOutboundTransporter(outboundTransporter) + transportServiceFindSessionMock.mockReturnValue(session) + session.send = jest.fn().mockRejectedValue(new Error('some error')) + + messageSender.setOutboundTransporter(outboundTransporter) + const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') + + await messageSender.sendMessage(outboundMessage) + + expect(sendMessageSpy).toHaveBeenCalledWith({ + connection, + payload: wireMessage, + endpoint: firstDidCommService.serviceEndpoint, + responseRequested: false, + }) + expect(sendMessageSpy).toHaveBeenCalledTimes(1) + }) + + test('call send message when session send method fails with missing keys', async () => { + messageSender.setOutboundTransporter(outboundTransporter) + transportServiceFindSessionMock.mockReturnValue(sessionWithoutKeys) + + messageSender.setOutboundTransporter(outboundTransporter) + const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') + + await messageSender.sendMessage(outboundMessage) + + expect(sendMessageSpy).toHaveBeenCalledWith({ + connection, + payload: wireMessage, + endpoint: firstDidCommService.serviceEndpoint, + responseRequested: false, + }) + expect(sendMessageSpy).toHaveBeenCalledTimes(1) + }) + + test('call send message on session when there is a session for a given connection', async () => { + messageSender.setOutboundTransporter(outboundTransporter) + const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') + session.connection = connection + transportServiceFindSessionMock.mockReturnValue(session) + + await messageSender.sendMessage(outboundMessage) + + expect(session.send).toHaveBeenCalledWith({ + connection, + payload: wireMessage, + }) + expect(sendMessageSpy).toHaveBeenCalledTimes(0) + }) + + test('call send message with connection, payload and endpoint from first DidComm service', async () => { messageSender.setOutboundTransporter(outboundTransporter) const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') @@ -116,12 +179,11 @@ describe('MessageSender', () => { payload: wireMessage, endpoint: firstDidCommService.serviceEndpoint, responseRequested: false, - session, }) expect(sendMessageSpy).toHaveBeenCalledTimes(1) }) - test('calls send message with connection, payload and endpoint from second DidComm service when the first fails', async () => { + test('call send message with connection, payload and endpoint from second DidComm service when the first fails', async () => { messageSender.setOutboundTransporter(outboundTransporter) const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') @@ -135,12 +197,11 @@ describe('MessageSender', () => { payload: wireMessage, endpoint: secondDidCommService.serviceEndpoint, responseRequested: false, - session, }) expect(sendMessageSpy).toHaveBeenCalledTimes(2) }) - test('calls send message with responseRequested when message has return route', async () => { + test('call send message with responseRequested when message has return route', async () => { messageSender.setOutboundTransporter(outboundTransporter) const sendMessageSpy = jest.spyOn(outboundTransporter, 'sendMessage') @@ -155,7 +216,6 @@ describe('MessageSender', () => { payload: wireMessage, endpoint: firstDidCommService.serviceEndpoint, responseRequested: true, - session, }) expect(sendMessageSpy).toHaveBeenCalledTimes(1) }) @@ -174,7 +234,7 @@ describe('MessageSender', () => { jest.resetAllMocks() }) - test('returns outbound message context with connection, payload and endpoint', async () => { + test('return outbound message context with connection, payload and endpoint', async () => { const message = new AgentMessage() const outboundMessage = createOutboundMessage(connection, message) diff --git a/src/agent/__tests__/TransportService.test.ts b/src/agent/__tests__/TransportService.test.ts index 22e3d536f2..3dd12a1c7c 100644 --- a/src/agent/__tests__/TransportService.test.ts +++ b/src/agent/__tests__/TransportService.test.ts @@ -1,9 +1,8 @@ import { getMockConnection } from '../../__tests__/helpers' -import testLogger from '../../__tests__/logger' import { ConnectionInvitationMessage, ConnectionRole, DidCommService, DidDoc } from '../../modules/connections' import { TransportService } from '../TransportService' -const logger = testLogger +import { DummyTransportSession } from './stubs' describe('TransportService', () => { describe('findServices', () => { @@ -23,7 +22,7 @@ describe('TransportService', () => { service: [testDidCommService], }) - transportService = new TransportService(logger) + transportService = new TransportService() }) test(`returns empty array when there is no their DidDoc and role is ${ConnectionRole.Inviter}`, () => { @@ -62,4 +61,24 @@ describe('TransportService', () => { ]) }) }) + + describe('removeSession', () => { + let transportService: TransportService + + beforeEach(() => { + transportService = new TransportService() + }) + + test(`remove session saved for a given connection`, () => { + const connection = getMockConnection({ id: 'test-123', role: ConnectionRole.Inviter }) + const session = new DummyTransportSession('dummy-session-123') + session.connection = connection + + transportService.saveSession(session) + expect(transportService.findSessionByConnectionId(connection.id)).toEqual(session) + + transportService.removeSession(session) + expect(transportService.findSessionByConnectionId(connection.id)).toEqual(undefined) + }) + }) }) diff --git a/src/agent/__tests__/stubs.ts b/src/agent/__tests__/stubs.ts new file mode 100644 index 0000000000..5bdb3b5bb6 --- /dev/null +++ b/src/agent/__tests__/stubs.ts @@ -0,0 +1,20 @@ +import type { ConnectionRecord } from '../../modules/connections' +import type { AgentMessage } from '../AgentMessage' +import type { EnvelopeKeys } from '../EnvelopeService' +import type { TransportSession } from '../TransportService' + +export class DummyTransportSession implements TransportSession { + public id: string + public readonly type = 'http' + public keys?: EnvelopeKeys + public inboundMessage?: AgentMessage + public connection?: ConnectionRecord + + public constructor(id: string) { + this.id = id + } + + public send(): Promise { + throw new Error('Method not implemented.') + } +} diff --git a/src/decorators/transport/TransportDecoratorExtension.ts b/src/decorators/transport/TransportDecoratorExtension.ts index fff46a7ccd..0a30b008b7 100644 --- a/src/decorators/transport/TransportDecoratorExtension.ts +++ b/src/decorators/transport/TransportDecoratorExtension.ts @@ -31,6 +31,11 @@ export function TransportDecorated(Base: T) { // transport is thread but threadId is either missing or doesn't match. Return false return false } + + public hasAnyReturnRoute() { + const returnRoute = this.transport?.returnRoute + return returnRoute && (returnRoute === ReturnRouteTypes.all || returnRoute === ReturnRouteTypes.thread) + } } return TransportDecoratorExtension diff --git a/src/transport/WsOutboundTransporter.ts b/src/transport/WsOutboundTransporter.ts index 2e21713c5d..58d4bb78c6 100644 --- a/src/transport/WsOutboundTransporter.ts +++ b/src/transport/WsOutboundTransporter.ts @@ -1,5 +1,4 @@ import type { Agent } from '../agent/Agent' -import type { TransportSession } from '../agent/TransportService' import type { Logger } from '../logger' import type { ConnectionRecord } from '../modules/connections' import type { OutboundPackage } from '../types' @@ -8,15 +7,6 @@ import type { OutboundTransporter } from './OutboundTransporter' import { InjectionSymbols } from '../constants' import { WebSocket } from '../utils/ws' -export class WebSocketTransportSession implements TransportSession { - public readonly type = 'websocket' - public socket?: WebSocket - - public constructor(socket?: WebSocket) { - this.socket = socket - } -} - export class WsOutboundTransporter implements OutboundTransporter { private transportTable: Map = new Map() private agent: Agent @@ -41,18 +31,10 @@ export class WsOutboundTransporter implements OutboundTransporter { } public async sendMessage(outboundPackage: OutboundPackage) { - const { connection, payload, endpoint, session } = outboundPackage - this.logger.debug( - `Sending outbound message to connection ${connection.id} over ${session?.type} transport.`, - payload - ) - - if (session instanceof WebSocketTransportSession && session.socket?.readyState === WebSocket.OPEN) { - session.socket.send(JSON.stringify(payload)) - } else { - const socket = await this.resolveSocket(connection, endpoint) - socket.send(JSON.stringify(payload)) - } + const { connection, payload, endpoint } = outboundPackage + this.logger.debug(`Sending outbound message to connection ${connection.id} over websocket transport.`, payload) + const socket = await this.resolveSocket(connection, endpoint) + socket.send(JSON.stringify(payload)) } private async resolveSocket(connection: ConnectionRecord, endpoint?: string) { diff --git a/tests/mediator-ws.ts b/tests/mediator-ws.ts index c3abdb2c1f..e06583f5e8 100644 --- a/tests/mediator-ws.ts +++ b/tests/mediator-ws.ts @@ -1,19 +1,40 @@ import type { InboundTransporter } from '../src' +import type { TransportSession } from '../src/agent/TransportService' +import type { OutboundPackage } from '../src/types' import cors from 'cors' import express from 'express' -import { v4 as uuid } from 'uuid' import WebSocket from 'ws' -import { Agent, WebSocketTransportSession, WsOutboundTransporter } from '../src' +import { Agent, WsOutboundTransporter, AriesFrameworkError } from '../src' import testLogger from '../src/__tests__/logger' import { InMemoryMessageRepository } from '../src/storage/InMemoryMessageRepository' import { DidCommMimeType } from '../src/types' +import { uuid } from '../src/utils/uuid' import config from './config' const logger = testLogger +class WebSocketTransportSession implements TransportSession { + public id: string + public readonly type = 'websocket' + public socket: WebSocket + + public constructor(id: string, socket: WebSocket) { + this.id = id + this.socket = socket + } + + public async send(outboundMessage: OutboundPackage): Promise { + // logger.debug(`Sending outbound message via ${this.type} transport session`) + if (this.socket.readyState !== WebSocket.OPEN) { + throw new AriesFrameworkError(`${this.type} transport session has been closed.`) + } + this.socket.send(JSON.stringify(outboundMessage.payload)) + } +} + class WsInboundTransporter implements InboundTransporter { private socketServer: WebSocket.Server @@ -31,24 +52,23 @@ class WsInboundTransporter implements InboundTransporter { if (!this.socketIds[socketId]) { logger.debug(`Saving new socket with id ${socketId}.`) this.socketIds[socketId] = socket - this.listenOnWebSocketMessages(agent, socket) - socket.on('close', () => logger.debug('Socket closed.')) + const session = new WebSocketTransportSession(socketId, socket) + this.listenOnWebSocketMessages(agent, socket, session) + socket.on('close', () => { + logger.debug('Socket closed.') + agent.removeSession(session) + }) } else { logger.debug(`Socket with id ${socketId} already exists.`) } }) } - private listenOnWebSocketMessages(agent: Agent, socket: WebSocket) { + private listenOnWebSocketMessages(agent: Agent, socket: WebSocket, session: TransportSession) { // eslint-disable-next-line @typescript-eslint/no-explicit-any socket.addEventListener('message', async (event: any) => { logger.debug('WebSocket message event received.', { url: event.target.url, data: event.data }) - // @ts-expect-error Property 'dispatchEvent' is missing in type WebSocket imported from 'ws' module but required in type 'WebSocket'. - const session = new WebSocketTransportSession(socket) - const outboundMessage = await agent.receiveMessage(JSON.parse(event.data), session) - if (outboundMessage) { - socket.send(JSON.stringify(outboundMessage.payload)) - } + await agent.receiveMessage(JSON.parse(event.data), session) }) } } diff --git a/tests/mediator.ts b/tests/mediator.ts index f706ea24cc..7f986f7364 100644 --- a/tests/mediator.ts +++ b/tests/mediator.ts @@ -1,7 +1,8 @@ import type { InboundTransporter, OutboundTransporter } from '../src' +import type { TransportSession } from '../src/agent/TransportService' import type { MessageRepository } from '../src/storage/MessageRepository' import type { OutboundPackage } from '../src/types' -import type { Express } from 'express' +import type { Express, Request, Response } from 'express' import cors from 'cors' import express from 'express' @@ -10,9 +11,35 @@ import { Agent, AriesFrameworkError } from '../src' import testLogger from '../src/__tests__/logger' import { InMemoryMessageRepository } from '../src/storage/InMemoryMessageRepository' import { DidCommMimeType } from '../src/types' +import { uuid } from '../src/utils/uuid' import config from './config' +const logger = testLogger + +class HttpTransportSession implements TransportSession { + public id: string + public readonly type = 'http' + public req: Request + public res: Response + + public constructor(id: string, req: Request, res: Response) { + this.id = id + this.req = req + this.res = res + } + + public async send(outboundMessage: OutboundPackage): Promise { + logger.debug(`Sending outbound message via ${this.type} transport session`) + + if (this.res.headersSent) { + throw new AriesFrameworkError(`${this.type} transport session has been closed.`) + } + + this.res.status(200).json(outboundMessage.payload).end() + } +} + class HttpInboundTransporter implements InboundTransporter { private app: Express @@ -22,18 +49,21 @@ class HttpInboundTransporter implements InboundTransporter { public async start(agent: Agent) { this.app.post('/msg', async (req, res) => { + const session = new HttpTransportSession(uuid(), req, res) try { const message = req.body const packedMessage = JSON.parse(message) + await agent.receiveMessage(packedMessage, session) - const outboundMessage = await agent.receiveMessage(packedMessage) - if (outboundMessage) { - res.status(200).json(outboundMessage.payload).end() - } else { + // If agent did not use session when processing message we need to send response here. + if (!res.headersSent) { res.status(200).end() } - } catch { + } catch (error) { + logger.error(`Error processing message in mediator: ${error.message}`, error) res.status(500).send('Error processing message') + } finally { + agent.removeSession(session) } }) } @@ -67,7 +97,7 @@ class StorageOutboundTransporter implements OutboundTransporter { throw new AriesFrameworkError('Trying to save message without theirKey!') } - testLogger.debug('Storing message', { connection, payload }) + logger.debug('Storing message', { connection, payload }) this.messageRepository.save(connection.theirKey, payload) } @@ -125,5 +155,5 @@ app.get('/api/routes', async (req, res) => { app.listen(PORT, async () => { await agent.initialize() messageReceiver.start(agent) - testLogger.info(`Application started on port ${PORT}`) + logger.info(`Application started on port ${PORT}`) })