Skip to content

Commit

Permalink
feat: support new did docoument for msg receiver
Browse files Browse the repository at this point in the history
Signed-off-by: Timo Glastra <[email protected]>
  • Loading branch information
TimoGlastra committed Jan 27, 2022
1 parent 4c7bc43 commit 557e9ee
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 97 deletions.
80 changes: 61 additions & 19 deletions packages/core/src/agent/MessageReceiver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import type { TransportSession } from './TransportService'
import { Lifecycle, scoped } from 'tsyringe'

import { AriesFrameworkError } from '../error'
import { ConnectionService } from '../modules/connections/services/ConnectionService'
import { ConnectionRepository } from '../modules/connections'
import { DidRepository } from '../modules/dids/repository/DidRepository'
import { ProblemReportError, ProblemReportMessage, ProblemReportReason } from '../modules/problem-reports'
import { JsonTransformer } from '../utils/JsonTransformer'
import { MessageValidator } from '../utils/MessageValidator'
Expand All @@ -28,25 +29,28 @@ export class MessageReceiver {
private envelopeService: EnvelopeService
private transportService: TransportService
private messageSender: MessageSender
private connectionService: ConnectionService
private dispatcher: Dispatcher
private logger: Logger
private didRepository: DidRepository
private connectionRepository: ConnectionRepository
public readonly inboundTransports: InboundTransport[] = []

public constructor(
config: AgentConfig,
envelopeService: EnvelopeService,
transportService: TransportService,
messageSender: MessageSender,
connectionService: ConnectionService,
dispatcher: Dispatcher
connectionRepository: ConnectionRepository,
dispatcher: Dispatcher,
didRepository: DidRepository
) {
this.config = config
this.envelopeService = envelopeService
this.transportService = transportService
this.messageSender = messageSender
this.connectionService = connectionService
this.connectionRepository = connectionRepository
this.dispatcher = dispatcher
this.didRepository = didRepository
this.logger = this.config.logger
}

Expand Down Expand Up @@ -77,21 +81,10 @@ export class MessageReceiver {
}

private async receiveEncryptedMessage(encryptedMessage: EncryptedMessage, session?: TransportSession) {
const { plaintextMessage, senderKey, recipientKey } = await this.decryptMessage(encryptedMessage)
const decryptedMessage = await this.decryptMessage(encryptedMessage)
const { plaintextMessage, senderKey, recipientKey } = decryptedMessage

let connection: ConnectionRecord | null = null

// Only fetch connection if recipientKey and senderKey are present (AuthCrypt)
if (senderKey && recipientKey) {
connection = await this.connectionService.findByVerkey(recipientKey)

// Throw error if the recipient key (ourKey) does not match the key of the connection record
if (connection && connection.theirKey !== null && connection.theirKey !== senderKey) {
throw new AriesFrameworkError(
`Inbound message senderKey '${senderKey}' is different from connection.theirKey '${connection.theirKey}'`
)
}
}
const connection = await this.findConnectionByMessageKeys(decryptedMessage)

this.logger.info(
`Received message with type '${plaintextMessage['@type']}' from connection ${connection?.id} (${connection?.theirLabel})`,
Expand Down Expand Up @@ -171,6 +164,55 @@ export class MessageReceiver {
return message
}

private async findConnectionByMessageKeys({
recipientKey,
senderKey,
}: DecryptedMessageContext): Promise<ConnectionRecord | null> {
// We only fetch connections that are sent in AuthCrypt mode
if (!recipientKey || !senderKey) return null

let connection: ConnectionRecord | null = null

// Try to find the did records that holds the sender and recipient keys
const ourDidRecord = await this.didRepository.findByVerkey(recipientKey)

// If both our did record and their did record is available we can find a matching did record
if (ourDidRecord) {
const theirDidRecord = await this.didRepository.findByVerkey(senderKey)

if (theirDidRecord) {
connection = await this.connectionRepository.findSingleByQuery({
did: ourDidRecord.id,
theirDid: theirDidRecord.id,
})
} else {
connection = await this.connectionRepository.findSingleByQuery({
did: ourDidRecord.id,
})

// If theirDidRecord was not found, and connection.theirDid is set, it means the sender is not authenticated
// to send messages to use
if (connection && connection.theirDid) {
throw new AriesFrameworkError(`Inbound message senderKey '${senderKey}' is different from connection did`)
}
}
}

// If no connection was found, we search in the connection record, where legacy did documents are stored
if (!connection) {
connection = await this.connectionRepository.findByVerkey(recipientKey)

// Throw error if the recipient key (ourKey) does not match the key of the connection record
if (connection && connection.theirKey !== null && connection.theirKey !== senderKey) {
throw new AriesFrameworkError(
`Inbound message senderKey '${senderKey}' is different from connection.theirKey '${connection.theirKey}'`
)
}
}

return connection
}

/**
* Transform an plaintext DIDComm message into it's corresponding message class. Will look at all message types in the registered handlers.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ describe('ConnectionService', () => {
verkey: 'my-key',
role: ConnectionRole.Inviter,
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const theirDid = 'their-did'
const theirVerkey = 'their-verkey'
Expand Down Expand Up @@ -395,7 +395,7 @@ describe('ConnectionService', () => {
senderVerkey: 'sender-verkey',
})

mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(null))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(null))
return expect(connectionService.processRequest(messageContext)).rejects.toThrowError(
'Unable to process connection request: connection for verkey test-verkey not found'
)
Expand All @@ -411,7 +411,7 @@ describe('ConnectionService', () => {
role: ConnectionRole.Inviter,
multiUseInvitation: true,
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const theirDid = 'their-did'
const theirVerkey = 'their-verkey'
Expand Down Expand Up @@ -458,7 +458,7 @@ describe('ConnectionService', () => {
it(`throws an error when connection role is ${ConnectionRole.Invitee} and not ${ConnectionRole.Inviter}`, async () => {
expect.assertions(1)

mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(
mockFunction(connectionRepository.findByVerkey).mockReturnValue(
Promise.resolve(getMockConnection({ role: ConnectionRole.Invitee }))
)

Expand All @@ -482,7 +482,7 @@ describe('ConnectionService', () => {
verkey: recipientVerkey,
})

mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connection))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connection))

const connectionRequest = new ConnectionRequestMessage({
did: 'did',
Expand Down Expand Up @@ -512,7 +512,7 @@ describe('ConnectionService', () => {
role: ConnectionRole.Inviter,
multiUseInvitation: true,
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const theirDidDoc = new DidDoc({
id: 'their-did',
Expand Down Expand Up @@ -618,7 +618,7 @@ describe('ConnectionService', () => {
serviceEndpoint: 'test',
}),
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const otherPartyConnection = new Connection({
did: theirDid,
Expand Down Expand Up @@ -667,7 +667,7 @@ describe('ConnectionService', () => {
recipientVerkey: 'recipientVerkey',
})

mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(
mockFunction(connectionRepository.findByVerkey).mockReturnValue(
Promise.resolve(
getMockConnection({
role: ConnectionRole.Inviter,
Expand All @@ -692,7 +692,7 @@ describe('ConnectionService', () => {
role: ConnectionRole.Invitee,
state: ConnectionState.Requested,
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const otherPartyConnection = new Connection({
did: theirDid,
Expand Down Expand Up @@ -749,7 +749,7 @@ describe('ConnectionService', () => {
recipientVerkey: 'test-verkey',
senderVerkey: 'sender-verkey',
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(null))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(null))

return expect(connectionService.processResponse(messageContext)).rejects.toThrowError(
'Unable to process connection response: connection for verkey test-verkey not found'
Expand All @@ -774,7 +774,7 @@ describe('ConnectionService', () => {
theirDid: undefined,
theirDidDoc: undefined,
})
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(connectionRecord))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(connectionRecord))

const otherPartyConnection = new Connection({
did: theirDid,
Expand Down Expand Up @@ -1071,11 +1071,11 @@ describe('ConnectionService', () => {
expect(result).toBe(expected)
})

it('getById should return value from connectionRepository.getSingleByQuery', async () => {
it('getByThreadId should return value from connectionRepository.getSingleByQuery', async () => {
const expected = getMockConnection()
mockFunction(connectionRepository.getSingleByQuery).mockReturnValue(Promise.resolve(expected))
mockFunction(connectionRepository.getByThreadId).mockReturnValue(Promise.resolve(expected))
const result = await connectionService.getByThreadId('threadId')
expect(connectionRepository.getSingleByQuery).toBeCalledWith({ threadId: 'threadId' })
expect(connectionRepository.getByThreadId).toBeCalledWith('threadId')

expect(result).toBe(expected)
})
Expand All @@ -1091,18 +1091,18 @@ describe('ConnectionService', () => {

it('findByVerkey should return value from connectionRepository.findSingleByQuery', async () => {
const expected = getMockConnection()
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(expected))
mockFunction(connectionRepository.findByVerkey).mockReturnValue(Promise.resolve(expected))
const result = await connectionService.findByVerkey('verkey')
expect(connectionRepository.findSingleByQuery).toBeCalledWith({ verkey: 'verkey' })
expect(connectionRepository.findByVerkey).toBeCalledWith('verkey')

expect(result).toBe(expected)
})

it('findByTheirKey should return value from connectionRepository.findSingleByQuery', async () => {
const expected = getMockConnection()
mockFunction(connectionRepository.findSingleByQuery).mockReturnValue(Promise.resolve(expected))
mockFunction(connectionRepository.findByTheirKey).mockReturnValue(Promise.resolve(expected))
const result = await connectionService.findByTheirKey('theirKey')
expect(connectionRepository.findSingleByQuery).toBeCalledWith({ theirKey: 'theirKey' })
expect(connectionRepository.findByTheirKey).toBeCalledWith('theirKey')

expect(result).toBe(expected)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export type DefaultConnectionTags = {
verkey?: string
theirKey?: string
mediatorId?: string
did: string
theirDid?: string
}

export class ConnectionRecord
Expand Down Expand Up @@ -112,6 +114,8 @@ export class ConnectionRecord
verkey: this.verkey,
theirKey: this.theirKey || undefined,
mediatorId: this.mediatorId,
did: this.did,
theirDid: this.theirDid,
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,33 @@ export class ConnectionRepository extends Repository<ConnectionRecord> {
public constructor(@inject(InjectionSymbols.StorageService) storageService: StorageService<ConnectionRecord>) {
super(ConnectionRecord, storageService)
}

public async findByDids({ ourDid, theirDid }: { ourDid: string; theirDid: string }) {
return this.findSingleByQuery({
did: ourDid,
theirDid,
})
}

public findByVerkey(verkey: string): Promise<ConnectionRecord | null> {
return this.findSingleByQuery({
verkey,
})
}

public findByTheirKey(verkey: string): Promise<ConnectionRecord | null> {
return this.findSingleByQuery({
theirKey: verkey,
})
}

public findByInvitationKey(key: string): Promise<ConnectionRecord | null> {
return this.findSingleByQuery({
invitationKey: key,
})
}

public getByThreadId(threadId: string): Promise<ConnectionRecord> {
return this.getSingleByQuery({ threadId })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,7 @@ export class ConnectionService {
* @throws {RecordDuplicateError} if multiple connections are found for the given verkey
*/
public findByVerkey(verkey: string): Promise<ConnectionRecord | null> {
return this.connectionRepository.findSingleByQuery({
verkey,
})
return this.connectionRepository.findByVerkey(verkey)
}

/**
Expand All @@ -616,9 +614,7 @@ export class ConnectionService {
* @throws {RecordDuplicateError} if multiple connections are found for the given verkey
*/
public findByTheirKey(verkey: string): Promise<ConnectionRecord | null> {
return this.connectionRepository.findSingleByQuery({
theirKey: verkey,
})
return this.connectionRepository.findByTheirKey(verkey)
}

/**
Expand All @@ -629,9 +625,7 @@ export class ConnectionService {
* @throws {RecordDuplicateError} if multiple connections are found for the given verkey
*/
public findByInvitationKey(key: string): Promise<ConnectionRecord | null> {
return this.connectionRepository.findSingleByQuery({
invitationKey: key,
})
return this.connectionRepository.findByInvitationKey(key)
}

/**
Expand All @@ -643,7 +637,7 @@ export class ConnectionService {
* @returns The connection record
*/
public getByThreadId(threadId: string): Promise<ConnectionRecord> {
return this.connectionRepository.getSingleByQuery({ threadId })
return this.connectionRepository.getByThreadId(threadId)
}

private async createConnection(options: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { IndyLedgerService } from '../../ledger'
import type { DidDocumentRepository } from '../repository'
import type { DidRepository } from '../repository'

import { getAgentConfig, mockProperty } from '../../../../tests/helpers'
import { JsonTransformer } from '../../../utils/JsonTransformer'
Expand All @@ -16,7 +16,7 @@ const agentConfig = getAgentConfig('DidResolverService')

describe('DidResolverService', () => {
const indyLedgerServiceMock = jest.fn() as unknown as IndyLedgerService
const didDocumentRepositoryMock = jest.fn() as unknown as DidDocumentRepository
const didDocumentRepositoryMock = jest.fn() as unknown as DidRepository
const didResolverService = new DidResolverService(agentConfig, indyLedgerServiceMock, didDocumentRepositoryMock)

it('should correctly find and call the correct resolver for a specified did', async () => {
Expand Down
Loading

0 comments on commit 557e9ee

Please sign in to comment.