diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index f4506c854c..f3e7c9f282 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -624,7 +624,15 @@ object Graph { } } - case class Vertex(features: Features[NodeFeature], incomingEdges: Map[ChannelDesc, GraphEdge]) + case class Vertex(features: Features[NodeFeature], incomingEdges: Map[ChannelDesc, GraphEdge]) { + def update(desc: ChannelDesc, newShortChannelId: RealShortChannelId, newCapacity: Satoshi): Vertex = + incomingEdges.get(desc) match { + case None => this + case Some(edge) => + val updatedEdge = edge.copy(desc = desc.copy(shortChannelId = newShortChannelId), capacity = newCapacity) + copy(incomingEdges = incomingEdges - desc + (desc.copy(shortChannelId = newShortChannelId) -> updatedEdge)) + } + } /** A graph data structure that uses an adjacency list, stores the incoming edges of the neighbors */ case class DirectedGraph(private val vertices: Map[PublicKey, Vertex]) { @@ -688,14 +696,10 @@ object Graph { * @return a new graph with updated vertexes */ def updateChannel(desc: ChannelDesc, newShortChannelId: RealShortChannelId, newCapacity: Satoshi): DirectedGraph = { - val newDesc = desc.copy(shortChannelId = newShortChannelId) - val updatedVertices = - vertices - .updatedWith(desc.b)(_.map(vertexB => vertexB.copy(incomingEdges = vertexB.incomingEdges - desc + - (newDesc -> vertexB.incomingEdges(desc).copy(desc = newDesc, capacity = newCapacity))))) - .updatedWith(desc.a)(_.map(vertexA => vertexA.copy(incomingEdges = vertexA.incomingEdges - desc.reversed + - (newDesc.reversed -> vertexA.incomingEdges(desc.reversed).copy(desc = newDesc.reversed, capacity = newCapacity))))) - DirectedGraph(updatedVertices) + DirectedGraph(vertices + .updatedWith(desc.b)(_.map(_.update(desc, newShortChannelId, newCapacity))) + .updatedWith(desc.a)(_.map(_.update(desc.reversed, newShortChannelId, newCapacity))) + ) } /** diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala index 1bc11f26d3..15cc094900 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala @@ -98,8 +98,7 @@ class BalanceEstimateSpec extends AnyFunSuite { .couldSend(60_000 msat, TimestampSecond.now()) // a splice-in that increases channel capacity increases high but not low bounds - val balance1 = balance - .updateEdge(a.desc, RealShortChannelId(5), 250 sat) + val balance1 = balance.updateEdge(a.desc, RealShortChannelId(5), 250 sat) assert(balance1.maxCapacity == 250.sat) assert(balance1.low == 60_000.msat) assert(balance1.high == 190_000.msat) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala index a5ce621282..8d286c33d1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala @@ -461,4 +461,25 @@ class GraphSpec extends AnyFunSuite { assert(MessagePath.dijkstraMessagePath(graph, a, f, Set.empty, boundaries, BlockHeight(793397), wr).isEmpty) } } + + test("a channel update only changes the scid and capacity of one edge") { + // A --> B has two edges with different short channel ids. + val edge = makeEdge(7, a, b, 1 msat, 1) + val g = makeTestGraph().addEdge(edge) + + val g1 = g.updateChannel(ChannelDesc(ShortChannelId(7), a, b), RealShortChannelId(10), 99 sat) + val edge1 = g1.getEdge(ChannelDesc(ShortChannelId(10), a, b)).get + assert(edge1.capacity == 99.sat) + assert(g1.getEdge(ChannelDesc(ShortChannelId(7), a, b)).isEmpty) + + // Only the scid and capacity of one edge changes. + assert(g1 == makeTestGraph().addEdge(edge1)) + + // Updates are symmetric. + assert(g1 == g.updateChannel(ChannelDesc(ShortChannelId(7), b, a), RealShortChannelId(10), 99 sat)) + + // Updates to an unknown channel do not change the graph. + assert(g == g.updateChannel(ChannelDesc(ShortChannelId(1), randomKey().publicKey, b), RealShortChannelId(10), 99 sat)) + } + }