Skip to content

Commit

Permalink
feat(prism-agent): check issuing DID validity when creating a VC offe…
Browse files Browse the repository at this point in the history
…r + return 'metaRetries' (#740)

Signed-off-by: Benjamin Voiturier <[email protected]>
Signed-off-by: Milos Backonja <[email protected]>
Signed-off-by: Anton Baliasnikov <[email protected]>
Signed-off-by: Pat Losoponkul <[email protected]>
Co-authored-by: Milos Backonja <[email protected]>
Co-authored-by: atala-dev <[email protected]>
Co-authored-by: patlo-iog <[email protected]>
Signed-off-by: Shota Jolbordi <[email protected]>
  • Loading branch information
4 people authored and Shota Jolbordi committed Oct 2, 2023
1 parent 33c90c1 commit 3a55455
Show file tree
Hide file tree
Showing 32 changed files with 286 additions and 75 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ lazy val D_Castor = new {
Seq(
D.zio,
D.zioTest,
D.zioMock,
D.zioTestSbt,
D.zioTestMagnolia,
D.circeCore,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.iohk.atala.castor.core.service

import io.iohk.atala.castor.core.model.did.*
import io.iohk.atala.castor.core.model.error
import io.iohk.atala.prism.crypto.EC
import io.iohk.atala.prism.crypto.keys.ECKeyPair
import io.iohk.atala.shared.models.Base64UrlString
import zio.mock.{Expectation, Mock, Proxy}
import zio.test.Assertion
import zio.{IO, URLayer, ZIO, ZLayer, mock}

import scala.collection.immutable.ArraySeq

object MockDIDService extends Mock[DIDService] {

object ScheduleOperation extends Effect[SignedPrismDIDOperation, error.DIDOperationError, ScheduleDIDOperationOutcome]
// FIXME leaving this out for now as it gives a "java.lang.AssertionError: assertion failed: class Array" compilation error
// object GetScheduledDIDOperationDetail extends Effect[Array[Byte], error.DIDOperationError, Option[ScheduledDIDOperationDetail]]
object ResolveDID extends Effect[PrismDID, error.DIDResolutionError, Option[(DIDMetadata, DIDData)]]

override val compose: URLayer[mock.Proxy, DIDService] =
ZLayer {
for {
proxy <- ZIO.service[Proxy]
} yield new DIDService {
override def scheduleOperation(
operation: SignedPrismDIDOperation
): IO[error.DIDOperationError, ScheduleDIDOperationOutcome] =
proxy(ScheduleOperation, operation)

override def getScheduledDIDOperationDetail(
operationId: Array[Byte]
): IO[error.DIDOperationError, Option[ScheduledDIDOperationDetail]] =
???

override def resolveDID(did: PrismDID): IO[error.DIDResolutionError, Option[(DIDMetadata, DIDData)]] =
proxy(ResolveDID, did)
}
}

def createDID(
verificationRelationship: VerificationRelationship
): (PrismDIDOperation.Create, ECKeyPair, DIDMetadata, DIDData) = {
val masterKeyPair = EC.INSTANCE.generateKeyPair()
val keyPair = EC.INSTANCE.generateKeyPair()
val createOperation = PrismDIDOperation.Create(
publicKeys = Seq(
InternalPublicKey(
id = "master-0",
purpose = InternalKeyPurpose.Master,
publicKeyData = PublicKeyData.ECCompressedKeyData(
crv = EllipticCurve.SECP256K1,
data = Base64UrlString.fromByteArray(masterKeyPair.getPublicKey.getEncodedCompressed)
)
),
PublicKey(
id = "key-0",
purpose = verificationRelationship,
publicKeyData = PublicKeyData.ECCompressedKeyData(
crv = EllipticCurve.SECP256K1,
data = Base64UrlString.fromByteArray(keyPair.getPublicKey.getEncodedCompressed)
)
),
),
services = Nil,
context = Nil,
)
val longFormDid = PrismDID.buildLongFormFromOperation(createOperation)
// val canonicalDid = longFormDid.asCanonical

val didMetadata =
DIDMetadata(
lastOperationHash = ArraySeq.from(longFormDid.stateHash.toByteArray),
canonicalId = None, // unpublished DID must not contain canonicalId
deactivated = false, // unpublished DID cannot be deactivated
created = None, // unpublished DID cannot have timestamp
updated = None // unpublished DID cannot have timestamp
)
val didData = DIDData(
id = longFormDid.asCanonical,
publicKeys = createOperation.publicKeys.collect { case pk: PublicKey => pk },
services = createOperation.services,
internalKeys = createOperation.publicKeys.collect { case pk: InternalPublicKey => pk },
context = createOperation.context
)
(createOperation, keyPair, didMetadata, didData)
}

def resolveDIDExpectation(didMetadata: DIDMetadata, didData: DIDData): Expectation[DIDService] =
MockDIDService.ResolveDID(
assertion = Assertion.anything,
result = Expectation.value(Some(didMetadata, didData))
)
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package io.iohk.atala.pollux.core.model

import io.iohk.atala.mercury.protocol.issuecredential.OfferCredential
import io.iohk.atala.mercury.protocol.issuecredential.RequestCredential
import io.iohk.atala.mercury.protocol.issuecredential.IssueCredential
import IssueCredentialRecord._
import java.time.Instant
import io.iohk.atala.castor.core.model.did.CanonicalPrismDID
import io.iohk.atala.mercury.protocol.issuecredential.{IssueCredential, OfferCredential, RequestCredential}
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.*

import java.time.Instant

final case class IssueCredentialRecord(
id: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import zio.*
trait CredentialRepository {
def createIssueCredentialRecord(record: IssueCredentialRecord): RIO[WalletAccessContext, Int]
def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): RIO[WalletAccessContext, (Seq[IssueCredentialRecord], Int)]
Expand All @@ -27,7 +27,7 @@ trait CredentialRepository {

def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Option[IssueCredentialRecord]]

def updateCredentialRecordProtocolState(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package io.iohk.atala.pollux.core.repository

import io.iohk.atala.mercury.protocol.issuecredential.IssueCredential
import io.iohk.atala.mercury.protocol.issuecredential.RequestCredential
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.ProtocolState
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.PublicationState
import io.iohk.atala.pollux.core.model._
import io.iohk.atala.pollux.core.model.error.CredentialRepositoryError._
import io.iohk.atala.mercury.protocol.issuecredential.{IssueCredential, RequestCredential}
import io.iohk.atala.pollux.core.model.*
import io.iohk.atala.pollux.core.model.IssueCredentialRecord.{ProtocolState, PublicationState}
import io.iohk.atala.pollux.core.model.error.CredentialRepositoryError.*
import io.iohk.atala.prism.crypto.MerkleInclusionProof
import io.iohk.atala.shared.models.WalletId
import io.iohk.atala.shared.models.{WalletAccessContext, WalletId}
import zio.*

import java.time.Instant
import io.iohk.atala.shared.models.WalletAccessContext

class CredentialRepositoryInMemory(
walletRefs: Ref[Map[WalletId, Ref[Map[DidCommID, IssueCredentialRecord]]]],
Expand Down Expand Up @@ -76,14 +73,15 @@ class CredentialRepositoryInMemory(
}

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
offset: Option[Int],
limit: Option[Int]
): RIO[WalletAccessContext, (Seq[IssueCredentialRecord], Int)] = {
for {
storeRef <- walletStoreRef
store <- storeRef.get
paginated = store.values.toSeq.drop(offset.getOrElse(0)).take(limit.getOrElse(Int.MaxValue))
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
paginated = records.toSeq.drop(offset.getOrElse(0)).take(limit.getOrElse(Int.MaxValue))
} yield paginated -> store.values.size
}

Expand Down Expand Up @@ -209,20 +207,22 @@ class CredentialRepositoryInMemory(
for {
storeRef <- walletStoreRef
store <- storeRef.get
} yield store.values
.filter(rec => states.contains(rec.protocolState) & (!ignoreWithZeroRetries | rec.metaRetries > 0))
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
} yield records
.filter(rec => states.contains(rec.protocolState))
.take(limit)
.toSeq
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Option[IssueCredentialRecord]] = {
for {
storeRef <- walletStoreRef
store <- storeRef.get
} yield store.values.find(_.thid == thid).filter(!ignoreWithZeroRetries | _.metaRetries > 0)
records = if (ignoreWithZeroRetries) store.values.filter(_.metaRetries > 0) else store.values
} yield records.find(_.thid == thid)
}

override def updateWithSubjectId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.*

trait PresentationRepository {
def createPresentationRecord(record: PresentationRecord): RIO[WalletAccessContext, Int]
def getPresentationRecords(ignoreWithZeroRetries: Boolean = true): RIO[WalletAccessContext, Seq[PresentationRecord]]
def getPresentationRecords(ignoreWithZeroRetries: Boolean): RIO[WalletAccessContext, Seq[PresentationRecord]]
def getPresentationRecord(recordId: DidCommID): RIO[WalletAccessContext, Option[PresentationRecord]]
def getPresentationRecordsByStates(
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PresentationRepositoryInMemory(
}

override def getPresentationRecords(
ignoreWithZeroRetries: Boolean = true,
ignoreWithZeroRetries: Boolean,
): RIO[WalletAccessContext, Seq[PresentationRecord]] = {
for {
storeRef <- walletStoreRef
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ trait CredentialService {

/** Return a list of records as well as a count of all filtered items */
def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)]
Expand All @@ -49,7 +50,8 @@ trait CredentialService {
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]]

def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]]

def receiveCredentialOffer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,24 @@ private class CredentialServiceImpl(
credential.maybeId.map(_.split("/").last).map(DidCommID(_))

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int],
limit: Option[Int]
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)] = {
for {
records <- credentialRepository
.getIssueCredentialRecords(offset = offset, limit = limit)
.getIssueCredentialRecords(ignoreWithZeroRetries = ignoreWithZeroRetries, offset = offset, limit = limit)
.mapError(RepositoryError.apply)
} yield records
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]] =
for {
record <- credentialRepository
.getIssueCredentialRecordByThreadId(thid)
.getIssueCredentialRecordByThreadId(thid, ignoreWithZeroRetries)
.mapError(RepositoryError.apply)
} yield record

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,17 @@ class CredentialServiceNotifier(
svc.getIssueCredentialRecord(recordId)

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]] =
svc.getIssueCredentialRecordByThreadId(thid)
svc.getIssueCredentialRecordByThreadId(thid, ignoreWithZeroRetries)

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): ZIO[WalletAccessContext, CredentialServiceError, (Seq[IssueCredentialRecord], Int)] =
svc.getIssueCredentialRecords(offset, limit)
svc.getIssueCredentialRecords(ignoreWithZeroRetries, offset, limit)

override def getIssueCredentialRecordsByStates(
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ object MockCredentialService extends Mock[CredentialService] {
???

override def getIssueCredentialRecords(
ignoreWithZeroRetries: Boolean,
offset: Option[Int] = None,
limit: Option[Int] = None
): IO[CredentialServiceError, (Seq[IssueCredentialRecord], Int)] =
Expand All @@ -190,7 +191,8 @@ object MockCredentialService extends Mock[CredentialService] {
???

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
thid: DidCommID,
ignoreWithZeroRetries: Boolean
): IO[CredentialServiceError, Option[IssueCredentialRecord]] = ???
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ object MockPresentationService extends Mock[PresentationService] {

override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] = ???

override def getPresentationRecords(): IO[PresentationError, Seq[PresentationRecord]] = ???
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): IO[PresentationError, Seq[PresentationRecord]] = ???

override def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ trait PresentationService {
options: Option[io.iohk.atala.pollux.core.model.presentation.Options]
): ZIO[WalletAccessContext, PresentationError, PresentationRecord]

def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]]
def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]]

def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ private class PresentationServiceImpl(
override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] =
credential.maybeId.map(_.split("/").last).map(UUID.fromString)

override def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] = {
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] = {
for {
records <- presentationRepository
.getPresentationRecords()
.getPresentationRecords(ignoreWithZeroRetries)
.mapError(RepositoryError.apply)
} yield records
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ class PresentationServiceNotifier(
override def extractIdFromCredential(credential: W3cCredentialPayload): Option[UUID] =
svc.extractIdFromCredential(credential)

override def getPresentationRecords(): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] =
svc.getPresentationRecords()
override def getPresentationRecords(
ignoreWithZeroRetries: Boolean
): ZIO[WalletAccessContext, PresentationError, Seq[PresentationRecord]] =
svc.getPresentationRecords(ignoreWithZeroRetries)

override def createPresentationPayloadFromRecord(
record: DidCommID,
Expand Down
Loading

0 comments on commit 3a55455

Please sign in to comment.