From 2a1e84d46a185e0c63d567c8254352d3682f1a40 Mon Sep 17 00:00:00 2001 From: t-bast Date: Fri, 9 Aug 2024 18:23:17 +0200 Subject: [PATCH] Add support for trampoline failures Add support for the trampoline failure messages added to the BOLTs. We also add supports for encrypting failure e2e using the trampoline shared secrets on top of the outer onion shared secrets. This is a work-in-progress: the basic mechanism works, but it needs some clean-up / refactoring. --- .../payment/OutgoingPaymentFailure.kt | 5 +- .../payment/OutgoingPaymentHandler.kt | 10 ++-- .../payment/OutgoingPaymentPacket.kt | 35 +++++++---- .../fr/acinq/lightning/wire/FailureMessage.kt | 36 +++++++---- .../OutgoingPaymentHandlerTestsCommon.kt | 59 +++++++++++++++---- .../wire/FailureMessageTestsCommon.kt | 5 +- 6 files changed, 108 insertions(+), 42 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentFailure.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentFailure.kt index 77bab3048..a67f82bc4 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentFailure.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentFailure.kt @@ -78,13 +78,13 @@ data class OutgoingPaymentFailure(val reason: FinalFailure, val failures: List when (failure.value) { is AmountBelowMinimum -> LightningOutgoingPayment.Part.Status.Failure.PaymentAmountTooSmall is FeeInsufficient -> LightningOutgoingPayment.Part.Status.Failure.NotEnoughFees - TrampolineExpiryTooSoon -> LightningOutgoingPayment.Part.Status.Failure.NotEnoughFees - TrampolineFeeInsufficient -> LightningOutgoingPayment.Part.Status.Failure.NotEnoughFees + is TrampolineFeeOrExpiryInsufficient -> LightningOutgoingPayment.Part.Status.Failure.NotEnoughFees is FinalIncorrectCltvExpiry -> LightningOutgoingPayment.Part.Status.Failure.RecipientRejectedPayment is FinalIncorrectHtlcAmount -> LightningOutgoingPayment.Part.Status.Failure.RecipientRejectedPayment is IncorrectOrUnknownPaymentDetails -> LightningOutgoingPayment.Part.Status.Failure.RecipientRejectedPayment PaymentTimeout -> LightningOutgoingPayment.Part.Status.Failure.RecipientLiquidityIssue UnknownNextPeer -> LightningOutgoingPayment.Part.Status.Failure.RecipientIsOffline + UnknownNextTrampoline -> LightningOutgoingPayment.Part.Status.Failure.RecipientIsOffline is ExpiryTooSoon -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure ExpiryTooFar -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure is ChannelDisabled -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure @@ -92,6 +92,7 @@ data class OutgoingPaymentFailure(val reason: FinalFailure, val failures: List LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure PermanentChannelFailure -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure PermanentNodeFailure -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure + TemporaryTrampolineFailure -> LightningOutgoingPayment.Part.Status.Failure.TemporaryRemoteFailure is InvalidOnionBlinding -> LightningOutgoingPayment.Part.Status.Failure.Uninterpretable(failure.value.message) is InvalidOnionHmac -> LightningOutgoingPayment.Part.Status.Failure.Uninterpretable(failure.value.message) is InvalidOnionKey -> LightningOutgoingPayment.Part.Status.Failure.Uninterpretable(failure.value.message) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index 34df48660..50f4cea10 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -20,10 +20,7 @@ import fr.acinq.lightning.logging.mdc import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.utils.UUID import fr.acinq.lightning.utils.msat -import fr.acinq.lightning.wire.FailureMessage -import fr.acinq.lightning.wire.TrampolineExpiryTooSoon -import fr.acinq.lightning.wire.TrampolineFeeInsufficient -import fr.acinq.lightning.wire.UnknownNextPeer +import fr.acinq.lightning.wire.* class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: WalletParams, val db: OutgoingPaymentsDb) { @@ -168,8 +165,9 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle val trampolineFees = payment.request.trampolineFeesOverride ?: walletParams.trampolineFees val finalError = when { trampolineFees.size <= payment.attemptNumber + 1 -> FinalFailure.RetryExhausted - failure == Either.Right(UnknownNextPeer) -> FinalFailure.RecipientUnreachable - failure != Either.Right(TrampolineExpiryTooSoon) && failure != Either.Right(TrampolineFeeInsufficient) -> FinalFailure.UnknownError // non-retriable error + failure == Either.Right(UnknownNextPeer) || failure == Either.Right(UnknownNextTrampoline) -> FinalFailure.RecipientUnreachable + // TODO: take actual fees returned into account (rework the trampoline fees mechanism). + failure != Either.Right(TemporaryTrampolineFailure) && failure.right !is TrampolineFeeOrExpiryInsufficient -> FinalFailure.UnknownError // non-retriable error else -> null } return if (finalError != null) { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt index c92968c6e..a664f0cc8 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt @@ -2,6 +2,7 @@ package fr.acinq.lightning.payment import fr.acinq.bitcoin.* import fr.acinq.bitcoin.utils.Either +import fr.acinq.bitcoin.utils.flatMap import fr.acinq.lightning.CltvExpiry import fr.acinq.lightning.Feature import fr.acinq.lightning.Lightning @@ -9,6 +10,7 @@ import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.channel.ChannelCommand import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets +import fr.acinq.lightning.crypto.sphinx.SharedSecrets import fr.acinq.lightning.crypto.sphinx.Sphinx import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.wire.* @@ -53,7 +55,9 @@ object OutgoingPaymentPacket { val trampolinePaymentSecret = Lightning.randomBytes32() val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet) val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) - return Triple(trampolineAmount, trampolineExpiry, paymentOnion) + // We merge the shared secrets from each onion to allow decrypting failure onions. + val sharedSecrets = SharedSecrets(paymentOnion.sharedSecrets.perHopSecrets + trampolineOnion.sharedSecrets.perHopSecrets) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion.copy(sharedSecrets = sharedSecrets)) } /** @@ -162,16 +166,16 @@ object OutgoingPaymentPacket { } fun buildHtlcFailure(nodeSecret: PrivateKey, paymentHash: ByteVector32, onion: OnionRoutingPacket, reason: ChannelCommand.Htlc.Settlement.Fail.Reason): Either { - // we need to decrypt the payment onion to obtain the shared secret to build the error packet - return when (val result = Sphinx.peel(nodeSecret, paymentHash, onion)) { - is Either.Right -> { - val encryptedReason = when (reason) { - is ChannelCommand.Htlc.Settlement.Fail.Reason.Bytes -> FailurePacket.wrap(reason.bytes.toByteArray(), result.value.sharedSecret) - is ChannelCommand.Htlc.Settlement.Fail.Reason.Failure -> FailurePacket.create(result.value.sharedSecret, reason.message) - } - Either.Right(ByteVector(encryptedReason)) + return extractSharedSecrets(nodeSecret, paymentHash, onion).map { sharedSecrets -> + val encryptedReason = when (reason) { + is ChannelCommand.Htlc.Settlement.Fail.Reason.Bytes -> FailurePacket.wrap(reason.bytes.toByteArray(), sharedSecrets.first()) + is ChannelCommand.Htlc.Settlement.Fail.Reason.Failure -> FailurePacket.create(sharedSecrets.first(), reason.message) + } + if (sharedSecrets.size == 2) { + ByteVector(FailurePacket.wrap(encryptedReason, sharedSecrets.last())) + } else { + ByteVector(encryptedReason) } - is Either.Left -> Either.Left(result.value) } } @@ -183,4 +187,15 @@ object OutgoingPaymentPacket { } } + private fun extractSharedSecrets(nodeSecret: PrivateKey, paymentHash: ByteVector32, onion: OnionRoutingPacket): Either> { + // We decrypt the payment onion to obtain the shared secret. + return Sphinx.peel(nodeSecret, paymentHash, onion).flatMap { outer -> + // If it contains a trampoline onion, we decrypt it as well to obtain the shared secret. + when (val trampolineOnion = PaymentOnion.PerHopPayload.read(outer.payload.toByteArray()).map { it.get() }.right) { + null -> Either.Right(listOf(outer.sharedSecret)) + else -> Sphinx.peel(nodeSecret, paymentHash, trampolineOnion.packet).map { listOf(it.sharedSecret, outer.sharedSecret) } + } + } + } + } \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/FailureMessage.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/FailureMessage.kt index 38a75ebc6..e9f81aca4 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/FailureMessage.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/FailureMessage.kt @@ -5,6 +5,7 @@ import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.bitcoin.io.ByteArrayOutput import fr.acinq.bitcoin.io.Output import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.CltvExpiryDelta import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.utils.toByteVector32 @@ -41,10 +42,8 @@ sealed class FailureMessage { UnknownNextPeer.code -> UnknownNextPeer AmountBelowMinimum.code -> AmountBelowMinimum(MilliSatoshi(LightningCodecs.u64(stream)), readChannelUpdate(stream)) FeeInsufficient.code -> FeeInsufficient(MilliSatoshi(LightningCodecs.u64(stream)), readChannelUpdate(stream)) - TrampolineFeeInsufficient.code -> TrampolineFeeInsufficient IncorrectCltvExpiry.code -> IncorrectCltvExpiry(CltvExpiry(LightningCodecs.u32(stream).toLong()), readChannelUpdate(stream)) ExpiryTooSoon.code -> ExpiryTooSoon(readChannelUpdate(stream)) - TrampolineExpiryTooSoon.code -> TrampolineExpiryTooSoon IncorrectOrUnknownPaymentDetails.code -> { val amount = if (stream.availableBytes > 0) MilliSatoshi(LightningCodecs.u64(stream)) else MilliSatoshi(0) val blockHeight = if (stream.availableBytes > 0) LightningCodecs.u32(stream).toLong() else 0L @@ -56,6 +55,9 @@ sealed class FailureMessage { ExpiryTooFar.code -> ExpiryTooFar InvalidOnionPayload.code -> InvalidOnionPayload(LightningCodecs.bigSize(stream), LightningCodecs.u16(stream)) PaymentTimeout.code -> PaymentTimeout + TemporaryTrampolineFailure.code -> TemporaryTrampolineFailure + TrampolineFeeOrExpiryInsufficient.code -> TrampolineFeeOrExpiryInsufficient(MilliSatoshi(LightningCodecs.u32(stream).toLong()), LightningCodecs.u32(stream), CltvExpiryDelta(LightningCodecs.u16(stream))) + UnknownNextTrampoline.code -> UnknownNextTrampoline else -> UnknownFailureMessage(code) } } @@ -90,13 +92,11 @@ sealed class FailureMessage { LightningCodecs.writeU64(input.amount.toLong(), out) writeChannelUpdate(input.update, out) } - TrampolineFeeInsufficient -> {} is IncorrectCltvExpiry -> { LightningCodecs.writeU32(input.expiry.toLong().toInt(), out) writeChannelUpdate(input.update, out) } is ExpiryTooSoon -> writeChannelUpdate(input.update, out) - TrampolineExpiryTooSoon -> {} is IncorrectOrUnknownPaymentDetails -> { LightningCodecs.writeU64(input.amount.toLong(), out) LightningCodecs.writeU32(input.height.toInt(), out) @@ -114,6 +114,13 @@ sealed class FailureMessage { LightningCodecs.writeU16(input.offset, out) } PaymentTimeout -> {} + TemporaryTrampolineFailure -> {} + is TrampolineFeeOrExpiryInsufficient -> { + LightningCodecs.writeU32(input.feeBase.toLong().toInt(), out) + LightningCodecs.writeU32(input.feeProportionalMillionths, out) + LightningCodecs.writeU16(input.expiryDelta.toInt(), out) + } + UnknownNextTrampoline -> {} is UnknownFailureMessage -> {} } } @@ -195,10 +202,6 @@ data class FeeInsufficient(val amount: MilliSatoshi, override val update: Channe override val message get() = "payment fee was below the minimum required by the channel" companion object { const val code = UPDATE or 12 } } -object TrampolineFeeInsufficient : FailureMessage(), Node { - override val code get() = NODE or 51 - override val message get() = "payment fee was below the minimum required by the trampoline node" -} data class IncorrectCltvExpiry(val expiry: CltvExpiry, override val update: ChannelUpdate) : FailureMessage(), Update { override val code get() = IncorrectCltvExpiry.code override val message get() = "payment expiry doesn't match the value in the onion" @@ -209,10 +212,6 @@ data class ExpiryTooSoon(override val update: ChannelUpdate) : FailureMessage(), override val message get() = "payment expiry is too close to the current block height for safe handling by the relaying node" companion object { const val code = UPDATE or 14 } } -object TrampolineExpiryTooSoon : FailureMessage(), Node { - override val code get() = NODE or 52 - override val message get() = "payment expiry is too close to the current block height for safe handling by the relaying node" -} data class IncorrectOrUnknownPaymentDetails(val amount: MilliSatoshi, val height: Long) : FailureMessage(), Perm { override val code get() = IncorrectOrUnknownPaymentDetails.code override val message get() = "incorrect payment details or unknown payment hash" @@ -246,6 +245,19 @@ data object PaymentTimeout : FailureMessage() { override val code get() = 23 override val message get() = "the complete payment amount was not received within a reasonable time" } +data object TemporaryTrampolineFailure : FailureMessage(), Node { + override val code get() = NODE or 25 + override val message get() = "the trampoline node was unable to relay the payment because of downstream temporary failures" +} +data class TrampolineFeeOrExpiryInsufficient(val feeBase: MilliSatoshi, val feeProportionalMillionths: Int, val expiryDelta: CltvExpiryDelta) : FailureMessage(), Node { + override val code get() = TrampolineFeeOrExpiryInsufficient.code + override val message get() = "trampoline fees or expiry are insufficient to relay the payment" + companion object { const val code = NODE or 26 } +} +data object UnknownNextTrampoline : FailureMessage(), Perm { + override val code get() = PERM or 27 + override val message get() = "the trampoline node was unable to find the next trampoline node" +} /** * We allow remote nodes to send us unknown failure codes (e.g. deprecated failure codes). * By reading the PERM and NODE bits of the failure code we can still extract useful information for payment retry even diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt index 6cc39e607..a0d894b31 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt @@ -376,7 +376,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // This first attempt fails because fees are too low. val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val progress2 = outgoingPaymentHandler.processAddSettled(channelId1, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val progress2 = outgoingPaymentHandler.processAddSettled(channelId1, createTrampolineFailure(add1, attempt, TrampolineFeeOrExpiryInsufficient(100.msat, 100, CltvExpiryDelta(48))), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(progress2) val (channelId2, add2) = findAddHtlcCommand(progress2) assertEquals(channelId1, channelId2) @@ -451,9 +451,9 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { assertEquals(83_100_000.msat, add1.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createTrampolineFailure(add1, attempt, TrampolineFeeOrExpiryInsufficient(100.msat, 100, CltvExpiryDelta(48))), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(fail) - val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.InsufficientBalance, listOf(Either.Right(TrampolineFeeInsufficient)))) + val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.InsufficientBalance, listOf(Either.Right(TrampolineFeeOrExpiryInsufficient(100.msat, 100, CltvExpiryDelta(48)))))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) @@ -479,15 +479,15 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { assertEquals(230_000.msat, add1.amount) val attempt1 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val progress2 = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add1, attempt1, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val progress2 = outgoingPaymentHandler.processAddSettled(alice.channelId, createTrampolineFailure(add1, attempt1, TrampolineFeeOrExpiryInsufficient(100.msat, 100, CltvExpiryDelta(48))), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(progress2) val (_, add2) = findAddHtlcCommand(progress2) assertEquals(240_000.msat, add2.amount) val attempt2 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add2, attempt2, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createTrampolineFailure(add2, attempt2, TemporaryTrampolineFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(fail) - val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeInsufficient), Either.Right(TrampolineFeeInsufficient)))) + val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeOrExpiryInsufficient(100.msat, 100, CltvExpiryDelta(48))), Either.Right(TemporaryTrampolineFailure)))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) @@ -512,7 +512,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { assertEquals(50_000.msat, add.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add, attempt, remoteFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createTrampolineFailure(add, attempt, remoteFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(fail) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(userFailure, listOf(Either.Right(remoteFailure)))) assertFailureEquals(expected, fail) @@ -522,6 +522,28 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } } + @Test + fun `recipient failure`()= runSuspendTest { + val (alice, _) = TestsHelper.reachNormal() + val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) + val invoice = makeInvoice(amount = null, supportsTrampoline = true) + val payment = PayInvoice(UUID.randomUUID(), 50_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) + + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + val (_, add) = findAddHtlcCommand(progress) + assertEquals(50_000.msat, add.amount) + + val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRecipientFailure(add, attempt, IncorrectOrUnknownPaymentDetails(50_000.msat, 0)), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) + val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.UnknownError, listOf(Either.Right(IncorrectOrUnknownPaymentDetails(50_000.msat, 0))))) + assertFailureEquals(expected, fail) + + assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) + assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) + } + @Test fun `failure after a wallet restart`() = runSuspendTest { val (alice, _) = TestsHelper.reachNormal() @@ -543,7 +565,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Step 2: the wallet restarts and payment fails. run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add, attempt, TemporaryNodeFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createTrampolineFailure(add, attempt, TemporaryNodeFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) assertIs(fail) assertEquals(attempt.request, fail.request) assertEquals(FinalFailure.WalletRestarted, fail.failure.reason) @@ -618,8 +640,25 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { return ChannelAction.ProcessCmdRes.AddSettledFulfill(add.paymentId, updateAddHtlc, ChannelAction.HtlcResult.Fulfill.RemoteFulfill(UpdateFulfillHtlc(channelId, updateAddHtlc.id, preimage))) } - private fun createRemoteFailure(add: ChannelCommand.Htlc.Add, attempt: OutgoingPaymentHandler.PaymentAttempt, failureMessage: FailureMessage): ChannelAction.ProcessCmdRes.AddSettledFail { - val reason = FailurePacket.create(attempt.sharedSecrets.perHopSecrets.last().first, failureMessage) + private fun createTrampolineFailure(add: ChannelCommand.Htlc.Add, attempt: OutgoingPaymentHandler.PaymentAttempt, failureMessage: FailureMessage): ChannelAction.ProcessCmdRes.AddSettledFail { + val reason = FailurePacket.create(attempt.sharedSecrets.perHopSecrets.first().first, failureMessage) + val updateAddHtlc = makeUpdateAddHtlc(randomBytes32(), add) + return ChannelAction.ProcessCmdRes.AddSettledFail( + add.paymentId, + updateAddHtlc, + ChannelAction.HtlcResult.Fail.RemoteFail(UpdateFailHtlc(updateAddHtlc.channelId, updateAddHtlc.id, reason.toByteVector())) + ) + } + + private fun createRecipientFailure(add: ChannelCommand.Htlc.Add, attempt: OutgoingPaymentHandler.PaymentAttempt, failureMessage: FailureMessage) : ChannelAction.ProcessCmdRes.AddSettledFail { + // TODO: explain (1 shared secret for the outer onion, 2 shared secrets for the trampoline hop trampoline -> recipient) + assertEquals(3, attempt.sharedSecrets.perHopSecrets.size) + // The recipient encrypts the failure with its trampoline shared secret. + val failure = FailurePacket.create(attempt.sharedSecrets.perHopSecrets[2].first, failureMessage) + // The trampoline node encrypts the failure with its trampoline shared secret. + val intermediate = FailurePacket.wrap(failure, attempt.sharedSecrets.perHopSecrets[1].first) + // The trampoline node encrypts the failure with its outer shared secret. + val reason = FailurePacket.wrap(intermediate, attempt.sharedSecrets.perHopSecrets[0].first) val updateAddHtlc = makeUpdateAddHtlc(randomBytes32(), add) return ChannelAction.ProcessCmdRes.AddSettledFail( add.paymentId, diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/FailureMessageTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/FailureMessageTestsCommon.kt index d9e5d50dc..356007cc9 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/FailureMessageTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/FailureMessageTestsCommon.kt @@ -53,8 +53,9 @@ class FailureMessageTestsCommon : LightningTestSuite() { ExpiryTooFar, InvalidOnionPayload(561, 1105), PaymentTimeout, - TrampolineFeeInsufficient, - TrampolineExpiryTooSoon + TemporaryTrampolineFailure, + UnknownNextTrampoline, + TrampolineFeeOrExpiryInsufficient(100.msat, 50, CltvExpiryDelta(36)) ) msgs.forEach {