Skip to content

Commit

Permalink
fix: query the record by credential and proof role
Browse files Browse the repository at this point in the history
Signed-off-by: Berend Sliedrecht <[email protected]>
  • Loading branch information
berendsliedrecht committed Mar 1, 2024
1 parent add7e09 commit 6f0a374
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential proposal with message id ${proposalMessage.id}`)

let credentialRecord = await this.findByThreadAndConnectionId(
let credentialRecord = await this.findByThreadIdConnectionIdAndRole(
messageContext.agentContext,
proposalMessage.threadId,
CredentialRole.Issuer,
connection?.id
)

Expand Down Expand Up @@ -503,7 +504,12 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential offer with id ${offerMessage.id}`)

let credentialRecord = await this.findByThreadAndConnectionId(agentContext, offerMessage.threadId, connection?.id)
let credentialRecord = await this.findByThreadIdConnectionIdAndRole(
agentContext,
offerMessage.threadId,
CredentialRole.Holder,
connection?.id
)

const offerAttachment = offerMessage.getOfferAttachmentById(INDY_CREDENTIAL_OFFER_ATTACHMENT_ID)
if (!offerAttachment) {
Expand Down Expand Up @@ -739,7 +745,11 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential request with id ${requestMessage.id}`)

const credentialRecord = await this.getByThreadAndConnectionId(messageContext.agentContext, requestMessage.threadId)
const credentialRecord = await this.getByThreadIdConnectionIdAndRole(
messageContext.agentContext,
requestMessage.threadId,
CredentialRole.Issuer
)
agentContext.config.logger.trace('Credential record found when processing credential request', credentialRecord)

const proposalMessage = await didCommMessageRepository.findAgentMessage(messageContext.agentContext, {
Expand Down Expand Up @@ -885,9 +895,10 @@ export class V1CredentialProtocol
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const credentialRecord = await this.getByThreadAndConnectionId(
const credentialRecord = await this.getByThreadIdConnectionIdAndRole(
messageContext.agentContext,
issueMessage.threadId,
CredentialRole.Holder,
connection?.id
)

Expand Down Expand Up @@ -990,9 +1001,10 @@ export class V1CredentialProtocol
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const credentialRecord = await this.getByThreadAndConnectionId(
const credentialRecord = await this.getByThreadIdConnectionIdAndRole(
messageContext.agentContext,
ackMessage.threadId,
CredentialRole.Issuer,
connection?.id
)

Expand Down Expand Up @@ -1029,7 +1041,7 @@ export class V1CredentialProtocol
*
*/
public async createProblemReport(
agentContext: AgentContext,
_agentContext: AgentContext,
{ credentialRecord, description }: CredentialProtocolOptions.CreateCredentialProblemReportOptions
): Promise<CredentialProtocolOptions.CredentialProtocolMsgReturnType<ProblemReportMessage>> {
const message = new V1CredentialProblemReportMessage({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ const didCommMessageRecord = new DidCommMessageRecord({
})

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const getAgentMessageMock = async (agentContext: AgentContext, options: { messageClass: any }) => {
const getAgentMessageMock = async (_agentContext: AgentContext, options: { messageClass: any }) => {
if (options.messageClass === V1ProposeCredentialMessage) {
return credentialProposalMessage
}
Expand Down Expand Up @@ -312,7 +312,7 @@ describe('V1CredentialProtocol', () => {
invalidCredentialStates.map(async (state) => {
await expect(
credentialProtocol.acceptOffer(agentContext, { credentialRecord: mockCredentialRecord({ state }) })
).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
).rejects.toThrow(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
})
)
})
Expand Down Expand Up @@ -347,6 +347,7 @@ describe('V1CredentialProtocol', () => {
// then
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
role: CredentialRole.Issuer,
})
expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1)
expect(returnedCredentialRecord.state).toEqual(CredentialState.RequestReceived)
Expand All @@ -363,6 +364,7 @@ describe('V1CredentialProtocol', () => {
// then
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
role: CredentialRole.Issuer,
})
expect(returnedCredentialRecord.state).toEqual(CredentialState.RequestReceived)
})
Expand All @@ -375,7 +377,7 @@ describe('V1CredentialProtocol', () => {
mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(
Promise.resolve(mockCredentialRecord({ state }))
)
await expect(credentialProtocol.processRequest(messageContext)).rejects.toThrowError(
await expect(credentialProtocol.processRequest(messageContext)).rejects.toThrow(
`Credential record is in invalid state ${state}. Valid states are: ${validState}.`
)
})
Expand Down Expand Up @@ -516,6 +518,7 @@ describe('V1CredentialProtocol', () => {
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
connectionId: connection.id,
role: CredentialRole.Holder,
})

expect(didCommMessageRepository.saveAgentMessage).toHaveBeenCalledWith(agentContext, {
Expand Down Expand Up @@ -614,7 +617,7 @@ describe('V1CredentialProtocol', () => {
connectionId: 'b1e2f039-aa39-40be-8643-6ce2797b5190',
}),
})
).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
).rejects.toThrow(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
})
)
})
Expand Down Expand Up @@ -652,6 +655,7 @@ describe('V1CredentialProtocol', () => {
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
connectionId: connection.id,
role: CredentialRole.Issuer,
})
expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1)
const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls
Expand Down Expand Up @@ -746,18 +750,24 @@ describe('V1CredentialProtocol', () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getById(agentContext, expected.id)
expect(credentialRepository.getById).toBeCalledWith(agentContext, expected.id)
expect(credentialRepository.getById).toHaveBeenCalledWith(agentContext, expected.id)

expect(result).toBe(expected)
})

it('getById should return value from credentialRepository.getSingleByQuery', async () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getByThreadAndConnectionId(agentContext, 'threadId', 'connectionId')
expect(credentialRepository.getSingleByQuery).toBeCalledWith(agentContext, {
const result = await credentialProtocol.getByThreadIdConnectionIdAndRole(
agentContext,
'threadId',
CredentialRole.Issuer,
'connectionId'
)
expect(credentialRepository.getSingleByQuery).toHaveBeenCalledWith(agentContext, {
threadId: 'threadId',
connectionId: 'connectionId',
role: CredentialRole.Issuer,
})

expect(result).toBe(expected)
Expand All @@ -767,7 +777,7 @@ describe('V1CredentialProtocol', () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.findById).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.findById(agentContext, expected.id)
expect(credentialRepository.findById).toBeCalledWith(agentContext, expected.id)
expect(credentialRepository.findById).toHaveBeenCalledWith(agentContext, expected.id)

expect(result).toBe(expected)
})
Expand All @@ -777,7 +787,7 @@ describe('V1CredentialProtocol', () => {

mockFunction(credentialRepository.getAll).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getAll(agentContext)
expect(credentialRepository.getAll).toBeCalledWith(agentContext)
expect(credentialRepository.getAll).toHaveBeenCalledWith(agentContext)

expect(result).toEqual(expect.arrayContaining(expected))
})
Expand All @@ -787,7 +797,7 @@ describe('V1CredentialProtocol', () => {

mockFunction(credentialRepository.findByQuery).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.findAllByQuery(agentContext, { state: CredentialState.OfferSent })
expect(credentialRepository.findByQuery).toBeCalledWith(agentContext, { state: CredentialState.OfferSent })
expect(credentialRepository.findByQuery).toHaveBeenCalledWith(agentContext, { state: CredentialState.OfferSent })

expect(result).toEqual(expect.arrayContaining(expected))
})
Expand Down
21 changes: 16 additions & 5 deletions packages/anoncreds/src/protocols/proofs/v1/V1ProofProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<

agentContext.config.logger.debug(`Processing presentation proposal with message id ${proposalMessage.id}`)

let proofRecord = await this.findByThreadAndConnectionId(agentContext, proposalMessage.threadId, connection?.id)
let proofRecord = await this.findByThreadIdConnectionIdAndRole(
agentContext,
proposalMessage.threadId,
ProofRole.Verifier,
connection?.id
)

// Proof record already exists, this is a response to an earlier message sent by us
if (proofRecord) {
Expand Down Expand Up @@ -422,7 +427,12 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<

agentContext.config.logger.debug(`Processing presentation request with id ${proofRequestMessage.id}`)

let proofRecord = await this.findByThreadAndConnectionId(agentContext, proofRequestMessage.threadId, connection?.id)
let proofRecord = await this.findByThreadIdConnectionIdAndRole(
agentContext,
proofRequestMessage.threadId,
ProofRole.Prover,
connection?.id
)

const requestAttachment = proofRequestMessage.getRequestAttachmentById(INDY_PROOF_REQUEST_ATTACHMENT_ID)
if (!requestAttachment) {
Expand Down Expand Up @@ -760,7 +770,7 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const proofRecord = await this.getByThreadAndConnectionId(agentContext, presentationMessage.threadId)
const proofRecord = await this.getByThreadIdConnectionIdAndRole(agentContext, presentationMessage.threadId)

const proposalMessage = await didCommMessageRepository.findAgentMessage(agentContext, {
associatedRecordId: proofRecord.id,
Expand Down Expand Up @@ -887,9 +897,10 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const proofRecord = await this.getByThreadAndConnectionId(
const proofRecord = await this.getByThreadIdConnectionIdAndRole(
agentContext,
presentationAckMessage.threadId,
ProofRole.Prover,
connection?.id
)

Expand Down Expand Up @@ -920,7 +931,7 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
}

public async createProblemReport(
agentContext: AgentContext,
_agentContext: AgentContext,
{ proofRecord, description }: ProofProtocolOptions.CreateProofProblemReportOptions
): Promise<ProofProtocolOptions.ProofProtocolMsgReturnType<ProblemReportMessage>> {
const message = new V1PresentationProblemReportMessage({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import type { Query } from '../../../storage/StorageService'
import type { ProblemReportMessage } from '../../problem-reports'
import type { CredentialStateChangedEvent } from '../CredentialEvents'
import type { CredentialFormatService, ExtractCredentialFormats } from '../formats'
import type { CredentialRole } from '../models'
import type { CredentialExchangeRecord } from '../repository'

import { EventEmitter } from '../../../agent/EventEmitter'
Expand Down Expand Up @@ -141,9 +142,10 @@ export abstract class BaseCredentialProtocol<CFs extends CredentialFormatService

agentContext.config.logger.debug(`Processing problem report with message id ${credentialProblemReportMessage.id}`)

const credentialRecord = await this.getByThreadAndConnectionId(
const credentialRecord = await this.getByThreadIdConnectionIdAndRole(
agentContext,
credentialProblemReportMessage.threadId,
undefined,
connection.id
)

Expand Down Expand Up @@ -275,41 +277,49 @@ export abstract class BaseCredentialProtocol<CFs extends CredentialFormatService
* Retrieve a credential record by connection id and thread id
*
* @param connectionId The connection id
* @param role The role of the record, i.e. holder or issuer
* @param threadId The thread id
*
* @throws {RecordNotFoundError} If no record is found
* @throws {RecordDuplicateError} If multiple records are found
* @returns The credential record
*/
public getByThreadAndConnectionId(
public getByThreadIdConnectionIdAndRole(
agentContext: AgentContext,
threadId: string,
role?: CredentialRole,
connectionId?: string
): Promise<CredentialExchangeRecord> {
const credentialRepository = agentContext.dependencyManager.resolve(CredentialRepository)

return credentialRepository.getSingleByQuery(agentContext, {
connectionId,
threadId,
role,
})
}

/**
* Find a credential record by connection id and thread id, returns null if not found
*
* @param connectionId The connection id
* @param threadId The thread id
* @param role The role of the record, i.e. holder or issuer
* @param connectionId The connection id
*
* @returns The credential record
*/
public findByThreadAndConnectionId(
public findByThreadIdConnectionIdAndRole(
agentContext: AgentContext,
threadId: string,
role?: CredentialRole,
connectionId?: string
): Promise<CredentialExchangeRecord | null> {
const credentialRepository = agentContext.dependencyManager.resolve(CredentialRepository)

return credentialRepository.findSingleByQuery(agentContext, {
connectionId,
threadId,
role,
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ export interface CredentialProtocol<CFs extends CredentialFormatService[] = Cred
credentialRecord: CredentialExchangeRecord,
options?: DeleteCredentialOptions
): Promise<void>
getByThreadAndConnectionId(
getByThreadIdConnectionIdAndRole(
agentContext: AgentContext,
threadId: string,
connectionId?: string
): Promise<CredentialExchangeRecord>
findByThreadAndConnectionId(
findByThreadIdConnectionIdAndRole(
agentContext: AgentContext,
threadId: string,
connectionId?: string
Expand Down
Loading

0 comments on commit 6f0a374

Please sign in to comment.