diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt index bd1aa3e0d..0ac111bab 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt @@ -372,4 +372,18 @@ sealed class PleaseOpenChannelRejectedTlv : Tlv { override fun read(input: Input): ExpectedFees = ExpectedFees(LightningCodecs.tu64(input).msat) } } +} + +sealed class PayToOpenRequestTlv : Tlv { + /** Blinding ephemeral public key that should be used to derive shared secrets when using route blinding. */ + data class Blinding(val publicKey: PublicKey) : PayToOpenRequestTlv() { + override val tag: Long get() = Blinding.tag + + override fun write(out: Output) = LightningCodecs.writeBytes(publicKey.value, out) + + companion object : TlvValueReader { + const val tag: Long = 0 + override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33))) + } + } } \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/HtlcTlv.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/HtlcTlv.kt new file mode 100644 index 000000000..e24adc3cd --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/HtlcTlv.kt @@ -0,0 +1,19 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.PublicKey +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output + +sealed class UpdateAddHtlcTlv : Tlv { + /** Blinding ephemeral public key that should be used to derive shared secrets when using route blinding. */ + data class Blinding(val publicKey: PublicKey) : UpdateAddHtlcTlv() { + override val tag: Long get() = Blinding.tag + + override fun write(out: Output) = LightningCodecs.writeBytes(publicKey.value, out) + + companion object : TlvValueReader { + const val tag: Long = 0 + override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33))) + } + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt index 17d15826d..27dc8b6a0 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt @@ -14,7 +14,6 @@ import fr.acinq.lightning.logging.* import fr.acinq.lightning.router.Announcements import fr.acinq.lightning.utils.* import fr.acinq.secp256k1.Hex -import kotlinx.serialization.Serializable import kotlin.math.max import kotlin.math.min @@ -1059,12 +1058,17 @@ data class UpdateAddHtlc( LightningCodecs.writeBytes(paymentHash, out) LightningCodecs.writeU32(cltvExpiry.toLong().toInt(), out) OnionRoutingPacketSerializer(OnionRoutingPacket.PaymentPacketLength).write(onionRoutingPacket, out) - UpdateAddHtlcTlv.tlvSerializer.write(tlvStream) + TlvStreamSerializer(false, readers).write(tlvStream, out) } companion object : LightningMessageReader { const val type: Long = 128 + @Suppress("UNCHECKED_CAST") + private val readers = mapOf( + UpdateAddHtlcTlv.Blinding.tag to UpdateAddHtlcTlv.Blinding as TlvValueReader + ) + override fun read(input: Input): UpdateAddHtlc { val channelId = ByteVector32(LightningCodecs.bytes(input, 32)) val id = LightningCodecs.u64(input) @@ -1072,7 +1076,7 @@ data class UpdateAddHtlc( val paymentHash = ByteVector32(LightningCodecs.bytes(input, 32)) val expiry = CltvExpiry(LightningCodecs.u32(input).toLong()) val onion = OnionRoutingPacketSerializer(OnionRoutingPacket.PaymentPacketLength).read(input) - val tlvStream = UpdateAddHtlcTlv.tlvSerializer.read(input) + val tlvStream = TlvStreamSerializer(false, readers).read(input) return UpdateAddHtlc(channelId, id, amount, paymentHash, expiry, onion, tlvStream) } @@ -1083,37 +1087,14 @@ data class UpdateAddHtlc( paymentHash: ByteVector32, cltvExpiry: CltvExpiry, onionRoutingPacket: OnionRoutingPacket, - blinding: PublicKey?): UpdateAddHtlc { - val tlvStream: TlvStream = blinding?.let { TlvStream(UpdateAddHtlcTlv.Blinding(it)) } ?: TlvStream.empty() + blinding: PublicKey? + ): UpdateAddHtlc { + val tlvStream = TlvStream(setOfNotNull(blinding?.let { UpdateAddHtlcTlv.Blinding(it) })) return UpdateAddHtlc(channelId, id, amountMsat, paymentHash, cltvExpiry, onionRoutingPacket, tlvStream) } } } -sealed class UpdateAddHtlcTlv : Tlv { - data class Blinding(val publicKey: PublicKey) : UpdateAddHtlcTlv() { - override val tag: Long get() = Blinding.tag - - override fun write(out: Output) { - LightningCodecs.writeBytes(publicKey.value, out) - } - - companion object : TlvValueReader { - const val tag: Long = 0 - - override fun read(input: Input): Blinding = Blinding(PublicKey(LightningCodecs.bytes(input, 33))) - } - } - - companion object { - val tlvSerializer = TlvStreamSerializer( - false, @Suppress("UNCHECKED_CAST") mapOf( - UpdateAddHtlcTlv.Blinding.tag to UpdateAddHtlcTlv.Blinding as TlvValueReader, - ) - ) - } -} - data class UpdateFulfillHtlc( override val channelId: ByteVector32, override val id: Long, @@ -1636,10 +1617,12 @@ data class PayToOpenRequest( val paymentHash: ByteVector32, val expireAt: Long, val finalPacket: OnionRoutingPacket, - val blinding: PublicKey? = null + val tlvStream: TlvStream = TlvStream.empty(), ) : LightningMessage, HasChainHash { override val type: Long get() = PayToOpenRequest.type + val blinding: PublicKey? = tlvStream.get()?.publicKey + override fun write(out: Output) { LightningCodecs.writeBytes(chainHash.value, out) LightningCodecs.writeU64(fundingSatoshis.toLong(), out) @@ -1650,12 +1633,17 @@ data class PayToOpenRequest( LightningCodecs.writeU32(expireAt.toInt(), out) LightningCodecs.writeU16(finalPacket.payload.size(), out) OnionRoutingPacketSerializer(finalPacket.payload.size()).write(finalPacket, out) - blinding?.let { LightningCodecs.writeBytes(it.value, out) } + TlvStreamSerializer(false, readers).write(tlvStream, out) } companion object : LightningMessageReader { const val type: Long = 35021 + @Suppress("UNCHECKED_CAST") + private val readers = mapOf( + PayToOpenRequestTlv.Blinding.tag to PayToOpenRequestTlv.Blinding as TlvValueReader + ) + override fun read(input: Input): PayToOpenRequest { return PayToOpenRequest( chainHash = BlockHash(LightningCodecs.bytes(input, 32)), @@ -1666,7 +1654,7 @@ data class PayToOpenRequest( paymentHash = ByteVector32(LightningCodecs.bytes(input, 32)), expireAt = LightningCodecs.u32(input).toLong(), finalPacket = OnionRoutingPacketSerializer(LightningCodecs.u16(input)).read(input), - blinding = if (input.availableBytes > 0) PublicKey(LightningCodecs.bytes(input, 33)) else null + tlvStream = TlvStreamSerializer(false, readers).read(input), ) } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt index f256a683d..ce9015630 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt @@ -251,8 +251,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { hops = channelHops(paymentHandler.nodeParams.nodeId), finalPayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()), payloadLength = OnionRoutingPacket.PaymentPacketLength - ).third.packet, - blinding = null + ).third.packet ) val result = paymentHandler.process(payToOpenRequest, TestConstants.defaultBlockHeight) @@ -344,8 +343,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { hops = trampolineHops, finalPayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret.reversed()), // <-- wrong secret payloadLength = 400 - ).third.packet, - blinding = null + ).third.packet ) val result = paymentHandler.process(payToOpenRequest, TestConstants.defaultBlockHeight) @@ -1319,8 +1317,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { hops = channelHops(TestConstants.Bob.nodeParams.nodeId), finalPayload = finalPayload, payloadLength = OnionRoutingPacket.PaymentPacketLength - ).third.packet, - blinding = null + ).third.packet ) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/LightningCodecsTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/LightningCodecsTestsCommon.kt index acb976052..5ab79a091 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/LightningCodecsTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/LightningCodecsTestsCommon.kt @@ -752,18 +752,25 @@ class LightningCodecsTestsCommon : LightningTestSuite() { @Test fun `encode - decode pay-to-open messages`() { + val onionPacket = OnionRoutingPacket(0, ByteVector("0209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c"), ByteVector("0102030405"), ByteVector32("e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f")) + val blinding = PublicKey.fromHex("033da8b63fd839472b49935127072039e65d8f99d4603f14d79cdf74b59f895721") + val preimage = ByteVector32("339770785632e71fe1f4b48b8b90d14af94a2a3a2c70af66f2156ed8a150f795") val testCases = listOf( - PayToOpenRequest(BlockHash(randomBytes32()), 10_000.sat, 5_000.msat, 100.msat, 10.sat, randomBytes32(), 100, OnionRoutingPacket(0, randomKey().publicKey().value, ByteVector("0102030405"), randomBytes32())), - PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Success(randomBytes32())), - PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Failure(null)), - PayToOpenResponse(BlockHash(randomBytes32()), randomBytes32(), PayToOpenResponse.Result.Failure(ByteVector("deadbeef"))), + // @formatter:off + PayToOpenRequest(Block.LivenetGenesisBlock.hash, 10_000.sat, 5_000.msat, 100.msat, 10.sat, preimage.sha256(), 100, onionPacket) to Hex.decode("88cd 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 0000000000002710 0000000000001388 0000000000000064 000000000000000a e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 00000064 0005 000209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c0102030405e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f"), + PayToOpenRequest(Block.LivenetGenesisBlock.hash, 10_000.sat, 5_000.msat, 100.msat, 10.sat, preimage.sha256(), 100, onionPacket, TlvStream(PayToOpenRequestTlv.Blinding(blinding))) to Hex.decode("88cd 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 0000000000002710 0000000000001388 0000000000000064 000000000000000a e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 00000064 0005 000209be9bd1016d73fc1f611c6f8fdccd99ffb0885594e96156a268ee9afd35559c0102030405e0a0d5be2ca6faafa03880258e4af33a0d15aa950ab738c88566a471bf3bb14f 0021033da8b63fd839472b49935127072039e65d8f99d4603f14d79cdf74b59f895721"), + PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Success(preimage)) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 339770785632e71fe1f4b48b8b90d14af94a2a3a2c70af66f2156ed8a150f795"), + PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Failure(null)) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 0000000000000000000000000000000000000000000000000000000000000000"), + PayToOpenResponse(Block.LivenetGenesisBlock.hash, preimage.sha256(), PayToOpenResponse.Result.Failure(ByteVector("deadbeef"))) to Hex.decode("88bb 6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000 e7e2ae1540f63627007acc816c8b978b8344265b581840d5feec7ff0a85bbf0b 0000000000000000000000000000000000000000000000000000000000000000 0004deadbeef"), + // @formatter:on ) testCases.forEach { - val encoded = LightningMessage.encode(it) - val decoded = LightningMessage.decode(encoded) + val decoded = LightningMessage.decode(it.second) assertNotNull(decoded) - assertEquals(it, decoded) + assertEquals(it.first, decoded) + val encoded = LightningMessage.encode(decoded) + assertArrayEquals(it.second, encoded) } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt index 91a543049..a28dad51f 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt @@ -13,9 +13,9 @@ import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.utils.msat import fr.acinq.secp256k1.Hex import kotlin.test.Test +import kotlin.test.assertContentEquals import kotlin.test.assertEquals import kotlin.test.assertFails -import kotlin.test.assertNull class PaymentOnionTestsCommon : LightningTestSuite() { @Test @@ -55,7 +55,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() { } @Test - fun `encode - decode variable-length -- tlv -- node relay per-hop payload`() { + fun `encode - decode node relay per-hop payload`() { val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) val expected = PaymentOnion.NodeRelayPayload(TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.OutgoingNodeId(nodeId))) val bin = Hex.decode("2e 02020231 04012a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") @@ -72,7 +72,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() { } @Test - fun `encode - decode variable-length -- tlv -- node relay to legacy per-hop payload`() { + fun `encode - decode node relay to legacy per-hop payload`() { val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) val features = ByteVector("0a") val node1 = PublicKey(Hex.decode("036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")) @@ -110,7 +110,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() { } @Test - fun `encode - decode variable-length -- tlv -- final per-hop payload`() { + fun `encode - decode final per-hop payload`() { val testCases = mapOf( TlvStream( OnionPaymentPayloadTlv.AmountToForward(561.msat), @@ -181,7 +181,27 @@ class PaymentOnionTestsCommon : LightningTestSuite() { } @Test - fun `encode - decode variable-length -- tlv -- final per-hop payload with custom user records`() { + fun `encode - decode final blinded per-hop payload`() { + val blindedTlvs = TlvStream( + RouteBlindingEncryptedDataTlv.PathId(ByteVector("2a2a2a2a")), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1.msat) + ) + val testCases = mapOf( + // @formatter:off + TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(1234567)), OnionPaymentPayloadTlv.EncryptedRecipientData(ByteVector("deadbeef")), OnionPaymentPayloadTlv.TotalAmount(1105.msat)) to Hex.decode("13 02020231 040312d687 0a04deadbeef 12020451"), + TlvStream(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(1234567)), OnionPaymentPayloadTlv.EncryptedRecipientData(ByteVector("deadbeef")), OnionPaymentPayloadTlv.BlindingPoint(PublicKey.fromHex("036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")), OnionPaymentPayloadTlv.TotalAmount(1105.msat)) to Hex.decode("36 02020231 040312d687 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2 12020451"), + // @formatter:on + ) + testCases.forEach { + val decoded = PaymentOnion.PerHopPayload.tlvSerializer.read(it.value) + assertEquals(it.key, decoded) + val payload = PaymentOnion.FinalPayload.Blinded(it.key, blindedTlvs) + assertContentEquals(payload.write(), it.value) + } + } + + @Test + fun `encode - decode final per-hop payload with custom user records`() { val tlvs = TlvStream( setOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.PaymentData(ByteVector32.Zeroes, 0.msat)), setOf(GenericTlv(5432123457L, ByteVector("16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"))) @@ -213,7 +233,7 @@ class PaymentOnionTestsCommon : LightningTestSuite() { } @Test - fun `decode variable-length -- tlv -- final per-hop payload missing information`() { + fun `decode final per-hop payload missing information`() { val testCases = listOf( Hex.decode("25 04012a 0820ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // missing amount Hex.decode("26 02020231 0820ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // missing cltv