Skip to content

Commit

Permalink
fix(prism-agent): fix concurrent requests breaking DID index counter (#…
Browse files Browse the repository at this point in the history
…571)

* chore: fix test typo

* chore: refactor create did handler

* fix: createDID with only 1 permit semaphore

* chore: pr cleanup

* ci: use undeprecated vault image

* chore: pr cleanup
  • Loading branch information
patlo-iog authored Jun 27, 2023
1 parent 5c5eb23 commit e8411dd
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 50 deletions.
2 changes: 1 addition & 1 deletion infrastructure/shared/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ services:
condition: service_healthy

vault-server:
image: vault:latest
image: hashicorp/vault:latest
# ports:
# - "8200:8200"
environment:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import scala.language.implicitConversions

import java.security.{PrivateKey as JavaPrivateKey, PublicKey as JavaPublicKey}
import scala.collection.immutable.ArraySeq
import io.iohk.atala.agent.walletapi.service.handler.DIDCreateHandler

/** A wrapper around Castor's DIDService providing key-management capability. Analogous to the secretAPI in
* indy-wallet-sdk.
Expand All @@ -28,7 +29,8 @@ final class ManagedDIDServiceImpl private[walletapi] (
private[walletapi] val secretStorage: DIDSecretStorage,
override private[walletapi] val nonSecretStorage: DIDNonSecretStorage,
apollo: Apollo,
seed: Array[Byte]
seed: Array[Byte],
createDIDSem: Semaphore
) extends ManagedDIDService {

private val CURVE = EllipticCurve.SECP256K1
Expand All @@ -38,10 +40,8 @@ final class ManagedDIDServiceImpl private[walletapi] (
private val keyResolver = KeyResolver(apollo, nonSecretStorage, secretStorage)(seed)

private val publicationHandler = PublicationHandler(didService, keyResolver)(DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, secretStorage, publicationHandler)(seed)

private val generateCreateOperationHdKey =
OperationFactory(apollo).makeCreateOperationHdKey(DEFAULT_MASTER_KEY_ID, seed)
private val didCreateHandler = DIDCreateHandler(apollo, nonSecretStorage)(seed, DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, publicationHandler)(seed)

def syncManagedDIDState: IO[GetManagedDIDError, Unit] = nonSecretStorage
.listManagedDID(offset = None, limit = None)
Expand Down Expand Up @@ -123,30 +123,26 @@ final class ManagedDIDServiceImpl private[walletapi] (
} yield outcome
}

// TODO: update this method to use the same handler as updateManagedDID
def createAndStoreDID(didTemplate: ManagedDIDTemplate): IO[CreateManagedDIDError, LongFormPrismDID] = {
for {
val effect = for {
_ <- ZIO
.fromEither(ManagedDIDTemplateValidator.validate(didTemplate))
.mapError(CreateManagedDIDError.InvalidArgument.apply)
didIndex <- nonSecretStorage
.getMaxDIDIndex()
.mapBoth(
CreateManagedDIDError.WalletStorageError.apply,
maybeIdx => maybeIdx.map(_ + 1).getOrElse(0)
)
generated <- generateCreateOperationHdKey(didIndex, didTemplate)
(createOperation, hdKey) = generated
longFormDID = PrismDID.buildLongFormFromOperation(createOperation)
did = longFormDID.asCanonical
material <- didCreateHandler.materialize(didTemplate)
_ <- ZIO
.fromEither(didOpValidator.validate(createOperation))
.fromEither(didOpValidator.validate(material.operation))
.mapError(CreateManagedDIDError.InvalidOperation.apply)
state = ManagedDIDState(createOperation, didIndex, PublicationState.Created())
_ <- nonSecretStorage
.insertManagedDID(did, state, hdKey.keyPaths ++ hdKey.internalKeyPaths)
.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield longFormDID
_ <- material.persist.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield PrismDID.buildLongFormFromOperation(material.operation)

// This synchronizes createDID effect to only allow 1 execution at a time
// to avoid concurrent didIndex update. Long-term solution should be
// solved at the DB level.
//
// Performance may be improved by not synchronizing the whole operation,
// but only the counter increment part allowing multiple in-flight create operations
// once didIndex is acquired.
createDIDSem.withPermit(effect)
}

def updateManagedDID(
Expand Down Expand Up @@ -366,7 +362,16 @@ object ManagedDIDServiceImpl {
nonSecretStorage <- ZIO.service[DIDNonSecretStorage]
apollo <- ZIO.service[Apollo]
seed <- ZIO.serviceWithZIO[SeedResolver](_.resolve)
} yield ManagedDIDServiceImpl(didService, didOpValidator, secretStorage, nonSecretStorage, apollo, seed)
createDIDSem <- Semaphore.make(1)
} yield ManagedDIDServiceImpl(
didService,
didOpValidator,
secretStorage,
nonSecretStorage,
apollo,
seed,
createDIDSem
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.iohk.atala.agent.walletapi.service.handler

import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.model.CreateDIDHdKey
import io.iohk.atala.agent.walletapi.model.ManagedDIDState
import io.iohk.atala.agent.walletapi.model.ManagedDIDTemplate
import io.iohk.atala.agent.walletapi.model.error.CreateManagedDIDError
import io.iohk.atala.agent.walletapi.storage.DIDNonSecretStorage
import io.iohk.atala.castor.core.model.did.PrismDIDOperation
import zio.*
import io.iohk.atala.agent.walletapi.util.OperationFactory
import io.iohk.atala.agent.walletapi.model.PublicationState

private[walletapi] class DIDCreateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage
)(
seed: Array[Byte],
masterKeyId: String
) {
def materialize(
didTemplate: ManagedDIDTemplate
): IO[CreateManagedDIDError, DIDCreateMaterial] = {
val operationFactory = OperationFactory(apollo)
for {
didIndex <- nonSecretStorage
.getMaxDIDIndex()
.mapBoth(
CreateManagedDIDError.WalletStorageError.apply,
maybeIdx => maybeIdx.map(_ + 1).getOrElse(0)
)
generated <- operationFactory.makeCreateOperationHdKey(masterKeyId, seed)(didIndex, didTemplate)
(createOperation, hdKey) = generated
state = ManagedDIDState(createOperation, didIndex, PublicationState.Created())
} yield DIDCreateMaterialImpl(nonSecretStorage)(createOperation, state, hdKey)
}
}

private[walletapi] trait DIDCreateMaterial {
def operation: PrismDIDOperation.Create
def state: ManagedDIDState
def persist: Task[Unit]
}

private[walletapi] class DIDCreateMaterialImpl(nonSecretStorage: DIDNonSecretStorage)(
val operation: PrismDIDOperation.Create,
val state: ManagedDIDState,
hdKey: CreateDIDHdKey
) extends DIDCreateMaterial {
def persist: Task[Unit] = {
val did = operation.did
for {
_ <- nonSecretStorage
.insertManagedDID(did, state, hdKey.keyPaths ++ hdKey.internalKeyPaths)
.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ import io.iohk.atala.castor.core.model.did.ScheduledDIDOperationStatus
import io.iohk.atala.castor.core.model.did.SignedPrismDIDOperation
import scala.collection.immutable.ArraySeq

class DIDUpdateHandler(
private[walletapi] class DIDUpdateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage,
secretStorage: DIDSecretStorage,
publicationHandler: PublicationHandler
)(
seed: Array[Byte]
Expand All @@ -44,12 +43,12 @@ class DIDUpdateHandler(
result <- operationFactory.makeUpdateOperationHdKey(seed)(did, previousOperationHash, actions, keyCounter)
(operation, hdKey) = result
signedOperation <- publicationHandler.signOperationWithMasterKey[UpdateManagedDIDError](state, operation)
} yield HdKeyUpdateMaterial(secretStorage, nonSecretStorage)(operation, signedOperation, state, hdKey)
} yield HdKeyUpdateMaterial(nonSecretStorage)(operation, signedOperation, state, hdKey)
}
}
}

trait DIDUpdateMaterial {
private[walletapi] trait DIDUpdateMaterial {

def operation: PrismDIDOperation.Update

Expand Down Expand Up @@ -78,7 +77,7 @@ trait DIDUpdateMaterial {

}

class HdKeyUpdateMaterial(secretStorage: DIDSecretStorage, nonSecretStorage: DIDNonSecretStorage)(
private class HdKeyUpdateMaterial(nonSecretStorage: DIDNonSecretStorage)(
val operation: PrismDIDOperation.Update,
val signedOperation: SignedPrismDIDOperation,
val state: ManagedDIDState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ object KeyDerivation extends ZIOSpecDefault, VaultTestContainerSupport {
private val seedHex = "00" * 64
private val seed = HexString.fromStringUnsafe(seedHex).toByteArray

override def spec = suite("Key derivation benchamrk")(
override def spec = suite("Key derivation benchmark")(
deriveKeyBenchmark.provide(Apollo.prism14Layer),
queryKeyBenchmark.provide(vaultKvClientLayer, Apollo.prism14Layer)
) @@ TestAspect.sequential @@ TestAspect.timed @@ TestAspect.tag("benchmark") @@ TestAspect.ignore

private val deriveKeyBenchmark = suite("Key derivation benchmark")(
benchamrkKeyDerivation(1),
benchamrkKeyDerivation(8),
benchamrkKeyDerivation(16),
benchamrkKeyDerivation(32),
benchmarkKeyDerivation(1),
benchmarkKeyDerivation(8),
benchmarkKeyDerivation(16),
benchmarkKeyDerivation(32),
) @@ TestAspect.before(deriveKeyWarmUp())

private val queryKeyBenchmark = suite("Query key benchmark - vault storage")(
Expand All @@ -36,7 +36,7 @@ object KeyDerivation extends ZIOSpecDefault, VaultTestContainerSupport {
benchmarkVaultQuery(32),
) @@ TestAspect.before(vaultWarmUp())

private def benchamrkKeyDerivation(parallelism: Int) = {
private def benchmarkKeyDerivation(parallelism: Int) = {
test(s"derive 50000 keys - $parallelism parallelism") {
for {
apollo <- ZIO.service[Apollo]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package io.iohk.atala.agent.walletapi.service

import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.crypto.ApolloSpecHelper
import io.iohk.atala.agent.walletapi.model.UpdateManagedDIDAction
import io.iohk.atala.agent.walletapi.model.error.UpdateManagedDIDError
import io.iohk.atala.agent.walletapi.model.error.{CreateManagedDIDError, PublishManagedDIDError}
import io.iohk.atala.agent.walletapi.model.{DIDPublicKeyTemplate, ManagedDIDState, ManagedDIDTemplate, PublicationState}
import io.iohk.atala.agent.walletapi.sql.JdbcDIDNonSecretStorage
import io.iohk.atala.agent.walletapi.sql.JdbcDIDSecretStorage
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.agent.walletapi.vault.VaultDIDSecretStorage
import io.iohk.atala.castor.core.model.did.InternalKeyPurpose
import io.iohk.atala.castor.core.model.did.{
DIDData,
DIDMetadata,
Expand All @@ -19,23 +28,13 @@ import io.iohk.atala.castor.core.model.did.{
import io.iohk.atala.castor.core.model.error
import io.iohk.atala.castor.core.service.DIDService
import io.iohk.atala.castor.core.util.DIDOperationValidator
import io.iohk.atala.test.container.DBTestUtils
import io.iohk.atala.test.container.PostgresTestContainerSupport
import io.iohk.atala.test.container.VaultTestContainerSupport
import scala.collection.immutable.ArraySeq
import zio.*
import zio.test.*
import zio.test.Assertion.*

import scala.collection.immutable.ArraySeq
import io.iohk.atala.test.container.PostgresTestContainerSupport
import io.iohk.atala.test.container.VaultTestContainerSupport
import io.iohk.atala.agent.walletapi.crypto.ApolloSpecHelper
import io.iohk.atala.agent.walletapi.sql.JdbcDIDSecretStorage
import io.iohk.atala.agent.walletapi.sql.JdbcDIDNonSecretStorage
import io.iohk.atala.test.container.DBTestUtils
import io.iohk.atala.castor.core.model.did.InternalKeyPurpose
import io.iohk.atala.agent.walletapi.model.error.UpdateManagedDIDError
import io.iohk.atala.agent.walletapi.model.UpdateManagedDIDAction
import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.agent.walletapi.vault.VaultDIDSecretStorage
import zio.test.TestAspect.sequential

object ManagedDIDServiceSpec
Expand Down Expand Up @@ -270,6 +269,19 @@ object ManagedDIDServiceSpec
)
val result = ZIO.serviceWithZIO[ManagedDIDService](_.createAndStoreDID(template))
assertZIO(result.exit)(fails(isSubtype[CreateManagedDIDError.InvalidArgument](anything)))
},
test("concurrent DID creation successfully create DID using different did-index") {
for {
svc <- ZIO.service[ManagedDIDService]
dids <- ZIO
.foreachPar(1 to 50)(_ => svc.createAndStoreDID(generateDIDTemplate()).map(_.asCanonical))
.withParallelism(8)
.map(_.toList)
states <- ZIO
.foreach(dids)(did => svc.nonSecretStorage.getManagedDIDState(did))
.map(_.toList.flatten)
} yield assert(dids)(hasSize(equalTo(50))) &&
assert(states.map(_.didIndex))(hasSameElementsDistinct(0 until 50))
}
)

Expand Down

0 comments on commit e8411dd

Please sign in to comment.