Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly type Sphinx shared secrets #2959

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,17 @@ object Sphinx extends Logging {
val isLastPacket: Boolean = nextPacket.hmac == ByteVector32.Zeroes
}

/** Shared secret used to encrypt the payload for a given node. */
case class SharedSecret(secret: ByteVector32, remoteNodeId: PublicKey)

/**
* A encrypted onion packet with all the associated shared secrets.
*
* @param packet encrypted onion packet.
* @param sharedSecrets shared secrets (one per node in the route). Known (and needed) only if you're creating the
* packet. Empty if you're just forwarding the packet to the next node.
*/
case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[(ByteVector32, PublicKey)])
case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[SharedSecret])

/**
* Generate a deterministic filler to prevent intermediate nodes from knowing their position in the route.
Expand Down Expand Up @@ -239,12 +242,12 @@ object Sphinx extends Logging {
*/
def create(sessionKey: PrivateKey, packetPayloadLength: Int, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector], associatedData: Option[ByteVector32]): Try[PacketAndSecrets] = Try {
require(payloadsTotalSize(payloads) <= packetPayloadLength, s"packet per-hop payloads cannot exceed $packetPayloadLength bytes")
val (ephemeralPublicKeys, sharedsecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", packetPayloadLength, sharedsecrets.dropRight(1), payloads.dropRight(1))
val (ephemeralPublicKeys, sharedSecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", packetPayloadLength, sharedSecrets.dropRight(1), payloads.dropRight(1))

// We deterministically-derive the initial payload bytes: see https://github.com/lightningnetwork/lightning-rfc/pull/697
val startingBytes = generateStream(generateKey("pad", sessionKey.value), packetPayloadLength)
val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedsecrets.last, Left(startingBytes), filler)
val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedSecrets.last, Left(startingBytes), filler)

@tailrec
def loop(hopPayloads: Seq[ByteVector], ephKeys: Seq[PublicKey], sharedSecrets: Seq[ByteVector32], packet: OnionRoutingPacket): OnionRoutingPacket = {
Expand All @@ -254,8 +257,8 @@ object Sphinx extends Logging {
}
}

val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedsecrets.dropRight(1), lastPacket)
PacketAndSecrets(packet, sharedsecrets.zip(publicKeys))
val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedSecrets.dropRight(1), lastPacket)
PacketAndSecrets(packet, sharedSecrets.zip(publicKeys).map { case (secret, remoteNodeId) => SharedSecret(secret, remoteNodeId) })
}

/**
Expand Down Expand Up @@ -324,20 +327,18 @@ object Sphinx extends Logging {
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
* failure packet otherwise.
*/
def decrypt(packet: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = {
@tailrec
def loop(packet: ByteVector, secrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = secrets match {
@tailrec
def decrypt(packet: ByteVector, sharedSecrets: Seq[SharedSecret]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = {
sharedSecrets match {
case Nil => Left(CannotDecryptFailurePacket(packet))
case (secret, pubkey) :: tail =>
val packet1 = wrap(packet, secret)
val um = generateKey("um", secret)
case ss :: tail =>
val packet1 = wrap(packet, ss.secret)
val um = generateKey("um", ss.secret)
FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
case Attempt.Successful(value) => Right(DecryptedFailurePacket(pubkey, value.value))
case _ => loop(packet1, tail)
case Attempt.Successful(value) => Right(DecryptedFailurePacket(ss.remoteNodeId, value.value))
case _ => decrypt(packet1, tail)
}
}

loop(packet, sharedSecrets)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ object IncomingPaymentPacket {
* @param outgoingChannel channel to send the HTLC to.
* @param sharedSecrets shared secrets (used to decrypt the error in case of payment failure).
*/
case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[(ByteVector32, PublicKey)])
case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[Sphinx.SharedSecret])

/** Helpers to create outgoing payment packets. */
object OutgoingPaymentPacket {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ object PaymentLifecycle {
sealed trait Data
case object WaitingForRequest extends Data
case class WaitingForRoute(request: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data
case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data {
case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[Sphinx.SharedSecret], ignore: Ignore, route: Route) extends Data {
val recipient = request.recipient
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == referencePaymentPayloads)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))

val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4)
assert(packets(0).hmac == ByteVector32(hex"901fb2bb905d1cfac67727f900daa2bb9da6801ac31ccce78663e5021e83983b"))
Expand All @@ -159,7 +159,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == paymentPayloadsFull)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))

val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4)
assert(packets(0).hmac == ByteVector32(hex"859cd694cf604442547246f4fae144f255e71e30cb366b9775f488cac713f0db"))
Expand Down Expand Up @@ -196,7 +196,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, _, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == trampolinePaymentPayloads)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))
}

test("create packet with invalid payload") {
Expand Down Expand Up @@ -229,19 +229,19 @@ class SphinxSpec extends AnyFunSuite {
val packet1 = FailurePacket.create(sharedSecrets.head, expected.failureMessage)
assert(packet1.length == 292)

val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted1 == expected)

val packet2 = FailurePacket.wrap(packet1, sharedSecrets(1))
assert(packet2.length == 292)

val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted2 == expected)

val packet3 = FailurePacket.wrap(packet2, sharedSecrets(2))
assert(packet3.length == 292)

val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted3 == expected)
}

Expand All @@ -258,7 +258,7 @@ class SphinxSpec extends AnyFunSuite {
sharedSecrets(1)),
sharedSecrets(2))

assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => (sharedSecrets(i), publicKeys(i)))).isLeft)
assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))).isLeft)
}

test("last node replies with a short failure message (old reference test vector)") {
Expand Down Expand Up @@ -565,7 +565,7 @@ class SphinxSpec extends AnyFunSuite {
assert(payloadEve.allowedFeatures.isEmpty)

assert(Seq(onionPayloadAlice, onionPayloadBob, onionPayloadCarol, onionPayloadDave, onionPayloadEve) == payloads)
assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_._1))
assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_.secret))

val packets = Seq(packetForBob, packetForCarol, packetForDave, packetForEve, packetForNobody)
assert(packets(0).hmac == ByteVector32(hex"73fba184685e19b9af78afe876aa4e4b4242382b293133771d95a2bd83fa9c62"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1))
val failure = TemporaryChannelFailure(Some(update_bc))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// payment lifecycle will ask the router to temporarily exclude this channel from its route calculations
assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId)
Expand Down Expand Up @@ -533,7 +533,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING
routerForwarder.expectMsg(defaultRouteRequest(a, cfg))
Expand All @@ -548,7 +548,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(43), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure2 = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified_2))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head._1, failure2)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head.secret, failure2)))))

// this time the payment lifecycle will ask the router to temporarily exclude this channel from its route calculations
routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration)))
Expand Down Expand Up @@ -578,7 +578,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

// the node replies with a temporary failure containing the same update as the one we already have (likely a balance issue)
val failure = TemporaryChannelFailure(Some(update_bc))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))
// we should temporarily exclude that channel
assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId)
routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration)))
Expand Down Expand Up @@ -612,7 +612,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING
val extraEdges1 = Seq(
Expand Down Expand Up @@ -651,7 +651,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
// we disable the channel
val channelUpdate_cd_disabled = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, scid_cd, CltvExpiryDelta(42), update_cd.htlcMinimumMsat, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.htlcMaximumMsat, enable = false)
val failure = ChannelDisabled(channelUpdate_cd_disabled.messageFlags, channelUpdate_cd_disabled.channelFlags, Some(channelUpdate_cd_disabled))
val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1)
val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1).secret, failure), sharedSecrets1.head.secret)
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion))))

assert(routerForwarder.expectMsgType[RouteCouldRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc).map(_.shortChannelId))
Expand All @@ -674,7 +674,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData

register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// payment lifecycle forwards the embedded channelUpdate to the router
awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE)
Expand Down Expand Up @@ -713,7 +713,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

// The payment fails inside the blinded route: the introduction node sends back an error.
val failure = InvalidOnionBlinding(randomBytes32())
val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head._1, failure)
val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head.secret, failure)
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion))))

// We retry but we exclude the failed blinded route.
Expand Down Expand Up @@ -955,7 +955,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// The payment fails without retrying
sender.expectMsgType[PaymentFailed]
Expand Down
Loading