From 84766ce3e5728f8c6b158019e3f71f423decc592 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Fri, 27 Jan 2023 22:04:09 -0800 Subject: [PATCH 01/12] Start Katz centrality --- rustworkx-core/src/centrality.rs | 70 +++++++++++++++++++++++ src/centrality.rs | 97 ++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 3 files changed, 169 insertions(+) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index f76a80a9f..2729f60fd 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -423,6 +423,76 @@ where Ok(None) } +pub fn katz_centrality( + graph: G, + mut weight_fn: F, + alpha: Option, + beta_map: Option>, + beta_scalar: Option, + max_iter: Option, + tol: Option, +) -> Result>, E> +where + G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount, + G::NodeId: Eq + std::hash::Hash, + F: FnMut(G::EdgeRef) -> Result, +{ + let alpha: f64 = alpha.unwrap_or(0.1); + + let mut beta: HashMap = beta_map.unwrap_or_else(|| HashMap::new()); + + if beta.is_empty() { + // beta_map was none + // populate hashmap with default value + let beta_scalar = beta_scalar.unwrap_or(1.0); + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + beta.insert(node.clone(), beta_scalar); + } + } + + // Check if beta contains all node indices + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + if !beta.contains_key(&node) { + return Ok(None); // beta_map was provided but did not include all nodes + } + } + + let tol: f64 = tol.unwrap_or(1e-6); + let max_iter = max_iter.unwrap_or(100); + + let mut x: Vec = vec![0.; graph.node_bound()]; + let node_count = graph.node_count(); + for _ in 0..max_iter { + let x_last = x.clone(); + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + x[node] += beta.get(&node).unwrap_or(&0.0); + for edge in graph.edges(node_index) { + let w = weight_fn(edge)?; + let neighbor = edge.target(); + x[graph.to_index(neighbor)] += alpha * x_last[node] * w; + } + } + let norm: f64 = x.iter().map(|val| val.powi(2)).sum::().sqrt(); + if norm == 0. { + return Ok(None); + } + for v in x.iter_mut() { + *v /= norm; + } + if (0..x.len()) + .map(|node| (x[node] - x_last[node]).abs()) + .sum::() + < node_count as f64 * tol + { + return Ok(Some(x)); + } + } + Ok(None) +} + #[cfg(test)] mod test_eigenvector_centrality { diff --git a/src/centrality.rs b/src/centrality.rs index 7393fbbd5..3e81c5e40 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -296,3 +296,100 @@ pub fn digraph_eigenvector_centrality( ))), } } + +#[pyfunction(default_weight = "1.0", max_iter = "100", tol = "1e-6")] +#[pyo3(text_signature = "(graph, /, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)")] +pub fn graph_katz_centrality( + py: Python, + graph: &graph::PyGraph, + weight_fn: Option, + default_weight: f64, + max_iter: usize, + tol: f64, +) -> PyResult { + let mut edge_weights = vec![default_weight; graph.graph.edge_bound()]; + if weight_fn.is_some() { + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + for edge in graph.graph.edge_indices() { + edge_weights[edge.index()] = + cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; + } + } + let ev_centrality = centrality::katz_centrality( + &graph.graph, + |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, + None, + None, + None, + Some(max_iter), + Some(tol), + )?; + match ev_centrality { + Some(centrality) => Ok(CentralityMapping { + centralities: centrality + .iter() + .enumerate() + .filter_map(|(k, v)| { + if graph.graph.contains_node(NodeIndex::new(k)) { + Some((k, *v)) + } else { + None + } + }) + .collect(), + }), + None => Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))), + } +} + +#[pyfunction(default_weight = "1.0", max_iter = "100", tol = "1e-6")] +#[pyo3(text_signature = "(graph, /, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)")] +pub fn digraph_katz_centrality( + py: Python, + graph: &digraph::PyDiGraph, + weight_fn: Option, + default_weight: f64, + max_iter: usize, + tol: f64, +) -> PyResult { + let mut edge_weights = vec![default_weight; graph.graph.edge_bound()]; + if weight_fn.is_some() { + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + for edge in graph.graph.edge_indices() { + edge_weights[edge.index()] = + cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; + } + } + let ev_centrality = centrality::katz_centrality( + &graph.graph, + |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, + None, + None, + None, + Some(max_iter), + Some(tol), + )?; + + match ev_centrality { + Some(centrality) => Ok(CentralityMapping { + centralities: centrality + .iter() + .enumerate() + .filter_map(|(k, v)| { + if graph.graph.contains_node(NodeIndex::new(k)) { + Some((k, *v)) + } else { + None + } + }) + .collect(), + }), + None => Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))), + } +} diff --git a/src/lib.rs b/src/lib.rs index a9ad4f4dd..95711673d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -416,6 +416,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_betweenness_centrality))?; m.add_wrapped(wrap_pyfunction!(graph_eigenvector_centrality))?; m.add_wrapped(wrap_pyfunction!(digraph_eigenvector_centrality))?; + m.add_wrapped(wrap_pyfunction!(graph_katz_centrality))?; + m.add_wrapped(wrap_pyfunction!(digraph_katz_centrality))?; m.add_wrapped(wrap_pyfunction!(graph_astar_shortest_path))?; m.add_wrapped(wrap_pyfunction!(digraph_astar_shortest_path))?; m.add_wrapped(wrap_pyfunction!(graph_greedy_color))?; From ef1252252027467a327a7efc71802554116c4c01 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Sun, 29 Jan 2023 16:38:04 -0800 Subject: [PATCH 02/12] More Katz centrality --- docs/source/api.rst | 1 + rustworkx/__init__.py | 37 ++++++++++++++++++++++++ src/centrality.rs | 66 +++++++++++++++++++++++++++++++++++++------ 3 files changed, 96 insertions(+), 8 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 8368188c2..d8de10ffb 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -55,6 +55,7 @@ Centrality rustworkx.betweenness_centrality rustworkx.eigenvector_centrality + rustworkx.katz_centrality .. _traversal: diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 6d82d41f2..b4ba48df7 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -1654,6 +1654,43 @@ def _graph_eigenvector_centrality( ) +@functools.singledispatch +def katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 +): + pass + + +@katz_centrality.register(PyDiGraph) +def _digraph_katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 +): + return digraph_katz_centrality( + graph, + alpha=alpha, + beta=beta, + weight_fn=weight_fn, + default_weight=default_weight, + max_iter=max_iter, + tol=tol, + ) + + +@katz_centrality.register(PyGraph) +def _graph_katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 +): + return graph_katz_centrality( + graph, + alpha=alpha, + beta=beta, + weight_fn=weight_fn, + default_weight=default_weight, + max_iter=max_iter, + tol=tol, + ) + + @functools.singledispatch def vf2_mapping( first, diff --git a/src/centrality.rs b/src/centrality.rs index 3e81c5e40..ebecb33bc 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -21,8 +21,10 @@ use crate::FailedToConverge; use petgraph::graph::NodeIndex; use petgraph::visit::EdgeIndexable; use petgraph::visit::EdgeRef; +use petgraph::visit::IntoNodeIdentifiers; use pyo3::prelude::*; use rustworkx_core::centrality; +use hashbrown::HashMap; /// Compute the betweenness centrality of all nodes in a PyGraph. /// @@ -297,11 +299,20 @@ pub fn digraph_eigenvector_centrality( } } -#[pyfunction(default_weight = "1.0", max_iter = "100", tol = "1e-6")] -#[pyo3(text_signature = "(graph, /, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)")] +#[pyfunction( + alpha = "0.1", + default_weight = "1.0", + max_iter = "100", + tol = "1e-6" +)] +#[pyo3( + text_signature = "(graph, /, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)" +)] pub fn graph_katz_centrality( py: Python, graph: &graph::PyGraph, + alpha: f64, + beta: PyObject, weight_fn: Option, default_weight: f64, max_iter: usize, @@ -315,11 +326,26 @@ pub fn graph_katz_centrality( cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; } } + + let mut beta_map: HashMap = HashMap::new(); + + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + } + } + let ev_centrality = centrality::katz_centrality( &graph.graph, |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, - None, - None, + Some(alpha), + Some(beta_map), None, Some(max_iter), Some(tol), @@ -345,11 +371,20 @@ pub fn graph_katz_centrality( } } -#[pyfunction(default_weight = "1.0", max_iter = "100", tol = "1e-6")] -#[pyo3(text_signature = "(graph, /, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)")] +#[pyfunction( + alpha = "0.1", + default_weight = "1.0", + max_iter = "100", + tol = "1e-6" +)] +#[pyo3( + text_signature = "(graph, /, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)" +)] pub fn digraph_katz_centrality( py: Python, graph: &digraph::PyDiGraph, + alpha: f64, + beta: PyObject, weight_fn: Option, default_weight: f64, max_iter: usize, @@ -363,11 +398,26 @@ pub fn digraph_katz_centrality( cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; } } + + let mut beta_map: HashMap = HashMap::new(); + + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + } + } + let ev_centrality = centrality::katz_centrality( &graph.graph, |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, - None, - None, + Some(alpha), + Some(beta_map), None, Some(max_iter), Some(tol), From 11bda245fccafc8daaf56276dbed3ab97381c6f9 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Mon, 30 Jan 2023 14:56:21 -0800 Subject: [PATCH 03/12] Almost done with Katz --- docs/source/api.rst | 2 + rustworkx-core/src/centrality.rs | 29 ++++--- rustworkx/__init__.py | 4 +- src/centrality.rs | 79 ++++++++++++++----- .../digraph/test_centrality.py | 46 +++++++++++ .../rustworkx_tests/graph/test_centrality.py | 42 ++++++++++ 6 files changed, 168 insertions(+), 34 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index d8de10ffb..bbe636b7a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -322,6 +322,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_num_shortest_paths_unweighted rustworkx.digraph_betweenness_centrality rustworkx.digraph_eigenvector_centrality + rustworkx.digraph_katz_centrality rustworkx.digraph_unweighted_average_shortest_path_length rustworkx.digraph_bfs_search rustworkx.digraph_dijkstra_search @@ -376,6 +377,7 @@ typed API based on the data type. rustworkx.graph_num_shortest_paths_unweighted rustworkx.graph_betweenness_centrality rustworkx.graph_eigenvector_centrality + rustworkx.graph_katz_centrality rustworkx.graph_unweighted_average_shortest_path_length rustworkx.graph_bfs_search rustworkx.graph_dijkstra_search diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index 2729f60fd..e18d32cb1 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -392,7 +392,7 @@ where F: FnMut(G::EdgeRef) -> Result, { let tol: f64 = tol.unwrap_or(1e-6); - let max_iter = max_iter.unwrap_or(100); + let max_iter = max_iter.unwrap_or(1000); let mut x: Vec = vec![1.; graph.node_bound()]; let node_count = graph.node_count(); for _ in 0..max_iter { @@ -439,7 +439,7 @@ where { let alpha: f64 = alpha.unwrap_or(0.1); - let mut beta: HashMap = beta_map.unwrap_or_else(|| HashMap::new()); + let mut beta: HashMap = beta_map.unwrap_or_else(HashMap::new); if beta.is_empty() { // beta_map was none @@ -447,7 +447,7 @@ where let beta_scalar = beta_scalar.unwrap_or(1.0); for node_index in graph.node_identifiers() { let node = graph.to_index(node_index); - beta.insert(node.clone(), beta_scalar); + beta.insert(node, beta_scalar); } } @@ -466,30 +466,37 @@ where let node_count = graph.node_count(); for _ in 0..max_iter { let x_last = x.clone(); + x = vec![0.; graph.node_bound()]; for node_index in graph.node_identifiers() { let node = graph.to_index(node_index); - x[node] += beta.get(&node).unwrap_or(&0.0); for edge in graph.edges(node_index) { let w = weight_fn(edge)?; let neighbor = edge.target(); - x[graph.to_index(neighbor)] += alpha * x_last[node] * w; + x[graph.to_index(neighbor)] += x_last[node] * w; } } - let norm: f64 = x.iter().map(|val| val.powi(2)).sum::().sqrt(); - if norm == 0. { - return Ok(None); - } - for v in x.iter_mut() { - *v /= norm; + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + x[node] = alpha * x[node] + beta.get(&node).unwrap_or(&0.0); } if (0..x.len()) .map(|node| (x[node] - x_last[node]).abs()) .sum::() < node_count as f64 * tol { + // Normalize vector + let norm: f64 = x.iter().map(|val| val.powi(2)).sum::().sqrt(); + if norm == 0. { + return Ok(None); + } + for v in x.iter_mut() { + *v /= norm; + } + return Ok(Some(x)); } } + Ok(None) } diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index b4ba48df7..dbaa0ec54 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -1663,7 +1663,7 @@ def katz_centrality( @katz_centrality.register(PyDiGraph) def _digraph_katz_centrality( - graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6 ): return digraph_katz_centrality( graph, @@ -1678,7 +1678,7 @@ def _digraph_katz_centrality( @katz_centrality.register(PyGraph) def _graph_katz_centrality( - graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6 ): return graph_katz_centrality( graph, diff --git a/src/centrality.rs b/src/centrality.rs index ebecb33bc..2779a1b7d 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -10,6 +10,8 @@ // License for the specific language governing permissions and limitations // under the License. +#![allow(clippy::too_many_arguments)] + use std::convert::TryFrom; use crate::digraph; @@ -18,13 +20,14 @@ use crate::iterators::CentralityMapping; use crate::CostFn; use crate::FailedToConverge; +use hashbrown::HashMap; use petgraph::graph::NodeIndex; use petgraph::visit::EdgeIndexable; use petgraph::visit::EdgeRef; use petgraph::visit::IntoNodeIdentifiers; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use rustworkx_core::centrality; -use hashbrown::HashMap; /// Compute the betweenness centrality of all nodes in a PyGraph. /// @@ -301,18 +304,20 @@ pub fn digraph_eigenvector_centrality( #[pyfunction( alpha = "0.1", + beta = "None", + weight_fn = "None", default_weight = "1.0", - max_iter = "100", + max_iter = "1000", tol = "1e-6" )] #[pyo3( - text_signature = "(graph, /, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)" + text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" )] pub fn graph_katz_centrality( py: Python, graph: &graph::PyGraph, alpha: f64, - beta: PyObject, + beta: Option, weight_fn: Option, default_weight: f64, max_iter: usize, @@ -329,15 +334,30 @@ pub fn graph_katz_centrality( let mut beta_map: HashMap = HashMap::new(); - match beta.extract::(py) { - Ok(beta_scalar) => { - // User provided a scalar, populate beta_map with the value - for node_index in graph.graph.node_identifiers() { - beta_map.insert(node_index.index(), beta_scalar); + if let Some(beta) = beta { + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + + for node_index in graph.graph.node_identifiers() { + if !beta_map.contains_key(&node_index.index()) { + return Err(PyValueError::new_err( + "Beta does not contain all node indices", + )); + } + } } } - Err(_) => { - beta_map = beta.extract::>(py)?; + } else { + // Populate with 1.0 + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), 1.0); } } @@ -373,18 +393,20 @@ pub fn graph_katz_centrality( #[pyfunction( alpha = "0.1", + beta = "None", + weight_fn = "None", default_weight = "1.0", - max_iter = "100", + max_iter = "1000", tol = "1e-6" )] #[pyo3( - text_signature = "(graph, /, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6)" + text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" )] pub fn digraph_katz_centrality( py: Python, graph: &digraph::PyDiGraph, alpha: f64, - beta: PyObject, + beta: Option, weight_fn: Option, default_weight: f64, max_iter: usize, @@ -401,15 +423,30 @@ pub fn digraph_katz_centrality( let mut beta_map: HashMap = HashMap::new(); - match beta.extract::(py) { - Ok(beta_scalar) => { - // User provided a scalar, populate beta_map with the value - for node_index in graph.graph.node_identifiers() { - beta_map.insert(node_index.index(), beta_scalar); + if let Some(beta) = beta { + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + + for node_index in graph.graph.node_identifiers() { + if !beta_map.contains_key(&node_index.index()) { + return Err(PyValueError::new_err( + "Beta does not contain all node indices", + )); + } + } } } - Err(_) => { - beta_map = beta.extract::>(py)?; + } else { + // Populate with 1.0 + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), 1.0); } } diff --git a/tests/rustworkx_tests/digraph/test_centrality.py b/tests/rustworkx_tests/digraph/test_centrality.py index 4a739b0bf..116688756 100644 --- a/tests/rustworkx_tests/digraph/test_centrality.py +++ b/tests/rustworkx_tests/digraph/test_centrality.py @@ -14,6 +14,7 @@ import unittest import rustworkx +import networkx as nx class TestCentralityDiGraph(unittest.TestCase): @@ -150,3 +151,48 @@ def test_no_convergence(self): graph = rustworkx.PyDiGraph() with self.assertRaises(rustworkx.FailedToConverge): rustworkx.eigenvector_centrality(graph, max_iter=0) + + +class TestKatzCentrality(unittest.TestCase): + def test_complete_graph(self): + graph = rustworkx.generators.directed_complete_graph(5) + centrality = rustworkx.digraph_katz_centrality(graph) + expected_value = math.sqrt(1.0 / 5.0) + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.directed_complete_graph(5) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.katz_centrality(graph, max_iter=0) + + def test_beta_scalar(self): + rx_graph = rustworkx.generators.directed_grid_graph(5, 2) + beta = 0.3 + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_dictionary(self): + rx_graph = rustworkx.generators.directed_grid_graph(5, 2) + beta = {i: 0.1 * i**2 for i in range(10)} + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_incomplete(self): + graph = rustworkx.generators.directed_grid_graph(5, 2) + with self.assertRaises(ValueError): + rustworkx.katz_centrality(graph, beta={0: 0.25}) diff --git a/tests/rustworkx_tests/graph/test_centrality.py b/tests/rustworkx_tests/graph/test_centrality.py index ff5559e2a..94c7f0571 100644 --- a/tests/rustworkx_tests/graph/test_centrality.py +++ b/tests/rustworkx_tests/graph/test_centrality.py @@ -14,6 +14,7 @@ import unittest import rustworkx +import networkx as nx class TestCentralityGraph(unittest.TestCase): @@ -121,3 +122,44 @@ def test_no_convergence(self): graph = rustworkx.PyGraph() with self.assertRaises(rustworkx.FailedToConverge): rustworkx.eigenvector_centrality(graph, max_iter=0) + + +class TestKatzCentrality(unittest.TestCase): + def test_complete_graph(self): + graph = rustworkx.generators.complete_graph(5) + centrality = rustworkx.graph_katz_centrality(graph) + expected_value = math.sqrt(1.0 / 5.0) + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.complete_graph(5) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.katz_centrality(graph, max_iter=0) + + def test_beta_scalar(self): + graph = rustworkx.generators.generalized_petersen_graph(5, 2) + expected_value = 0.31622776601683794 + + centrality = rustworkx.katz_centrality(graph, alpha=0.1, beta=0.1, tol=1e-8) + + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_beta_dictionary(self): + rx_graph = rustworkx.generators.generalized_petersen_graph(5, 2) + beta = {i: 0.1 * i**2 for i in range(10)} + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.Graph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_incomplete(self): + graph = rustworkx.generators.generalized_petersen_graph(5, 2) + with self.assertRaises(ValueError): + rustworkx.katz_centrality(graph, beta={0: 0.25}) From a9cb24cfd165eede934ea56a924ebe08ca1b232c Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Mon, 30 Jan 2023 21:02:05 -0800 Subject: [PATCH 04/12] Add release note --- .../notes/add-katz-5389c6e5bd30e176.yaml | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 releasenotes/notes/add-katz-5389c6e5bd30e176.yaml diff --git a/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml b/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml new file mode 100644 index 000000000..a10e4dc7c --- /dev/null +++ b/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml @@ -0,0 +1,33 @@ +--- +features: + - | + Added a new function, :func:`~.katz_centrality()` which is used to + compute the Katz centrality for all nodes in a given graph. For + example: + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import mpl_draw + + graph = rx.generators.hexagonal_lattice_graph(4, 4) + centrality = rx.katz_centrality(graph) + + # Generate a color list + colors = [] + for node in graph.node_indices(): + centrality_score = centrality[node] + graph[node] = centrality_score + colors.append(centrality_score) + mpl_draw( + graph, + with_labels=True, + node_color=colors, + node_size=650, + labels=lambda x: "{0:.2f}".format(x) + ) + + - | + Added a new function to rustworkx-core ``katz_centrality`` to the + ``rustworkx_core::centrality`` modules which is used to compute the + Katz centrality for all nodes in a given graph. From 55c3890bce8271ce082d5a04931199ecdc440755 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Tue, 31 Jan 2023 09:18:06 -0800 Subject: [PATCH 05/12] Add docs to rustworkx-core --- rustworkx-core/src/centrality.rs | 47 ++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index e18d32cb1..8c4130114 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -423,6 +423,53 @@ where Ok(None) } +/// Compute the Katz centrality of a graph +/// +/// For details on the Katz centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// [`eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html) +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the eigenvector centrality. +/// +/// Arguments: +/// +/// * `graph` - The graph object to run the algorithm on +/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for +/// an edge in the graph and is expected to return a `Result` of +/// the weight of that edge. +/// * `alpha` - Attenuation factor. If set to `None`, a default value of 0.1 is used. +/// * `beta_map` - Immediate neighbourhood weights. Must contain all node indices or be `None`. +/// * `beta_scalar` - Immediate neighbourhood scalar that replaces `beta_map` in case `beta_map` is None. +/// Defaults to 1.0 in case `None` is provided. +/// * `max_iter` - The maximum number of iterations in the power method. If +/// set to `None` a default value of 100 is used. +/// * `tol` - The error tolerance used when checking for convergence in the +/// power method. If set to `None` a default value of 1e-6 is used. +/// +/// # Example +/// ```rust +/// use rustworkx_core::Result; +/// use rustworkx_core::petgraph; +/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers}; +/// use rustworkx_core::centrality::katz_centrality; +/// +/// let g = petgraph::graph::UnGraph::::from_edges(&[ +/// (0, 1), (1, 2) +/// ]); +/// // Calculate the eigenvector centrality +/// let output: Result>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None); +/// ``` pub fn katz_centrality( graph: G, mut weight_fn: F, From 758fe89ccf544e2ef4f02497384bb41799fa2d18 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Tue, 31 Jan 2023 10:26:59 -0800 Subject: [PATCH 06/12] Add documentation and rustworkx-core tests --- rustworkx-core/src/centrality.rs | 116 ++++++++++++++++++++++++++++++- src/centrality.rs | 76 ++++++++++++++++++++ 2 files changed, 191 insertions(+), 1 deletion(-) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index 8c4130114..79e72d815 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -436,7 +436,7 @@ where /// iterations is reached or when the computed vector between two iterations /// is smaller than the error tolerance multiplied by the number of nodes. /// The implementation of this algorithm is based on the NetworkX -/// [`eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html) +/// [`katz_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html) /// function. /// /// In the case of multigraphs the weights of any parallel edges will be @@ -636,3 +636,117 @@ mod test_eigenvector_centrality { } } } + +#[cfg(test)] +mod test_katz_centrality { + + use crate::centrality::katz_centrality; + use crate::petgraph; + use crate::Result; + use hashbrown::HashMap; + + macro_rules! assert_almost_equal { + ($x:expr, $y:expr, $d:expr) => { + if ($x - $y).abs() >= $d { + panic!("{} != {} within delta of {}", $x, $y, $d); + } + }; + } + #[test] + fn test_no_convergence() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, None, None, Some(0), None); + let result = output.unwrap(); + assert_eq!(None, result); + } + + #[test] + fn test_incomplete_beta() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let beta_map: HashMap = [(0, 1.0)].iter().cloned().collect(); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None); + let result = output.unwrap(); + assert_eq!(None, result); + } + + #[test] + fn test_complete_beta() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let beta_map: HashMap = + [(0, 0.5), (1, 1.0), (2, 0.5)].iter().cloned().collect(); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None); + let result = output.unwrap().unwrap(); + let expected_values: Vec = + vec![0.4318894504492167, 0.791797325823564, 0.4318894504492167]; + for i in 0..3 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } + + #[test] + fn test_undirected_complete_graph() { + let g = petgraph::graph::UnGraph::::from_edges([ + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + (3, 4), + ]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), Some(0.2), None, Some(1.1), None, None); + let result = output.unwrap().unwrap(); + let expected_value: f64 = (1_f64 / 5_f64).sqrt(); + let expected_values: Vec = vec![expected_value; 5]; + for i in 0..5 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } + + #[test] + fn test_directed_graph() { + let g = petgraph::graph::DiGraph::::from_edges([ + (0, 1), + (0, 2), + (1, 3), + (2, 1), + (2, 4), + (3, 1), + (3, 4), + (3, 5), + (4, 5), + (4, 6), + (4, 7), + (5, 7), + (6, 0), + (6, 4), + (6, 7), + (7, 5), + (7, 6), + ]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, None, None, None, None); + let result = output.unwrap().unwrap(); + let expected_values: Vec = vec![ + 0.3135463087489011, + 0.3719056758615039, + 0.3094350787809586, + 0.31527101632646026, + 0.3760169058294464, + 0.38618584417917906, + 0.35465874858087904, + 0.38976653416801743, + ]; + + for i in 0..8 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } +} diff --git a/src/centrality.rs b/src/centrality.rs index 2779a1b7d..fe1c15724 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -302,6 +302,44 @@ pub fn digraph_eigenvector_centrality( } } +/// Compute the Katz centrality of a :class:`~PyGraph`. +/// +/// For details on the eigenvector centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// `katz_centrality() `__ +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the eigenvector centrality. +/// +/// :param PyGraph graph: The graph object to run the algorithm on +/// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. +/// :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood +/// weight is used for all nodes. If a dictionary is provided, it must contain all node indices. +/// If beta is not specified, a default value of 1.0 is used. +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified ``default_weight`` will be used as the weight +/// for every edge in ``graph`` +/// :param float default_weight: If ``weight_fn`` is not set the default weight +/// value to use for the weight of all edges +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 1000 is used. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-6 is used. +/// +/// :returns: a read-only dict-like object whose keys are the node indices and values are the +/// centrality score for that node. +/// :rtype: CentralityMapping #[pyfunction( alpha = "0.1", beta = "None", @@ -391,6 +429,44 @@ pub fn graph_katz_centrality( } } +/// Compute the Katz centrality of a :class:`~PyDiGraph`. +/// +/// For details on the eigenvector centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// `katz_centrality() `__ +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the eigenvector centrality. +/// +/// :param PyDiGraph graph: The graph object to run the algorithm on +/// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. +/// :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood +/// weight is used for all nodes. If a dictionary is provided, it must contain all node indices. +/// If beta is not specified, a default value of 1.0 is used. +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified ``default_weight`` will be used as the weight +/// for every edge in ``graph`` +/// :param float default_weight: If ``weight_fn`` is not set the default weight +/// value to use for the weight of all edges +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 1000 is used. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-6 is used. +/// +/// :returns: a read-only dict-like object whose keys are the node indices and values are the +/// centrality score for that node. +/// :rtype: CentralityMapping #[pyfunction( alpha = "0.1", beta = "None", From 3bb57bc918aa4b224cc04734d34ba3188961176e Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Tue, 31 Jan 2023 10:30:34 -0800 Subject: [PATCH 07/12] Fix max iter --- rustworkx-core/src/centrality.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index 79e72d815..6924db1be 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -392,7 +392,7 @@ where F: FnMut(G::EdgeRef) -> Result, { let tol: f64 = tol.unwrap_or(1e-6); - let max_iter = max_iter.unwrap_or(1000); + let max_iter = max_iter.unwrap_or(100); let mut x: Vec = vec![1.; graph.node_bound()]; let node_count = graph.node_count(); for _ in 0..max_iter { @@ -507,7 +507,7 @@ where } let tol: f64 = tol.unwrap_or(1e-6); - let max_iter = max_iter.unwrap_or(100); + let max_iter = max_iter.unwrap_or(1000); let mut x: Vec = vec![0.; graph.node_bound()]; let node_count = graph.node_count(); From 7384261be8324b0c2a0f9501b94589555e58f4d0 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Tue, 31 Jan 2023 13:56:49 -0800 Subject: [PATCH 08/12] Documentation details --- rustworkx/__init__.py | 41 ++++++++++++++++++++++++++++++++++++++++- src/centrality.rs | 4 ++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index dbaa0ec54..d62484cc2 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -1658,7 +1658,46 @@ def _graph_eigenvector_centrality( def katz_centrality( graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 ): - pass + """Compute the Katz centrality of a graph. + + For details on the Katz centrality refer to: + + Leo Katz. “A New Status Index Derived from Sociometric Index.” + Psychometrika 18(1):39–43, 1953 + + + This function uses a power iteration method to compute the eigenvector + and convergence is not guaranteed. The function will stop when `max_iter` + iterations is reached or when the computed vector between two iterations + is smaller than the error tolerance multiplied by the number of nodes. + The implementation of this algorithm is based on the NetworkX + `katz_centrality() `__ + function. + + In the case of multigraphs the weights of any parallel edges will be + summed when computing the Katz centrality. + + :param graph: Graph to be used. Can either be a + :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`. + :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. + :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood + weight is used for all nodes. If a dictionary is provided, it must contain all node indices. + If beta is not specified, a default value of 1.0 is used. + :param weight_fn: An optional input callable that will be passed the edge's + payload object and is expected to return a `float` weight for that edge. + If this is not specified ``default_weight`` will be used as the weight + for every edge in ``graph`` + :param float default_weight: If ``weight_fn`` is not set the default weight + value to use for the weight of all edges + :param int max_iter: The maximum number of iterations in the power method. If + not specified a default value of 100 is used. + :param float tol: The error tolerance used when checking for convergence in the + power method. If this is not specified default value of 1e-6 is used. + + :returns: a read-only dict-like object whose keys are the node indices and values are the + centrality score for that node. + :rtype: CentralityMapping + """ @katz_centrality.register(PyDiGraph) diff --git a/src/centrality.rs b/src/centrality.rs index fe1c15724..0518e912a 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -319,7 +319,7 @@ pub fn digraph_eigenvector_centrality( /// function. /// /// In the case of multigraphs the weights of any parallel edges will be -/// summed when computing the eigenvector centrality. +/// summed when computing the Katz centrality. /// /// :param PyGraph graph: The graph object to run the algorithm on /// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. @@ -446,7 +446,7 @@ pub fn graph_katz_centrality( /// function. /// /// In the case of multigraphs the weights of any parallel edges will be -/// summed when computing the eigenvector centrality. +/// summed when computing the Katz centrality. /// /// :param PyDiGraph graph: The graph object to run the algorithm on /// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. From 8d3821ab7093e8aac3967b4d0a86dd7d9fe73741 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Wed, 1 Feb 2023 11:24:22 -0800 Subject: [PATCH 09/12] Tweak signature --- src/centrality.rs | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/centrality.rs b/src/centrality.rs index 5b146c241..9c438b3d5 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -371,12 +371,15 @@ pub fn digraph_eigenvector_centrality( /// centrality score for that node. /// :rtype: CentralityMapping #[pyfunction( - alpha = "0.1", - beta = "None", - weight_fn = "None", - default_weight = "1.0", - max_iter = "1000", - tol = "1e-6" + signature = ( + graph, + alpha=0.1, + beta=None, + weight_fn=None, + default_weight=1.0, + max_iter=1000, + tol=1e-6 + ) )] #[pyo3( text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" @@ -498,12 +501,15 @@ pub fn graph_katz_centrality( /// centrality score for that node. /// :rtype: CentralityMapping #[pyfunction( - alpha = "0.1", - beta = "None", - weight_fn = "None", - default_weight = "1.0", - max_iter = "1000", - tol = "1e-6" + signature = ( + graph, + alpha=0.1, + beta=None, + weight_fn=None, + default_weight=1.0, + max_iter=1000, + tol=1e-6 + ) )] #[pyo3( text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" From beb1b3e8f6bd05fcd7f8eef141411686225d7c32 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Tue, 30 May 2023 21:08:56 -0400 Subject: [PATCH 10/12] Apply suggestions from code review Co-authored-by: Matthew Treinish --- src/centrality.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/centrality.rs b/src/centrality.rs index e2b27e1b1..ca055cec3 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -565,7 +565,7 @@ pub fn digraph_eigenvector_centrality( /// Compute the Katz centrality of a :class:`~PyGraph`. /// -/// For details on the eigenvector centrality refer to: +/// For details on the Katz centrality refer to: /// /// Leo Katz. “A New Status Index Derived from Sociometric Index.” /// Psychometrika 18(1):39–43, 1953 @@ -695,7 +695,7 @@ pub fn graph_katz_centrality( /// Compute the Katz centrality of a :class:`~PyDiGraph`. /// -/// For details on the eigenvector centrality refer to: +/// For details on the Katz centrality refer to: /// /// Leo Katz. “A New Status Index Derived from Sociometric Index.” /// Psychometrika 18(1):39–43, 1953 From 33db1fe127ef3b24b32799916e61ffa9f40e9b76 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Thu, 1 Jun 2023 01:13:58 +0000 Subject: [PATCH 11/12] Suggestion from code review --- rustworkx-core/src/centrality.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index b1bbf0d5a..7cd60f4c9 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -745,13 +745,13 @@ where let node = graph.to_index(node_index); beta.insert(node, beta_scalar); } - } - - // Check if beta contains all node indices - for node_index in graph.node_identifiers() { - let node = graph.to_index(node_index); - if !beta.contains_key(&node) { - return Ok(None); // beta_map was provided but did not include all nodes + } else { + // Check if beta contains all node indices + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + if !beta.contains_key(&node) { + return Ok(None); // beta_map was provided but did not include all nodes + } } } From 77a26e5900212402e6430f580592bdbc30048018 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Thu, 1 Jun 2023 01:24:37 +0000 Subject: [PATCH 12/12] Code review suggestions --- rustworkx-core/src/centrality.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index 7cd60f4c9..c099c6304 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -718,6 +718,9 @@ where /// ]); /// // Calculate the eigenvector centrality /// let output: Result>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None); +/// let centralities = output.unwrap().unwrap(); +/// assert!(centralities[1] > centralities[0], "Node 1 is more central than node 0"); +/// assert!(centralities[1] > centralities[2], "Node 1 is more central than node 2"); /// ``` pub fn katz_centrality( graph: G,