Skip to content

Commit

Permalink
Make AsyncPaymentTriggerer the entry point for cancels
Browse files Browse the repository at this point in the history
  • Loading branch information
remyers committed Dec 29, 2022
1 parent c72c6be commit df11a7b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ import fr.acinq.eclair.io.{PeerReadyNotifier, Switchboard}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.Command
import fr.acinq.eclair.{BlockHeight, Logs}

import scala.concurrent.duration.Duration

/**
* This actor waits for an async payment receiver to become ready to receive a payment or for a block timeout to expire.
* If the receiver of the payment is a connected peer, spawn a PeerReadyNotifier actor.
Expand All @@ -40,13 +38,15 @@ object AsyncPaymentTriggerer {
sealed trait Command
case class Start(switchboard: ActorRef[Switchboard.GetPeerInfo]) extends Command
case class Watch(replyTo: ActorRef[Result], remoteNodeId: PublicKey, paymentHash: ByteVector32, timeout: BlockHeight) extends Command
case class Cancel(paymentHash: ByteVector32) extends Command
private[relay] case class NotifierStopped(remoteNodeId: PublicKey) extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command
private case class WrappedCurrentBlockHeight(currentBlockHeight: CurrentBlockHeight) extends Command

sealed trait Result
case object AsyncPaymentTriggered extends Result
case object AsyncPaymentTimeout extends Result
case object AsyncPaymentCanceled extends Result
// @formatter:on

def apply(): Behavior[Command] = Behaviors.setup { context =>
Expand All @@ -70,12 +70,21 @@ private class AsyncPaymentTriggerer(switchboard: ActorRef[Switchboard.GetPeerInf
def update(currentBlockHeight: BlockHeight): Option[PeerPayments] = {
val expiredPayments = pendingPayments.filter(_.expired(currentBlockHeight))
expiredPayments.foreach(e => e.replyTo ! AsyncPaymentTimeout)
val pendingPayments1 = pendingPayments.removedAll(expiredPayments)
if (pendingPayments1.isEmpty) {
updatePaymentsOrStop(pendingPayments.removedAll(expiredPayments))
}

def cancel(paymentHash: ByteVector32): Option[PeerPayments] = {
val canceledPayment = pendingPayments.find(_.paymentHash == paymentHash)
if (canceledPayment.isDefined) canceledPayment.get.replyTo ! AsyncPaymentCanceled
updatePaymentsOrStop(pendingPayments.removedAll(canceledPayment))
}

private def updatePaymentsOrStop(pendingPayments: Set[Payment]): Option[PeerPayments] = {
if (pendingPayments.isEmpty) {
context.stop(notifier)
None
} else {
Some(PeerPayments(notifier, pendingPayments1))
Some(PeerPayments(notifier, pendingPayments))
}
}

Expand Down Expand Up @@ -105,6 +114,11 @@ private class AsyncPaymentTriggerer(switchboard: ActorRef[Switchboard.GetPeerInf
val peer1 = PeerPayments(peer.notifier, peer.pendingPayments + Payment(replyTo, timeout, paymentHash))
watching(peers + (remoteNodeId -> peer1))
}
case Cancel(paymentHash) =>
val peers1 = peers.flatMap {
case (remoteNodeId, peer) => peer.cancel(paymentHash).map(peer1 => remoteNodeId -> peer1)
}
watching(peers1)
case WrappedCurrentBlockHeight(CurrentBlockHeight(currentBlockHeight)) =>
val peers1 = peers.flatMap {
case (remoteNodeId, peer) => peer.update(currentBlockHeight).map(peer1 => remoteNodeId -> peer1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ object NodeRelay {
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command
case object Stop extends Command
case object RelayAsyncPayment extends Command
case object CancelAsyncPayment extends Command
private case class WrappedMultiPartExtraPaymentReceived(mppExtraReceived: MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]) extends Command
private case class WrappedMultiPartPaymentFailed(mppFailed: MultiPartPaymentFSM.MultiPartPaymentFailed) extends Command
private case class WrappedMultiPartPaymentSucceeded(mppSucceeded: MultiPartPaymentFSM.MultiPartPaymentSucceeded) extends Command
private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command
private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
private case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command
private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command
// @formatter:on

trait OutgoingPaymentFactory {
Expand Down Expand Up @@ -228,7 +227,7 @@ class NodeRelay private(nodeParams: NodeParams,
context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout)
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
case CancelAsyncPayment =>
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentCanceled) =>
context.log.warn(s"payment sender canceled a waiting async payment")
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
probe.expectNoMessage(100 millis)
}

test("remote node does not connect before sender cancels") { f =>
import f._

triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId)

// cancel of an unwatched payment does nothing
triggerer ! Cancel(ByteVector32.One)
probe.expectNoMessage(100 millis)

triggerer ! Cancel(ByteVector32.Zeroes)
probe.expectMessage(AsyncPaymentCanceled)
}

test("duplicate watches should emit only one trigger") { f =>
import f._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register}
import fr.acinq.eclair.crypto.Sphinx
import AsyncPaymentTriggerer.{AsyncPaymentTimeout, AsyncPaymentTriggered, Watch}
import AsyncPaymentTriggerer.{AsyncPaymentCanceled, AsyncPaymentTimeout, AsyncPaymentTriggered, Watch}
import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Invoice.ExtraEdge
Expand Down Expand Up @@ -432,7 +432,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger

// fail the payment if waiting when payment sender sends cancel message
nodeRelayer ! NodeRelay.CancelAsyncPayment
nodeRelayer ! NodeRelay.WrappedPeerReadyResult(AsyncPaymentCanceled)

incomingAsyncPayment.foreach { p =>
val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]]
Expand Down

0 comments on commit df11a7b

Please sign in to comment.