Skip to content

Commit

Permalink
Move update_add_htlc tlv to its own file
Browse files Browse the repository at this point in the history
And use a tlv stream for pay-to-open to be future-proof.
Fix a bug where the `blinding` point of an HTLC was not encoded.
  • Loading branch information
t-bast committed Apr 17, 2024
1 parent 7ae4e57 commit 63d3a5e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 51 deletions.
14 changes: 14 additions & 0 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,18 @@ sealed class PleaseOpenChannelRejectedTlv : Tlv {
override fun read(input: Input): ExpectedFees = ExpectedFees(LightningCodecs.tu64(input).msat)
}
}
}

sealed class PayToOpenRequestTlv : Tlv {
/** Blinding ephemeral public key that should be used to derive shared secrets when using route blinding. */
data class Blinding(val publicKey: PublicKey) : PayToOpenRequestTlv() {
override val tag: Long get() = Blinding.tag

override fun write(out: Output) = LightningCodecs.writeBytes(publicKey.value, out)

companion object : TlvValueReader<Blinding> {
const val tag: Long = 0
override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33)))
}
}
}
19 changes: 19 additions & 0 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/HtlcTlv.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package fr.acinq.lightning.wire

import fr.acinq.bitcoin.PublicKey
import fr.acinq.bitcoin.io.Input
import fr.acinq.bitcoin.io.Output

sealed class UpdateAddHtlcTlv : Tlv {
/** Blinding ephemeral public key that should be used to derive shared secrets when using route blinding. */
data class Blinding(val publicKey: PublicKey) : UpdateAddHtlcTlv() {
override val tag: Long get() = Blinding.tag

override fun write(out: Output) = LightningCodecs.writeBytes(publicKey.value, out)

companion object : TlvValueReader<Blinding> {
const val tag: Long = 0
override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33)))
}
}
}
52 changes: 20 additions & 32 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import fr.acinq.lightning.logging.*
import fr.acinq.lightning.router.Announcements
import fr.acinq.lightning.utils.*
import fr.acinq.secp256k1.Hex
import kotlinx.serialization.Serializable
import kotlin.math.max
import kotlin.math.min

Expand Down Expand Up @@ -1059,20 +1058,25 @@ data class UpdateAddHtlc(
LightningCodecs.writeBytes(paymentHash, out)
LightningCodecs.writeU32(cltvExpiry.toLong().toInt(), out)
OnionRoutingPacketSerializer(OnionRoutingPacket.PaymentPacketLength).write(onionRoutingPacket, out)
UpdateAddHtlcTlv.tlvSerializer.write(tlvStream)
TlvStreamSerializer(false, readers).write(tlvStream, out)
}

companion object : LightningMessageReader<UpdateAddHtlc> {
const val type: Long = 128

@Suppress("UNCHECKED_CAST")
private val readers = mapOf(
UpdateAddHtlcTlv.Blinding.tag to UpdateAddHtlcTlv.Blinding as TlvValueReader<UpdateAddHtlcTlv>
)

override fun read(input: Input): UpdateAddHtlc {
val channelId = ByteVector32(LightningCodecs.bytes(input, 32))
val id = LightningCodecs.u64(input)
val amount = MilliSatoshi(LightningCodecs.u64(input))
val paymentHash = ByteVector32(LightningCodecs.bytes(input, 32))
val expiry = CltvExpiry(LightningCodecs.u32(input).toLong())
val onion = OnionRoutingPacketSerializer(OnionRoutingPacket.PaymentPacketLength).read(input)
val tlvStream = UpdateAddHtlcTlv.tlvSerializer.read(input)
val tlvStream = TlvStreamSerializer(false, readers).read(input)
return UpdateAddHtlc(channelId, id, amount, paymentHash, expiry, onion, tlvStream)
}

Expand All @@ -1083,37 +1087,14 @@ data class UpdateAddHtlc(
paymentHash: ByteVector32,
cltvExpiry: CltvExpiry,
onionRoutingPacket: OnionRoutingPacket,
blinding: PublicKey?): UpdateAddHtlc {
val tlvStream: TlvStream<UpdateAddHtlcTlv> = blinding?.let { TlvStream(UpdateAddHtlcTlv.Blinding(it)) } ?: TlvStream.empty()
blinding: PublicKey?
): UpdateAddHtlc {
val tlvStream = TlvStream(setOfNotNull<UpdateAddHtlcTlv>(blinding?.let { UpdateAddHtlcTlv.Blinding(it) }))
return UpdateAddHtlc(channelId, id, amountMsat, paymentHash, cltvExpiry, onionRoutingPacket, tlvStream)
}
}
}

sealed class UpdateAddHtlcTlv : Tlv {
data class Blinding(val publicKey: PublicKey) : UpdateAddHtlcTlv() {
override val tag: Long get() = Blinding.tag

override fun write(out: Output) {
LightningCodecs.writeBytes(publicKey.value, out)
}

companion object : TlvValueReader<Blinding> {
const val tag: Long = 0

override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33)))
}
}

companion object {
val tlvSerializer = TlvStreamSerializer(
false, @Suppress("UNCHECKED_CAST") mapOf(
UpdateAddHtlcTlv.Blinding.tag to UpdateAddHtlcTlv.Blinding as TlvValueReader<UpdateAddHtlcTlv>,
)
)
}
}

data class UpdateFulfillHtlc(
override val channelId: ByteVector32,
override val id: Long,
Expand Down Expand Up @@ -1636,10 +1617,12 @@ data class PayToOpenRequest(
val paymentHash: ByteVector32,
val expireAt: Long,
val finalPacket: OnionRoutingPacket,
val blinding: PublicKey? = null
val tlvStream: TlvStream<PayToOpenRequestTlv> = TlvStream.empty(),
) : LightningMessage, HasChainHash {
override val type: Long get() = PayToOpenRequest.type

val blinding: PublicKey? = tlvStream.get<PayToOpenRequestTlv.Blinding>()?.publicKey

override fun write(out: Output) {
LightningCodecs.writeBytes(chainHash.value, out)
LightningCodecs.writeU64(fundingSatoshis.toLong(), out)
Expand All @@ -1650,12 +1633,17 @@ data class PayToOpenRequest(
LightningCodecs.writeU32(expireAt.toInt(), out)
LightningCodecs.writeU16(finalPacket.payload.size(), out)
OnionRoutingPacketSerializer(finalPacket.payload.size()).write(finalPacket, out)
blinding?.let { LightningCodecs.writeBytes(it.value, out) }
TlvStreamSerializer(false, readers).write(tlvStream, out)
}

companion object : LightningMessageReader<PayToOpenRequest> {
const val type: Long = 35021

@Suppress("UNCHECKED_CAST")
private val readers = mapOf(
PayToOpenRequestTlv.Blinding.tag to PayToOpenRequestTlv.Blinding as TlvValueReader<PayToOpenRequestTlv>
)

override fun read(input: Input): PayToOpenRequest {
return PayToOpenRequest(
chainHash = BlockHash(LightningCodecs.bytes(input, 32)),
Expand All @@ -1666,7 +1654,7 @@ data class PayToOpenRequest(
paymentHash = ByteVector32(LightningCodecs.bytes(input, 32)),
expireAt = LightningCodecs.u32(input).toLong(),
finalPacket = OnionRoutingPacketSerializer(LightningCodecs.u16(input)).read(input),
blinding = if (input.availableBytes > 0) PublicKey(LightningCodecs.bytes(input, 33)) else null
tlvStream = TlvStreamSerializer(false, readers).read(input),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
hops = channelHops(paymentHandler.nodeParams.nodeId),
finalPayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()),
payloadLength = OnionRoutingPacket.PaymentPacketLength
).third.packet,
blinding = null
).third.packet
)
val result = paymentHandler.process(payToOpenRequest, TestConstants.defaultBlockHeight)

Expand Down Expand Up @@ -344,8 +343,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
hops = trampolineHops,
finalPayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret.reversed()), // <-- wrong secret
payloadLength = 400
).third.packet,
blinding = null
).third.packet
)
val result = paymentHandler.process(payToOpenRequest, TestConstants.defaultBlockHeight)

Expand Down Expand Up @@ -1319,8 +1317,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
hops = channelHops(TestConstants.Bob.nodeParams.nodeId),
finalPayload = finalPayload,
payloadLength = OnionRoutingPacket.PaymentPacketLength
).third.packet,
blinding = null
).third.packet
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,18 +752,25 @@ class LightningCodecsTestsCommon : LightningTestSuite() {

@Test
fun `encode - decode pay-to-open messages`() {
val onionPacket = OnionRoutingPacket(0, ByteVector("0209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c"), ByteVector("0102030405"), ByteVector32("e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f"))
val blinding = PublicKey.fromHex("033da8b63fd839472b49935127072039e65d8f99d4603f14d79cdf74b59f895721")
val preimage = ByteVector32("339770785632e71fe1f4b48b8b90d14af94a2a3a2c70af66f2156ed8a150f795")
val testCases = listOf(
PayToOpenRequest(BlockHash(randomBytes32()), 10_000.sat, 5_000.msat, 100.msat, 10.sat, randomBytes32(), 100, OnionRoutingPacket(0, randomKey().publicKey().value, ByteVector("0102030405"), randomBytes32())),
PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Success(randomBytes32())),
PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Failure(null)),
PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Failure(ByteVector("deadbeef"))),
// @formatter:off
PayToOpenRequest(Block.LivenetGenesisBlock.hash, 10_000.sat, 5_000.msat, 100.msat, 10.sat, preimage.sha256(), 100, onionPacket) to Hex.decode("88cd 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 0000000000002710 0000000000001388 0000000000000064 000000000000000a e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 00000064 0005 000209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c0102030405e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f"),
PayToOpenRequest(Block.LivenetGenesisBlock.hash, 10_000.sat, 5_000.msat, 100.msat, 10.sat, preimage.sha256(), 100, onionPacket, TlvStream(PayToOpenRequestTlv.Blinding(blinding))) to Hex.decode("88cd 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 0000000000002710 0000000000001388 0000000000000064 000000000000000a e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 00000064 0005 000209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c0102030405e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f 0021033da8b63fd839472b49935127072039e65d8f99d4603f14d79cdf74b59f895721"),
PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Success(preimage)) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 339770785632e71fe1f4b48b8b90d14af94a2a3a2c70af66f2156ed8a150f795"),
PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Failure(null)) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 0000000000000000000000000000000000000000000000000000000000000000"),
PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Failure(ByteVector("deadbeef"))) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 0000000000000000000000000000000000000000000000000000000000000000 0004deadbeef"),
// @formatter:on
)

testCases.forEach {
val encoded = LightningMessage.encode(it)
val decoded = LightningMessage.decode(encoded)
val decoded = LightningMessage.decode(it.second)
assertNotNull(decoded)
assertEquals(it, decoded)
assertEquals(it.first, decoded)
val encoded = LightningMessage.encode(decoded)
assertArrayEquals(it.second, encoded)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import fr.acinq.lightning.tests.utils.LightningTestSuite
import fr.acinq.lightning.utils.msat
import fr.acinq.secp256k1.Hex
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFails
import kotlin.test.assertNull

class PaymentOnionTestsCommon : LightningTestSuite() {
@Test
Expand Down Expand Up @@ -55,7 +55,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() {
}

@Test
fun `encode - decode variable-length -- tlv -- node relay per-hop payload`() {
fun `encode - decode node relay per-hop payload`() {
val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"))
val expected = PaymentOnion.NodeRelayPayload(TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.OutgoingNodeId(nodeId)))
val bin = Hex.decode("2e 02020231 04012a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")
Expand All @@ -72,7 +72,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() {
}

@Test
fun `encode - decode variable-length -- tlv -- node relay to legacy per-hop payload`() {
fun `encode - decode node relay to legacy per-hop payload`() {
val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"))
val features = ByteVector("0a")
val node1 = PublicKey(Hex.decode("036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2"))
Expand Down Expand Up @@ -110,7 +110,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() {
}

@Test
fun `encode - decode variable-length -- tlv -- final per-hop payload`() {
fun `encode - decode final per-hop payload`() {
val testCases = mapOf(
TlvStream(
OnionPaymentPayloadTlv.AmountToForward(561.msat),
Expand Down Expand Up @@ -181,7 +181,27 @@ class PaymentOnionTestsCommon : LightningTestSuite() {
}

@Test
fun `encode - decode variable-length -- tlv -- final per-hop payload with custom user records`() {
fun `encode - decode final blinded per-hop payload`() {
val blindedTlvs = TlvStream(
RouteBlindingEncryptedDataTlv.PathId(ByteVector("2a2a2a2a")),
RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1.msat)
)
val testCases = mapOf(
// @formatter:off
TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(1234567)), OnionPaymentPayloadTlv.EncryptedRecipientData(ByteVector("deadbeef")), OnionPaymentPayloadTlv.TotalAmount(1105.msat)) to Hex.decode("13 02020231 040312d687 0a04deadbeef 12020451"),
TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(1234567)), OnionPaymentPayloadTlv.EncryptedRecipientData(ByteVector("deadbeef")), OnionPaymentPayloadTlv.BlindingPoint(PublicKey.fromHex("036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")), OnionPaymentPayloadTlv.TotalAmount(1105.msat)) to Hex.decode("36 02020231 040312d687 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2 12020451"),
// @formatter:on
)
testCases.forEach {
val decoded = PaymentOnion.PerHopPayload.tlvSerializer.read(it.value)
assertEquals(it.key, decoded)
val payload = PaymentOnion.FinalPayload.Blinded(it.key, blindedTlvs)
assertContentEquals(payload.write(), it.value)
}
}

@Test
fun `encode - decode final per-hop payload with custom user records`() {
val tlvs = TlvStream(
setOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.PaymentData(ByteVector32.Zeroes, 0.msat)),
setOf(GenericTlv(5432123457L, ByteVector("16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828")))
Expand Down Expand Up @@ -213,7 +233,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() {
}

@Test
fun `decode variable-length -- tlv -- final per-hop payload missing information`() {
fun `decode final per-hop payload missing information`() {
val testCases = listOf(
Hex.decode("25 04012a 0820ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // missing amount
Hex.decode("26 02020231 0820ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // missing cltv
Expand Down

0 comments on commit 63d3a5e

Please sign in to comment.