Skip to content

Commit

Permalink
modify the api layer, type annotations, the binder and the rust core
Browse files Browse the repository at this point in the history
  • Loading branch information
gluonhiggs committed Jun 5, 2024
1 parent f2238fd commit 8001f6e
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::connectivity::conn_components::connected_components;
use crate::dictmap::*;
use crate::shortest_path::dijkstra;
use crate::shortest_path::{astar, dijkstra};
use crate::Result;
use hashbrown::{HashMap, HashSet};
use petgraph::algo::{astar, min_spanning_tree, Measure};
use petgraph::algo::{min_spanning_tree, Measure};
use petgraph::csr::{DefaultIx, IndexType};
use petgraph::data::{DataMap, Element};
use petgraph::graph::Graph;
Expand All @@ -13,6 +13,7 @@ use petgraph::visit::{
IntoNeighborsDirected, IntoNodeIdentifiers, IntoNodeReferences, NodeIndexable, Visitable,
};
use petgraph::Undirected;
use std::cmp::Ordering;
use std::convert::Infallible;
use std::hash::Hash;

Expand All @@ -39,7 +40,7 @@ where
G::NodeId: Eq + Hash,
G::EdgeWeight: Clone,
F: FnMut(G::EdgeRef) -> Result<K, E>,
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
K: Clone + PartialOrd + Copy + Measure + Default,
{
components
.into_iter()
Expand Down Expand Up @@ -77,7 +78,7 @@ where
})
.collect()
}
pub fn minimum_cycle_basis<G, F, K, E>(graph: G, mut weight_fn: F) -> Result<Vec<Vec<NodeIndex>>, E>
pub fn minimal_cycle_basis<G, F, K, E>(graph: G, mut weight_fn: F) -> Result<Vec<Vec<NodeIndex>>, E>
where
G: EdgeCount
+ IntoNodeIdentifiers
Expand All @@ -88,10 +89,10 @@ where
+ IntoNeighborsDirected
+ Visitable
+ IntoEdges,
G::EdgeWeight: Clone + PartialOrd,
G::EdgeWeight: Clone,
G::NodeId: Eq + Hash,
F: FnMut(G::EdgeRef) -> Result<K, E>,
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
K: Clone + PartialOrd + Copy + Measure + Default,
{
let conn_components = connected_components(&graph);
let mut min_cycle_basis = Vec::new();
Expand Down Expand Up @@ -136,7 +137,7 @@ where
H::EdgeWeight: Clone + PartialOrd,
H::NodeId: Eq + Hash,
F: FnMut(H::EdgeRef) -> Result<K, E>,
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
K: Clone + PartialOrd + Copy + Measure + Default,
{
let mut sub_cb: Vec<Vec<usize>> = Vec::new();
let num_edges = subgraph.edge_count();
Expand Down Expand Up @@ -243,7 +244,7 @@ where
H: IntoNodeReferences + IntoEdgeReferences + DataMap + NodeIndexable + EdgeIndexable,
H::NodeId: Eq + Hash,
F: FnMut(H::EdgeRef) -> Result<K, E>,
K: Clone + PartialOrd + Copy + Measure + Default + Ord,
K: Clone + PartialOrd + Copy + Measure + Default,
{
let mut gi = Graph::<_, _, petgraph::Undirected>::default();
let mut subgraph_gi_map = HashMap::new();
Expand Down Expand Up @@ -290,31 +291,36 @@ where
|edge| Ok(*edge.weight()),
None,
);
// Find the shortest distance in the result and store it in the shortest_path_map
let spl = result.unwrap()[&gi_lifted_nodeidx];
shortest_path_map.insert(subnodeid, spl);
}
let min_start = shortest_path_map.iter().min_by_key(|x| x.1).unwrap().0;
let min_start = shortest_path_map
.iter()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
.unwrap()
.0;
let min_start_node = subgraph_gi_map[min_start].0;
let min_start_lifted_node = subgraph_gi_map[min_start].1;
let result = astar(
let result: Result<Option<(K, Vec<NodeIndex>)>> = astar(
&gi,
min_start_node,
|finish| finish == min_start_lifted_node,
|e| *e.weight(),
|_| K::default(),
min_start_node.clone(),
|finish| Ok(finish == min_start_lifted_node.clone()),
|e| Ok(*e.weight()),
|_| Ok(K::default()),
);

let mut min_path: Vec<usize> = Vec::new();
match result {
Some((_cost, path)) => {
Ok(Some((_cost, path))) => {
for node in path {
if let Some(&subgraph_nodeid) = gi_subgraph_map.get(&node) {
let subgraph_node = NodeIndexable::to_index(&subgraph, subgraph_nodeid);
min_path.push(subgraph_node.index());
}
}
}
None => {}
Ok(None) => {}
Err(_) => {}
}
let edgelist = min_path
.windows(2)
Expand Down Expand Up @@ -344,9 +350,9 @@ where
}

#[cfg(test)]
mod test_minimum_cycle_basis {
use crate::connectivity::minimum_cycle_basis::minimum_cycle_basis;
use petgraph::graph::Graph;
mod test_minimal_cycle_basis {
use crate::connectivity::minimal_cycle_basis::minimal_cycle_basis;
use petgraph::graph::{Graph, NodeIndex};
use petgraph::Undirected;
use std::convert::Infallible;

Expand All @@ -356,7 +362,7 @@ mod test_minimum_cycle_basis {
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let output = minimum_cycle_basis(&graph, weight_fn).unwrap();
let output = minimal_cycle_basis(&graph, weight_fn).unwrap();
assert_eq!(output.len(), 0);
}

Expand All @@ -372,8 +378,7 @@ mod test_minimum_cycle_basis {
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let cycles = minimum_cycle_basis(&graph, weight_fn);
println!("Cycles {:?}", cycles.as_ref().unwrap());
let cycles = minimal_cycle_basis(&graph, weight_fn);
assert_eq!(cycles.unwrap().len(), 1);
}

Expand All @@ -393,10 +398,60 @@ mod test_minimum_cycle_basis {
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let cycles = minimum_cycle_basis(&graph, weight_fn);
let cycles = minimal_cycle_basis(&graph, weight_fn);
assert_eq!(cycles.unwrap().len(), 2);
}

#[test]
fn test_non_trivial_graph() {
let mut g = Graph::<&str, i32, Undirected>::new_undirected();
let a = g.add_node("A");
let b = g.add_node("B");
let c = g.add_node("C");
let d = g.add_node("D");
let e = g.add_node("E");
let f = g.add_node("F");

g.add_edge(a, b, 7);
g.add_edge(c, a, 9);
g.add_edge(a, d, 11);
g.add_edge(b, c, 10);
g.add_edge(d, c, 2);
g.add_edge(d, e, 9);
g.add_edge(b, f, 15);
g.add_edge(c, f, 11);
g.add_edge(e, f, 6);

let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let output = minimal_cycle_basis(&g, weight_fn);
let mut actual_output = output.unwrap();
for cycle in &mut actual_output {
cycle.sort();
}
actual_output.sort();

let expected_output: Vec<Vec<NodeIndex>> = vec![
vec![
NodeIndex::new(5),
NodeIndex::new(2),
NodeIndex::new(3),
NodeIndex::new(4),
],
vec![NodeIndex::new(2), NodeIndex::new(5), NodeIndex::new(1)],
vec![NodeIndex::new(0), NodeIndex::new(2), NodeIndex::new(1)],
vec![NodeIndex::new(2), NodeIndex::new(3), NodeIndex::new(0)],
];
let mut sorted_expected_output = expected_output.clone();
for cycle in &mut sorted_expected_output {
cycle.sort();
}
sorted_expected_output.sort();

assert_eq!(actual_output, sorted_expected_output);
}

#[test]
fn test_weighted_diamond_graph() {
let mut weighted_diamond = Graph::<(), i32, Undirected>::new_undirected();
Expand All @@ -412,20 +467,19 @@ mod test_minimum_cycle_basis {
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let output = minimum_cycle_basis(&weighted_diamond, weight_fn);
let expected_output: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]];
let output = minimal_cycle_basis(&weighted_diamond, weight_fn);
let expected_output1: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![0, 1, 2, 3]];
let expected_output2: Vec<Vec<usize>> = vec![vec![1, 2, 3], vec![0, 1, 2, 3]];
for cycle in output.unwrap().iter() {
println!("{:?}", cycle);
let mut node_indices: Vec<usize> = Vec::new();
for node in cycle.iter() {
node_indices.push(node.index());
}
node_indices.sort();
println!("Node indices {:?}", node_indices);
if expected_output.contains(&node_indices) {
println!("Found cycle {:?}", node_indices);
}
assert!(expected_output.contains(&node_indices));
assert!(
expected_output1.contains(&node_indices)
|| expected_output2.contains(&node_indices)
);
}
}

Expand All @@ -444,7 +498,7 @@ mod test_minimum_cycle_basis {
let weight_fn =
|_edge: petgraph::graph::EdgeReference<()>| -> Result<i32, Infallible> { Ok(1) };

let output = minimum_cycle_basis(&unweighted_diamond, weight_fn);
let output = minimal_cycle_basis(&unweighted_diamond, weight_fn);
let expected_output: Vec<Vec<usize>> = vec![vec![0, 1, 3], vec![1, 2, 3]];
for cycle in output.unwrap().iter() {
let mut node_indices: Vec<usize> = Vec::new();
Expand Down Expand Up @@ -476,7 +530,7 @@ mod test_minimum_cycle_basis {
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| -> Result<i32, Infallible> {
Ok(*edge.weight())
};
let output = minimum_cycle_basis(&complete_graph, weight_fn);
let output = minimal_cycle_basis(&complete_graph, weight_fn);
for cycle in output.unwrap().iter() {
assert_eq!(cycle.len(), 3);
}
Expand Down
4 changes: 2 additions & 2 deletions rustworkx-core/src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod cycle_basis;
mod find_cycle;
mod isolates;
mod min_cut;
mod minimum_cycle_basis;
mod minimal_cycle_basis;

pub use all_simple_paths::{
all_simple_paths_multiple_targets, longest_simple_path_multiple_targets,
Expand All @@ -37,4 +37,4 @@ pub use cycle_basis::cycle_basis;
pub use find_cycle::find_cycle;
pub use isolates::isolates;
pub use min_cut::stoer_wagner_min_cut;
pub use minimum_cycle_basis::minimum_cycle_basis;
pub use minimal_cycle_basis::minimal_cycle_basis;
26 changes: 26 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,32 @@ def all_pairs_dijkstra_path_lengths(graph, edge_cost_fn):
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@_rustworkx_dispatch
def minimum_cycle_basis(graph, edge_cost_fn):
"""Find the minimum cycle basis of a graph.
This function will find the minimum cycle basis of a graph based on the
following papers
References:
[1] Kavitha, Telikepalli, et al. "An O(m^2n) Algorithm for
Minimum Cycle Basis of Graphs."
http://link.springer.com/article/10.1007/s00453-007-9064-z
[2] de Pina, J. 1995. Applications of shortest path methods.
Ph.D. thesis, University of Amsterdam, Netherlands
:param graph: The input graph to use. Can either be a
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
:param edge_cost_fn: A callable object that acts as a weight function for
an edge. It will accept a single positional argument, the edge's weight
object and will return a float which will be used to represent the
weight/cost of the edge
:return: A list of cycles where each cycle is a list of node indices
:rtype: list
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))

@_rustworkx_dispatch
def dijkstra_shortest_path_lengths(graph, node, edge_cost_fn, goal=None):
"""Compute the lengths of the shortest paths for a graph object using
Expand Down
2 changes: 1 addition & 1 deletion rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ from .rustworkx import graph_longest_simple_path as graph_longest_simple_path
from .rustworkx import digraph_core_number as digraph_core_number
from .rustworkx import graph_core_number as graph_core_number
from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut
from .rustworkx import minimum_cycle_basis as minimum_cycle_basis
from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis
from .rustworkx import simple_cycles as simple_cycles
from .rustworkx import digraph_isolates as digraph_isolates
from .rustworkx import graph_isolates as graph_isolates
Expand Down
2 changes: 1 addition & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def stoer_wagner_min_cut(
/,
weight_fn: Callable[[_T], float] | None = ...,
) -> tuple[float, NodeIndices] | None: ...
def minimum_cycle_basis(
def graph_minimum_cycle_basis(
graph: PyGraph[_S, _T],
/,
weight_fn: Callable[[_T], float] | None = ...
Expand Down
41 changes: 41 additions & 0 deletions src/connectivity/minimum_cycle_basis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use rustworkx_core::connectivity::minimal_cycle_basis;

use pyo3::exceptions::PyIndexError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::visit::EdgeIndexable;
use petgraph::EdgeType;

use crate::{CostFn, StablePyGraph};

pub fn minimum_cycle_basis_map<Ty: EdgeType + Sync>(
py: Python,
graph: &StablePyGraph<Ty>,
edge_cost_fn: PyObject,
) -> PyResult<Vec<Vec<NodeIndex>>> {
if graph.node_count() == 0 {
return Ok(vec![]);
} else if graph.edge_count() == 0 {
return Ok(vec![]);
}
let edge_cost_callable = CostFn::from(edge_cost_fn);
let mut edge_weights: Vec<Option<f64>> = Vec::with_capacity(graph.edge_bound());
for index in 0..=graph.edge_bound() {
let raw_weight = graph.edge_weight(EdgeIndex::new(index));
match raw_weight {
Some(weight) => edge_weights.push(Some(edge_cost_callable.call(py, weight)?)),
None => edge_weights.push(None),
};
}
let edge_cost = |e: EdgeIndex| -> PyResult<f64> {
match edge_weights[e.index()] {
Some(weight) => Ok(weight),
None => Err(PyIndexError::new_err("No edge found for index")),
}
};
let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap();
Ok(cycle_basis)
}
Loading

0 comments on commit 8001f6e

Please sign in to comment.