diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index 7eef03c941..edea5c157a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -20,6 +20,7 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.wire.protocol.MessageOnion.{BlindedFinalPayload, BlindedRelayPayload, FinalPayload, RelayPayload} import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol._ import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} @@ -36,21 +37,21 @@ object OnionMessages { intermediateNodes: Seq[IntermediateNode], destination: Either[Recipient, Sphinx.RouteBlinding.BlindedRoute]): Sphinx.RouteBlinding.BlindedRoute = { val last = destination match { - case Left(Recipient(nodeId, _, _)) => EncryptedRecipientDataTlv.OutgoingNodeId(nodeId) :: Nil - case Right(Sphinx.RouteBlinding.BlindedRoute(nodeId, blindingKey, _)) => EncryptedRecipientDataTlv.OutgoingNodeId(nodeId) :: EncryptedRecipientDataTlv.NextBlinding(blindingKey) :: Nil + case Left(Recipient(nodeId, _, _)) => OutgoingNodeId(nodeId) :: Nil + case Right(Sphinx.RouteBlinding.BlindedRoute(nodeId, blindingKey, _)) => OutgoingNodeId(nodeId) :: NextBlinding(blindingKey) :: Nil } val intermediatePayloads = if (intermediateNodes.isEmpty) { Nil } else { - (intermediateNodes.tail.map(node => EncryptedRecipientDataTlv.OutgoingNodeId(node.nodeId) :: Nil) :+ last) - .zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(EncryptedRecipientDataTlv.Padding(_) :: Nil).getOrElse(Nil) ++ tlvs } + (intermediateNodes.tail.map(node => OutgoingNodeId(node.nodeId) :: Nil) :+ last) + .zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(Padding(_) :: Nil).getOrElse(Nil) ++ tlvs } .map(tlvs => BlindedRelayPayload(TlvStream(tlvs))) .map(MessageOnionCodecs.blindedRelayPayloadCodec.encode(_).require.bytes) } destination match { case Left(Recipient(nodeId, pathId, padding)) => - val tlvs = padding.map(EncryptedRecipientDataTlv.Padding(_) :: Nil).getOrElse(Nil) ++ pathId.map(EncryptedRecipientDataTlv.PathId(_) :: Nil).getOrElse(Nil) + val tlvs = padding.map(Padding(_) :: Nil).getOrElse(Nil) ++ pathId.map(PathId(_) :: Nil).getOrElse(Nil) val lastPayload = MessageOnionCodecs.blindedFinalPayloadCodec.encode(BlindedFinalPayload(TlvStream(tlvs))).require.bytes Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ nodeId, intermediatePayloads :+ lastPayload) case Right(route) => @@ -61,11 +62,12 @@ object OnionMessages { /** * Builds an encrypted onion containing a message that should be relayed to the destination. - * @param sessionKey A random key to encrypt the onion - * @param blindingSecret A random key to encrypt the onion + * + * @param sessionKey A random key to encrypt the onion + * @param blindingSecret A random key to encrypt the onion * @param intermediateNodes List of intermediate nodes between us and the destination, can be empty if we want to contact the destination directly - * @param destination The destination of this message, can be a node id or a blinded route - * @param content List of TLVs to send to the recipient of the message + * @param destination The destination of this message, can be a node id or a blinded route + * @param content List of TLVs to send to the recipient of the message * @return The node id to send the onion to and the onion containing the message */ def buildMessage(sessionKey: PrivateKey, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala index bc01fd3990..ecc28a86f5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala @@ -53,9 +53,9 @@ object MessageOnion { } /** Content of the encrypted data of an intermediate node's per-hop payload. */ - case class BlindedRelayPayload(records: TlvStream[EncryptedRecipientDataTlv]) { - val nextNodeId: PublicKey = records.get[EncryptedRecipientDataTlv.OutgoingNodeId].get.nodeId - val nextBlindingOverride: Option[PublicKey] = records.get[EncryptedRecipientDataTlv.NextBlinding].map(_.blinding) + case class BlindedRelayPayload(records: TlvStream[RouteBlindingEncryptedDataTlv]) { + val nextNodeId: PublicKey = records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId + val nextBlindingOverride: Option[PublicKey] = records.get[RouteBlindingEncryptedDataTlv.NextBlinding].map(_.blinding) } /** Per-hop payload for a final node. */ @@ -65,8 +65,8 @@ object MessageOnion { } /** Content of the encrypted data of a final node's per-hop payload. */ - case class BlindedFinalPayload(records: TlvStream[EncryptedRecipientDataTlv]) { - val pathId: Option[ByteVector] = records.get[EncryptedRecipientDataTlv.PathId].map(_.data) + case class BlindedFinalPayload(records: TlvStream[RouteBlindingEncryptedDataTlv]) { + val pathId: Option[ByteVector] = records.get[RouteBlindingEncryptedDataTlv.PathId].map(_.data) } } @@ -106,15 +106,15 @@ object MessageOnionCodecs { case FinalPayload(tlvs) => tlvs }) - val blindedRelayPayloadCodec: Codec[BlindedRelayPayload] = EncryptedRecipientDataCodecs.encryptedRecipientDataCodec.narrow({ - case tlvs if tlvs.get[EncryptedRecipientDataTlv.OutgoingNodeId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) - case tlvs if tlvs.get[EncryptedRecipientDataTlv.PathId].nonEmpty => Attempt.failure(ForbiddenTlv(UInt64(6))) + val blindedRelayPayloadCodec: Codec[BlindedRelayPayload] = RouteBlindingEncryptedDataCodecs.encryptedDataCodec.narrow({ + case tlvs if tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case tlvs if tlvs.get[RouteBlindingEncryptedDataTlv.PathId].nonEmpty => Attempt.failure(ForbiddenTlv(UInt64(6))) case tlvs => Attempt.successful(BlindedRelayPayload(tlvs)) }, { case BlindedRelayPayload(tlvs) => tlvs }) - val blindedFinalPayloadCodec: Codec[BlindedFinalPayload] = EncryptedRecipientDataCodecs.encryptedRecipientDataCodec.narrow( + val blindedFinalPayloadCodec: Codec[BlindedFinalPayload] = RouteBlindingEncryptedDataCodecs.encryptedDataCodec.narrow( tlvs => Attempt.successful(BlindedFinalPayload(tlvs)), { case BlindedFinalPayload(tlvs) => tlvs @@ -128,4 +128,5 @@ object MessageOnionCodecs { ("publicKey" | bytes(33)), ("onionPayload" | bytes)) ~ ("hmac" | bytes32) flattenLeftPairs).as[OnionRoutingPacket] + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala similarity index 67% rename from eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataTlv.scala rename to eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 2eef111ce0..546383ea72 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -17,24 +17,28 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.{ShortChannelId, UInt64} import scodec.bits.ByteVector import scala.util.Try -sealed trait EncryptedRecipientDataTlv extends Tlv +/** + * Created by t-bast on 19/10/2021. + */ + +sealed trait RouteBlindingEncryptedDataTlv extends Tlv -object EncryptedRecipientDataTlv { +object RouteBlindingEncryptedDataTlv { /** Some padding can be added to ensure all payloads are the same size to improve privacy. */ - case class Padding(dummy: ByteVector) extends EncryptedRecipientDataTlv + case class Padding(dummy: ByteVector) extends RouteBlindingEncryptedDataTlv /** Id of the outgoing channel, used to identify the next node. */ - case class OutgoingChannelId(shortChannelId: ShortChannelId) extends EncryptedRecipientDataTlv + case class OutgoingChannelId(shortChannelId: ShortChannelId) extends RouteBlindingEncryptedDataTlv /** Id of the next node. */ - case class OutgoingNodeId(nodeId: PublicKey) extends EncryptedRecipientDataTlv + case class OutgoingNodeId(nodeId: PublicKey) extends RouteBlindingEncryptedDataTlv /** * The final recipient may store some data in the encrypted payload for itself to avoid storing it locally. @@ -42,16 +46,16 @@ object EncryptedRecipientDataTlv { * It should use that field to detect when blinded routes are used outside of their intended use (malicious probing) * and react accordingly (ignore the message or send an error depending on the use-case). */ - case class PathId(data: ByteVector) extends EncryptedRecipientDataTlv + case class PathId(data: ByteVector) extends RouteBlindingEncryptedDataTlv /** Blinding override for the rest of the route. */ - case class NextBlinding(blinding: PublicKey) extends EncryptedRecipientDataTlv + case class NextBlinding(blinding: PublicKey) extends RouteBlindingEncryptedDataTlv } -object EncryptedRecipientDataCodecs { +object RouteBlindingEncryptedDataCodecs { - import EncryptedRecipientDataTlv._ + import RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol.CommonCodecs.{publicKey, shortchannelid, varint, varintoverflow} import scodec.Codec import scodec.bits.HexStringSyntax @@ -63,27 +67,27 @@ object EncryptedRecipientDataCodecs { private val pathId: Codec[PathId] = variableSizeBytesLong(varintoverflow, "path_id" | bytes).as[PathId] private val nextBlinding: Codec[NextBlinding] = (("length" | constant(hex"21")) :: ("blinding" | publicKey)).as[NextBlinding] - private val encryptedRecipientDataTlvCodec = discriminated[EncryptedRecipientDataTlv].by(varint) + private val encryptedDataTlvCodec = discriminated[RouteBlindingEncryptedDataTlv].by(varint) .typecase(UInt64(1), padding) .typecase(UInt64(2), outgoingChannelId) .typecase(UInt64(4), outgoingNodeId) .typecase(UInt64(6), pathId) .typecase(UInt64(8), nextBlinding) - val encryptedRecipientDataCodec: Codec[TlvStream[EncryptedRecipientDataTlv]] = TlvCodecs.tlvStream[EncryptedRecipientDataTlv](encryptedRecipientDataTlvCodec).complete + val encryptedDataCodec: Codec[TlvStream[RouteBlindingEncryptedDataTlv]] = TlvCodecs.tlvStream[RouteBlindingEncryptedDataTlv](encryptedDataTlvCodec).complete /** * Decrypt and decode the contents of an encrypted_recipient_data TLV field. * - * @param nodePrivKey this node's private key. - * @param blindingKey blinding point (usually provided in the lightning message). - * @param encryptedRecipientData encrypted recipient data (usually provided inside an onion). + * @param nodePrivKey this node's private key. + * @param blindingKey blinding point (usually provided in the lightning message). + * @param encryptedData encrypted route blinding data (usually provided inside an onion). * @return decrypted contents of the encrypted recipient data, which usually contain information about the next node, * and the blinding point that should be sent to the next node. */ - def decode(nodePrivKey: PrivateKey, blindingKey: PublicKey, encryptedRecipientData: ByteVector): Try[(TlvStream[EncryptedRecipientDataTlv], PublicKey)] = { - RouteBlinding.decryptPayload(nodePrivKey, blindingKey, encryptedRecipientData).flatMap { - case (payload, nextBlindingKey) => encryptedRecipientDataCodec.decode(payload.bits).map(r => (r.value, nextBlindingKey)).toTry + def decode(nodePrivKey: PrivateKey, blindingKey: PublicKey, encryptedData: ByteVector): Try[(TlvStream[RouteBlindingEncryptedDataTlv], PublicKey)] = { + Sphinx.RouteBlinding.decryptPayload(nodePrivKey, blindingKey, encryptedData).flatMap { + case (payload, nextBlindingKey) => encryptedDataCodec.decode(payload.bits).map(r => (r.value, nextBlindingKey)).toTry } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 92aac5c7b7..c423ca0a79 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId, UInt64, randomKey} @@ -387,7 +388,7 @@ class SphinxSpec extends AnyFunSuite { )) assert(blindedRoute.encryptedPayloads === blindedRoute.introductionNode.encryptedPayload +: blindedRoute.subsequentNodes.map(_.encryptedPayload)) assert(blindedRoute.subsequentNodes.map(_.encryptedPayload) === Seq( - hex"146c9694ead7de2a54fc43e8bb927bfc377dda7ed5a2e36b327b739e368aa602e43e07e14b3d7ed493e7ea6245924d9a03d22f0fca56babd7da19f49b7", + hex"146c9694ead7de2a54fc43e8bb927bfc377dda7ed5a2e36b327b739e368aa602e43e07e14bfb81d66e1e295f848b6f15ee6483005abb830f4ef08a9da6", hex"8ad7d5d448f15208417a1840f82274101b3c254c24b1b49fd676fd0c4293c9aa66ed51da52579e934a869f016f213044d1b13b63bf586e9c9832106b59", hex"52a45a884542d180e76fe84fc13e71a01f65d943ff89aed29b94644a91b037b9143cfda8f1ff25ba61c37108a5ae57d9ddc5ab688ee8b2f9f6bd94522c", hex"6a4ac764cbf146ffd73299563b07c56052af4acd681d9d0882728c6f399ace90392b694d5e347612dc1417f1b3a9c82d6d4db18b6eb32134e554db7d00", @@ -422,6 +423,79 @@ class SphinxSpec extends AnyFunSuite { assert(payload4 === routeBlindingPayloads(4)) } + test("concatenate blinded routes (reference test vector)") { + // The recipient creates a blinded route to himself. + val (blindingOverride, blindedRouteEnd, payloadsEnd) = { + val sessionKey = PrivateKey(hex"0101010101010101010101010101010101010101010101010101010101010101") + val payloads = Seq( + hex"0421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", + hex"042102edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145", + hex"010f000000000000000000000000000000 061000112233445566778899aabbccddeeff" + ) + val blindedRoute = RouteBlinding.create(sessionKey, publicKeys.drop(2), payloads) + assert(blindedRoute.blindingKey === PublicKey(hex"031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f")) + (blindedRoute.blindingKey, blindedRoute, payloads) + } + // The sender also wants to use route blinding to reach the introduction point. + val (blindedRouteStart, payloadsStart) = { + val sessionKey = PrivateKey(hex"0202020202020202020202020202020202020202020202020202020202020202") + val payloads = Seq( + hex"0121000000000000000000000000000000000000000000000000000000000000000000 04210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c", + // NB: this payload contains the blinding key override. + hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007 0821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f" + ) + (RouteBlinding.create(sessionKey, publicKeys.take(2), payloads), payloads) + } + val blindedRoute = BlindedRoute(publicKeys(0), blindedRouteStart.blindingKey, blindedRouteStart.blindedNodes ++ blindedRouteEnd.blindedNodes) + assert(blindedRoute.blindingKey === PublicKey(hex"024d4b6cd1361032ca9bd2aeb9d900aa4d45d9ead80ac9423374c451a7254d0766")) + assert(blindedRoute.blindedNodeIds === Seq( + PublicKey(hex"0303176d13958a8a59d59517a6223e12cf291ba5f65c8011efcdca0a52c3850abc"), + PublicKey(hex"03adbdd3c0fb69641e96de2d5ac923ffc0910d3ed4dfe2314609fae61a71df4da2"), + PublicKey(hex"021026e6369e42b7f6d723c0c56a3e0b4d67111f07685bd03e9fa6d93ac6bb6dbe"), + PublicKey(hex"02ba3db3fe7f1ed28c4d82f28cf358373cbf3241a16aba265b1b6fb26f094c0c7f"), + PublicKey(hex"0379d4ca14cb19e2f7bcb217d36267e3d03b027bc4228923967f5b2e32cbb763c1"), + )) + assert(blindedRoute.encryptedPayloads === Seq( + hex"31da0d438752ed0f19ccd970a386ead7155fd187becd4e1770d561dffdb03d3568dac746dde98725f146582cb040207e8b6c070e28d707564a4dd9fb53f9274ad69d09add393b509a2fa42df5055d7c8aeda5881d5aa", + hex"d9dfa92f898dc8e37b73c944aa4205f225337b2edde67623e775c79e2bcf395dc205004aa07fdc65712afa5c2687aff9bb3d5e6af7c89cc94f23f962a27844ce7629773f9413ebcf131dbc35818410df207f29b013b0", + hex"30015dcdcbce70bdcd0125be8ccd541b101d95bcb049ccfc737f91c98cc139cb6f16354ec5a38e77eca769c2245ac4467524d6", + hex"11e49a0e5f4f8a73b30551bd20448abeb297339b6983ab30d4a227a858311656cbf2444aeff66bd4c8f320ce00ce4ddfed7ca3", + hex"fe7e62b65ac8e1c2a319ba53a5519b3f8073416971ae3e722ebc008f38999d590d70d40557e44557c0d32b891bd967119c1f78", + )) + + // The introduction point can decrypt its encrypted payload and obtain the next ephemeral public key. + val Success((payload0, ephKey1)) = RouteBlinding.decryptPayload(privKeys(0), blindedRoute.blindingKey, blindedRoute.encryptedPayloads(0)) + assert(payload0 === payloadsStart(0)) + assert(ephKey1 === PublicKey(hex"02be4b436dbc6cfa43d7d5652bc630ffdaf0dac93e6682db7950828506055ad1a7")) + + // The next node can derive the private key used to unwrap the onion and decrypt its encrypted payload. + assert(RouteBlinding.derivePrivateKey(privKeys(1), ephKey1).publicKey === blindedRoute.blindedNodeIds(1)) + val Success((payload1, ephKey2)) = RouteBlinding.decryptPayload(privKeys(1), ephKey1, blindedRoute.encryptedPayloads(1)) + assert(payload1 === payloadsStart(1)) + assert(ephKey2 === PublicKey(hex"03fb82254d740754efddc3318674f4e26cefcb8dec42a3910c08c64d19f25e50b7")) + // NB: this node finds a blinding override and will transmit that instead of ephKey2 to the next node. + assert(payload1.containsSlice(blindingOverride.value)) + + // The next node must be given the blinding override to derive the private key used to unwrap the onion and decrypt its encrypted payload. + assert(RouteBlinding.decryptPayload(privKeys(2), ephKey2, blindedRoute.encryptedPayloads(2)).isFailure) + assert(RouteBlinding.derivePrivateKey(privKeys(2), blindingOverride).publicKey === blindedRoute.blindedNodeIds(2)) + val Success((payload2, ephKey3)) = RouteBlinding.decryptPayload(privKeys(2), blindingOverride, blindedRoute.encryptedPayloads(2)) + assert(payload2 === payloadsEnd(0)) + assert(ephKey3 === PublicKey(hex"03932f4ab7605e8c046b5677becd4d61fdfdc8b9d10f1e9c3080ced0d64fd76931")) + + // The next node can derive the private key used to unwrap the onion and decrypt its encrypted payload. + assert(RouteBlinding.derivePrivateKey(privKeys(3), ephKey3).publicKey === blindedRoute.blindedNodeIds(3)) + val Success((payload3, ephKey4)) = RouteBlinding.decryptPayload(privKeys(3), ephKey3, blindedRoute.encryptedPayloads(3)) + assert(payload3 === payloadsEnd(1)) + assert(ephKey4 === PublicKey(hex"037bceb365470d24f8204c622e1b7959c6beeb774c634640de6c8401079159fc58")) + + // The last node can derive the private key used to unwrap the onion and decrypt its encrypted payload. + assert(RouteBlinding.derivePrivateKey(privKeys(4), ephKey4).publicKey === blindedRoute.blindedNodeIds(4)) + val Success((payload4, ephKey5)) = RouteBlinding.decryptPayload(privKeys(4), ephKey4, blindedRoute.encryptedPayloads(4)) + assert(payload4 === payloadsEnd(2)) + assert(ephKey5 === PublicKey(hex"0339ddfa85a2155fb27e94742885fad85696e54920aa148cb86e00bcb8ee346bd4")) + } + test("invalid blinded route") { val encryptedPayloads = RouteBlinding.create(sessionKey, publicKeys, routeBlindingPayloads).encryptedPayloads // Invalid node private key: @@ -473,9 +547,9 @@ class SphinxSpec extends AnyFunSuite { val tlvs2 = PaymentOnionCodecs.tlvPerHopPayloadCodec.decode(payload2.bits).require.value assert(tlvs2.get[OnionPaymentPayloadTlv.BlindingPoint].map(_.publicKey) === Some(blindingEphemeralKey0)) assert(tlvs2.get[OnionPaymentPayloadTlv.EncryptedRecipientData].nonEmpty) - val Success((recipientTlvs2, blindingEphemeralKey1)) = EncryptedRecipientDataCodecs.decode(privKeys(2), blindingEphemeralKey0, tlvs2.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) - assert(recipientTlvs2.get[EncryptedRecipientDataTlv.OutgoingChannelId].map(_.shortChannelId) === Some(ShortChannelId(1105))) - assert(recipientTlvs2.get[EncryptedRecipientDataTlv.OutgoingNodeId].map(_.nodeId) === Some(publicKeys(3))) + val Success((recipientTlvs2, blindingEphemeralKey1)) = RouteBlindingEncryptedDataCodecs.decode(privKeys(2), blindingEphemeralKey0, tlvs2.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) + assert(recipientTlvs2.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].map(_.shortChannelId) === Some(ShortChannelId(1105))) + assert(recipientTlvs2.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].map(_.nodeId) === Some(publicKeys(3))) // The fourth hop is a blinded hop. // It receives the blinding key from the previous node (e.g. in a tlv field in update_add_htlc) which it can use to @@ -485,8 +559,8 @@ class SphinxSpec extends AnyFunSuite { val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(blindedPrivKey3, associatedData, nextPacket2) val tlvs3 = PaymentOnionCodecs.tlvPerHopPayloadCodec.decode(payload3.bits).require.value assert(tlvs3.get[OnionPaymentPayloadTlv.EncryptedRecipientData].nonEmpty) - val Success((recipientTlvs3, blindingEphemeralKey2)) = EncryptedRecipientDataCodecs.decode(privKeys(3), blindingEphemeralKey1, tlvs3.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) - assert(recipientTlvs3.get[EncryptedRecipientDataTlv.OutgoingNodeId].map(_.nodeId) === Some(publicKeys(4))) + val Success((recipientTlvs3, blindingEphemeralKey2)) = RouteBlindingEncryptedDataCodecs.decode(privKeys(3), blindingEphemeralKey1, tlvs3.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) + assert(recipientTlvs3.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].map(_.nodeId) === Some(publicKeys(4))) // The fifth hop is the blinded recipient. // It receives the blinding key from the previous node (e.g. in a tlv field in update_add_htlc) which it can use to @@ -495,8 +569,8 @@ class SphinxSpec extends AnyFunSuite { val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(blindedPrivKey4, associatedData, nextPacket3) val tlvs4 = PaymentOnionCodecs.tlvPerHopPayloadCodec.decode(payload4.bits).require.value assert(tlvs4.get[OnionPaymentPayloadTlv.EncryptedRecipientData].nonEmpty) - val Success((recipientTlvs4, _)) = EncryptedRecipientDataCodecs.decode(privKeys(4), blindingEphemeralKey2, tlvs4.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) - assert(recipientTlvs4.get[EncryptedRecipientDataTlv.PathId].map(_.data) === associatedData.map(_.bytes)) + val Success((recipientTlvs4, _)) = RouteBlindingEncryptedDataCodecs.decode(privKeys(4), blindingEphemeralKey2, tlvs4.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) + assert(recipientTlvs4.get[RouteBlindingEncryptedDataTlv.PathId].map(_.data) === associatedData.map(_.bytes)) assert(Seq(payload0, payload1, payload2, payload3, payload4) == payloads) assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1)) @@ -582,10 +656,10 @@ object SphinxSpec { hex"23 f8 21 02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619" ) - // This test vector uses route blinding payloads (encrypted_recipient_data). + // This test vector uses route blinding payloads (encrypted_data). val routeBlindingPayloads = Seq( hex"0208000000000000002a 3903123456", - hex"011900000000000000000000000000000000000000000000000000 02080000000000000231 fdffff0206c1 3b00", + hex"011900000000000000000000000000000000000000000000000000 02080000000000000231 3b00 fdffff0206c1", hex"02080000000000000451 0421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"01080000000000000000 042102edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145", hex"0109000000000000000000 06204242424242424242424242424242424242424242424242424242424242424242", diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index 772bba1378..166c895707 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -22,10 +22,11 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.PacketAndSecrets import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} import fr.acinq.eclair.randomKey -import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.{blindedRelayPayloadCodec, blindedFinalPayloadCodec, relayPerHopPayloadCodec} import fr.acinq.eclair.wire.protocol.MessageOnion.{BlindedFinalPayload, BlindedRelayPayload, RelayPayload} +import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.{blindedFinalPayloadCodec, blindedRelayPayloadCodec, relayPerHopPayloadCodec} import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData -import fr.acinq.eclair.wire.protocol.{EncryptedRecipientDataTlv, OnionMessage, TlvStream} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ +import fr.acinq.eclair.wire.protocol.{OnionMessage, TlvStream} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -66,16 +67,16 @@ class OnionMessagesSpec extends AnyFunSuite { /* * Building the onion manually */ - val messageForAlice = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.OutgoingNodeId(bob.publicKey))) + val messageForAlice = BlindedRelayPayload(TlvStream(OutgoingNodeId(bob.publicKey))) val encodedForAlice = blindedRelayPayloadCodec.encode(messageForAlice).require.bytes assert(encodedForAlice == hex"04210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c") - val messageForBob = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.OutgoingNodeId(carol.publicKey), EncryptedRecipientDataTlv.NextBlinding(blindingOverride.publicKey))) + val messageForBob = BlindedRelayPayload(TlvStream(OutgoingNodeId(carol.publicKey), NextBlinding(blindingOverride.publicKey))) val encodedForBob = blindedRelayPayloadCodec.encode(messageForBob).require.bytes assert(encodedForBob == hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007082102989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f") - val messageForCarol = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), EncryptedRecipientDataTlv.OutgoingNodeId(dave.publicKey))) + val messageForCarol = BlindedRelayPayload(TlvStream(Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OutgoingNodeId(dave.publicKey))) val encodedForCarol = blindedRelayPayloadCodec.encode(messageForCarol).require.bytes assert(encodedForCarol == hex"012300000000000000000000000000000000000000000000000000000000000000000000000421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991") - val messageForDave = BlindedFinalPayload(TlvStream(EncryptedRecipientDataTlv.PathId(hex"01234567"))) + val messageForDave = BlindedFinalPayload(TlvStream(PathId(hex"01234567"))) val encodedForDave = blindedFinalPayloadCodec.encode(messageForDave).require.bytes assert(encodedForDave == hex"060401234567") @@ -96,7 +97,7 @@ class OnionMessagesSpec extends AnyFunSuite { val sessionKey = PrivateKey(hex"090909090909090909090909090909090909090909090909090909090909090901") - val PacketAndSecrets(packet, _) = Sphinx.create(sessionKey,1300, publicKeys, payloads, None) + val PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, 1300, publicKeys, payloads, None) assert(packet.hmac == ByteVector32(hex"d84e7135092450c8cc98bb969aa6d9127dd07da53a3c46b2e9339d111f5f301d")) assert(packet.publicKey == PublicKey(hex"0256b328b30c8bf5839e24058747879408bdb36241dc9c2e7c619faa12b2920967").value) assert(packet.payload == @@ -155,7 +156,7 @@ class OnionMessagesSpec extends AnyFunSuite { assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"bae3d9ea2b06efd1b7b9b49b6cdcaad0e789474a6939ffa54ff5ec9224d5b76c")) val enctlv = hex"6970e870b473ddbc27e3098bfa45bb1aa54f1f637f803d957e6271d8ffeba89da2665d62123763d9b634e30714144a1c165ac9" assert(blindedNodes.head.encryptedPayload == enctlv) - val message = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.OutgoingNodeId(nextNodeId))) + val message = BlindedRelayPayload(TlvStream(OutgoingNodeId(nextNodeId))) assert(blindedRelayPayloadCodec.encode(message).require.bytes == encmsg) val relayNext = blindedRelayPayloadCodec.decode(encmsg.bits).require.value assert(relayNext.nextNodeId == nextNodeId) @@ -182,7 +183,7 @@ class OnionMessagesSpec extends AnyFunSuite { assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"9afb8b2ebc174dcf9e270be24771da7796542398d29d4ff6a4e7b6b4b9205cfe")) val enctlv = hex"1630da85e8759b8f3b94d74a539c6f0d870a87cf03d4986175865a2985553c997b560c32613bd9184c1a6d41a37027aabdab5433009d8409a1b638eb90373778a05716af2c2140b3196dca23997cdad4cfa7a7adc8d4" assert(blindedHops.head.encryptedPayload == enctlv) - val message = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.OutgoingNodeId(nextNodeId), EncryptedRecipientDataTlv.NextBlinding(PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701").publicKey))) + val message = BlindedRelayPayload(TlvStream(OutgoingNodeId(nextNodeId), NextBlinding(PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701").publicKey))) assert(blindedRelayPayloadCodec.encode(message).require.bytes == encmsg) val relayNext = blindedRelayPayloadCodec.decode(encmsg.bits).require.value assert(relayNext.nextNodeId == nextNodeId) @@ -209,7 +210,7 @@ class OnionMessagesSpec extends AnyFunSuite { assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"cc3b918cda6b1b049bdbe469c4dd952935e7c1518dd9c7ed0cd2cd5bc2742b82")) val enctlv = hex"8285acbceb37dfb38b877a888900539be656233cd74a55c55344fb068f9d8da365340d21db96fb41b76123207daeafdfb1f571e3fea07a22e10da35f03109a0380b3c69fcbed9c698086671809658761cf65ecbc3c07a2e5" assert(blindedHops.head.encryptedPayload == enctlv) - val message = BlindedRelayPayload(TlvStream(EncryptedRecipientDataTlv.Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), EncryptedRecipientDataTlv.OutgoingNodeId(nextNodeId))) + val message = BlindedRelayPayload(TlvStream(Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OutgoingNodeId(nextNodeId))) assert(blindedRelayPayloadCodec.encode(message).require.bytes == encmsg) val relayNext = blindedRelayPayloadCodec.decode(encmsg.bits).require.value assert(relayNext.nextNodeId == nextNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala deleted file mode 100644 index 315e864cbf..0000000000 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala +++ /dev/null @@ -1,62 +0,0 @@ -package fr.acinq.eclair.wire.protocol - -import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding -import fr.acinq.eclair.wire.protocol.EncryptedRecipientDataTlv._ -import fr.acinq.eclair.{ShortChannelId, UInt64, randomKey} -import org.scalatest.funsuite.AnyFunSuiteLike -import scodec.bits.HexStringSyntax - -import scala.util.Success - -class EncryptedRecipientDataSpec extends AnyFunSuiteLike { - - test("decode encrypted recipient data") { - val sessionKey = randomKey() - val nodePrivKeys = Seq(randomKey(), randomKey(), randomKey(), randomKey(), randomKey()) - val payloads = Seq( - (TlvStream[EncryptedRecipientDataTlv](Padding(hex"000000"), OutgoingChannelId(ShortChannelId(561))), hex"0103000000 02080000000000000231"), - (TlvStream[EncryptedRecipientDataTlv](OutgoingNodeId(PublicKey(hex"025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486"))), hex"0421025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486"), - (TlvStream[EncryptedRecipientDataTlv](OutgoingNodeId(PublicKey(hex"025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486")), NextBlinding(PublicKey(hex"027710df7a1d7ad02e3572841a829d141d9f56b17de9ea124d2f83ea687b2e0461"))), hex"0421025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486 0821027710df7a1d7ad02e3572841a829d141d9f56b17de9ea124d2f83ea687b2e0461"), - (TlvStream[EncryptedRecipientDataTlv](PathId(hex"0101010101010101010101010101010101010101010101010101010101010101")), hex"06200101010101010101010101010101010101010101010101010101010101010101"), - (TlvStream[EncryptedRecipientDataTlv](Seq(OutgoingChannelId(ShortChannelId(42))), Seq(GenericTlv(UInt64(65535), hex"06c1"))), hex"0208000000000000002a fdffff0206c1"), - ) - - val blindedRoute = RouteBlinding.create(sessionKey, nodePrivKeys.map(_.publicKey), payloads.map(_._2)) - val blinding0 = sessionKey.publicKey - val Success((decryptedPayload0, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(0), blinding0, blindedRoute.encryptedPayloads(0)) - val Success((decryptedPayload1, blinding2)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(1), blinding1, blindedRoute.encryptedPayloads(1)) - val Success((decryptedPayload2, blinding3)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(2), blinding2, blindedRoute.encryptedPayloads(2)) - val Success((decryptedPayload3, blinding4)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(3), blinding3, blindedRoute.encryptedPayloads(3)) - val Success((decryptedPayload4, _)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(4), blinding4, blindedRoute.encryptedPayloads(4)) - assert(Seq(decryptedPayload0, decryptedPayload1, decryptedPayload2, decryptedPayload3, decryptedPayload4) === payloads.map(_._1)) - } - - test("decode invalid encrypted recipient data") { - val testCases = Seq( - hex"02080000000000000231 ff", // additional trailing bytes after tlv stream - hex"01040000 02080000000000000231", // invalid padding tlv - hex"0420025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce14", // invalid public key length - hex"0c20025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce14", // invalid next blinding length - hex"02080000000000000231 0103000000", // invalid tlv stream ordering - hex"02080000000000000231 10080000000000000231", // unknown even tlv field - ) - - for (testCase <- testCases) { - val nodePrivKeys = Seq(randomKey(), randomKey()) - val payloads = Seq(hex"02080000000000000231", testCase) - val blindingPrivKey = randomKey() - val blindedRoute = RouteBlinding.create(blindingPrivKey, nodePrivKeys.map(_.publicKey), payloads) - // The payload for the first node is valid. - val blinding0 = blindingPrivKey.publicKey - val Success((_, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.head) - // If the first node is given invalid decryption material, it cannot decrypt recipient data. - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding0, blindedRoute.encryptedPayloads.head).isFailure) - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding1, blindedRoute.encryptedPayloads.head).isFailure) - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.last).isFailure) - // The payload for the last node is invalid, even with valid decryption material. - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding1, blindedRoute.encryptedPayloads.last).isFailure) - } - } - -} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala new file mode 100644 index 0000000000..67eaea0dd7 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala @@ -0,0 +1,84 @@ +package fr.acinq.eclair.wire.protocol + +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ +import fr.acinq.eclair.{ShortChannelId, UInt64, randomKey} +import org.scalatest.funsuite.AnyFunSuiteLike +import scodec.bits.{ByteVector, HexStringSyntax} + +import scala.util.Success + +class RouteBlindingSpec extends AnyFunSuiteLike { + + test("decode route blinding data (reference test vector)") { + val payloads = Map[ByteVector, TlvStream[RouteBlindingEncryptedDataTlv]]( + hex"0208000000000000002a 3903123456" -> TlvStream(Seq(OutgoingChannelId(ShortChannelId(42))), Seq(GenericTlv(UInt64(57), hex"123456"))), + hex"011900000000000000000000000000000000000000000000000000 02080000000000000231 3b00 fdffff0206c1" -> TlvStream(Seq(Padding(hex"00000000000000000000000000000000000000000000000000"), OutgoingChannelId(ShortChannelId(561))), Seq(GenericTlv(UInt64(59), hex""), GenericTlv(UInt64(65535), hex"06c1"))), + hex"02080000000000000451 0421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991" -> TlvStream(OutgoingChannelId(ShortChannelId(1105)), OutgoingNodeId(PublicKey(hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991"))), + hex"01080000000000000000 042102edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145" -> TlvStream(Padding(hex"0000000000000000"), OutgoingNodeId(PublicKey(hex"02edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145"))), + hex"0109000000000000000000 06204242424242424242424242424242424242424242424242424242424242424242" -> TlvStream(Padding(hex"000000000000000000"), PathId(hex"4242424242424242424242424242424242424242424242424242424242424242")), + hex"0421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991" -> TlvStream(OutgoingNodeId(PublicKey(hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991"))), + hex"042102edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145" -> TlvStream(OutgoingNodeId(PublicKey(hex"02edabbd16b41c8371b92ef2f04c1185b4f03b6dcd52ba9b78d9d7c89c8f221145"))), + hex"010f000000000000000000000000000000 061000112233445566778899aabbccddeeff" -> TlvStream(Padding(hex"000000000000000000000000000000"), PathId(hex"00112233445566778899aabbccddeeff")), + hex"0121000000000000000000000000000000000000000000000000000000000000000000 04210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c" -> TlvStream(Padding(hex"000000000000000000000000000000000000000000000000000000000000000000"), OutgoingNodeId(PublicKey(hex"0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c"))), + hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007 0821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f" -> TlvStream(OutgoingNodeId(PublicKey(hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007")), NextBlinding(PublicKey(hex"031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f"))), + ) + + for ((encoded, data) <- payloads) { + val decoded = RouteBlindingEncryptedDataCodecs.encryptedDataCodec.decode(encoded.bits).require.value + assert(decoded === data) + val reEncoded = RouteBlindingEncryptedDataCodecs.encryptedDataCodec.encode(data).require.bytes + assert(reEncoded === encoded) + } + } + + test("decode encrypted route blinding data") { + val sessionKey = randomKey() + val nodePrivKeys = Seq(randomKey(), randomKey(), randomKey(), randomKey(), randomKey()) + val payloads = Seq[(TlvStream[RouteBlindingEncryptedDataTlv], ByteVector)]( + (TlvStream(Padding(hex"000000"), OutgoingChannelId(ShortChannelId(561))), hex"0103000000 02080000000000000231"), + (TlvStream(OutgoingNodeId(PublicKey(hex"025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486"))), hex"0421025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486"), + (TlvStream(OutgoingNodeId(PublicKey(hex"025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486")), NextBlinding(PublicKey(hex"027710df7a1d7ad02e3572841a829d141d9f56b17de9ea124d2f83ea687b2e0461"))), hex"0421025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce1486 0821027710df7a1d7ad02e3572841a829d141d9f56b17de9ea124d2f83ea687b2e0461"), + (TlvStream(PathId(hex"0101010101010101010101010101010101010101010101010101010101010101")), hex"06200101010101010101010101010101010101010101010101010101010101010101"), + (TlvStream(Seq(OutgoingChannelId(ShortChannelId(42))), Seq(GenericTlv(UInt64(65535), hex"06c1"))), hex"0208000000000000002a fdffff0206c1"), + ) + + val blindedRoute = Sphinx.RouteBlinding.create(sessionKey, nodePrivKeys.map(_.publicKey), payloads.map(_._2)) + val blinding0 = sessionKey.publicKey + val Success((decryptedPayload0, blinding1)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys(0), blinding0, blindedRoute.encryptedPayloads(0)) + val Success((decryptedPayload1, blinding2)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys(1), blinding1, blindedRoute.encryptedPayloads(1)) + val Success((decryptedPayload2, blinding3)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys(2), blinding2, blindedRoute.encryptedPayloads(2)) + val Success((decryptedPayload3, blinding4)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys(3), blinding3, blindedRoute.encryptedPayloads(3)) + val Success((decryptedPayload4, _)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys(4), blinding4, blindedRoute.encryptedPayloads(4)) + assert(Seq(decryptedPayload0, decryptedPayload1, decryptedPayload2, decryptedPayload3, decryptedPayload4) === payloads.map(_._1)) + } + + test("decode invalid encrypted route blinding data") { + val testCases = Seq( + hex"02080000000000000231 ff", // additional trailing bytes after tlv stream + hex"01040000 02080000000000000231", // invalid padding tlv + hex"0420025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce14", // invalid public key length + hex"0c20025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce14", // invalid next blinding length + hex"02080000000000000231 0103000000", // invalid tlv stream ordering + hex"02080000000000000231 10080000000000000231", // unknown even tlv field + ) + + for (testCase <- testCases) { + val nodePrivKeys = Seq(randomKey(), randomKey()) + val payloads = Seq(hex"02080000000000000231", testCase) + val blindingPrivKey = randomKey() + val blindedRoute = Sphinx.RouteBlinding.create(blindingPrivKey, nodePrivKeys.map(_.publicKey), payloads) + // The payload for the first node is valid. + val blinding0 = blindingPrivKey.publicKey + val Success((_, blinding1)) = RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.head) + // If the first node is given invalid decryption material, it cannot decrypt recipient data. + assert(RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys.last, blinding0, blindedRoute.encryptedPayloads.head).isFailure) + assert(RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys.head, blinding1, blindedRoute.encryptedPayloads.head).isFailure) + assert(RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.last).isFailure) + // The payload for the last node is invalid, even with valid decryption material. + assert(RouteBlindingEncryptedDataCodecs.decode(nodePrivKeys.last, blinding1, blindedRoute.encryptedPayloads.last).isFailure) + } + } + +}