Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prism-agent): fix concurrent requests breaking DID index counter #571

Merged
merged 6 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infrastructure/shared/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ services:
condition: service_healthy

vault-server:
image: vault:latest
image: hashicorp/vault:latest
# ports:
# - "8200:8200"
environment:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import scala.language.implicitConversions

import java.security.{PrivateKey as JavaPrivateKey, PublicKey as JavaPublicKey}
import scala.collection.immutable.ArraySeq
import io.iohk.atala.agent.walletapi.service.handler.DIDCreateHandler

/** A wrapper around Castor's DIDService providing key-management capability. Analogous to the secretAPI in
* indy-wallet-sdk.
Expand All @@ -28,7 +29,8 @@ final class ManagedDIDServiceImpl private[walletapi] (
private[walletapi] val secretStorage: DIDSecretStorage,
override private[walletapi] val nonSecretStorage: DIDNonSecretStorage,
apollo: Apollo,
seed: Array[Byte]
seed: Array[Byte],
createDIDSem: Semaphore
) extends ManagedDIDService {

private val CURVE = EllipticCurve.SECP256K1
Expand All @@ -38,10 +40,8 @@ final class ManagedDIDServiceImpl private[walletapi] (
private val keyResolver = KeyResolver(apollo, nonSecretStorage, secretStorage)(seed)

private val publicationHandler = PublicationHandler(didService, keyResolver)(DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, secretStorage, publicationHandler)(seed)

private val generateCreateOperationHdKey =
OperationFactory(apollo).makeCreateOperationHdKey(DEFAULT_MASTER_KEY_ID, seed)
private val didCreateHandler = DIDCreateHandler(apollo, nonSecretStorage)(seed, DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, publicationHandler)(seed)

def syncManagedDIDState: IO[GetManagedDIDError, Unit] = nonSecretStorage
.listManagedDID(offset = None, limit = None)
Expand Down Expand Up @@ -123,30 +123,26 @@ final class ManagedDIDServiceImpl private[walletapi] (
} yield outcome
}

// TODO: update this method to use the same handler as updateManagedDID
def createAndStoreDID(didTemplate: ManagedDIDTemplate): IO[CreateManagedDIDError, LongFormPrismDID] = {
for {
val effect = for {
_ <- ZIO
.fromEither(ManagedDIDTemplateValidator.validate(didTemplate))
.mapError(CreateManagedDIDError.InvalidArgument.apply)
didIndex <- nonSecretStorage
.getMaxDIDIndex()
.mapBoth(
CreateManagedDIDError.WalletStorageError.apply,
maybeIdx => maybeIdx.map(_ + 1).getOrElse(0)
)
generated <- generateCreateOperationHdKey(didIndex, didTemplate)
(createOperation, hdKey) = generated
longFormDID = PrismDID.buildLongFormFromOperation(createOperation)
did = longFormDID.asCanonical
material <- didCreateHandler.materialize(didTemplate)
_ <- ZIO
.fromEither(didOpValidator.validate(createOperation))
.fromEither(didOpValidator.validate(material.operation))
.mapError(CreateManagedDIDError.InvalidOperation.apply)
state = ManagedDIDState(createOperation, didIndex, PublicationState.Created())
_ <- nonSecretStorage
.insertManagedDID(did, state, hdKey.keyPaths ++ hdKey.internalKeyPaths)
.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield longFormDID
_ <- material.persist.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield PrismDID.buildLongFormFromOperation(material.operation)

// This synchronizes createDID effect to only allow 1 execution at a time
// to avoid concurrent didIndex update. Long-term solution should be
// solved at the DB level.
//
// Performance may be improved by not synchronizing the whole operation,
// but only the counter increment part allowing multiple in-flight create operations
// once didIndex is acquired.
createDIDSem.withPermit(effect)
}

def updateManagedDID(
Expand Down Expand Up @@ -366,7 +362,16 @@ object ManagedDIDServiceImpl {
nonSecretStorage <- ZIO.service[DIDNonSecretStorage]
apollo <- ZIO.service[Apollo]
seed <- ZIO.serviceWithZIO[SeedResolver](_.resolve)
} yield ManagedDIDServiceImpl(didService, didOpValidator, secretStorage, nonSecretStorage, apollo, seed)
createDIDSem <- Semaphore.make(1)
} yield ManagedDIDServiceImpl(
didService,
didOpValidator,
secretStorage,
nonSecretStorage,
apollo,
seed,
createDIDSem
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.iohk.atala.agent.walletapi.service.handler

import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.model.CreateDIDHdKey
import io.iohk.atala.agent.walletapi.model.ManagedDIDState
import io.iohk.atala.agent.walletapi.model.ManagedDIDTemplate
import io.iohk.atala.agent.walletapi.model.error.CreateManagedDIDError
import io.iohk.atala.agent.walletapi.storage.DIDNonSecretStorage
import io.iohk.atala.castor.core.model.did.PrismDIDOperation
import zio.*
import io.iohk.atala.agent.walletapi.util.OperationFactory
import io.iohk.atala.agent.walletapi.model.PublicationState

private[walletapi] class DIDCreateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage
)(
seed: Array[Byte],
masterKeyId: String
) {
def materialize(
didTemplate: ManagedDIDTemplate
): IO[CreateManagedDIDError, DIDCreateMaterial] = {
val operationFactory = OperationFactory(apollo)
for {
didIndex <- nonSecretStorage
.getMaxDIDIndex()
.mapBoth(
CreateManagedDIDError.WalletStorageError.apply,
maybeIdx => maybeIdx.map(_ + 1).getOrElse(0)
)
generated <- operationFactory.makeCreateOperationHdKey(masterKeyId, seed)(didIndex, didTemplate)
(createOperation, hdKey) = generated
state = ManagedDIDState(createOperation, didIndex, PublicationState.Created())
} yield DIDCreateMaterialImpl(nonSecretStorage)(createOperation, state, hdKey)
}
}

private[walletapi] trait DIDCreateMaterial {
def operation: PrismDIDOperation.Create
def state: ManagedDIDState
def persist: Task[Unit]
}

private[walletapi] class DIDCreateMaterialImpl(nonSecretStorage: DIDNonSecretStorage)(
val operation: PrismDIDOperation.Create,
val state: ManagedDIDState,
hdKey: CreateDIDHdKey
) extends DIDCreateMaterial {
def persist: Task[Unit] = {
val did = operation.did
for {
_ <- nonSecretStorage
.insertManagedDID(did, state, hdKey.keyPaths ++ hdKey.internalKeyPaths)
.mapError(CreateManagedDIDError.WalletStorageError.apply)
} yield ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ import io.iohk.atala.castor.core.model.did.ScheduledDIDOperationStatus
import io.iohk.atala.castor.core.model.did.SignedPrismDIDOperation
import scala.collection.immutable.ArraySeq

class DIDUpdateHandler(
private[walletapi] class DIDUpdateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage,
secretStorage: DIDSecretStorage,
publicationHandler: PublicationHandler
)(
seed: Array[Byte]
Expand All @@ -44,12 +43,12 @@ class DIDUpdateHandler(
result <- operationFactory.makeUpdateOperationHdKey(seed)(did, previousOperationHash, actions, keyCounter)
(operation, hdKey) = result
signedOperation <- publicationHandler.signOperationWithMasterKey[UpdateManagedDIDError](state, operation)
} yield HdKeyUpdateMaterial(secretStorage, nonSecretStorage)(operation, signedOperation, state, hdKey)
} yield HdKeyUpdateMaterial(nonSecretStorage)(operation, signedOperation, state, hdKey)
}
}
}

trait DIDUpdateMaterial {
private[walletapi] trait DIDUpdateMaterial {

def operation: PrismDIDOperation.Update

Expand Down Expand Up @@ -78,7 +77,7 @@ trait DIDUpdateMaterial {

}

class HdKeyUpdateMaterial(secretStorage: DIDSecretStorage, nonSecretStorage: DIDNonSecretStorage)(
private class HdKeyUpdateMaterial(nonSecretStorage: DIDNonSecretStorage)(
val operation: PrismDIDOperation.Update,
val signedOperation: SignedPrismDIDOperation,
val state: ManagedDIDState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ object KeyDerivation extends ZIOSpecDefault, VaultTestContainerSupport {
private val seedHex = "00" * 64
private val seed = HexString.fromStringUnsafe(seedHex).toByteArray

override def spec = suite("Key derivation benchamrk")(
override def spec = suite("Key derivation benchmark")(
deriveKeyBenchmark.provide(Apollo.prism14Layer),
queryKeyBenchmark.provide(vaultKvClientLayer, Apollo.prism14Layer)
) @@ TestAspect.sequential @@ TestAspect.timed @@ TestAspect.tag("benchmark") @@ TestAspect.ignore

private val deriveKeyBenchmark = suite("Key derivation benchmark")(
benchamrkKeyDerivation(1),
benchamrkKeyDerivation(8),
benchamrkKeyDerivation(16),
benchamrkKeyDerivation(32),
benchmarkKeyDerivation(1),
benchmarkKeyDerivation(8),
benchmarkKeyDerivation(16),
benchmarkKeyDerivation(32),
) @@ TestAspect.before(deriveKeyWarmUp())

private val queryKeyBenchmark = suite("Query key benchmark - vault storage")(
Expand All @@ -36,7 +36,7 @@ object KeyDerivation extends ZIOSpecDefault, VaultTestContainerSupport {
benchmarkVaultQuery(32),
) @@ TestAspect.before(vaultWarmUp())

private def benchamrkKeyDerivation(parallelism: Int) = {
private def benchmarkKeyDerivation(parallelism: Int) = {
test(s"derive 50000 keys - $parallelism parallelism") {
for {
apollo <- ZIO.service[Apollo]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package io.iohk.atala.agent.walletapi.service

import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.crypto.ApolloSpecHelper
import io.iohk.atala.agent.walletapi.model.UpdateManagedDIDAction
import io.iohk.atala.agent.walletapi.model.error.UpdateManagedDIDError
import io.iohk.atala.agent.walletapi.model.error.{CreateManagedDIDError, PublishManagedDIDError}
import io.iohk.atala.agent.walletapi.model.{DIDPublicKeyTemplate, ManagedDIDState, ManagedDIDTemplate, PublicationState}
import io.iohk.atala.agent.walletapi.sql.JdbcDIDNonSecretStorage
import io.iohk.atala.agent.walletapi.sql.JdbcDIDSecretStorage
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.agent.walletapi.vault.VaultDIDSecretStorage
import io.iohk.atala.castor.core.model.did.InternalKeyPurpose
import io.iohk.atala.castor.core.model.did.{
DIDData,
DIDMetadata,
Expand All @@ -19,23 +28,13 @@ import io.iohk.atala.castor.core.model.did.{
import io.iohk.atala.castor.core.model.error
import io.iohk.atala.castor.core.service.DIDService
import io.iohk.atala.castor.core.util.DIDOperationValidator
import io.iohk.atala.test.container.DBTestUtils
import io.iohk.atala.test.container.PostgresTestContainerSupport
import io.iohk.atala.test.container.VaultTestContainerSupport
import scala.collection.immutable.ArraySeq
import zio.*
import zio.test.*
import zio.test.Assertion.*

import scala.collection.immutable.ArraySeq
import io.iohk.atala.test.container.PostgresTestContainerSupport
import io.iohk.atala.test.container.VaultTestContainerSupport
import io.iohk.atala.agent.walletapi.crypto.ApolloSpecHelper
import io.iohk.atala.agent.walletapi.sql.JdbcDIDSecretStorage
import io.iohk.atala.agent.walletapi.sql.JdbcDIDNonSecretStorage
import io.iohk.atala.test.container.DBTestUtils
import io.iohk.atala.castor.core.model.did.InternalKeyPurpose
import io.iohk.atala.agent.walletapi.model.error.UpdateManagedDIDError
import io.iohk.atala.agent.walletapi.model.UpdateManagedDIDAction
import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.agent.walletapi.vault.VaultDIDSecretStorage
import zio.test.TestAspect.sequential

object ManagedDIDServiceSpec
Expand Down Expand Up @@ -270,6 +269,19 @@ object ManagedDIDServiceSpec
)
val result = ZIO.serviceWithZIO[ManagedDIDService](_.createAndStoreDID(template))
assertZIO(result.exit)(fails(isSubtype[CreateManagedDIDError.InvalidArgument](anything)))
},
test("concurrent DID creation successfully create DID using different did-index") {
for {
svc <- ZIO.service[ManagedDIDService]
dids <- ZIO
.foreachPar(1 to 50)(_ => svc.createAndStoreDID(generateDIDTemplate()).map(_.asCanonical))
.withParallelism(8)
.map(_.toList)
states <- ZIO
.foreach(dids)(did => svc.nonSecretStorage.getManagedDIDState(did))
.map(_.toList.flatten)
} yield assert(dids)(hasSize(equalTo(50))) &&
assert(states.map(_.didIndex))(hasSameElementsDistinct(0 until 50))
}
)

Expand Down