From 09ff67a70beb36c4a8394bce5de2dd809e4e84ec Mon Sep 17 00:00:00 2001 From: sstone Date: Mon, 29 Jan 2024 17:22:39 +0100 Subject: [PATCH] Improve error handling for musig2 module --- .../fr/acinq/bitcoin/crypto/musig2/Musig2.kt | 52 +++++++--- .../crypto/musig2/Musig2TestsCommon.kt | 97 ++++++++++--------- 2 files changed, 85 insertions(+), 64 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt b/src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt index 8402cb75..e050506c 100644 --- a/src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt +++ b/src/commonMain/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2.kt @@ -1,6 +1,7 @@ package fr.acinq.bitcoin.crypto.musig2 import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.utils.Either import fr.acinq.secp256k1.Hex import fr.acinq.secp256k1.Secp256k1 import kotlin.jvm.JvmStatic @@ -23,14 +24,16 @@ public data class KeyAggCache(val data: ByteVector) { * @param isXonly true if the tweak is an x-only tweak * @return an updated cache, and the tweaked aggregated public key */ - public fun tweak(tweak: ByteVector32, isXonly: Boolean): Pair { + public fun tweak(tweak: ByteVector32, isXonly: Boolean): Either> = try { val localCache = toByteArray() val tweaked = if (isXonly) { Secp256k1.musigPubkeyXonlyTweakAdd(localCache, tweak.toByteArray()) } else { Secp256k1.musigPubkeyTweakAdd(localCache, tweak.toByteArray()) } - return Pair(KeyAggCache(localCache), PublicKey.parse(tweaked)) + Either.Right(Pair(KeyAggCache(localCache), PublicKey.parse(tweaked))) + } catch (t: Throwable) { + Either.Left(t) } public companion object { @@ -40,10 +43,12 @@ public data class KeyAggCache(val data: ByteVector) { * @return a new (if cache was null) or updated cache, and the aggregated public key */ @JvmStatic - public fun add(pubkeys: List, cache: KeyAggCache?): Pair { + public fun add(pubkeys: List, cache: KeyAggCache? = null): Either> = try { val localCache = cache?.data?.toByteArray() ?: ByteArray(Secp256k1.MUSIG2_PUBLIC_KEYAGG_CACHE_SIZE) val aggkey = Secp256k1.musigPubkeyAgg(pubkeys.map { it.value.toByteArray() }.toTypedArray(), localCache) - return Pair(XonlyPublicKey(aggkey.byteVector32()), KeyAggCache(localCache.byteVector())) + Either.Right(Pair(XonlyPublicKey(aggkey.byteVector32()), KeyAggCache(localCache.byteVector()))) + } catch (t: Throwable) { + Either.Left(t) } } } @@ -64,8 +69,10 @@ public data class Session(val data: ByteVector) { * @param aggCache key aggregation cache * @return a Musig2 partial signature */ - public fun sign(secretNonce: SecretNonce, pk: PrivateKey, aggCache: KeyAggCache): ByteVector32 { - return Secp256k1.musigPartialSign(secretNonce.data.toByteArray(), pk.value.toByteArray(), aggCache.data.toByteArray(), toByteArray()).byteVector32() + public fun sign(secretNonce: SecretNonce, pk: PrivateKey, aggCache: KeyAggCache): Either = try { + Either.Right(Secp256k1.musigPartialSign(secretNonce.data.toByteArray(), pk.value.toByteArray(), aggCache.data.toByteArray(), toByteArray()).byteVector32()) + } catch (t: Throwable) { + Either.Left(t) } /** @@ -75,18 +82,23 @@ public data class Session(val data: ByteVector) { * @param cache key aggregation cache * @return true if the partial signature is valid */ - public fun verify(psig: ByteVector32, pubnonce: IndividualNonce, pubkey: PublicKey, cache: KeyAggCache): Boolean { - return Secp256k1.musigPartialSigVerify(psig.toByteArray(), pubnonce.toByteArray(), pubkey.value.toByteArray(), cache.data.toByteArray(), toByteArray()) == 1 + public fun verify(psig: ByteVector32, pubnonce: IndividualNonce, pubkey: PublicKey, cache: KeyAggCache): Boolean = try { + Secp256k1.musigPartialSigVerify(psig.toByteArray(), pubnonce.toByteArray(), pubkey.value.toByteArray(), cache.data.toByteArray(), toByteArray()) == 1 + } catch (t: Throwable) { + false } /** * @param psigs partial signatures * @return the aggregate of all input partial signatures */ - public fun add(psigs: List): ByteVector64 { - return Secp256k1.musigPartialSigAgg(toByteArray(), psigs.map { it.toByteArray() }.toTypedArray()).byteVector64() + public fun add(psigs: List): Either = try { + Either.Right(Secp256k1.musigPartialSigAgg(toByteArray(), psigs.map { it.toByteArray() }.toTypedArray()).byteVector64()) + } catch (t: Throwable) { + Either.Left(t) } + public companion object { /** * @param aggregatedNonce aggregated public nonce @@ -95,9 +107,11 @@ public data class Session(val data: ByteVector) { * @return a Musig signing session */ @JvmStatic - public fun build(aggregatedNonce: AggregatedNonce, msg: ByteVector32, cache: KeyAggCache): Session { + public fun build(aggregatedNonce: AggregatedNonce, msg: ByteVector32, cache: KeyAggCache): Either = try { val session = Secp256k1.musigNonceProcess(aggregatedNonce.toByteArray(), msg.toByteArray(), cache.data.toByteArray()) - return Session(session.byteVector()) + Either.Right(Session(session.byteVector())) + } catch (t: Throwable) { + Either.Left(t) } } } @@ -125,9 +139,13 @@ public data class SecretNonce(val data: ByteVector) { * @return a (secret nonce, public nonce) tuple */ @JvmStatic - public fun generate(sessionId: ByteVector32, seckey: PrivateKey?, pubkey: PublicKey, msg: ByteVector32?, cache: KeyAggCache?, extraInput: ByteVector32?): Pair { + public fun generate(sessionId: ByteVector32, seckey: PrivateKey?, pubkey: PublicKey, msg: ByteVector32?, cache: KeyAggCache?, extraInput: ByteVector32?): Either> = try { val nonce = Secp256k1.musigNonceGen(sessionId.toByteArray(), seckey?.value?.toByteArray(), pubkey.value.toByteArray(), msg?.toByteArray(), cache?.data?.toByteArray(), extraInput?.toByteArray()) - return Pair(SecretNonce(nonce.copyOfRange(0, Secp256k1.MUSIG2_SECRET_NONCE_SIZE)), IndividualNonce(nonce.copyOfRange(Secp256k1.MUSIG2_SECRET_NONCE_SIZE, Secp256k1.MUSIG2_SECRET_NONCE_SIZE + Secp256k1.MUSIG2_PUBLIC_NONCE_SIZE))) + val secretNonce = SecretNonce(nonce.copyOfRange(0, Secp256k1.MUSIG2_SECRET_NONCE_SIZE)) + val publicNonce = IndividualNonce(nonce.copyOfRange(Secp256k1.MUSIG2_SECRET_NONCE_SIZE, Secp256k1.MUSIG2_SECRET_NONCE_SIZE + Secp256k1.MUSIG2_PUBLIC_NONCE_SIZE)) + Either.Right(Pair(secretNonce, publicNonce)) + } catch (t: Throwable) { + Either.Left(t) } } } @@ -148,9 +166,11 @@ public data class IndividualNonce(val data: ByteVector) { public companion object { @JvmStatic - public fun aggregate(nonces: List): AggregatedNonce { + public fun aggregate(nonces: List): Either = try { val agg = Secp256k1.musigNonceAgg(nonces.map { it.toByteArray() }.toTypedArray()) - return AggregatedNonce(agg) + Either.Right(AggregatedNonce(agg)) + } catch (t: Throwable) { + Either.Left(t) } } } diff --git a/src/commonTest/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2TestsCommon.kt b/src/commonTest/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2TestsCommon.kt index 868efef4..2c573ed4 100644 --- a/src/commonTest/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2TestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/bitcoin/crypto/musig2/Musig2TestsCommon.kt @@ -2,6 +2,7 @@ package fr.acinq.bitcoin.crypto.musig2 import fr.acinq.bitcoin.* import fr.acinq.bitcoin.reference.TransactionTestsCommon +import fr.acinq.bitcoin.utils.flatMap import fr.acinq.secp256k1.Hex import kotlinx.serialization.json.* import kotlin.random.Random @@ -20,7 +21,7 @@ class Musig2TestsCommon { tests.jsonObject["valid_test_cases"]!!.jsonArray.forEach { val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val expected = XonlyPublicKey(ByteVector32.fromValidHex(it.jsonObject["expected"]!!.jsonPrimitive.content)) - val (aggkey, _) = KeyAggCache.add(keyIndices.map { pubkeys[it] }, null) + val (aggkey, _) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!! assertEquals(expected, aggkey) } tests.jsonObject["error_test_cases"]!!.jsonArray.forEach { @@ -28,8 +29,8 @@ class Musig2TestsCommon { val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean } assertFails { - var (_, cache) = KeyAggCache.add(keyIndices.map { pubkeys[it] }, null) - tweakIndices.zip(isXonly).forEach { cache = cache.tweak(tweaks[it.first], it.second).first } + var (_, cache) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!! + tweakIndices.zip(isXonly).forEach { cache = cache.tweak(tweaks[it.first], it.second).right!!.first } } } } @@ -47,7 +48,7 @@ class Musig2TestsCommon { //val expectedSecnonce = SecretNonce(it.jsonObject["expected_secnonce"]!!.jsonPrimitive.content) val expectedPubnonce = IndividualNonce(it.jsonObject["expected_pubnonce"]!!.jsonPrimitive.content) if (aggpk == null) { - val (_, pubnonce) = SecretNonce.generate(randprime, sk, pk, msg?.byteVector32(), null, extraInput?.byteVector32()) + val (_, pubnonce) = SecretNonce.generate(randprime, sk, pk, msg?.byteVector32(), null, extraInput?.byteVector32()).right!! // assertEquals(expectedSecnonce, secnonce) assertEquals(expectedPubnonce, pubnonce) } @@ -61,13 +62,13 @@ class Musig2TestsCommon { tests.jsonObject["valid_test_cases"]!!.jsonArray.forEach { val nonceIndices = it.jsonObject["pnonce_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val expected = AggregatedNonce(it.jsonObject["expected"]!!.jsonPrimitive.content) - val agg = IndividualNonce.aggregate(nonceIndices.map { nonces[it] }) + val agg = IndividualNonce.aggregate(nonceIndices.map { nonces[it] }).right!! assertEquals(expected, agg) } tests.jsonObject["error_test_cases"]!!.jsonArray.forEach { val nonceIndices = it.jsonObject["pnonce_indices"]!!.jsonArray.map { it.jsonPrimitive.int } - assertFails { - IndividualNonce.aggregate(nonceIndices.map { nonces[it] }) + assertTrue { + IndividualNonce.aggregate(nonceIndices.map { nonces[it] }).isLeft } } } @@ -86,39 +87,39 @@ class Musig2TestsCommon { val nonceIndices = it.jsonObject["nonce_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val psigIndices = it.jsonObject["psig_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val expected = ByteVector64.fromValidHex(it.jsonObject["expected"]!!.jsonPrimitive.content) - val aggnonce = IndividualNonce.aggregate(nonceIndices.map { pnonces[it] }) + val aggnonce = IndividualNonce.aggregate(nonceIndices.map { pnonces[it] }).right!! val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean } assertEquals(AggregatedNonce(it.jsonObject["aggnonce"]!!.jsonPrimitive.content), aggnonce) val cache = run { - var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }, null) + var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!! tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second }.forEach { (tweak, isXonly) -> - c = c.tweak(tweak, isXonly).first + c = c.tweak(tweak, isXonly).right!!.first } c } - val session = Session.build(aggnonce, msg, cache) - val aggsig = session.add(psigIndices.map { psigs[it] }) + val session = Session.build(aggnonce, msg, cache).right!! + val aggsig = session.add(psigIndices.map { psigs[it] }).right!! assertEquals(expected, aggsig) } tests.jsonObject["error_test_cases"]!!.jsonArray.forEach { val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val nonceIndices = it.jsonObject["nonce_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val psigIndices = it.jsonObject["psig_indices"]!!.jsonArray.map { it.jsonPrimitive.int } - val aggnonce = IndividualNonce.aggregate(nonceIndices.map { pnonces[it] }) + val aggnonce = IndividualNonce.aggregate(nonceIndices.map { pnonces[it] }).right!! val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int } val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean } assertEquals(AggregatedNonce(it.jsonObject["aggnonce"]!!.jsonPrimitive.content), aggnonce) val cache = run { - var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }, null) + var (_, c) = KeyAggCache.add(keyIndices.map { pubkeys[it] }).right!! tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second }.forEach { (tweak, isXonly) -> - c = c.tweak(tweak, isXonly).first + c = c.tweak(tweak, isXonly).right!!.first } c } - val session = Session.build(aggnonce, msg, cache) - assertFails { - session.add(psigIndices.map { psigs[it] }) + val session = Session.build(aggnonce, msg, cache).right!! + assertTrue { + session.add(psigIndices.map { psigs[it] }).isLeft } } } @@ -140,23 +141,23 @@ class Musig2TestsCommon { val aggsig = run { val nonces = privkeys.map { - SecretNonce.generate(random.nextBytes(32).byteVector32(), it, it.publicKey(), null, null, null) + SecretNonce.generate(random.nextBytes(32).byteVector32(), it, it.publicKey(), null, null, null).right!! } val secnonces = nonces.map { it.first } val pubnonces = nonces.map { it.second } // aggregate public nonces - val aggnonce = IndividualNonce.aggregate(pubnonces) + val aggnonce = IndividualNonce.aggregate(pubnonces).right!! val cache = run { - val (_, c) = KeyAggCache.add(pubkeys, null) - val (c1, _) = c.tweak(plainTweak, false) - val (c2, _) = c1.tweak(xonlyTweak, true) + val (_, c) = KeyAggCache.add(pubkeys).right!! + val (c1, _) = c.tweak(plainTweak, false).right!! + val (c2, _) = c1.tweak(xonlyTweak, true).right!! c2 } - val session = Session.build(aggnonce, msg, cache) + val session = Session.build(aggnonce, msg, cache).right!! // create partial signatures val psigs = privkeys.indices.map { - session.sign(secnonces[it], privkeys[it], cache) + session.sign(secnonces[it], privkeys[it], cache).right!! } // verify partial signatures @@ -165,14 +166,14 @@ class Musig2TestsCommon { } // aggregate partial signatures - session.add(psigs) + session.add(psigs).right!! } // aggregate public keys val aggpub = run { - val (_, c) = KeyAggCache.add(pubkeys, null) - val (c1, _) = c.tweak(plainTweak, false) - val (_, p) = c1.tweak(xonlyTweak, true) + val (_, c) = KeyAggCache.add(pubkeys).right!! + val (c1, _) = c.tweak(plainTweak, false).right!! + val (_, p) = c1.tweak(xonlyTweak, true).right!! p } @@ -188,7 +189,7 @@ class Musig2TestsCommon { val bobPubKey = bobPrivKey.publicKey() // Alice and Bob exchange public keys and agree on a common aggregated key - val (internalPubKey, cache) = KeyAggCache.add(listOf(alicePubKey, bobPubKey), null) + val (internalPubKey, cache) = KeyAggCache.add(listOf(alicePubKey, bobPubKey)).right!! // we use the standard BIP86 tweak val commonPubKey = internalPubKey.outputKey(Crypto.TaprootTweak.NoScriptTweak).first @@ -199,19 +200,19 @@ class Musig2TestsCommon { val spendingTx = Transaction(2, listOf(TxIn(OutPoint(tx, 0), sequence = 0)), listOf(TxOut(Satoshi(10000), Script.pay2wpkh(alicePubKey))), 0) val commonSig = run { - val random = kotlin.random.Random.Default - val aliceNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), alicePrivKey, alicePubKey, null, cache, null) - val bobNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), bobPrivKey, bobPubKey, null, null, null) + val random = Random.Default + val aliceNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), alicePrivKey, alicePubKey, null, cache, null).right!! + val bobNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), bobPrivKey, bobPubKey, null, null, null).right!! - val aggnonce = IndividualNonce.aggregate(listOf(aliceNonce.second, bobNonce.second)) + val aggnonce = IndividualNonce.aggregate(listOf(aliceNonce.second, bobNonce.second)).right!! val msg = Transaction.hashForSigningSchnorr(spendingTx, 0, listOf(tx.txOut[0]), SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT) // we use the same ctx for Alice and Bob, they both know all the public keys that are used here - val (cache1, _) = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true) - val session = Session.build(aggnonce, msg, cache1) - val aliceSig = session.sign(aliceNonce.first, alicePrivKey, cache1) - val bobSig = session.sign(bobNonce.first, bobPrivKey, cache1) - session.add(listOf(aliceSig, bobSig)) + val (cache1, _) = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true).right!! + val session = Session.build(aggnonce, msg, cache1).right!! + val aliceSig = session.sign(aliceNonce.first, alicePrivKey, cache1).right!! + val bobSig = session.sign(bobNonce.first, bobPrivKey, cache1).right!! + session.add(listOf(aliceSig, bobSig)).right!! } // this tx looks like any other tx that spends a p2tr output, with a single signature @@ -235,7 +236,7 @@ class Musig2TestsCommon { val merkleRoot = scriptTree.hash() // the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key - val (internalPubKey, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey()), null) + val (internalPubKey, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).right!! // it is tweaked with the script's merkle root to get the pubkey that will be exposed val pubkeyScript: List = Script.pay2tr(internalPubKey, merkleRoot) @@ -257,17 +258,17 @@ class Musig2TestsCommon { ) // this is the beginning of an interactive musig2 signing session. if user and server are disconnected before they have exchanged partial // signatures they will have to start again with fresh nonces - val userNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null) - val serverNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null) + val userNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null).right!! + val serverNonce = SecretNonce.generate(random.nextBytes(32).byteVector32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null).right!! val txHash = Transaction.hashForSigningSchnorr(tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT) - val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce.second)) + val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce.second)).right!! - val (cache1, _) = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true) - val session = Session.build(commonNonce, txHash, cache1) - val userSig = session.sign(userNonce.first, userPrivateKey, cache1) - val serverSig = session.sign(serverNonce.first, serverPrivateKey, cache1) - val commonSig = session.add(listOf(userSig, serverSig)) + val (cache1, _) = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true).right!! + val session = Session.build(commonNonce, txHash, cache1).right!! + val userSig = session.sign(userNonce.first, userPrivateKey, cache1).right!! + val serverSig = session.sign(serverNonce.first, serverPrivateKey, cache1).right!! + val commonSig = session.add(listOf(userSig, serverSig)).right!! val signedTx = tx.updateWitness(0, ScriptWitness(listOf(commonSig))) Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) }