forked from Qiskit/rustworkx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree.rs
137 lines (123 loc) · 4.9 KB
/
tree.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
use std::cmp::Ordering;
use super::{graph, weight_callable};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;
use petgraph::prelude::*;
use petgraph::stable_graph::EdgeReference;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
use rayon::prelude::*;
use crate::iterators::WeightedEdgeList;
/// Find the edges in the minimum spanning tree or forest of a graph
/// using Kruskal's algorithm.
///
/// :param PyGraph graph: Undirected graph
/// :param weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. This
/// tells rustworkx/rust how to extract a numerical weight as a ``float``
/// for edge object. Some simple examples are::
///
/// minimum_spanning_edges(graph, weight_fn: lambda x: 1)
///
/// to return a weight of 1 for all edges. Also::
///
/// minimum_spanning_edges(graph, weight_fn: float)
///
/// to cast the edge object as a float as the weight.
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
///
/// :returns: The :math:`N - |c|` edges of the Minimum Spanning Tree (or Forest, if :math:`|c| > 1`)
/// where :math:`N` is the number of nodes and :math:`|c|` is the number of connected components of the graph
/// :rtype: WeightedEdgeList
#[pyfunction]
#[pyo3(signature=(graph, weight_fn=None, default_weight=1.0), text_signature = "(graph, weight_fn=None, default_weight=1.0)")]
pub fn minimum_spanning_edges(
py: Python,
graph: &graph::PyGraph,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<WeightedEdgeList> {
let mut subgraphs = UnionFind::<usize>::new(graph.graph.node_bound());
let mut edge_list: Vec<(f64, EdgeReference<PyObject>)> =
Vec::with_capacity(graph.graph.edge_count());
for edge in graph.graph.edge_references() {
let weight = weight_callable(py, &weight_fn, edge.weight(), default_weight)?;
if weight.is_nan() {
return Err(PyValueError::new_err("NaN found as an edge weight"));
}
edge_list.push((weight, edge));
}
edge_list.par_sort_unstable_by(|a, b| {
let weight_a = a.0;
let weight_b = b.0;
weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
});
let mut answer: Vec<(usize, usize, PyObject)> = Vec::new();
for float_edge_pair in edge_list.iter() {
let edge = float_edge_pair.1;
let u = edge.source().index();
let v = edge.target().index();
if subgraphs.union(u, v) {
let w = edge.weight().clone_ref(py);
answer.push((u, v, w));
}
}
Ok(WeightedEdgeList { edges: answer })
}
/// Find the minimum spanning tree or forest of a graph
/// using Kruskal's algorithm.
///
/// :param PyGraph graph: Undirected graph
/// :param weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. This
/// tells rustworkx/rust how to extract a numerical weight as a ``float``
/// for edge object. Some simple examples are::
///
/// minimum_spanning_tree(graph, weight_fn: lambda x: 1)
///
/// to return a weight of 1 for all edges. Also::
///
/// minimum_spanning_tree(graph, weight_fn: float)
///
/// to cast the edge object as a float as the weight.
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
///
/// :returns: A Minimum Spanning Tree (or Forest, if the graph is not connected).
///
/// :rtype: PyGraph
///
/// .. note::
///
/// The new graph will keep the same node indices, but edge indices might differ.
#[pyfunction]
#[pyo3(signature=(graph, weight_fn=None, default_weight=1.0), text_signature = "(graph, weight_fn=None, default_weight=1.0)")]
pub fn minimum_spanning_tree(
py: Python,
graph: &graph::PyGraph,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<graph::PyGraph> {
let mut spanning_tree = (*graph).clone();
spanning_tree.graph.clear_edges();
for edge in minimum_spanning_edges(py, graph, weight_fn, default_weight)?
.edges
.iter()
{
spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py))?;
}
Ok(spanning_tree)
}