Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prism-agent): check issuing DID validity when creating a VC offer #740

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ lazy val D_Castor = new {
Seq(
D.zio,
D.zioTest,
D.zioMock,
D.zioTestSbt,
D.zioTestMagnolia,
D.circeCore,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.iohk.atala.castor.core.service

import io.iohk.atala.castor.core.model.did.*
import io.iohk.atala.castor.core.model.error
import io.iohk.atala.prism.crypto.EC
import io.iohk.atala.prism.crypto.keys.ECKeyPair
import io.iohk.atala.shared.models.Base64UrlString
import zio.mock.{Expectation, Mock, Proxy}
import zio.test.Assertion
import zio.{IO, URLayer, ZIO, ZLayer, mock}

import scala.collection.immutable.ArraySeq

object MockDIDService extends Mock[DIDService] {

object ScheduleOperation extends Effect[SignedPrismDIDOperation, error.DIDOperationError, ScheduleDIDOperationOutcome]
// FIXME leaving this out for now as it gives a "java.lang.AssertionError: assertion failed: class Array" compilation error
// object GetScheduledDIDOperationDetail extends Effect[Array[Byte], error.DIDOperationError, Option[ScheduledDIDOperationDetail]]
object ResolveDID extends Effect[PrismDID, error.DIDResolutionError, Option[(DIDMetadata, DIDData)]]

override val compose: URLayer[mock.Proxy, DIDService] =
ZLayer {
for {
proxy <- ZIO.service[Proxy]
} yield new DIDService {
override def scheduleOperation(
operation: SignedPrismDIDOperation
): IO[error.DIDOperationError, ScheduleDIDOperationOutcome] =
proxy(ScheduleOperation, operation)

override def getScheduledDIDOperationDetail(
operationId: Array[Byte]
): IO[error.DIDOperationError, Option[ScheduledDIDOperationDetail]] =
???

override def resolveDID(did: PrismDID): IO[error.DIDResolutionError, Option[(DIDMetadata, DIDData)]] =
proxy(ResolveDID, did)
}
}

def createDID(
verificationRelationship: VerificationRelationship
): (PrismDIDOperation.Create, ECKeyPair, DIDMetadata, DIDData) = {
val masterKeyPair = EC.INSTANCE.generateKeyPair()
val keyPair = EC.INSTANCE.generateKeyPair()
val createOperation = PrismDIDOperation.Create(
publicKeys = Seq(
InternalPublicKey(
id = "master-0",
purpose = InternalKeyPurpose.Master,
publicKeyData = PublicKeyData.ECCompressedKeyData(
crv = EllipticCurve.SECP256K1,
data = Base64UrlString.fromByteArray(masterKeyPair.getPublicKey.getEncodedCompressed)
)
),
PublicKey(
id = "key-0",
purpose = verificationRelationship,
publicKeyData = PublicKeyData.ECCompressedKeyData(
crv = EllipticCurve.SECP256K1,
data = Base64UrlString.fromByteArray(keyPair.getPublicKey.getEncodedCompressed)
)
),
),
services = Nil,
context = Nil,
)
val longFormDid = PrismDID.buildLongFormFromOperation(createOperation)
// val canonicalDid = longFormDid.asCanonical

val didMetadata =
DIDMetadata(
lastOperationHash = ArraySeq.from(longFormDid.stateHash.toByteArray),
canonicalId = None, // unpublished DID must not contain canonicalId
deactivated = false, // unpublished DID cannot be deactivated
created = None, // unpublished DID cannot have timestamp
updated = None // unpublished DID cannot have timestamp
)
val didData = DIDData(
id = longFormDid.asCanonical,
publicKeys = createOperation.publicKeys.collect { case pk: PublicKey => pk },
services = createOperation.services,
internalKeys = createOperation.publicKeys.collect { case pk: InternalPublicKey => pk },
context = createOperation.context
)
(createOperation, keyPair, didMetadata, didData)
}

def resolveDIDExpectation(didMetadata: DIDMetadata, didData: DIDData): Expectation[DIDService] =
MockDIDService.ResolveDID(
assertion = Assertion.anything,
result = Expectation.value(Some(didMetadata, didData))
)
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package io.iohk.atala.pollux.core.model

import io.iohk.atala.mercury.protocol.issuecredential.OfferCredential
import io.iohk.atala.mercury.protocol.issuecredential.RequestCredential
import io.iohk.atala.mercury.protocol.issuecredential.IssueCredential
import IssueCredentialRecord._
import java.time.Instant
import io.iohk.atala.castor.core.model.did.CanonicalPrismDID
import io.iohk.atala.mercury.protocol.issuecredential.{IssueCredential, OfferCredential, RequestCredential}
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.*

import java.time.Instant

final case class IssueCredentialRecord(
id: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import zio.*
trait CredentialRepository {
def createIssueCredentialRecord(record: IssueCredentialRecord): RIO[WalletAccessContext, Int]
def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): RIO[WalletAccessContext, (Seq[IssueCredentialRecord], Int)]
Expand All @@ -27,7 +27,7 @@ trait CredentialRepository {

def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Option[IssueCredentialRecord]]

def updateCredentialRecordProtocolState(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package io.iohk.atala.pollux.core.repository

import io.iohk.atala.mercury.protocol.issuecredential.IssueCredential
import io.iohk.atala.mercury.protocol.issuecredential.RequestCredential
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.ProtocolState
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.PublicationState
import io.iohk.atala.pollux.core.model._
import io.iohk.atala.pollux.core.model.error.CredentialRepositoryError._
import io.iohk.atala.mercury.protocol.issuecredential.{IssueCredential, RequestCredential}
import io.iohk.atala.pollux.core.model.*
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.{ProtocolState, PublicationState}
import io.iohk.atala.pollux.core.model.error.CredentialRepositoryError.*
import io.iohk.atala.prism.crypto.MerkleInclusionProof
import io.iohk.atala.shared.models.WalletId
import io.iohk.atala.shared.models.{WalletAccessContext, WalletId}
import zio.*

import java.time.Instant
import io.iohk.atala.shared.models.WalletAccessContext

class CredentialRepositoryInMemory(
walletRefs: Ref[Map[WalletId, Ref[Map[DidCommID, IssueCredentialRecord]]]],
Expand Down Expand Up @@ -76,14 +73,15 @@ class CredentialRepositoryInMemory(
}

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
offset: Option[Int],
limit: Option[Int]
): RIO[WalletAccessContext, (Seq[IssueCredentialRecord], Int)] = {
for {
storeRef <- walletStoreRef
store <- storeRef.get
paginated = store.values.toSeq.drop(offset.getOrElse(0)).take(limit.getOrElse(Int.MaxValue))
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
paginated = records.toSeq.drop(offset.getOrElse(0)).take(limit.getOrElse(Int.MaxValue))
} yield paginated -> store.values.size
}

Expand Down Expand Up @@ -209,20 +207,22 @@ class CredentialRepositoryInMemory(
for {
storeRef <- walletStoreRef
store <- storeRef.get
} yield store.values
.filter(rec => states.contains(rec.protocolState) & (!ignoreWithZeroRetries | rec.metaRetries > 0))
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
} yield records
.filter(rec => states.contains(rec.protocolState))
.take(limit)
.toSeq
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Option[IssueCredentialRecord]] = {
for {
storeRef <- walletStoreRef
store <- storeRef.get
} yield store.values.find(_.thid == thid).filter(!ignoreWithZeroRetries | _.metaRetries > 0)
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
} yield records.find(_.thid == thid)
}

override def updateWithSubjectId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.*

trait PresentationRepository {
def createPresentationRecord(record: PresentationRecord): RIO[WalletAccessContext, Int]
def getPresentationRecords(ignoreWithZeroRetries: Boolean = true): RIO[WalletAccessContext, Seq[PresentationRecord]]
def getPresentationRecords(ignoreWithZeroRetries: Boolean): RIO[WalletAccessContext, Seq[PresentationRecord]]
def getPresentationRecord(recordId: DidCommID): RIO[WalletAccessContext, Option[PresentationRecord]]
def getPresentationRecordsByStates(
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PresentationRepositoryInMemory(
}

override def getPresentationRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Seq[PresentationRecord]] = {
for {
storeRef <- walletStoreRef
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ trait CredentialService {

/** Return a list of records as well as a count of all filtered items */
def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)]
Expand All @@ -49,7 +50,8 @@ trait CredentialService {
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]]

def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]]

def receiveCredentialOffer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,24 @@ private class CredentialServiceImpl(
credential.maybeId.map(_.split("/").last).map(DidCommID(_))

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int],
limit: Option[Int]
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)] = {
for {
records <- credentialRepository
.getIssueCredentialRecords(offset = offset, limit = limit)
.getIssueCredentialRecords(ignoreWithZeroRetries = ignoreWithZeroRetries, offset = offset, limit = limit)
.mapError(RepositoryError.apply)
} yield records
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]] =
for {
record <- credentialRepository
.getIssueCredentialRecordByThreadId(thid)
.getIssueCredentialRecordByThreadId(thid, ignoreWithZeroRetries)
.mapError(RepositoryError.apply)
} yield record

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,17 @@ class CredentialServiceNotifier(
svc.getIssueCredentialRecord(recordId)

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]] =
svc.getIssueCredentialRecordByThreadId(thid)
svc.getIssueCredentialRecordByThreadId(thid, ignoreWithZeroRetries)

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)] =
svc.getIssueCredentialRecords(offset, limit)
svc.getIssueCredentialRecords(ignoreWithZeroRetries, offset, limit)

override def getIssueCredentialRecordsByStates(
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ object MockCredentialService extends Mock[CredentialService] {
???

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): IO[CredentialServiceError, (Seq[IssueCredentialRecord], Int)] =
Expand All @@ -190,7 +191,8 @@ object MockCredentialService extends Mock[CredentialService] {
???

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): IO[CredentialServiceError, Option[IssueCredentialRecord]] = ???
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ object MockPresentationService extends Mock[PresentationService] {

override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] = ???

override def getPresentationRecords(): IO[PresentationError, Seq[PresentationRecord]] = ???
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): IO[PresentationError, Seq[PresentationRecord]] = ???

override def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ trait PresentationService {
options: Option[io.iohk.atala.pollux.core.model.presentation.Options]
): ZIO[WalletAccessContext, PresentationError, PresentationRecord]

def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]]
def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]]

def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ private class PresentationServiceImpl(
override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] =
credential.maybeId.map(_.split("/").last).map(UUID.fromString)

override def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] = {
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] = {
for {
records <- presentationRepository
.getPresentationRecords()
.getPresentationRecords(ignoreWithZeroRetries)
.mapError(RepositoryError.apply)
} yield records
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ class PresentationServiceNotifier(
override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] =
svc.extractIdFromCredential(credential)

override def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] =
svc.getPresentationRecords()
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] =
svc.getPresentationRecords(ignoreWithZeroRetries)

override def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Loading
Loading