Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Katz Centrality #797

Merged
merged 21 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Centrality
rustworkx.betweenness_centrality
rustworkx.edge_betweenness_centrality
rustworkx.eigenvector_centrality
rustworkx.katz_centrality
rustworkx.closeness_centrality

.. _link-analysis:
Expand Down Expand Up @@ -336,6 +337,7 @@ the functions from the explicitly typed based on the data type.
rustworkx.digraph_edge_betweenness_centrality
rustworkx.digraph_closeness_centrality
rustworkx.digraph_eigenvector_centrality
rustworkx.digraph_katz_centrality
rustworkx.digraph_unweighted_average_shortest_path_length
rustworkx.digraph_bfs_search
rustworkx.digraph_dijkstra_search
Expand Down Expand Up @@ -393,6 +395,7 @@ typed API based on the data type.
rustworkx.graph_edge_betweenness_centrality
rustworkx.graph_closeness_centrality
rustworkx.graph_eigenvector_centrality
rustworkx.graph_katz_centrality
rustworkx.graph_unweighted_average_shortest_path_length
rustworkx.graph_bfs_search
rustworkx.graph_dijkstra_search
Expand Down
33 changes: 33 additions & 0 deletions releasenotes/notes/add-katz-5389c6e5bd30e176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
features:
- |
Added a new function, :func:`~.katz_centrality()` which is used to
compute the Katz centrality for all nodes in a given graph. For
example:

.. jupyter-execute::

import rustworkx as rx
from rustworkx.visualization import mpl_draw

graph = rx.generators.hexagonal_lattice_graph(4, 4)
centrality = rx.katz_centrality(graph)

# Generate a color list
colors = []
for node in graph.node_indices():
centrality_score = centrality[node]
graph[node] = centrality_score
colors.append(centrality_score)
mpl_draw(
graph,
with_labels=True,
node_color=colors,
node_size=650,
labels=lambda x: "{0:.2f}".format(x)
)

- |
Added a new function to rustworkx-core ``katz_centrality`` to the
``rustworkx_core::centrality`` modules which is used to compute the
Katz centrality for all nodes in a given graph.
238 changes: 238 additions & 0 deletions rustworkx-core/src/centrality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,130 @@ where
Ok(None)
}

/// Compute the Katz centrality of a graph
///
/// For details on the Katz centrality refer to:
///
/// Leo Katz. “A New Status Index Derived from Sociometric Index.”
/// Psychometrika 18(1):39–43, 1953
/// <https://link.springer.com/content/pdf/10.1007/BF02289026.pdf>
///
/// This function uses a power iteration method to compute the eigenvector
/// and convergence is not guaranteed. The function will stop when `max_iter`
/// iterations is reached or when the computed vector between two iterations
/// is smaller than the error tolerance multiplied by the number of nodes.
/// The implementation of this algorithm is based on the NetworkX
/// [`katz_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html)
/// function.
///
/// In the case of multigraphs the weights of any parallel edges will be
/// summed when computing the eigenvector centrality.
///
/// Arguments:
///
/// * `graph` - The graph object to run the algorithm on
/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for
/// an edge in the graph and is expected to return a `Result<f64>` of
/// the weight of that edge.
/// * `alpha` - Attenuation factor. If set to `None`, a default value of 0.1 is used.
/// * `beta_map` - Immediate neighbourhood weights. Must contain all node indices or be `None`.
/// * `beta_scalar` - Immediate neighbourhood scalar that replaces `beta_map` in case `beta_map` is None.
/// Defaults to 1.0 in case `None` is provided.
/// * `max_iter` - The maximum number of iterations in the power method. If
/// set to `None` a default value of 100 is used.
/// * `tol` - The error tolerance used when checking for convergence in the
/// power method. If set to `None` a default value of 1e-6 is used.
///
/// # Example
/// ```rust
/// use rustworkx_core::Result;
/// use rustworkx_core::petgraph;
/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers};
/// use rustworkx_core::centrality::katz_centrality;
///
/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
/// (0, 1), (1, 2)
/// ]);
/// // Calculate the eigenvector centrality
/// let output: Result<Option<Vec<f64>>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None);
IvanIsCoding marked this conversation as resolved.
Show resolved Hide resolved
/// ```
pub fn katz_centrality<G, F, E>(
graph: G,
mut weight_fn: F,
alpha: Option<f64>,
beta_map: Option<HashMap<usize, f64>>,
beta_scalar: Option<f64>,
max_iter: Option<usize>,
tol: Option<f64>,
) -> Result<Option<Vec<f64>>, E>
where
G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount,
G::NodeId: Eq + std::hash::Hash,
F: FnMut(G::EdgeRef) -> Result<f64, E>,
{
let alpha: f64 = alpha.unwrap_or(0.1);

let mut beta: HashMap<usize, f64> = beta_map.unwrap_or_else(HashMap::new);

if beta.is_empty() {
// beta_map was none
// populate hashmap with default value
let beta_scalar = beta_scalar.unwrap_or(1.0);
for node_index in graph.node_identifiers() {
let node = graph.to_index(node_index);
beta.insert(node, beta_scalar);
}
}

// Check if beta contains all node indices
for node_index in graph.node_identifiers() {
let node = graph.to_index(node_index);
if !beta.contains_key(&node) {
return Ok(None); // beta_map was provided but did not include all nodes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an error instead of None? It feels like a user error if they pass in a mapping that's incomplete here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None means it’s not defined… it map to an error in Python but in Rust we let the caller handle it

}
}
IvanIsCoding marked this conversation as resolved.
Show resolved Hide resolved

let tol: f64 = tol.unwrap_or(1e-6);
let max_iter = max_iter.unwrap_or(1000);

let mut x: Vec<f64> = vec![0.; graph.node_bound()];
let node_count = graph.node_count();
for _ in 0..max_iter {
let x_last = x.clone();
x = vec![0.; graph.node_bound()];
for node_index in graph.node_identifiers() {
let node = graph.to_index(node_index);
for edge in graph.edges(node_index) {
let w = weight_fn(edge)?;
let neighbor = edge.target();
x[graph.to_index(neighbor)] += x_last[node] * w;
}
}
for node_index in graph.node_identifiers() {
let node = graph.to_index(node_index);
x[node] = alpha * x[node] + beta.get(&node).unwrap_or(&0.0);
}
if (0..x.len())
.map(|node| (x[node] - x_last[node]).abs())
.sum::<f64>()
< node_count as f64 * tol
{
// Normalize vector
let norm: f64 = x.iter().map(|val| val.powi(2)).sum::<f64>().sqrt();
if norm == 0. {
return Ok(None);
}
for v in x.iter_mut() {
*v /= norm;
}

return Ok(Some(x));
}
}

Ok(None)
}

#[cfg(test)]
mod test_eigenvector_centrality {

Expand Down Expand Up @@ -761,6 +885,120 @@ mod test_eigenvector_centrality {
}
}

#[cfg(test)]
mod test_katz_centrality {

use crate::centrality::katz_centrality;
use crate::petgraph;
use crate::Result;
use hashbrown::HashMap;

macro_rules! assert_almost_equal {
($x:expr, $y:expr, $d:expr) => {
if ($x - $y).abs() >= $d {
panic!("{} != {} within delta of {}", $x, $y, $d);
}
};
}
#[test]
fn test_no_convergence() {
let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[(0, 1), (1, 2)]);
let output: Result<Option<Vec<f64>>> =
katz_centrality(&g, |_| Ok(1.), None, None, None, Some(0), None);
let result = output.unwrap();
assert_eq!(None, result);
}

#[test]
fn test_incomplete_beta() {
let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[(0, 1), (1, 2)]);
let beta_map: HashMap<usize, f64> = [(0, 1.0)].iter().cloned().collect();
let output: Result<Option<Vec<f64>>> =
katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
let result = output.unwrap();
assert_eq!(None, result);
}

#[test]
fn test_complete_beta() {
let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[(0, 1), (1, 2)]);
let beta_map: HashMap<usize, f64> =
[(0, 0.5), (1, 1.0), (2, 0.5)].iter().cloned().collect();
let output: Result<Option<Vec<f64>>> =
katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
let result = output.unwrap().unwrap();
let expected_values: Vec<f64> =
vec![0.4318894504492167, 0.791797325823564, 0.4318894504492167];
for i in 0..3 {
assert_almost_equal!(expected_values[i], result[i], 1e-4);
}
}

#[test]
fn test_undirected_complete_graph() {
let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(1, 2),
(1, 3),
(1, 4),
(2, 3),
(2, 4),
(3, 4),
]);
let output: Result<Option<Vec<f64>>> =
katz_centrality(&g, |_| Ok(1.), Some(0.2), None, Some(1.1), None, None);
let result = output.unwrap().unwrap();
let expected_value: f64 = (1_f64 / 5_f64).sqrt();
let expected_values: Vec<f64> = vec![expected_value; 5];
for i in 0..5 {
assert_almost_equal!(expected_values[i], result[i], 1e-4);
}
}

#[test]
fn test_directed_graph() {
let g = petgraph::graph::DiGraph::<i32, ()>::from_edges([
(0, 1),
(0, 2),
(1, 3),
(2, 1),
(2, 4),
(3, 1),
(3, 4),
(3, 5),
(4, 5),
(4, 6),
(4, 7),
(5, 7),
(6, 0),
(6, 4),
(6, 7),
(7, 5),
(7, 6),
]);
let output: Result<Option<Vec<f64>>> =
katz_centrality(&g, |_| Ok(1.), None, None, None, None, None);
let result = output.unwrap().unwrap();
let expected_values: Vec<f64> = vec![
0.3135463087489011,
0.3719056758615039,
0.3094350787809586,
0.31527101632646026,
0.3760169058294464,
0.38618584417917906,
0.35465874858087904,
0.38976653416801743,
];

for i in 0..8 {
assert_almost_equal!(expected_values[i], result[i], 1e-4);
}
}
}

/// Compute the closeness centrality of each node in the graph.
///
/// The closeness centrality of a node `u` is the reciprocal of the average
Expand Down
76 changes: 76 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,82 @@ def _graph_eigenvector_centrality(
)


@functools.singledispatch
def katz_centrality(
graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6
):
"""Compute the Katz centrality of a graph.

For details on the Katz centrality refer to:

Leo Katz. “A New Status Index Derived from Sociometric Index.”
Psychometrika 18(1):39–43, 1953
<https://link.springer.com/content/pdf/10.1007/BF02289026.pdf>

This function uses a power iteration method to compute the eigenvector
and convergence is not guaranteed. The function will stop when `max_iter`
iterations is reached or when the computed vector between two iterations
is smaller than the error tolerance multiplied by the number of nodes.
The implementation of this algorithm is based on the NetworkX
`katz_centrality() <https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html>`__
function.

In the case of multigraphs the weights of any parallel edges will be
summed when computing the Katz centrality.

:param graph: Graph to be used. Can either be a
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`.
:param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used.
:param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood
weight is used for all nodes. If a dictionary is provided, it must contain all node indices.
If beta is not specified, a default value of 1.0 is used.
:param weight_fn: An optional input callable that will be passed the edge's
payload object and is expected to return a `float` weight for that edge.
If this is not specified ``default_weight`` will be used as the weight
for every edge in ``graph``
:param float default_weight: If ``weight_fn`` is not set the default weight
value to use for the weight of all edges
:param int max_iter: The maximum number of iterations in the power method. If
not specified a default value of 100 is used.
:param float tol: The error tolerance used when checking for convergence in the
power method. If this is not specified default value of 1e-6 is used.

:returns: a read-only dict-like object whose keys are the node indices and values are the
centrality score for that node.
:rtype: CentralityMapping
"""


@katz_centrality.register(PyDiGraph)
def _digraph_katz_centrality(
graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6
):
return digraph_katz_centrality(
graph,
alpha=alpha,
beta=beta,
weight_fn=weight_fn,
default_weight=default_weight,
max_iter=max_iter,
tol=tol,
)


@katz_centrality.register(PyGraph)
def _graph_katz_centrality(
graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6
):
return graph_katz_centrality(
graph,
alpha=alpha,
beta=beta,
weight_fn=weight_fn,
default_weight=default_weight,
max_iter=max_iter,
tol=tol,
)


@functools.singledispatch
def vf2_mapping(
first,
Expand Down
Loading