diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt index e361b8e17..8bc70d23f 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt @@ -648,19 +648,27 @@ object OfferTypes { } } + private fun isOfferTlv(tlv: GenericTlv): Boolean { + // Offer TLVs are in the range [1, 79] or [1000000000, 1999999999]. + return tlv.tag in 1..79 || tlv.tag in 1000000000..1999999999 + } + + private fun isInvoiceRequestTlv(tlv: GenericTlv): Boolean { + // Invoice request TLVs are in the range [0, 159] or [1000000000, 2999999999]. + return tlv.tag in 0..159 || tlv.tag in 1000000000..2999999999 + } + fun filterOfferFields(tlvs: TlvStream): TlvStream { - // Offer TLVs are in the range (0, 80). return TlvStream( tlvs.records.filterIsInstance().toSet(), - tlvs.unknown.filter { it.tag < 80 }.toSet() + tlvs.unknown.filter { isOfferTlv(it) }.toSet() ) } fun filterInvoiceRequestFields(tlvs: TlvStream): TlvStream { - // Invoice request TLVs are in the range [0, 160): invoice request metadata (tag 0), offer TLVs, and additional invoice request TLVs in the range [80, 160). return TlvStream( tlvs.records.filterIsInstance().toSet(), - tlvs.unknown.filter { it.tag < 160 }.toSet() + tlvs.unknown.filter { isInvoiceRequestTlv(it) }.toSet() ) } @@ -806,7 +814,7 @@ object OfferTypes { fun validate(records: TlvStream): Either { if (records.get() == null && records.get() != null) return Left(MissingRequiredTlv(10)) if (records.get() == null && records.get() == null) return Left(MissingRequiredTlv(22)) - if (records.unknown.any { it.tag >= 80 }) return Left(ForbiddenTlv(records.unknown.find { it.tag >= 80 }!!.tag)) + if (records.unknown.any { !isOfferTlv(it) }) return Left(ForbiddenTlv(records.unknown.find { !isOfferTlv(it) }!!.tag)) return Right(Offer(records)) } @@ -932,7 +940,7 @@ object OfferTypes { if (records.get() == null) return Left(MissingRequiredTlv(0L)) if (records.get() == null) return Left(MissingRequiredTlv(88)) if (records.get() == null) return Left(MissingRequiredTlv(240)) - if (records.unknown.any { it.tag >= 160 }) return Left(ForbiddenTlv(records.unknown.find { it.tag >= 160 }!!.tag)) + if (records.unknown.any { !isInvoiceRequestTlv(it) }) return Left(ForbiddenTlv(records.unknown.find { !isInvoiceRequestTlv(it) }!!.tag)) return Right(InvoiceRequest(records)) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt index 2580ddeef..6a63baab4 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt @@ -10,6 +10,7 @@ import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.Lightning.randomKey import fr.acinq.lightning.crypto.RouteBlinding import fr.acinq.lightning.logging.MDCLogger +import fr.acinq.lightning.payment.Bolt12Invoice import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.tests.utils.testLoggerFactory @@ -522,4 +523,21 @@ class OfferTypesTestsCommon : LightningTestSuite() { val expectedOffer = Offer.decode("lno1zrxq8pjw7qjlm68mtp7e3yvxee4y5xrgjhhyf2fxhlphpckrvevh50u0qf70a6j2x2akrhazctejaaqr8y4qtzjtjzmfesay6mzr3s789uryuqsr8dpgfgxuk56vh7cl89769zdpdrkqwtypzhu2t8ehp73dqeeq65lsqvlx5pj8mw2kz54p4f6ct66stdfxz0df8nqq7svjjdjn2dv8sz28y7z07yg3vqyfyy8ywevqc8kzp36lhd5cqwlpkg8vdcqsfvz89axkmv5sgdysmwn95tpsct6mdercmz8jh2r82qqscrf6uc3tse5gw5sv5xjdfw8f6c").get() assertEquals(expectedOffer, offer) } + + @Test + fun `experimental TLVs range`() { + val trampolineNode = PublicKey.fromHex("03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f") + val nodeParams = TestConstants.Alice.nodeParams.copy(chain = Chain.Mainnet) + val (defaultOffer, key) = nodeParams.defaultOffer(trampolineNode) + val offerWithUnknownTlvs = Offer.validate(TlvStream(defaultOffer.records.records, setOf(GenericTlv(53, ByteVector.fromHex("b46af6")), GenericTlv(1000759647, ByteVector.fromHex("41dec6"))))).right!! + assertTrue(Offer.validate(TlvStream(defaultOffer.records.records, setOf(GenericTlv(127, ByteVector.fromHex("cd58"))))).isLeft) + assertTrue(Offer.validate(TlvStream(defaultOffer.records.records, setOf(GenericTlv(2045259641, ByteVector.fromHex("e84ad9"))))).isLeft) + val request = InvoiceRequest(offerWithUnknownTlvs, 5500.msat, 1, Features.empty, randomKey(), null, Block.LivenetGenesisBlock.hash) + assertEquals(request.offer, offerWithUnknownTlvs) + val requestWithUnknownTlvs = InvoiceRequest.validate(TlvStream(request.records.records, setOf(GenericTlv(127, ByteVector.fromHex("cd58")), GenericTlv(2045259645, ByteVector.fromHex("e84ad9"))))).right!! + assertTrue(InvoiceRequest.validate(TlvStream(request.records.records, setOf(GenericTlv(197, ByteVector.fromHex("cd58"))))).isLeft) + assertTrue(InvoiceRequest.validate(TlvStream(request.records.records, setOf(GenericTlv(3975455643, ByteVector.fromHex("e84ad9"))))).isLeft) + val invoice = Bolt12Invoice(requestWithUnknownTlvs, randomBytes32(), key, 300, Features.empty, listOf()) + assertEquals(removeSignature(invoice.invoiceRequest.records), removeSignature(requestWithUnknownTlvs.records)) + } } \ No newline at end of file