Skip to content

Commit

Permalink
Revert to using a map to store musig2 secret nonces
Browse files Browse the repository at this point in the history
The semantics of the secret nonce field added to tx inputs were wrong, these nonces are transient and should be tied to the lifecycle
of the interactive tx session, this is much more explicit now.
  • Loading branch information
sstone committed Dec 5, 2023
1 parent aa0c3c9 commit b825d1a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 71 deletions.
98 changes: 31 additions & 67 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,8 @@ sealed class InteractiveTxInput {
override val previousTx: Transaction,
override val previousTxOutput: Long,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2,
val secretNonce: SecretNonce) : Local() {
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as LocalMusig2SwapIn

if (serialId != other.serialId) return false
if (previousTx != other.previousTx) return false
if (previousTxOutput != other.previousTxOutput) return false
if (sequence != other.sequence) return false
if (swapInParams != other.swapInParams) return false
if (outPoint != other.outPoint) return false

return true
}

override fun hashCode(): Int {
var result = serialId.hashCode()
result = 31 * result + previousTx.hashCode()
result = 31 * result + previousTxOutput.hashCode()
result = 31 * result + sequence.hashCode()
result = 31 * result + swapInParams.hashCode()
result = 31 * result + outPoint.hashCode()
return result
}

}
/**
* A remote input that funds the interactive transaction.
Expand All @@ -181,32 +154,7 @@ sealed class InteractiveTxInput {
override val outPoint: OutPoint,
override val txOut: TxOut,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2,
val secretNonce: SecretNonce) : Remote() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as RemoteSwapInMusig2

if (serialId != other.serialId) return false
if (outPoint != other.outPoint) return false
if (txOut != other.txOut) return false
if (sequence != other.sequence) return false
if (swapInParams != other.swapInParams) return false

return true
}

override fun hashCode(): Int {
var result = serialId.hashCode()
result = 31 * result + outPoint.hashCode()
result = 31 * result + txOut.hashCode()
result = 31 * result + sequence.hashCode()
result = 31 * result + swapInParams.hashCode()
return result
}
}
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Remote()

/** The shared input can be added by us or by our peer, depending on who initiated the protocol. */
data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming, Outgoing
Expand Down Expand Up @@ -340,7 +288,6 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v
i.outputIndex.toLong(),
0xfffffffdU,
TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay),
SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPrivateKey.publicKey(), null, null, null, randomBytes32()),
)
}
}
Expand Down Expand Up @@ -485,7 +432,8 @@ data class SharedTransaction(
.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val userNonce = input.secretNonce
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
require(session.txCompleteReceived != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
Expand All @@ -512,7 +460,8 @@ data class SharedTransaction(
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId)
val userNonce = input.secretNonce
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
require(session.txCompleteReceived != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
Expand Down Expand Up @@ -659,7 +608,9 @@ data class InteractiveTxSession(
val txCompleteSent: TxComplete? = null,
val txCompleteReceived: TxComplete? = null,
val inputsReceivedCount: Int = 0,
val outputsReceivedCount: Int = 0) {
val outputsReceivedCount: Int = 0,
val secretNonces: Map<Long, SecretNonce> = mapOf()
) {

// Example flow:
// +-------+ +-------+
Expand Down Expand Up @@ -699,11 +650,17 @@ data class InteractiveTxSession(
null -> {
// generate a new secret nonce for each musig2 new swapin every time we send TxComplete
val localMusig2SwapIns = localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }
val secretNonces1 = localMusig2SwapIns.fold(secretNonces){ nonces, i ->
nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32())))
}
val remoteMusig2SwapIns = remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>()
val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }
val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).sortedBy { it.first }.map { it.second })
val next = copy(txCompleteSent = txComplete)
val secretNonces2 = remoteMusig2SwapIns.fold(secretNonces1){ nonces, i ->
nonces + (i.serialId to (nonces[i.serialId] ?: SecretNonce.generate(null, i.swapInParams.serverKey, null, null, null, randomBytes32()) ))
}
val serialIds = (localMusig2SwapIns.map { it.serialId} + remoteMusig2SwapIns.map { it.serialId }).sorted()
val nonces = serialIds.map { secretNonces2[it]?.publicNonce() }.filterNotNull()
val txComplete = TxComplete(fundingParams.channelId, nonces)
val next = copy(secretNonces = secretNonces2, txCompleteSent = txComplete)
if (next.isComplete) {
Pair(next, next.validateTx(txComplete))
} else {
Expand Down Expand Up @@ -739,7 +696,7 @@ data class InteractiveTxSession(
}
}

private fun receiveInput(message: TxAddInput): Either<InteractiveTxSessionAction.RemoteFailure, InteractiveTxInput.Incoming> {
private fun receiveInput(message: TxAddInput): Either<InteractiveTxSessionAction.RemoteFailure, InteractiveTxSession> {
if (inputsReceivedCount + 1 >= MAX_INPUTS_OUTPUTS_RECEIVED) {
return Either.Left(InteractiveTxSessionAction.TooManyInteractiveTxRounds(message.channelId))
}
Expand Down Expand Up @@ -771,8 +728,7 @@ data class InteractiveTxSession(
val txOut = message.previousTx.txOut[message.previousTxOutput.toInt()]
when {
message.swapInParamsMusig2 != null -> {
val secretNonce = SecretNonce.generate(null, message.swapInParamsMusig2.serverKey, null, null, null, randomBytes32())
InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2, secretNonce)
InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2)
}
message.swapInParams != null -> InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams)
else -> InteractiveTxInput.RemoteOnly(message.serialId, outpoint, txOut, message.sequence)
Expand All @@ -785,7 +741,15 @@ data class InteractiveTxSession(
if (message.sequence > 0xfffffffdU) {
return Either.Left(InteractiveTxSessionAction.NonReplaceableInput(message.channelId, message.serialId, input.outPoint.txid, input.outPoint.index, message.sequence.toLong()))
}
return Either.Right(input)
val session1 = this.copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = null)
val session2 = when (input) {
is InteractiveTxInput.RemoteSwapInMusig2 -> {
val secretNonce = secretNonces[input.serialId] ?: SecretNonce.generate(null, input.swapInParams.serverKey, null, null, null, randomBytes32())
session1.copy(secretNonces = secretNonces + (input.serialId to secretNonce))
}
else -> session1
}
return Either.Right(session2)
}

private fun receiveOutput(message: TxAddOutput): Either<InteractiveTxSessionAction.RemoteFailure, InteractiveTxOutput.Incoming> {
Expand Down Expand Up @@ -814,7 +778,7 @@ data class InteractiveTxSession(
is TxAddInput -> {
receiveInput(message).fold(
{ f -> Pair(this, f) },
{ input -> copy(remoteInputs = remoteInputs + input, inputsReceivedCount = inputsReceivedCount + 1, txCompleteReceived = null).send() }
{ next -> next.send() }
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ object Deserialization {
previousTx = readTransaction(),
previousTxOutput = readNumber(),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this),
secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey())
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this)
)
else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Local::class}")
}
Expand All @@ -265,8 +264,7 @@ object Deserialization {
outPoint = readOutPoint(),
txOut = TxOut.read(readDelimitedByteArray()),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this),
secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey())
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this)
)
else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Remote::class}")
}
Expand Down

0 comments on commit b825d1a

Please sign in to comment.