Skip to content

Commit

Permalink
Apply latest bitcoin-kmp musig2 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed Dec 6, 2023
1 parent fe20db0 commit 294b474
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 25 deletions.
5 changes: 5 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ kotlin {
}
}

configurations.all {
// do not cache changing (i.e. SNAPSHOT) dependencies
resolutionStrategy.cacheChangingModulesFor(0, TimeUnit.SECONDS)
}

sourceSets.all {
languageSettings.optIn("kotlin.RequiresOptIn")
languageSettings.optIn("kotlin.ExperimentalStdlibApi")
Expand Down
44 changes: 31 additions & 13 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ sealed class InteractiveTxInput {
override val previousTx: Transaction,
override val previousTxOutput: Long,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParams) : Local() {
val swapInParams: TxAddInputTlv.SwapInParams
) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
}

Expand All @@ -129,9 +130,11 @@ sealed class InteractiveTxInput {
override val previousTx: Transaction,
override val previousTxOutput: Long,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Local() {
val swapInParams: TxAddInputTlv.SwapInParamsMusig2
) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
}

/**
* A remote input that funds the interactive transaction.
* We only keep the data we need from our peer's TxAddInput to avoid storing potentially large messages in our DB.
Expand All @@ -147,17 +150,20 @@ sealed class InteractiveTxInput {
override val outPoint: OutPoint,
override val txOut: TxOut,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParams) : Remote()
val swapInParams: TxAddInputTlv.SwapInParams
) : Remote()

data class RemoteSwapInMusig2(
override val serialId: Long,
override val outPoint: OutPoint,
override val txOut: TxOut,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Remote()
val swapInParams: TxAddInputTlv.SwapInParamsMusig2
) : Remote()

/** The shared input can be added by us or by our peer, depending on who initiated the protocol. */
data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming, Outgoing
data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming,
Outgoing
}

sealed class InteractiveTxOutput {
Expand Down Expand Up @@ -282,6 +288,7 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v
0xfffffffdU,
TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.refundDelay)
)

else -> InteractiveTxInput.LocalMusig2SwapIn(
0,
i.previousTx.stripInputWitnesses(),
Expand Down Expand Up @@ -372,6 +379,7 @@ data class SharedTransaction(
val localFees: MilliSatoshi = localAmountIn - localAmountOut
val remoteFees: MilliSatoshi = remoteAmountIn - remoteAmountOut
val fees: Satoshi = (localFees + remoteFees).truncateToSatoshi()

// tx outputs spent by this transaction
val spentOutputs: Map<OutPoint, TxOut> = run {
val sharedOutput = sharedInput?.let { i -> mapOf(i.outPoint to i.txOut) } ?: mapOf()
Expand Down Expand Up @@ -432,13 +440,15 @@ data class SharedTransaction(
.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val userNonce = session.secretNonces[input.serialId]
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
require(session.txCompleteReceived != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
val commonNonce = IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce))
TxSignatures.Companion.PartialSignature(keyManager.swapInOnChainWallet.signSwapInputUserMusig2(unsignedTx, i, previousOutputs, userNonce, serverNonce), commonNonce)
val psig = keyManager.swapInOnChainWallet.signSwapInputUserMusig2(unsignedTx, i, previousOutputs, userNonce, serverNonce)
require(psig != null) { "cannot create partial signature for input ${input.serialId}" }
TxSignatures.Companion.PartialSignature(psig, commonNonce)
}
}.filterNotNull()

Expand All @@ -460,14 +470,16 @@ data class SharedTransaction(
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId)
val userNonce = session.secretNonces[input.serialId]
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
require(session.txCompleteReceived != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
val commonNonce = IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce))
val swapInProtocol = SwapInProtocolMusig2(input.swapInParams.userKey, serverKey.publicKey(), input.swapInParams.userRefundKey, input.swapInParams.refundDelay)
TxSignatures.Companion.PartialSignature(swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, serverNonce, serverKey, userNonce), commonNonce)
val psig = swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, serverNonce, serverKey, userNonce)
require(psig != null) { "cannot create partial signature for input ${input.serialId}" }
TxSignatures.Companion.PartialSignature(psig, commonNonce)
}
}.filterNotNull()

Expand Down Expand Up @@ -532,6 +544,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over
val unsignedTx = tx.buildUnsignedTx()
val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce)
val commonSig = ctx.partialSigAgg(listOf(userSig.sig, serverSig.sig))
require(commonSig != null)
val witness = swapInProtocol.witness(commonSig)
Pair(i.serialId, TxIn(OutPoint(i.previousTx, i.previousTxOutput), ByteVector.empty, i.sequence.toLong(), witness))
}
Expand All @@ -551,6 +564,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over
val unsignedTx = tx.buildUnsignedTx()
val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce)
val commonSig = ctx.partialSigAgg(listOf(userSig.sig, serverSig.sig))
require(commonSig != null)
val witness = swapInProtocol.witness(commonSig)
Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness))
}
Expand Down Expand Up @@ -650,14 +664,14 @@ data class InteractiveTxSession(
null -> {
// generate a new secret nonce for each musig2 new swapin every time we send TxComplete
val localMusig2SwapIns = localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
val secretNonces1 = localMusig2SwapIns.fold(secretNonces){ nonces, i ->
val secretNonces1 = localMusig2SwapIns.fold(secretNonces) { nonces, i ->
nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32())))
}
val remoteMusig2SwapIns = remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>()
val secretNonces2 = remoteMusig2SwapIns.fold(secretNonces1){ nonces, i ->
nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(null, i.swapInParams.serverKey, null, null, null, randomBytes32()) ))
val secretNonces2 = remoteMusig2SwapIns.fold(secretNonces1) { nonces, i ->
nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(null, i.swapInParams.serverKey, null, null, null, randomBytes32())))
}
val serialIds = (localMusig2SwapIns.map { it.serialId} + remoteMusig2SwapIns.map { it.serialId }).sorted()
val serialIds = (localMusig2SwapIns.map { it.serialId } + remoteMusig2SwapIns.map { it.serialId }).sorted()
val nonces = serialIds.map { secretNonces2[it]?.publicNonce() }.filterNotNull()
val txComplete = TxComplete(fundingParams.channelId, nonces)
val next = copy(secretNonces = secretNonces2, txCompleteSent = txComplete)
Expand All @@ -676,10 +690,12 @@ data class InteractiveTxSession(
val swapInParams = TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.refundDelay)
TxAddInput(fundingParams.channelId, msg.value.serialId, msg.value.previousTx, msg.value.previousTxOutput, msg.value.sequence, TlvStream(swapInParams))
}

is InteractiveTxInput.LocalMusig2SwapIn -> {
val swapInParams = TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay)
TxAddInput(fundingParams.channelId, msg.value.serialId, msg.value.previousTx, msg.value.previousTxOutput, msg.value.sequence, TlvStream(swapInParams))
}

is InteractiveTxInput.Shared -> TxAddInput(fundingParams.channelId, msg.value.serialId, msg.value.outPoint, msg.value.sequence)
}
Pair(next, InteractiveTxSessionAction.SendMessage(txAddInput))
Expand Down Expand Up @@ -730,6 +746,7 @@ data class InteractiveTxSession(
message.swapInParamsMusig2 != null -> {
InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2)
}

message.swapInParams != null -> InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams)
else -> InteractiveTxInput.RemoteOnly(message.serialId, outpoint, txOut, message.sequence)
}
Expand All @@ -747,6 +764,7 @@ data class InteractiveTxSession(
val secretNonce = secretNonces[input.serialId] ?: SecretNonce.generate(null, input.swapInParams.serverKey, null, null, null, randomBytes32())
session1.copy(secretNonces = secretNonces + (input.serialId to secretNonce))
}

else -> session1
}
return Either.Right(session2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ interface KeyManager {
return swapInProtocol.signSwapInputUser(fundingTx, index, parentTxOuts[fundingTx.txIn[index].outPoint.index.toInt()] , userPrivateKey)
}

fun signSwapInputUserMusig2(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userNonce: SecretNonce, serverNonce: IndividualNonce): ByteVector32 {
fun signSwapInputUserMusig2(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userNonce: SecretNonce, serverNonce: IndividualNonce): ByteVector32? {
return swapInProtocolMusig2.signSwapInputUser(fundingTx, index, parentTxOuts, userPrivateKey, userNonce, serverNonce)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ object Deserialization {

private fun Input.readTxId(): TxId = TxId(readByteVector32())

private fun Input.readPublicNonce() = IndividualNonce.fromBin(ByteArray(66).also { read(it, 0, it.size) })
private fun Input.readPublicNonce() = IndividualNonce(ByteArray(66).also { read(it, 0, it.size) })

private fun Input.readDelimitedByteArray(): ByteArray {
val size = readNumber().toInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SwapInProtocolMusig2(val userPublicKey: PublicKey, val serverPublicKey: Pu

fun witnessRefund(userSig: ByteVector64): ScriptWitness = ScriptWitness.empty.push(userSig).push(redeemScript).push(controlBlock)

fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userPrivateKey: PrivateKey, userNonce: SecretNonce, serverNonce: IndividualNonce): ByteVector32 {
fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userPrivateKey: PrivateKey, userNonce: SecretNonce, serverNonce: IndividualNonce): ByteVector32? {
require(userPrivateKey.publicKey() == userPublicKey)
val txHash = Transaction.hashForSigningSchnorr(fundingTx, index, parentTxOuts, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT)
val commonNonce = IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce))
Expand All @@ -101,7 +101,7 @@ class SwapInProtocolMusig2(val userPublicKey: PublicKey, val serverPublicKey: Pu
return Crypto.signSchnorr(txHash, userPrivateKey, Crypto.SchnorrTweak.NoTweak)
}

fun signSwapInputServer(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userNonce: IndividualNonce, serverPrivateKey: PrivateKey, serverNonce: SecretNonce): ByteVector32 {
fun signSwapInputServer(fundingTx: Transaction, index: Int, parentTxOuts: List<TxOut>, userNonce: IndividualNonce, serverPrivateKey: PrivateKey, serverNonce: SecretNonce): ByteVector32? {
val txHash = Transaction.hashForSigningSchnorr(fundingTx, index, parentTxOuts, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT)
val commonNonce = IndividualNonce.aggregate(listOf(userNonce, serverNonce.publicNonce()))
val ctx = SessionCtx(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ sealed class TxCompleteTlv : Tlv {
const val tag: Long = 101
override fun read(input: Input): Nonces {
val count = input.availableBytes / 66
val nonces = (0 until count).map { IndividualNonce.fromBin(LightningCodecs.bytes(input, 66)) }
val nonces = (0 until count).map { IndividualNonce(LightningCodecs.bytes(input, 66)) }
return Nonces(nonces)
}
}
Expand Down Expand Up @@ -146,7 +146,7 @@ sealed class TxSignaturesTlv : Tlv {
val count = input.availableBytes / (32 + 66)
val psigs = (0 until count).map {
val sig = LightningCodecs.bytes(input, 32).byteVector32()
val nonce = AggregatedNonce.fromBin(LightningCodecs.bytes(input, 66))
val nonce = AggregatedNonce(LightningCodecs.bytes(input, 66))
TxSignatures.Companion.PartialSignature(sig, nonce)
}
return SwapInUserPartialSigs(psigs)
Expand All @@ -167,7 +167,7 @@ sealed class TxSignaturesTlv : Tlv {
val count = input.availableBytes / (32 + 66)
val psigs = (0 until count).map {
val sig = LightningCodecs.bytes(input, 32).byteVector32()
val nonce = AggregatedNonce.fromBin(LightningCodecs.bytes(input, 66))
val nonce = AggregatedNonce(LightningCodecs.bytes(input, 66))
TxSignatures.Companion.PartialSignature(sig, nonce)
}
return SwapInServerPartialSigs(psigs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ class TransactionsTestsCommon : LightningTestSuite() {
val userNonce = SecretNonce.generate(userPrivateKey, userPrivateKey.publicKey(), commonPubKey, null, null, randomBytes32())
val serverNonce = SecretNonce.generate(serverPrivateKey, serverPrivateKey.publicKey(), commonPubKey, null, null, randomBytes32())

val userSig = swapInProtocolMusig2.signSwapInputUser(tx, 0, swapInTx.txOut, userPrivateKey, userNonce, serverNonce.publicNonce())
val serverSig = swapInProtocolMusig2.signSwapInputServer(tx, 0, swapInTx.txOut, userNonce.publicNonce(), serverPrivateKey, serverNonce)
val userSig = swapInProtocolMusig2.signSwapInputUser(tx, 0, swapInTx.txOut, userPrivateKey, userNonce, serverNonce.publicNonce())!!
val serverSig = swapInProtocolMusig2.signSwapInputServer(tx, 0, swapInTx.txOut, userNonce.publicNonce(), serverPrivateKey, serverNonce)!!
val ctx = swapInProtocolMusig2.signingCtx(tx, 0, swapInTx.txOut, IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce.publicNonce())))
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))!!
val signedTx = tx.updateWitness(0, swapInProtocolMusig2.witness(commonSig))
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ class LightningCodecsTestsCommon : LightningTestSuite() {
val pubKey1 = PrivateKey.fromHex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").publicKey()
val pubKey2 = PrivateKey.fromHex("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb").publicKey()
val swapInPartialSignatures = listOf(
TxSignatures.Companion.PartialSignature(ByteVector32("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"), AggregatedNonce(pubKey1, pubKey2)),
TxSignatures.Companion.PartialSignature(ByteVector32("dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"), AggregatedNonce(pubKey1, pubKey2))
TxSignatures.Companion.PartialSignature(ByteVector32("cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"), AggregatedNonce(pubKey1.value + pubKey2.value)),
TxSignatures.Companion.PartialSignature(ByteVector32("dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"), AggregatedNonce(pubKey1.value + pubKey2.value))
)
val signature = ByteVector64("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
// This is a random mainnet transaction.
Expand Down

0 comments on commit 294b474

Please sign in to comment.