From daf6a19f92bb2049bd35d09808a868f899178a83 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 17 Nov 2023 10:22:22 +0100 Subject: [PATCH] Add OfferManager --- .../kotlin/fr/acinq/lightning/NodeParams.kt | 3 + .../kotlin/fr/acinq/lightning/io/Peer.kt | 31 +++- .../acinq/lightning/payment/OfferManager.kt | 142 ++++++++++++++++++ .../lightning/payment/OfferPaymentFailure.kt | 10 ++ .../message/OfferManagerTestsCommon.kt | 87 +++++++++++ 5 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt create mode 100644 src/commonMain/kotlin/fr/acinq/lightning/payment/OfferPaymentFailure.kt create mode 100644 src/commonTest/kotlin/fr/acinq/lightning/message/OfferManagerTestsCommon.kt diff --git a/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt b/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt index 8d8fb0786..91bfb8b2f 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt @@ -120,6 +120,7 @@ data class RecipientCltvExpiryParams(val min: CltvExpiryDelta, val max: CltvExpi * @param zeroConfPeers list of peers with whom we use zero-conf (note that this is a strong trust assumption). * @param liquidityPolicy fee policy for liquidity events, can be modified at any time. * @param minFinalCltvExpiryDelta cltv-expiry-delta that we require when receiving a payment. + * @param maxFinalCltvExpiryDelta maximum cltv-expiry-delta that we accept when receiving a payment. * @param bolt12invoiceExpiry duration for which bolt12 invoices that we create are valid. */ data class NodeParams( @@ -151,6 +152,7 @@ data class NodeParams( val zeroConfPeers: Set, val liquidityPolicy: MutableStateFlow, val minFinalCltvExpiryDelta: CltvExpiryDelta, + val maxFinalCltvExpiryDelta: CltvExpiryDelta, val bolt12invoiceExpiry: Duration ) { val nodePrivateKey get() = keyManager.nodeKeys.nodeKey.privateKey @@ -226,6 +228,7 @@ data class NodeParams( paymentRecipientExpiryParams = RecipientCltvExpiryParams(CltvExpiryDelta(75), CltvExpiryDelta(200)), liquidityPolicy = MutableStateFlow(LiquidityPolicy.Auto(maxAbsoluteFee = 2_000.sat, maxRelativeFeeBasisPoints = 3_000 /* 3000 = 30 % */, skipAbsoluteFeeCheck = false)), minFinalCltvExpiryDelta = Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, + maxFinalCltvExpiryDelta = CltvExpiryDelta(500), bolt12invoiceExpiry = 60.seconds ) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index afb150d13..706d03188 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -13,7 +13,9 @@ import fr.acinq.lightning.channel.states.* import fr.acinq.lightning.crypto.noise.* import fr.acinq.lightning.db.* import fr.acinq.lightning.logging.* +import fr.acinq.lightning.payment.OnionMessageAction import fr.acinq.lightning.payment.* +import fr.acinq.lightning.payment.OfferManager import fr.acinq.lightning.serialization.Encryption.from import fr.acinq.lightning.serialization.Serialization.DeserializationResult import fr.acinq.lightning.transactions.Transactions @@ -74,9 +76,12 @@ data class PayToOpenResponseCommand(val payToOpenResponse: PayToOpenResponse) : data class SendPayment(val paymentId: UUID, val amount: MilliSatoshi, val recipient: PublicKey, val paymentRequest: PaymentRequest, val trampolineFeesOverride: List? = null) : PaymentCommand() { val paymentHash: ByteVector32 = paymentRequest.paymentHash } +data class PayOffer(val paymentId: UUID, val amount: MilliSatoshi, val quantity: Long, val offer: OfferTypes.Offer, val minReplyPathHops: Int, val trampolineFeesOverride: List? = null) : PaymentCommand() data class PurgeExpiredPayments(val fromCreatedAt: Long, val toCreatedAt: Long) : PaymentCommand() +data class SendMessage(val message: OnionMessage) : PeerCommand() + sealed class PeerEvent @Deprecated("Replaced by NodeEvents", replaceWith = ReplaceWith("PaymentEvents.PaymentReceived", "fr.acinq.lightning.PaymentEvents")) data class PaymentReceived(val incomingPayment: IncomingPayment, val received: IncomingPayment.Received) : PeerEvent() @@ -86,6 +91,8 @@ sealed class SendPaymentResult : PeerEvent() { } data class PaymentNotSent(override val request: SendPayment, val reason: OutgoingPaymentFailure) : SendPaymentResult() data class PaymentSent(override val request: SendPayment, val payment: LightningOutgoingPayment) : SendPaymentResult() +data class OfferNotPaid(val request: PayOffer, val reason: OfferPaymentFailure) : PeerEvent() +data class OfferInvoiceReceived(val request: PayOffer, val invoice: Bolt12Invoice, val payerKey: PrivateKey) : PeerEvent() data class ChannelClosing(val channelId: ByteVector32) : PeerEvent() /** @@ -200,6 +207,8 @@ class Peer( private var swapInJob: Job? = null + private val offerManager = OfferManager(nodeParams, walletParams, logger) + init { logger.info { "initializing peer" } launch { @@ -1091,7 +1100,7 @@ class Peer( } is OnionMessage -> { logger.info { "received ${msg::class.simpleName}" } - // TODO: process onion message + offerManager.receiveMessage(msg, _channels, currentTipFlow.filterNotNull().first().first)?.let { processOnionMessageAction(it) } } } } @@ -1240,6 +1249,26 @@ class Peer( } } } + is PayOffer -> { + val invoiceRequests = offerManager.requestInvoice(cmd) + invoiceRequests.forEach { input.send(SendMessage(it)) } + if (invoiceRequests.isEmpty()) { + _eventsFlow.emit(OfferNotPaid(cmd, OfferPaymentFailure.NoResponse)) + } + } + is SendMessage -> peerConnection?.send(cmd.message) + // TODO: timeout invoice requests + } + } + + private suspend fun processOnionMessageAction(action: OnionMessageAction) { + when (action) { + is OnionMessageAction.PayInvoice -> { + _eventsFlow.emit(OfferInvoiceReceived(action.payOffer, action.invoice, action.payerKey)) + input.send(SendPayment(action.payOffer.paymentId, action.payOffer.amount, action.invoice.nodeId, action.invoice, action.payOffer.trampolineFeesOverride)) + } + is OnionMessageAction.ReportPaymentFailure -> _eventsFlow.emit(OfferNotPaid(action.payOffer, action.failure)) + is OnionMessageAction.SendMessage -> input.send(SendMessage(action.message)) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt new file mode 100644 index 000000000..da6ae8e3e --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt @@ -0,0 +1,142 @@ +package fr.acinq.lightning.payment + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.PrivateKey +import fr.acinq.bitcoin.PublicKey +import fr.acinq.bitcoin.utils.Either.Left +import fr.acinq.bitcoin.utils.Either.Right +import fr.acinq.lightning.* +import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.Lightning.randomKey +import fr.acinq.lightning.channel.states.ChannelState +import fr.acinq.lightning.channel.states.Normal +import fr.acinq.lightning.channel.states.Offline +import fr.acinq.lightning.channel.states.Syncing +import fr.acinq.lightning.crypto.RouteBlinding +import fr.acinq.lightning.io.PayOffer +import fr.acinq.lightning.logging.MDCLogger +import fr.acinq.lightning.message.OnionMessages +import fr.acinq.lightning.message.OnionMessages.Destination +import fr.acinq.lightning.message.OnionMessages.IntermediateNode +import fr.acinq.lightning.message.OnionMessages.buildMessage +import fr.acinq.lightning.utils.currentTimestampMillis +import fr.acinq.lightning.utils.currentTimestampSeconds +import fr.acinq.lightning.utils.msat +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.wire.* +import kotlin.math.max + +sealed class OnionMessageAction { + data class SendMessage(val message: OnionMessage): OnionMessageAction() + data class PayInvoice(val payOffer: PayOffer, val invoice: Bolt12Invoice, val payerKey: PrivateKey): OnionMessageAction() + data class ReportPaymentFailure(val payOffer: PayOffer, val failure: OfferPaymentFailure): OnionMessageAction() +} + +private data class PendingInvoiceRequest(val payOffer: PayOffer, val payerKey: PrivateKey, val request: OfferTypes.InvoiceRequest, val createAtSeconds: Long) + +class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, val logger: MDCLogger) { + val remoteNodeId: PublicKey = walletParams.trampolineNode.id + private val pendingInvoiceRequests: HashMap = HashMap() + private val localOffers: HashMap = HashMap() + + fun registerOffer(offer: OfferTypes.Offer, pathId: ByteVector32) { + localOffers[pathId] = offer + } + + fun requestInvoice(payOffer: PayOffer): List { + val payerKey = randomKey() + val request = OfferTypes.InvoiceRequest(payOffer.offer, payOffer.amount, payOffer.quantity, nodeParams.features.bolt12Features(), payerKey, nodeParams.chainHash) + val replyPathId = randomBytes32() + pendingInvoiceRequests[replyPathId] = PendingInvoiceRequest(payOffer, payerKey, request, currentTimestampSeconds()) + val numHopsToAdd = max(0, payOffer.minReplyPathHops - 1) + val replyPathHops = (listOf(remoteNodeId) + List(numHopsToAdd) { nodeParams.nodeId }).map { IntermediateNode(it) } + val lastHop = Destination.Recipient(nodeParams.nodeId, replyPathId) + val replyPath = OnionMessages.buildRoute(randomKey(), replyPathHops, lastHop) + val messageContent = TlvStream(OnionMessagePayloadTlv.ReplyPath(replyPath), OnionMessagePayloadTlv.InvoiceRequest(request.records)) + return payOffer.offer.contactInfos.mapNotNull { contactInfo -> buildMessage(randomKey(), randomKey(), listOf(IntermediateNode(remoteNodeId)), Destination(contactInfo), messageContent).right } + } + + fun receiveMessage(msg: OnionMessage, channels: Map, currentBlockHeight: Int): OnionMessageAction? { + val decrypted = OnionMessages.decryptMessage(nodeParams.nodePrivateKey, msg, logger) + if (decrypted == null) { + return null + } else { + if (pendingInvoiceRequests.containsKey(decrypted.pathId)) { + val (payOffer, payerKey, request, _) = pendingInvoiceRequests[decrypted.pathId]!! + pendingInvoiceRequests.remove(decrypted.pathId) + val invoice = decrypted.content.records.get()?.let { Bolt12Invoice.validate(it.tlvs).right } + if (invoice == null) { + val error = decrypted.content.records.get()?.let { OfferTypes.InvoiceError.validate(it.tlvs).right } + val failure = error?.let { OfferPaymentFailure.InvoiceError(request, it) } ?: OfferPaymentFailure.InvalidResponse(request) + return OnionMessageAction.ReportPaymentFailure(payOffer, failure) + } else { + if (invoice.validateFor(request).isRight) { + return OnionMessageAction.PayInvoice(payOffer, invoice, payerKey) + } else { + return OnionMessageAction.ReportPaymentFailure( + payOffer, + OfferPaymentFailure.InvalidInvoice(request, invoice) + ) + } + } + } else if (localOffers.containsKey(decrypted.pathId) && decrypted.content.replyPath != null) { + val offer = localOffers[decrypted.pathId]!! + val offerPathId = ByteVector32(decrypted.pathId) + val request = decrypted.content.records.get()?.let { OfferTypes.InvoiceRequest.validate(it.tlvs).right } + if (request != null && request.offer == offer && request.isValid()) { + val amount = request.amount ?: (request.offer.amount!! * request.quantity) + val preimage = randomBytes32() + val pathId = OfferPaymentMetadata.V1(offerPathId, amount, preimage, request.payerId, request.quantity, currentTimestampMillis()).toPathId(nodeParams.nodePrivateKey) + val recipientPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(pathId))).write().toByteVector() + val remoteChannelUpdates = channels.values.mapNotNull { channelState -> + when (channelState) { + is Normal -> channelState.remoteChannelUpdate + is Offline -> (channelState.state as? Normal)?.remoteChannelUpdate + is Syncing -> (channelState.state as? Normal)?.remoteChannelUpdate + else -> null + } + } + val paymentInfo = OfferTypes.PaymentInfo( + feeBase = remoteChannelUpdates.maxOfOrNull { it.feeBaseMsat } ?: walletParams.invoiceDefaultRoutingFees.feeBase, + feeProportionalMillionths = remoteChannelUpdates.maxOfOrNull { it.feeProportionalMillionths } ?: walletParams.invoiceDefaultRoutingFees.feeProportional, + cltvExpiryDelta = remoteChannelUpdates.maxOfOrNull { it.cltvExpiryDelta } ?: walletParams.invoiceDefaultRoutingFees.cltvExpiryDelta, + minHtlc = remoteChannelUpdates.minOfOrNull { it.htlcMinimumMsat } ?: 1.msat, + maxHtlc = amount, + allowedFeatures = Features.empty + ) + val remoteNodePayload = RouteBlindingEncryptedData(TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId.peerId(nodeParams.nodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(paymentInfo.cltvExpiryDelta, paymentInfo.feeProportionalMillionths, paymentInfo.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints((paymentInfo.cltvExpiryDelta + nodeParams.maxFinalCltvExpiryDelta).toCltvExpiry(currentBlockHeight.toLong()), paymentInfo.minHtlc) + )).write().toByteVector() + val blindedRoute = RouteBlinding.create(randomKey(), listOf(remoteNodeId, nodeParams.nodeId), listOf(remoteNodePayload, recipientPayload)) + val path = Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRoute), paymentInfo) + val invoice = Bolt12Invoice(request, preimage, decrypted.blindedPrivateKey, nodeParams.bolt12invoiceExpiry.inWholeSeconds, nodeParams.features.bolt12Features(), listOf(path)) + return when (val invoiceMessage = buildMessage(randomKey(), randomKey(), listOf(IntermediateNode(remoteNodeId)), Destination.BlindedPath(decrypted.content.replyPath), TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)))) { + is Left -> null + is Right -> OnionMessageAction.SendMessage(invoiceMessage.value) + } + } else { + // TODO: send back invoice error + return null + } + } else { + // Ignore unexpected messages. + return null + } + } + } + + fun checkInvoiceRequestTimeout(timeoutSeconds: Long): List { + val timedOut = ArrayList() + val cutoff = currentTimestampSeconds() - timeoutSeconds + for ((pathId, pending) in pendingInvoiceRequests) { + if (pending.createAtSeconds < cutoff) { + timedOut.add(OnionMessageAction.ReportPaymentFailure(pending.payOffer, OfferPaymentFailure.NoResponse)) + pendingInvoiceRequests.remove(pathId) + } + } + return timedOut + } +} diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferPaymentFailure.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferPaymentFailure.kt new file mode 100644 index 000000000..42bd0033d --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OfferPaymentFailure.kt @@ -0,0 +1,10 @@ +package fr.acinq.lightning.payment + +import fr.acinq.lightning.wire.OfferTypes + +sealed class OfferPaymentFailure { + data object NoResponse : OfferPaymentFailure() + data class InvalidResponse(val request: OfferTypes.InvoiceRequest) : OfferPaymentFailure() + data class InvoiceError(val request: OfferTypes.InvoiceRequest, val error: OfferTypes.InvoiceError) : OfferPaymentFailure() + data class InvalidInvoice(val request: OfferTypes.InvoiceRequest, val invoice: Bolt12Invoice) : OfferPaymentFailure() +} \ No newline at end of file diff --git a/src/commonTest/kotlin/fr/acinq/lightning/message/OfferManagerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/message/OfferManagerTestsCommon.kt new file mode 100644 index 000000000..d37f48a06 --- /dev/null +++ b/src/commonTest/kotlin/fr/acinq/lightning/message/OfferManagerTestsCommon.kt @@ -0,0 +1,87 @@ +package fr.acinq.lightning.message + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.utils.Either +import fr.acinq.lightning.EncodedNodeId +import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.Lightning.randomKey +import fr.acinq.lightning.NodeUri +import fr.acinq.lightning.ShortChannelId +import fr.acinq.lightning.crypto.RouteBlinding +import fr.acinq.lightning.crypto.sphinx.DecryptedPacket +import fr.acinq.lightning.crypto.sphinx.Sphinx +import fr.acinq.lightning.io.PayOffer +import fr.acinq.lightning.logging.MDCLogger +import fr.acinq.lightning.payment.OfferManager +import fr.acinq.lightning.payment.OnionMessageAction +import fr.acinq.lightning.tests.TestConstants +import fr.acinq.lightning.tests.utils.LightningTestSuite +import fr.acinq.lightning.tests.utils.runSuspendTest +import fr.acinq.lightning.tests.utils.testLoggerFactory +import fr.acinq.lightning.utils.UUID +import fr.acinq.lightning.utils.msat +import fr.acinq.lightning.wire.MessageOnion +import fr.acinq.lightning.wire.OfferTypes +import fr.acinq.lightning.wire.OnionMessage +import fr.acinq.lightning.wire.RouteBlindingEncryptedData +import kotlin.test.* + +class OfferManagerTestsCommon : LightningTestSuite() { + val trampolineKey = randomKey() + val walletParams = TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(trampolineKey.publicKey(), "trampoline.com", 9735)) + val logger: MDCLogger = MDCLogger(testLoggerFactory.newLogger(this::class)) + + fun trampolineRelay(msg: OnionMessage): Pair> { + val blindedPrivateKey = RouteBlinding.derivePrivateKey(trampolineKey, msg.blindingKey) + val decrypted = Sphinx.peel( + blindedPrivateKey, + ByteVector.empty, + msg.onionRoutingPacket + ) + assertIs>(decrypted) + assertFalse(decrypted.value.isLastPacket) + val message = MessageOnion.read(decrypted.value.payload.toByteArray()) + val (decryptedPayload, nextBlinding) = RouteBlinding.decryptPayload( + trampolineKey, + msg.blindingKey, + message.encryptedData + ) + val relayInfo = RouteBlindingEncryptedData.read(decryptedPayload.toByteArray()) + + return Pair(OnionMessage(relayInfo.nextBlindingOverride ?: nextBlinding, decrypted.value.nextPacket), relayInfo.nextNodeId?.let { Either.Right(it) } ?: Either.Left(relayInfo.outgoingChannelId!!)) + } + + @Test + fun `offer workflow`() = runSuspendTest { + val aliceOfferManager = OfferManager(TestConstants.Alice.nodeParams, walletParams, logger) + val bobOfferManager = OfferManager(TestConstants.Bob.nodeParams, walletParams, logger) + + val pathId = randomBytes32() + val offer = OfferTypes.Offer.createBlindedOffer( + 1000.msat, + "Blockaccino", + TestConstants.Alice.nodeParams, + walletParams.trampolineNode, + pathId, + setOf(OfferTypes.OfferQuantityMax(0)) + ) + aliceOfferManager.registerOffer(offer, pathId) + + val payOffer = PayOffer(UUID.randomUUID(), 5500.msat, 5, offer, 2) + val invoiceRequests = bobOfferManager.requestInvoice(payOffer) + assertTrue(invoiceRequests.size == 1) + val relay1 = trampolineRelay(invoiceRequests.first()) + assertEquals(Either.Right(EncodedNodeId(trampolineKey.publicKey())), relay1.second) + val relay2 = trampolineRelay(relay1.first) + assertEquals(Either.Left(ShortChannelId.peerId(TestConstants.Alice.nodeParams.nodeId)), relay2.second) + val invoiceResponse = aliceOfferManager.receiveMessage(relay2.first, mapOf(), 0) + assertIs(invoiceResponse) + val relay3 = trampolineRelay(invoiceResponse.message) + assertEquals(Either.Right(EncodedNodeId(trampolineKey.publicKey())), relay3.second) + val relay4 = trampolineRelay(relay3.first) + assertEquals(Either.Right(EncodedNodeId(TestConstants.Bob.nodeParams.nodeId)), relay4.second) + val payInvoice = bobOfferManager.receiveMessage(relay4.first, mapOf(), 0) + assertIs(payInvoice) + assertEquals(payOffer, payInvoice.payOffer) + } +} \ No newline at end of file