Skip to content

Commit

Permalink
Add OfferManager
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Apr 22, 2024
1 parent fd2e313 commit daf6a19
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ data class RecipientCltvExpiryParams(val min: CltvExpiryDelta, val max: CltvExpi
* @param zeroConfPeers list of peers with whom we use zero-conf (note that this is a strong trust assumption).
* @param liquidityPolicy fee policy for liquidity events, can be modified at any time.
* @param minFinalCltvExpiryDelta cltv-expiry-delta that we require when receiving a payment.
* @param maxFinalCltvExpiryDelta maximum cltv-expiry-delta that we accept when receiving a payment.
* @param bolt12invoiceExpiry duration for which bolt12 invoices that we create are valid.
*/
data class NodeParams(
Expand Down Expand Up @@ -151,6 +152,7 @@ data class NodeParams(
val zeroConfPeers: Set<PublicKey>,
val liquidityPolicy: MutableStateFlow<LiquidityPolicy>,
val minFinalCltvExpiryDelta: CltvExpiryDelta,
val maxFinalCltvExpiryDelta: CltvExpiryDelta,
val bolt12invoiceExpiry: Duration
) {
val nodePrivateKey get() = keyManager.nodeKeys.nodeKey.privateKey
Expand Down Expand Up @@ -226,6 +228,7 @@ data class NodeParams(
paymentRecipientExpiryParams = RecipientCltvExpiryParams(CltvExpiryDelta(75), CltvExpiryDelta(200)),
liquidityPolicy = MutableStateFlow<LiquidityPolicy>(LiquidityPolicy.Auto(maxAbsoluteFee = 2_000.sat, maxRelativeFeeBasisPoints = 3_000 /* 3000 = 30 % */, skipAbsoluteFeeCheck = false)),
minFinalCltvExpiryDelta = Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA,
maxFinalCltvExpiryDelta = CltvExpiryDelta(500),
bolt12invoiceExpiry = 60.seconds
)
}
31 changes: 30 additions & 1 deletion src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import fr.acinq.lightning.channel.states.*
import fr.acinq.lightning.crypto.noise.*
import fr.acinq.lightning.db.*
import fr.acinq.lightning.logging.*
import fr.acinq.lightning.payment.OnionMessageAction
import fr.acinq.lightning.payment.*
import fr.acinq.lightning.payment.OfferManager
import fr.acinq.lightning.serialization.Encryption.from
import fr.acinq.lightning.serialization.Serialization.DeserializationResult
import fr.acinq.lightning.transactions.Transactions
Expand Down Expand Up @@ -74,9 +76,12 @@ data class PayToOpenResponseCommand(val payToOpenResponse: PayToOpenResponse) :
data class SendPayment(val paymentId: UUID, val amount: MilliSatoshi, val recipient: PublicKey, val paymentRequest: PaymentRequest, val trampolineFeesOverride: List<TrampolineFees>? = null) : PaymentCommand() {
val paymentHash: ByteVector32 = paymentRequest.paymentHash
}
data class PayOffer(val paymentId: UUID, val amount: MilliSatoshi, val quantity: Long, val offer: OfferTypes.Offer, val minReplyPathHops: Int, val trampolineFeesOverride: List<TrampolineFees>? = null) : PaymentCommand()

data class PurgeExpiredPayments(val fromCreatedAt: Long, val toCreatedAt: Long) : PaymentCommand()

data class SendMessage(val message: OnionMessage) : PeerCommand()

sealed class PeerEvent
@Deprecated("Replaced by NodeEvents", replaceWith = ReplaceWith("PaymentEvents.PaymentReceived", "fr.acinq.lightning.PaymentEvents"))
data class PaymentReceived(val incomingPayment: IncomingPayment, val received: IncomingPayment.Received) : PeerEvent()
Expand All @@ -86,6 +91,8 @@ sealed class SendPaymentResult : PeerEvent() {
}
data class PaymentNotSent(override val request: SendPayment, val reason: OutgoingPaymentFailure) : SendPaymentResult()
data class PaymentSent(override val request: SendPayment, val payment: LightningOutgoingPayment) : SendPaymentResult()
data class OfferNotPaid(val request: PayOffer, val reason: OfferPaymentFailure) : PeerEvent()
data class OfferInvoiceReceived(val request: PayOffer, val invoice: Bolt12Invoice, val payerKey: PrivateKey) : PeerEvent()
data class ChannelClosing(val channelId: ByteVector32) : PeerEvent()

/**
Expand Down Expand Up @@ -200,6 +207,8 @@ class Peer(

private var swapInJob: Job? = null

private val offerManager = OfferManager(nodeParams, walletParams, logger)

init {
logger.info { "initializing peer" }
launch {
Expand Down Expand Up @@ -1091,7 +1100,7 @@ class Peer(
}
is OnionMessage -> {
logger.info { "received ${msg::class.simpleName}" }
// TODO: process onion message
offerManager.receiveMessage(msg, _channels, currentTipFlow.filterNotNull().first().first)?.let { processOnionMessageAction(it) }
}
}
}
Expand Down Expand Up @@ -1240,6 +1249,26 @@ class Peer(
}
}
}
is PayOffer -> {
val invoiceRequests = offerManager.requestInvoice(cmd)
invoiceRequests.forEach { input.send(SendMessage(it)) }
if (invoiceRequests.isEmpty()) {
_eventsFlow.emit(OfferNotPaid(cmd, OfferPaymentFailure.NoResponse))
}
}
is SendMessage -> peerConnection?.send(cmd.message)
// TODO: timeout invoice requests
}
}

private suspend fun processOnionMessageAction(action: OnionMessageAction) {
when (action) {
is OnionMessageAction.PayInvoice -> {
_eventsFlow.emit(OfferInvoiceReceived(action.payOffer, action.invoice, action.payerKey))
input.send(SendPayment(action.payOffer.paymentId, action.payOffer.amount, action.invoice.nodeId, action.invoice, action.payOffer.trampolineFeesOverride))
}
is OnionMessageAction.ReportPaymentFailure -> _eventsFlow.emit(OfferNotPaid(action.payOffer, action.failure))
is OnionMessageAction.SendMessage -> input.send(SendMessage(action.message))
}
}
}
142 changes: 142 additions & 0 deletions src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package fr.acinq.lightning.payment

import fr.acinq.bitcoin.ByteVector
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.PrivateKey
import fr.acinq.bitcoin.PublicKey
import fr.acinq.bitcoin.utils.Either.Left
import fr.acinq.bitcoin.utils.Either.Right
import fr.acinq.lightning.*
import fr.acinq.lightning.Lightning.randomBytes32
import fr.acinq.lightning.Lightning.randomKey
import fr.acinq.lightning.channel.states.ChannelState
import fr.acinq.lightning.channel.states.Normal
import fr.acinq.lightning.channel.states.Offline
import fr.acinq.lightning.channel.states.Syncing
import fr.acinq.lightning.crypto.RouteBlinding
import fr.acinq.lightning.io.PayOffer
import fr.acinq.lightning.logging.MDCLogger
import fr.acinq.lightning.message.OnionMessages
import fr.acinq.lightning.message.OnionMessages.Destination
import fr.acinq.lightning.message.OnionMessages.IntermediateNode
import fr.acinq.lightning.message.OnionMessages.buildMessage
import fr.acinq.lightning.utils.currentTimestampMillis
import fr.acinq.lightning.utils.currentTimestampSeconds
import fr.acinq.lightning.utils.msat
import fr.acinq.lightning.utils.toByteVector
import fr.acinq.lightning.wire.*
import kotlin.math.max

sealed class OnionMessageAction {
data class SendMessage(val message: OnionMessage): OnionMessageAction()
data class PayInvoice(val payOffer: PayOffer, val invoice: Bolt12Invoice, val payerKey: PrivateKey): OnionMessageAction()
data class ReportPaymentFailure(val payOffer: PayOffer, val failure: OfferPaymentFailure): OnionMessageAction()
}

private data class PendingInvoiceRequest(val payOffer: PayOffer, val payerKey: PrivateKey, val request: OfferTypes.InvoiceRequest, val createAtSeconds: Long)

class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, val logger: MDCLogger) {
val remoteNodeId: PublicKey = walletParams.trampolineNode.id
private val pendingInvoiceRequests: HashMap<ByteVector32, PendingInvoiceRequest> = HashMap()
private val localOffers: HashMap<ByteVector32, OfferTypes.Offer> = HashMap()

fun registerOffer(offer: OfferTypes.Offer, pathId: ByteVector32) {
localOffers[pathId] = offer
}

fun requestInvoice(payOffer: PayOffer): List<OnionMessage> {
val payerKey = randomKey()
val request = OfferTypes.InvoiceRequest(payOffer.offer, payOffer.amount, payOffer.quantity, nodeParams.features.bolt12Features(), payerKey, nodeParams.chainHash)
val replyPathId = randomBytes32()
pendingInvoiceRequests[replyPathId] = PendingInvoiceRequest(payOffer, payerKey, request, currentTimestampSeconds())
val numHopsToAdd = max(0, payOffer.minReplyPathHops - 1)
val replyPathHops = (listOf(remoteNodeId) + List(numHopsToAdd) { nodeParams.nodeId }).map { IntermediateNode(it) }
val lastHop = Destination.Recipient(nodeParams.nodeId, replyPathId)
val replyPath = OnionMessages.buildRoute(randomKey(), replyPathHops, lastHop)
val messageContent = TlvStream(OnionMessagePayloadTlv.ReplyPath(replyPath), OnionMessagePayloadTlv.InvoiceRequest(request.records))
return payOffer.offer.contactInfos.mapNotNull { contactInfo -> buildMessage(randomKey(), randomKey(), listOf(IntermediateNode(remoteNodeId)), Destination(contactInfo), messageContent).right }
}

fun receiveMessage(msg: OnionMessage, channels: Map<ByteVector32, ChannelState>, currentBlockHeight: Int): OnionMessageAction? {
val decrypted = OnionMessages.decryptMessage(nodeParams.nodePrivateKey, msg, logger)
if (decrypted == null) {
return null
} else {
if (pendingInvoiceRequests.containsKey(decrypted.pathId)) {
val (payOffer, payerKey, request, _) = pendingInvoiceRequests[decrypted.pathId]!!
pendingInvoiceRequests.remove(decrypted.pathId)
val invoice = decrypted.content.records.get<OnionMessagePayloadTlv.Invoice>()?.let { Bolt12Invoice.validate(it.tlvs).right }
if (invoice == null) {
val error = decrypted.content.records.get<OnionMessagePayloadTlv.InvoiceError>()?.let { OfferTypes.InvoiceError.validate(it.tlvs).right }
val failure = error?.let { OfferPaymentFailure.InvoiceError(request, it) } ?: OfferPaymentFailure.InvalidResponse(request)
return OnionMessageAction.ReportPaymentFailure(payOffer, failure)
} else {
if (invoice.validateFor(request).isRight) {
return OnionMessageAction.PayInvoice(payOffer, invoice, payerKey)
} else {
return OnionMessageAction.ReportPaymentFailure(
payOffer,
OfferPaymentFailure.InvalidInvoice(request, invoice)
)
}
}
} else if (localOffers.containsKey(decrypted.pathId) && decrypted.content.replyPath != null) {
val offer = localOffers[decrypted.pathId]!!
val offerPathId = ByteVector32(decrypted.pathId)
val request = decrypted.content.records.get<OnionMessagePayloadTlv.InvoiceRequest>()?.let { OfferTypes.InvoiceRequest.validate(it.tlvs).right }
if (request != null && request.offer == offer && request.isValid()) {
val amount = request.amount ?: (request.offer.amount!! * request.quantity)
val preimage = randomBytes32()
val pathId = OfferPaymentMetadata.V1(offerPathId, amount, preimage, request.payerId, request.quantity, currentTimestampMillis()).toPathId(nodeParams.nodePrivateKey)
val recipientPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(pathId))).write().toByteVector()
val remoteChannelUpdates = channels.values.mapNotNull { channelState ->
when (channelState) {
is Normal -> channelState.remoteChannelUpdate
is Offline -> (channelState.state as? Normal)?.remoteChannelUpdate
is Syncing -> (channelState.state as? Normal)?.remoteChannelUpdate
else -> null
}
}
val paymentInfo = OfferTypes.PaymentInfo(
feeBase = remoteChannelUpdates.maxOfOrNull { it.feeBaseMsat } ?: walletParams.invoiceDefaultRoutingFees.feeBase,
feeProportionalMillionths = remoteChannelUpdates.maxOfOrNull { it.feeProportionalMillionths } ?: walletParams.invoiceDefaultRoutingFees.feeProportional,
cltvExpiryDelta = remoteChannelUpdates.maxOfOrNull { it.cltvExpiryDelta } ?: walletParams.invoiceDefaultRoutingFees.cltvExpiryDelta,
minHtlc = remoteChannelUpdates.minOfOrNull { it.htlcMinimumMsat } ?: 1.msat,
maxHtlc = amount,
allowedFeatures = Features.empty
)
val remoteNodePayload = RouteBlindingEncryptedData(TlvStream(
RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId.peerId(nodeParams.nodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(paymentInfo.cltvExpiryDelta, paymentInfo.feeProportionalMillionths, paymentInfo.feeBase),
RouteBlindingEncryptedDataTlv.PaymentConstraints((paymentInfo.cltvExpiryDelta + nodeParams.maxFinalCltvExpiryDelta).toCltvExpiry(currentBlockHeight.toLong()), paymentInfo.minHtlc)
)).write().toByteVector()
val blindedRoute = RouteBlinding.create(randomKey(), listOf(remoteNodeId, nodeParams.nodeId), listOf(remoteNodePayload, recipientPayload))
val path = Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRoute), paymentInfo)
val invoice = Bolt12Invoice(request, preimage, decrypted.blindedPrivateKey, nodeParams.bolt12invoiceExpiry.inWholeSeconds, nodeParams.features.bolt12Features(), listOf(path))
return when (val invoiceMessage = buildMessage(randomKey(), randomKey(), listOf(IntermediateNode(remoteNodeId)), Destination.BlindedPath(decrypted.content.replyPath), TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)))) {
is Left -> null
is Right -> OnionMessageAction.SendMessage(invoiceMessage.value)
}
} else {
// TODO: send back invoice error
return null
}
} else {
// Ignore unexpected messages.
return null
}
}
}

fun checkInvoiceRequestTimeout(timeoutSeconds: Long): List<OnionMessageAction.ReportPaymentFailure> {
val timedOut = ArrayList<OnionMessageAction.ReportPaymentFailure>()
val cutoff = currentTimestampSeconds() - timeoutSeconds
for ((pathId, pending) in pendingInvoiceRequests) {
if (pending.createAtSeconds < cutoff) {
timedOut.add(OnionMessageAction.ReportPaymentFailure(pending.payOffer, OfferPaymentFailure.NoResponse))
pendingInvoiceRequests.remove(pathId)
}
}
return timedOut
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package fr.acinq.lightning.payment

import fr.acinq.lightning.wire.OfferTypes

sealed class OfferPaymentFailure {
data object NoResponse : OfferPaymentFailure()
data class InvalidResponse(val request: OfferTypes.InvoiceRequest) : OfferPaymentFailure()
data class InvoiceError(val request: OfferTypes.InvoiceRequest, val error: OfferTypes.InvoiceError) : OfferPaymentFailure()
data class InvalidInvoice(val request: OfferTypes.InvoiceRequest, val invoice: Bolt12Invoice) : OfferPaymentFailure()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package fr.acinq.lightning.message

import fr.acinq.bitcoin.ByteVector
import fr.acinq.bitcoin.utils.Either
import fr.acinq.lightning.EncodedNodeId
import fr.acinq.lightning.Lightning.randomBytes32
import fr.acinq.lightning.Lightning.randomKey
import fr.acinq.lightning.NodeUri
import fr.acinq.lightning.ShortChannelId
import fr.acinq.lightning.crypto.RouteBlinding
import fr.acinq.lightning.crypto.sphinx.DecryptedPacket
import fr.acinq.lightning.crypto.sphinx.Sphinx
import fr.acinq.lightning.io.PayOffer
import fr.acinq.lightning.logging.MDCLogger
import fr.acinq.lightning.payment.OfferManager
import fr.acinq.lightning.payment.OnionMessageAction
import fr.acinq.lightning.tests.TestConstants
import fr.acinq.lightning.tests.utils.LightningTestSuite
import fr.acinq.lightning.tests.utils.runSuspendTest
import fr.acinq.lightning.tests.utils.testLoggerFactory
import fr.acinq.lightning.utils.UUID
import fr.acinq.lightning.utils.msat
import fr.acinq.lightning.wire.MessageOnion
import fr.acinq.lightning.wire.OfferTypes
import fr.acinq.lightning.wire.OnionMessage
import fr.acinq.lightning.wire.RouteBlindingEncryptedData
import kotlin.test.*

class OfferManagerTestsCommon : LightningTestSuite() {
val trampolineKey = randomKey()
val walletParams = TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(trampolineKey.publicKey(), "trampoline.com", 9735))
val logger: MDCLogger = MDCLogger(testLoggerFactory.newLogger(this::class))

fun trampolineRelay(msg: OnionMessage): Pair<OnionMessage, Either<ShortChannelId, EncodedNodeId>> {
val blindedPrivateKey = RouteBlinding.derivePrivateKey(trampolineKey, msg.blindingKey)
val decrypted = Sphinx.peel(
blindedPrivateKey,
ByteVector.empty,
msg.onionRoutingPacket
)
assertIs<Either.Right<DecryptedPacket>>(decrypted)
assertFalse(decrypted.value.isLastPacket)
val message = MessageOnion.read(decrypted.value.payload.toByteArray())
val (decryptedPayload, nextBlinding) = RouteBlinding.decryptPayload(
trampolineKey,
msg.blindingKey,
message.encryptedData
)
val relayInfo = RouteBlindingEncryptedData.read(decryptedPayload.toByteArray())

return Pair(OnionMessage(relayInfo.nextBlindingOverride ?: nextBlinding, decrypted.value.nextPacket), relayInfo.nextNodeId?.let { Either.Right(it) } ?: Either.Left(relayInfo.outgoingChannelId!!))
}

@Test
fun `offer workflow`() = runSuspendTest {
val aliceOfferManager = OfferManager(TestConstants.Alice.nodeParams, walletParams, logger)
val bobOfferManager = OfferManager(TestConstants.Bob.nodeParams, walletParams, logger)

val pathId = randomBytes32()
val offer = OfferTypes.Offer.createBlindedOffer(
1000.msat,
"Blockaccino",
TestConstants.Alice.nodeParams,
walletParams.trampolineNode,
pathId,
setOf(OfferTypes.OfferQuantityMax(0))
)
aliceOfferManager.registerOffer(offer, pathId)

val payOffer = PayOffer(UUID.randomUUID(), 5500.msat, 5, offer, 2)
val invoiceRequests = bobOfferManager.requestInvoice(payOffer)
assertTrue(invoiceRequests.size == 1)
val relay1 = trampolineRelay(invoiceRequests.first())
assertEquals(Either.Right(EncodedNodeId(trampolineKey.publicKey())), relay1.second)
val relay2 = trampolineRelay(relay1.first)
assertEquals(Either.Left(ShortChannelId.peerId(TestConstants.Alice.nodeParams.nodeId)), relay2.second)
val invoiceResponse = aliceOfferManager.receiveMessage(relay2.first, mapOf(), 0)
assertIs<OnionMessageAction.SendMessage>(invoiceResponse)
val relay3 = trampolineRelay(invoiceResponse.message)
assertEquals(Either.Right(EncodedNodeId(trampolineKey.publicKey())), relay3.second)
val relay4 = trampolineRelay(relay3.first)
assertEquals(Either.Right(EncodedNodeId(TestConstants.Bob.nodeParams.nodeId)), relay4.second)
val payInvoice = bobOfferManager.receiveMessage(relay4.first, mapOf(), 0)
assertIs<OnionMessageAction.PayInvoice>(payInvoice)
assertEquals(payOffer, payInvoice.payOffer)
}
}

0 comments on commit daf6a19

Please sign in to comment.