diff --git a/src/fiber/graph.rs b/src/fiber/graph.rs index d3e20b107..ccc899703 100644 --- a/src/fiber/graph.rs +++ b/src/fiber/graph.rs @@ -532,6 +532,11 @@ where self.history.reset(); } + #[cfg(test)] + pub fn set_source(&mut self, source: Pubkey) { + self.source = source; + } + /// Returns a list of `PaymentHopData` for all nodes in the route, /// including the origin and the target node. pub fn build_route( diff --git a/src/fiber/tests/graph.rs b/src/fiber/tests/graph.rs index 34c3ab7b0..6cd3af894 100644 --- a/src/fiber/tests/graph.rs +++ b/src/fiber/tests/graph.rs @@ -70,6 +70,10 @@ impl MockNetworkGraph { } } + fn set_source(&mut self, source: PublicKey) { + self.graph.set_source(source.into()); + } + pub fn mark_node_failed(&mut self, node: usize) { self.graph.mark_node_failed(self.keys[node].into()); } @@ -139,8 +143,6 @@ impl MockNetworkGraph { tlc_minimum_value: min_htlc_value.unwrap_or(0), channel_outpoint: channel_outpoint.clone(), }; - eprintln!("add channel_info: {:?}", channel_info); - eprintln!("add channel_update: {:?}", channel_update); self.graph.process_channel_update(channel_update).unwrap(); if let Some(fee_rate) = other_fee_rate { let channel_update = ChannelUpdate { @@ -157,7 +159,6 @@ impl MockNetworkGraph { }; eprintln!("add rev channel_update: {:?}", channel_update); self.graph.process_channel_update(channel_update).unwrap(); - //eprintln!("add channel_info: {:?}", channel_info); } } @@ -200,7 +201,7 @@ impl MockNetworkGraph { ); } - pub fn find_route( + pub fn find_path( &self, source: usize, target: usize, @@ -213,7 +214,7 @@ impl MockNetworkGraph { .find_path(source, target, amount, Some(max_fee), None, false) } - pub fn find_route_udt( + pub fn find_path_udt( &self, source: usize, target: usize, @@ -297,18 +298,18 @@ fn test_graph_find_path_basic() { network.add_edge(1, 2, Some(1), Some(2)); let node2 = network.keys[2]; - let route = network.find_route(1, 2, 100, 1000); + let route = network.find_path(1, 2, 100, 1000); assert!(route.is_err()); network.add_edge(1, 2, Some(120), Some(2)); - let route = network.find_route(1, 2, 100, 1000); + let route = network.find_path(1, 2, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); assert_eq!(route.len(), 1); assert_eq!(route[0].target, node2.into()); assert_eq!(route[0].channel_outpoint, network.edges[1].2); - let route = network.find_route(1, 3, 10, 100); + let route = network.find_path(1, 3, 10, 100); assert!(route.is_err()); } @@ -321,7 +322,7 @@ fn test_graph_find_path_three_nodes() { let node3 = network.keys[3]; // Test route from node 1 to node 3 - let route = network.find_route(1, 3, 100, 1000); + let route = network.find_path(1, 3, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); assert_eq!(route.len(), 2); @@ -331,7 +332,7 @@ fn test_graph_find_path_three_nodes() { assert_eq!(route[1].channel_outpoint, network.edges[1].2); // Test route from node 1 to node 2 - let route = network.find_route(1, 2, 100, 1000); + let route = network.find_path(1, 2, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); assert_eq!(route.len(), 1); @@ -339,7 +340,7 @@ fn test_graph_find_path_three_nodes() { assert_eq!(route[0].channel_outpoint, network.edges[0].2); // Test route from node 2 to node 3 - let route = network.find_route(2, 3, 100, 1000); + let route = network.find_path(2, 3, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); assert_eq!(route.len(), 1); @@ -347,7 +348,7 @@ fn test_graph_find_path_three_nodes() { assert_eq!(route[0].channel_outpoint, network.edges[1].2); // Test route from node 3 to node 1 (should fail) - let route = network.find_route(3, 1, 100, 1000); + let route = network.find_path(3, 1, 100, 1000); assert!(route.is_err()); } @@ -361,7 +362,7 @@ fn test_graph_find_path_fee() { network.add_edge(1, 3, Some(1000), Some(20000)); network.add_edge(3, 4, Some(1000), Some(10000)); - let route = network.find_route(1, 4, 100, 1000); + let route = network.find_path(1, 4, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); @@ -381,7 +382,7 @@ fn test_graph_find_path_direct_linear() { network.add_edge(3, 4, Some(1000), Some(2)); network.add_edge(4, 5, Some(1000), Some(1)); - let route = network.find_route(1, 5, 100, 1000); + let route = network.find_path(1, 5, 100, 1000); assert!(route.is_ok()); let route = route.unwrap(); @@ -401,14 +402,14 @@ fn test_graph_find_path_cycle() { network.add_edge(2, 3, Some(1000), Some(3)); network.add_edge(3, 1, Some(1000), Some(2)); - let route = network.find_route(1, 3, 100, 1000); + let route = network.find_path(1, 3, 100, 1000); assert!(route.is_ok()); network.add_edge(3, 4, Some(1000), Some(2)); network.add_edge(4, 5, Some(1000), Some(1)); - let route = network.find_route(1, 5, 100, 1000); + let route = network.find_path(1, 5, 100, 1000); assert!(route.is_ok()); } @@ -424,7 +425,7 @@ fn test_graph_find_path_cycle_in_middle() { network.add_edge(4, 5, Some(1000), Some(1)); - let route = network.find_route(1, 5, 100, 1000); + let route = network.find_path(1, 5, 100, 1000); assert!(route.is_ok()); } @@ -436,12 +437,12 @@ fn test_graph_find_path_loop_exit() { network.add_edge(2, 3, Some(1000), Some(3)); network.add_edge(3, 2, Some(1000), Some(2)); - let route = network.find_route(1, 3, 100, 1000); + let route = network.find_path(1, 3, 100, 1000); assert!(route.is_err()); // now add a path from node1 to node2, so that node1 can reach node3 network.add_edge(1, 2, Some(1000), Some(4)); - let route = network.find_route(1, 3, 100, 1000); + let route = network.find_path(1, 3, 100, 1000); assert!(route.is_ok()); } @@ -454,7 +455,7 @@ fn test_graph_find_path_amount_failed() { network.add_edge(3, 4, Some(1000), Some(4)); network.add_edge(4, 5, Some(1000), Some(1)); - let route = network.find_route(1, 5, 1000, 10); + let route = network.find_path(1, 5, 1000, 10); assert!(route.is_err()); } @@ -475,19 +476,15 @@ fn test_graph_find_optimal_path() { network.add_edge(1, 6, Some(500), Some(10000)); network.add_edge(6, 5, Some(500), Some(10000)); - let route = network.find_route(1, 5, 1000, 1000); - assert!(route.is_ok()); - let route = route.unwrap(); - // Check that the algorithm chose the longer path with lower fees + let route = network.find_path(1, 5, 1000, 1000).unwrap(); assert_eq!(route.len(), 4); - assert_eq!(route[0].channel_outpoint, network.edges[1].2); - assert_eq!(route[1].channel_outpoint, network.edges[2].2); - assert_eq!(route[2].channel_outpoint, network.edges[3].2); - assert_eq!(route[3].channel_outpoint, network.edges[4].2); + for (i, edge_index) in (1..=4).enumerate() { + assert_eq!(route[i].channel_outpoint, network.edges[edge_index].2); + } // Test with a smaller amount that allows using the direct path - let small_route = network.find_route(1, 5, 100, 100); + let small_route = network.find_path(1, 5, 100, 100); assert!(small_route.is_ok()); let small_route = small_route.unwrap(); @@ -497,13 +494,119 @@ fn test_graph_find_optimal_path() { assert_eq!(small_route[1].channel_outpoint, network.edges[6].2); } +#[test] +fn test_graph_build_router_is_ok_with_fee_rate() { + let mut network = MockNetworkGraph::new(6); + + // Direct path with high fee + network.add_edge(1, 5, Some(2000), Some(50000)); + + // Longer path with lower total fee + network.add_edge(1, 2, Some(2000), Some(10000)); + // this node has a very low fee rate + network.add_edge(2, 3, Some(2000), Some(1)); + network.add_edge(3, 4, Some(2000), Some(10000)); + network.add_edge(4, 5, Some(2000), Some(10000)); + + // check the fee rate + let source = network.keys[1]; + network.set_source(source); + let node5 = network.keys[5]; + let route = network.graph.build_route(SendPaymentData { + target_pubkey: node5.into(), + amount: 1000, + payment_hash: Hash256::default(), + invoice: None, + final_htlc_expiry_delta: None, + timeout: None, + max_fee_amount: Some(1000), + max_parts: None, + keysend: false, + udt_type_script: None, + preimage: None, + allow_self_payment: false, + }); + assert!(route.is_ok()); + let route = route.unwrap(); + let amounts = route.iter().map(|x| x.amount).collect::>(); + assert_eq!(amounts, vec![1022, 1011, 1010, 1000, 1000]); +} + +#[test] +fn test_graph_build_router_fee_rate_optimize() { + let mut network = MockNetworkGraph::new(10); + + // Direct path with low total fee rate + network.add_edge(1, 6, Some(2000), Some(50000)); + network.add_edge(6, 5, Some(2000), Some(50000)); + + // Longer path with lower total fee + network.add_edge(1, 2, Some(2000), Some(10000)); + network.add_edge(2, 3, Some(2000), Some(20000)); + network.add_edge(3, 4, Some(2000), Some(30000)); + network.add_edge(4, 5, Some(2000), Some(40000)); + + // check the fee rate + let source = network.keys[1]; + network.set_source(source); + let node5 = network.keys[5]; + let route = network.graph.build_route(SendPaymentData { + target_pubkey: node5.into(), + amount: 1000, + payment_hash: Hash256::default(), + invoice: None, + final_htlc_expiry_delta: None, + timeout: None, + max_fee_amount: Some(1000), + max_parts: None, + keysend: false, + udt_type_script: None, + preimage: None, + allow_self_payment: false, + }); + assert!(route.is_ok()); + let route = route.unwrap(); + let amounts = route.iter().map(|x| x.amount).collect::>(); + assert_eq!(amounts, vec![1050, 1000, 1000]); +} + +#[test] +fn test_graph_build_router_no_fee_with_direct_pay() { + let mut network = MockNetworkGraph::new(10); + + network.add_edge(1, 5, Some(2000), Some(50000)); + + // check the fee rate + let source = network.keys[1]; + network.set_source(source); + let node5 = network.keys[5]; + let route = network.graph.build_route(SendPaymentData { + target_pubkey: node5.into(), + amount: 1000, + payment_hash: Hash256::default(), + invoice: None, + final_htlc_expiry_delta: None, + timeout: None, + max_fee_amount: Some(1000), + max_parts: None, + keysend: false, + udt_type_script: None, + preimage: None, + allow_self_payment: false, + }); + assert!(route.is_ok()); + let route = route.unwrap(); + let amounts = route.iter().map(|x| x.amount).collect::>(); + assert_eq!(amounts, vec![1000, 1000]); +} + #[test] fn test_graph_find_path_err() { let mut network = MockNetworkGraph::new(6); let (node1, _node5) = (network.keys[1], network.keys[5]); network.add_edge(1, 2, Some(1000), Some(4)); - let route = network.find_route(1, 1, 100, 1000); + let route = network.find_path(1, 1, 100, 1000); assert!(route.is_err()); let no_exits_public_key = network.keys[0]; @@ -624,7 +727,7 @@ fn test_graph_find_path_udt() { network.add_edge_udt(1, 2, Some(1000), Some(1), udt_type_script.clone()); let node2 = network.keys[2]; - let route = network.find_route_udt(1, 2, 100, 1000, udt_type_script.clone()); + let route = network.find_path_udt(1, 2, 100, 1000, udt_type_script.clone()); assert!(route.is_ok()); let route = route.unwrap(); @@ -632,7 +735,7 @@ fn test_graph_find_path_udt() { assert_eq!(route[0].target, node2.into()); assert_eq!(route[0].channel_outpoint, network.edges[0].2); - let route = network.find_route(1, 3, 10, 100); + let route = network.find_path(1, 3, 10, 100); assert!(route.is_err()); }