Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: query the record by credential and proof role #1784

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential proposal with message id ${proposalMessage.id}`)

let credentialRecord = await this.findByThreadAndConnectionId(
messageContext.agentContext,
proposalMessage.threadId,
connection?.id
)
let credentialRecord = await this.findByProperties(messageContext.agentContext, {
threadId: proposalMessage.threadId,
role: CredentialRole.Issuer,
connectionId: connection?.id,
})

// Credential record already exists, this is a response to an earlier message sent by us
if (credentialRecord) {
Expand Down Expand Up @@ -503,7 +503,11 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential offer with id ${offerMessage.id}`)

let credentialRecord = await this.findByThreadAndConnectionId(agentContext, offerMessage.threadId, connection?.id)
let credentialRecord = await this.findByProperties(agentContext, {
threadId: offerMessage.threadId,
role: CredentialRole.Holder,
connectionId: connection?.id,
})

const offerAttachment = offerMessage.getOfferAttachmentById(INDY_CREDENTIAL_OFFER_ATTACHMENT_ID)
if (!offerAttachment) {
Expand Down Expand Up @@ -739,7 +743,11 @@ export class V1CredentialProtocol

agentContext.config.logger.debug(`Processing credential request with id ${requestMessage.id}`)

const credentialRecord = await this.getByThreadAndConnectionId(messageContext.agentContext, requestMessage.threadId)
const credentialRecord = await this.getByProperties(messageContext.agentContext, {
threadId: requestMessage.threadId,
role: CredentialRole.Issuer,
})

agentContext.config.logger.trace('Credential record found when processing credential request', credentialRecord)

const proposalMessage = await didCommMessageRepository.findAgentMessage(messageContext.agentContext, {
Expand Down Expand Up @@ -885,11 +893,11 @@ export class V1CredentialProtocol
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const credentialRecord = await this.getByThreadAndConnectionId(
messageContext.agentContext,
issueMessage.threadId,
connection?.id
)
const credentialRecord = await this.getByProperties(messageContext.agentContext, {
threadId: issueMessage.threadId,
role: CredentialRole.Holder,
connectionId: connection?.id,
})

const requestCredentialMessage = await didCommMessageRepository.findAgentMessage(messageContext.agentContext, {
associatedRecordId: credentialRecord.id,
Expand Down Expand Up @@ -990,11 +998,12 @@ export class V1CredentialProtocol
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const credentialRecord = await this.getByThreadAndConnectionId(
messageContext.agentContext,
ackMessage.threadId,
connection?.id
)
const credentialRecord = await this.getByProperties(messageContext.agentContext, {
threadId: ackMessage.threadId,

role: CredentialRole.Issuer,
connectionId: connection?.id,
})

const requestCredentialMessage = await didCommMessageRepository.getAgentMessage(messageContext.agentContext, {
associatedRecordId: credentialRecord.id,
Expand Down Expand Up @@ -1029,7 +1038,7 @@ export class V1CredentialProtocol
*
*/
public async createProblemReport(
agentContext: AgentContext,
_agentContext: AgentContext,
{ credentialRecord, description }: CredentialProtocolOptions.CreateCredentialProblemReportOptions
): Promise<CredentialProtocolOptions.CredentialProtocolMsgReturnType<ProblemReportMessage>> {
const message = new V1CredentialProblemReportMessage({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ const didCommMessageRecord = new DidCommMessageRecord({
})

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const getAgentMessageMock = async (agentContext: AgentContext, options: { messageClass: any }) => {
const getAgentMessageMock = async (_agentContext: AgentContext, options: { messageClass: any }) => {
if (options.messageClass === V1ProposeCredentialMessage) {
return credentialProposalMessage
}
Expand Down Expand Up @@ -312,7 +312,7 @@ describe('V1CredentialProtocol', () => {
invalidCredentialStates.map(async (state) => {
await expect(
credentialProtocol.acceptOffer(agentContext, { credentialRecord: mockCredentialRecord({ state }) })
).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
).rejects.toThrow(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
})
)
})
Expand Down Expand Up @@ -347,6 +347,7 @@ describe('V1CredentialProtocol', () => {
// then
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
role: CredentialRole.Issuer,
})
expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1)
expect(returnedCredentialRecord.state).toEqual(CredentialState.RequestReceived)
Expand All @@ -363,6 +364,7 @@ describe('V1CredentialProtocol', () => {
// then
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
role: CredentialRole.Issuer,
})
expect(returnedCredentialRecord.state).toEqual(CredentialState.RequestReceived)
})
Expand All @@ -375,7 +377,7 @@ describe('V1CredentialProtocol', () => {
mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(
Promise.resolve(mockCredentialRecord({ state }))
)
await expect(credentialProtocol.processRequest(messageContext)).rejects.toThrowError(
await expect(credentialProtocol.processRequest(messageContext)).rejects.toThrow(
`Credential record is in invalid state ${state}. Valid states are: ${validState}.`
)
})
Expand Down Expand Up @@ -516,6 +518,7 @@ describe('V1CredentialProtocol', () => {
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
connectionId: connection.id,
role: CredentialRole.Holder,
})

expect(didCommMessageRepository.saveAgentMessage).toHaveBeenCalledWith(agentContext, {
Expand Down Expand Up @@ -614,7 +617,7 @@ describe('V1CredentialProtocol', () => {
connectionId: 'b1e2f039-aa39-40be-8643-6ce2797b5190',
}),
})
).rejects.toThrowError(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
).rejects.toThrow(`Credential record is in invalid state ${state}. Valid states are: ${validState}.`)
})
)
})
Expand Down Expand Up @@ -652,6 +655,7 @@ describe('V1CredentialProtocol', () => {
expect(credentialRepository.getSingleByQuery).toHaveBeenNthCalledWith(1, agentContext, {
threadId: 'somethreadid',
connectionId: connection.id,
role: CredentialRole.Issuer,
})
expect(repositoryUpdateSpy).toHaveBeenCalledTimes(1)
const [[, updatedCredentialRecord]] = repositoryUpdateSpy.mock.calls
Expand Down Expand Up @@ -746,18 +750,25 @@ describe('V1CredentialProtocol', () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.getById).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getById(agentContext, expected.id)
expect(credentialRepository.getById).toBeCalledWith(agentContext, expected.id)
expect(credentialRepository.getById).toHaveBeenCalledWith(agentContext, expected.id)

expect(result).toBe(expected)
})

it('getById should return value from credentialRepository.getSingleByQuery', async () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.getSingleByQuery).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getByThreadAndConnectionId(agentContext, 'threadId', 'connectionId')
expect(credentialRepository.getSingleByQuery).toBeCalledWith(agentContext, {

const result = await credentialProtocol.getByProperties(agentContext, {
threadId: 'threadId',
role: CredentialRole.Issuer,
connectionId: 'connectionId',
})

expect(credentialRepository.getSingleByQuery).toHaveBeenCalledWith(agentContext, {
threadId: 'threadId',
connectionId: 'connectionId',
role: CredentialRole.Issuer,
})

expect(result).toBe(expected)
Expand All @@ -767,7 +778,7 @@ describe('V1CredentialProtocol', () => {
const expected = mockCredentialRecord()
mockFunction(credentialRepository.findById).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.findById(agentContext, expected.id)
expect(credentialRepository.findById).toBeCalledWith(agentContext, expected.id)
expect(credentialRepository.findById).toHaveBeenCalledWith(agentContext, expected.id)

expect(result).toBe(expected)
})
Expand All @@ -777,7 +788,7 @@ describe('V1CredentialProtocol', () => {

mockFunction(credentialRepository.getAll).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.getAll(agentContext)
expect(credentialRepository.getAll).toBeCalledWith(agentContext)
expect(credentialRepository.getAll).toHaveBeenCalledWith(agentContext)

expect(result).toEqual(expect.arrayContaining(expected))
})
Expand All @@ -787,7 +798,7 @@ describe('V1CredentialProtocol', () => {

mockFunction(credentialRepository.findByQuery).mockReturnValue(Promise.resolve(expected))
const result = await credentialProtocol.findAllByQuery(agentContext, { state: CredentialState.OfferSent })
expect(credentialRepository.findByQuery).toBeCalledWith(agentContext, { state: CredentialState.OfferSent })
expect(credentialRepository.findByQuery).toHaveBeenCalledWith(agentContext, { state: CredentialState.OfferSent })

expect(result).toEqual(expect.arrayContaining(expected))
})
Expand Down
29 changes: 20 additions & 9 deletions packages/anoncreds/src/protocols/proofs/v1/V1ProofProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<

agentContext.config.logger.debug(`Processing presentation proposal with message id ${proposalMessage.id}`)

let proofRecord = await this.findByThreadAndConnectionId(agentContext, proposalMessage.threadId, connection?.id)
let proofRecord = await this.findByProperties(agentContext, {
threadId: proposalMessage.threadId,
role: ProofRole.Verifier,
connectionId: connection?.id,
})

// Proof record already exists, this is a response to an earlier message sent by us
if (proofRecord) {
Expand Down Expand Up @@ -422,7 +426,11 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<

agentContext.config.logger.debug(`Processing presentation request with id ${proofRequestMessage.id}`)

let proofRecord = await this.findByThreadAndConnectionId(agentContext, proofRequestMessage.threadId, connection?.id)
let proofRecord = await this.findByProperties(agentContext, {
threadId: proofRequestMessage.threadId,
role: ProofRole.Prover,
connectionId: connection?.id,
})

const requestAttachment = proofRequestMessage.getRequestAttachmentById(INDY_PROOF_REQUEST_ATTACHMENT_ID)
if (!requestAttachment) {
Expand Down Expand Up @@ -760,7 +768,10 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const proofRecord = await this.getByThreadAndConnectionId(agentContext, presentationMessage.threadId)
const proofRecord = await this.getByProperties(agentContext, {
threadId: presentationMessage.threadId,
role: ProofRole.Verifier,
})

const proposalMessage = await didCommMessageRepository.findAgentMessage(agentContext, {
associatedRecordId: proofRecord.id,
Expand Down Expand Up @@ -887,11 +898,11 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
// only depends on the public api, rather than the internal API (this helps with breaking changes)
const connectionService = agentContext.dependencyManager.resolve(ConnectionService)

const proofRecord = await this.getByThreadAndConnectionId(
agentContext,
presentationAckMessage.threadId,
connection?.id
)
const proofRecord = await this.getByProperties(agentContext, {
threadId: presentationAckMessage.threadId,
role: ProofRole.Prover,
connectionId: connection?.id,
})

const lastReceivedMessage = await didCommMessageRepository.getAgentMessage(agentContext, {
associatedRecordId: proofRecord.id,
Expand Down Expand Up @@ -920,7 +931,7 @@ export class V1ProofProtocol extends BaseProofProtocol implements ProofProtocol<
}

public async createProblemReport(
agentContext: AgentContext,
_agentContext: AgentContext,
{ proofRecord, description }: ProofProtocolOptions.CreateProofProblemReportOptions
): Promise<ProofProtocolOptions.ProofProtocolMsgReturnType<ProblemReportMessage>> {
const message = new V1PresentationProblemReportMessage({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import type { Query } from '../../../storage/StorageService'
import type { ProblemReportMessage } from '../../problem-reports'
import type { CredentialStateChangedEvent } from '../CredentialEvents'
import type { CredentialFormatService, ExtractCredentialFormats } from '../formats'
import type { CredentialRole } from '../models'
import type { CredentialExchangeRecord } from '../repository'

import { EventEmitter } from '../../../agent/EventEmitter'
Expand Down Expand Up @@ -141,11 +142,10 @@ export abstract class BaseCredentialProtocol<CFs extends CredentialFormatService

agentContext.config.logger.debug(`Processing problem report with message id ${credentialProblemReportMessage.id}`)

const credentialRecord = await this.getByThreadAndConnectionId(
agentContext,
credentialProblemReportMessage.threadId,
connection.id
)
const credentialRecord = await this.getByProperties(agentContext, {
threadId: credentialProblemReportMessage.threadId,
connectionId: connection.id,
})

// Update record
credentialRecord.errorMessage = `${credentialProblemReportMessage.description.code}: ${credentialProblemReportMessage.description.en}`
Expand Down Expand Up @@ -274,42 +274,54 @@ export abstract class BaseCredentialProtocol<CFs extends CredentialFormatService
/**
* Retrieve a credential record by connection id and thread id
*
* @param connectionId The connection id
* @param threadId The thread id
* @param properties Properties to query by
*
* @throws {RecordNotFoundError} If no record is found
* @throws {RecordDuplicateError} If multiple records are found
* @returns The credential record
*/
public getByThreadAndConnectionId(
public getByProperties(
agentContext: AgentContext,
threadId: string,
connectionId?: string
properties: {
threadId: string
role?: CredentialRole
connectionId?: string
}
): Promise<CredentialExchangeRecord> {
const { role, connectionId, threadId } = properties
const credentialRepository = agentContext.dependencyManager.resolve(CredentialRepository)

return credentialRepository.getSingleByQuery(agentContext, {
connectionId,
threadId,
role,
})
}

/**
* Find a credential record by connection id and thread id, returns null if not found
*
* @param connectionId The connection id
* @param threadId The thread id
* @param role The role of the record, i.e. holder or issuer
* @param connectionId The connection id
*
* @returns The credential record
*/
public findByThreadAndConnectionId(
public findByProperties(
agentContext: AgentContext,
threadId: string,
connectionId?: string
properties: {
threadId: string
role?: CredentialRole
connectionId?: string
}
): Promise<CredentialExchangeRecord | null> {
const { role, connectionId, threadId } = properties
const credentialRepository = agentContext.dependencyManager.resolve(CredentialRepository)

return credentialRepository.findSingleByQuery(agentContext, {
connectionId,
threadId,
role,
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type { DependencyManager } from '../../../plugins'
import type { Query } from '../../../storage/StorageService'
import type { ProblemReportMessage } from '../../problem-reports'
import type { CredentialFormatService, ExtractCredentialFormats } from '../formats'
import type { CredentialRole } from '../models'
import type { CredentialState } from '../models/CredentialState'
import type { CredentialExchangeRecord } from '../repository'

Expand Down Expand Up @@ -112,15 +113,21 @@ export interface CredentialProtocol<CFs extends CredentialFormatService[] = Cred
credentialRecord: CredentialExchangeRecord,
options?: DeleteCredentialOptions
): Promise<void>
getByThreadAndConnectionId(
getByProperties(
agentContext: AgentContext,
threadId: string,
connectionId?: string
properties: {
threadId: string
connectionId?: string
role?: CredentialRole
}
): Promise<CredentialExchangeRecord>
findByThreadAndConnectionId(
findByProperties(
agentContext: AgentContext,
threadId: string,
connectionId?: string
properties: {
threadId: string
connectionId?: string
role?: CredentialRole
}
): Promise<CredentialExchangeRecord | null>
update(agentContext: AgentContext, credentialRecord: CredentialExchangeRecord): Promise<void>

Expand Down
Loading
Loading