From b825d1a0a67a4d9727b93378676b3d76f51c0d10 Mon Sep 17 00:00:00 2001 From: sstone Date: Mon, 4 Dec 2023 13:42:18 +0100 Subject: [PATCH] Revert to using a map to store musig2 secret nonces The semantics of the secret nonce field added to tx inputs were wrong, these nonces are transient and should be tied to the lifecycle of the interactive tx session, this is much more explicit now. --- .../acinq/lightning/channel/InteractiveTx.kt | 98 ++++++------------- .../serialization/v4/Deserialization.kt | 6 +- 2 files changed, 33 insertions(+), 71 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt index 62a06ab24..2c0640cf6 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt @@ -129,35 +129,8 @@ sealed class InteractiveTxInput { override val previousTx: Transaction, override val previousTxOutput: Long, override val sequence: UInt, - val swapInParams: TxAddInputTlv.SwapInParamsMusig2, - val secretNonce: SecretNonce) : Local() { + val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Local() { override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput) - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false - - other as LocalMusig2SwapIn - - if (serialId != other.serialId) return false - if (previousTx != other.previousTx) return false - if (previousTxOutput != other.previousTxOutput) return false - if (sequence != other.sequence) return false - if (swapInParams != other.swapInParams) return false - if (outPoint != other.outPoint) return false - - return true - } - - override fun hashCode(): Int { - var result = serialId.hashCode() - result = 31 * result + previousTx.hashCode() - result = 31 * result + previousTxOutput.hashCode() - result = 31 * result + sequence.hashCode() - result = 31 * result + swapInParams.hashCode() - result = 31 * result + outPoint.hashCode() - return result - } - } /** * A remote input that funds the interactive transaction. @@ -181,32 +154,7 @@ sealed class InteractiveTxInput { override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, - val swapInParams: TxAddInputTlv.SwapInParamsMusig2, - val secretNonce: SecretNonce) : Remote() { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false - - other as RemoteSwapInMusig2 - - if (serialId != other.serialId) return false - if (outPoint != other.outPoint) return false - if (txOut != other.txOut) return false - if (sequence != other.sequence) return false - if (swapInParams != other.swapInParams) return false - - return true - } - - override fun hashCode(): Int { - var result = serialId.hashCode() - result = 31 * result + outPoint.hashCode() - result = 31 * result + txOut.hashCode() - result = 31 * result + sequence.hashCode() - result = 31 * result + swapInParams.hashCode() - return result - } - } + val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Remote() /** The shared input can be added by us or by our peer, depending on who initiated the protocol. */ data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming, Outgoing @@ -340,7 +288,6 @@ data class FundingContributions(val inputs: List, v i.outputIndex.toLong(), 0xfffffffdU, TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay), - SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPrivateKey.publicKey(), null, null, null, randomBytes32()), ) } } @@ -485,7 +432,8 @@ data class SharedTransaction( .filterIsInstance() .find { txIn.outPoint == it.outPoint } ?.let { input -> - val userNonce = input.secretNonce + val userNonce = session.secretNonces[input.serialId] + require(userNonce != null) require(session.txCompleteReceived != null) val serverNonce = receivedNonces[input.serialId] require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } @@ -512,7 +460,8 @@ data class SharedTransaction( .find { txIn.outPoint == it.outPoint } ?.let { input -> val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId) - val userNonce = input.secretNonce + val userNonce = session.secretNonces[input.serialId] + require(userNonce != null) require(session.txCompleteReceived != null) val serverNonce = receivedNonces[input.serialId] require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } @@ -659,7 +608,9 @@ data class InteractiveTxSession( val txCompleteSent: TxComplete? = null, val txCompleteReceived: TxComplete? = null, val inputsReceivedCount: Int = 0, - val outputsReceivedCount: Int = 0) { + val outputsReceivedCount: Int = 0, + val secretNonces: Map = mapOf() +) { // Example flow: // +-------+ +-------+ @@ -699,11 +650,17 @@ data class InteractiveTxSession( null -> { // generate a new secret nonce for each musig2 new swapin every time we send TxComplete val localMusig2SwapIns = localInputs.filterIsInstance() - val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() } + val secretNonces1 = localMusig2SwapIns.fold(secretNonces){ nonces, i -> + nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32()))) + } val remoteMusig2SwapIns = remoteInputs.filterIsInstance() - val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() } - val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).sortedBy { it.first }.map { it.second }) - val next = copy(txCompleteSent = txComplete) + val secretNonces2 = remoteMusig2SwapIns.fold(secretNonces1){ nonces, i -> + nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(null, i.swapInParams.serverKey, null, null, null, randomBytes32()) )) + } + val serialIds = (localMusig2SwapIns.map { it.serialId} + remoteMusig2SwapIns.map { it.serialId }).sorted() + val nonces = serialIds.map { secretNonces2[it]?.publicNonce() }.filterNotNull() + val txComplete = TxComplete(fundingParams.channelId, nonces) + val next = copy(secretNonces = secretNonces2, txCompleteSent = txComplete) if (next.isComplete) { Pair(next, next.validateTx(txComplete)) } else { @@ -739,7 +696,7 @@ data class InteractiveTxSession( } } - private fun receiveInput(message: TxAddInput): Either { + private fun receiveInput(message: TxAddInput): Either { if (inputsReceivedCount + 1 >= MAX_INPUTS_OUTPUTS_RECEIVED) { return Either.Left(InteractiveTxSessionAction.TooManyInteractiveTxRounds(message.channelId)) } @@ -771,8 +728,7 @@ data class InteractiveTxSession( val txOut = message.previousTx.txOut[message.previousTxOutput.toInt()] when { message.swapInParamsMusig2 != null -> { - val secretNonce = SecretNonce.generate(null, message.swapInParamsMusig2.serverKey, null, null, null, randomBytes32()) - InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2, secretNonce) + InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2) } message.swapInParams != null -> InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams) else -> InteractiveTxInput.RemoteOnly(message.serialId, outpoint, txOut, message.sequence) @@ -785,7 +741,15 @@ data class InteractiveTxSession( if (message.sequence > 0xfffffffdU) { return Either.Left(InteractiveTxSessionAction.NonReplaceableInput(message.channelId, message.serialId, input.outPoint.txid, input.outPoint.index, message.sequence.toLong())) } - return Either.Right(input) + val session1 = this.copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = null) + val session2 = when (input) { + is InteractiveTxInput.RemoteSwapInMusig2 -> { + val secretNonce = secretNonces[input.serialId] ?: SecretNonce.generate(null, input.swapInParams.serverKey, null, null, null, randomBytes32()) + session1.copy(secretNonces = secretNonces + (input.serialId to secretNonce)) + } + else -> session1 + } + return Either.Right(session2) } private fun receiveOutput(message: TxAddOutput): Either { @@ -814,7 +778,7 @@ data class InteractiveTxSession( is TxAddInput -> { receiveInput(message).fold( { f -> Pair(this, f) }, - { input -> copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = null).send() } + { next -> next.send() } ) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt index 62724a32c..96958ea3c 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt @@ -240,8 +240,7 @@ object Deserialization { previousTx = readTransaction(), previousTxOutput = readNumber(), sequence = readNumber().toUInt(), - swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this), - secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey()) + swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this) ) else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Local::class}") } @@ -265,8 +264,7 @@ object Deserialization { outPoint = readOutPoint(), txOut = TxOut.read(readDelimitedByteArray()), sequence = readNumber().toUInt(), - swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this), - secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey()) + swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this) ) else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Remote::class}") }