Skip to content

Commit

Permalink
Rely on PyO3 type conversion for HashMap and HashSet (#171)
Browse files Browse the repository at this point in the history
* Rely on PyO3 type conversion for HashMap and HashSet

Since PyO3 0.12.0 the trait implementations for converting from
hashbrown's HashMap and HashSet types. [1] This commit leverage
this so we do not have to internally convert these objects to and from
python types.

[1] PyO3/pyo3#1114

* Revert PyDiGraph compose to use PyObject return

This commit reverts the change for the PyDiGraph.compose() method to use
a PyObject return type of a PyDict again. Doing the HashMap return
resulted in a double iteration over the output map, the first to convert
NodeIndex to a usize and the second in the PyO3 wrapper to convert the
HashMap to a PyDict in the python ffi. Returning a PyObject enables to
do the NodeIndex to usize conversion and PyDict generation at the same
time with a single loop over the contents.

* Fix PyGraph.compose() too

* Revert to PyObject for graph_greedy_color

* Standardize naming on weight_transform_callable
  • Loading branch information
mtreinish authored Oct 18, 2020
1 parent ba6f6e1 commit b92db69
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 122 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rayon = "1.4"

[dependencies.pyo3]
version = "0.12.3"
features = ["extension-module"]
features = ["extension-module", "hashbrown"]

[dependencies.hashbrown]
version = "0.9"
Expand Down
73 changes: 22 additions & 51 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ impl PyDiGraph {
/// specified node.
/// :rtype: dict
#[text_signature = "(node, /)"]
pub fn adj(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
pub fn adj(&mut self, node: usize) -> HashMap<usize, &PyObject> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
Expand All @@ -985,11 +985,7 @@ impl PyDiGraph {
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
out_map
}

/// Get the index and data for either the parent or children of a node.
Expand All @@ -1011,10 +1007,9 @@ impl PyDiGraph {
#[text_signature = "(node, direction, /)"]
pub fn adj_direction(
&mut self,
py: Python,
node: usize,
direction: bool,
) -> PyResult<PyObject> {
) -> PyResult<HashMap<usize, &PyObject>> {
let index = NodeIndex::new(node);
let dir = if direction {
petgraph::Direction::Incoming
Expand Down Expand Up @@ -1046,11 +1041,7 @@ impl PyDiGraph {
let edge_w = self.graph.edge_weight(edge);
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
Ok(out_map)
}

/// Get the index and edge data for all parents of a node.
Expand Down Expand Up @@ -1479,67 +1470,33 @@ impl PyDiGraph {
&mut self,
py: Python,
other: &PyDiGraph,
node_map: PyObject,
node_map: HashMap<usize, (usize, PyObject)>,
node_map_func: Option<PyObject>,
edge_map_func: Option<PyObject>,
) -> PyResult<PyObject> {
let mut new_node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let node_map_dict = node_map.cast_as::<PyDict>(py)?;
let mut node_map_hashmap: HashMap<usize, (usize, PyObject)> =
HashMap::default();
for (k, v) in node_map_dict.iter() {
node_map_hashmap.insert(k.extract()?, v.extract()?);
}

fn node_weight_callable(
py: Python,
node_map: &Option<PyObject>,
node: &PyObject,
) -> PyResult<PyObject> {
match node_map {
Some(node_map) => {
let res = node_map.call1(py, (node,))?;
Ok(res.to_object(py))
}
None => Ok(node.clone_ref(py)),
}
}

// TODO: Reimplement this without looping over the graphs
// Loop over other nodes add add to self graph
for node in other.graph.node_indices() {
let new_index = self.graph.add_node(node_weight_callable(
let new_index = self.graph.add_node(weight_transform_callable(
py,
&node_map_func,
&other.graph[node],
)?);
new_node_map.insert(node, new_index);
}

fn edge_weight_callable(
py: Python,
edge_map: &Option<PyObject>,
edge: &PyObject,
) -> PyResult<PyObject> {
match edge_map {
Some(edge_map) => {
let res = edge_map.call1(py, (edge,))?;
Ok(res.to_object(py))
}
None => Ok(edge.clone_ref(py)),
}
}

// loop over other edges and add to self graph
for edge in other.graph.edge_references() {
let new_p_index = new_node_map.get(&edge.source()).unwrap();
let new_c_index = new_node_map.get(&edge.target()).unwrap();
let weight =
edge_weight_callable(py, &edge_map_func, edge.weight())?;
weight_transform_callable(py, &edge_map_func, edge.weight())?;
self.graph.add_edge(*new_p_index, *new_c_index, weight);
}
// Add edges from map
for (this_index, (index, weight)) in node_map_hashmap.iter() {
for (this_index, (index, weight)) in node_map.iter() {
let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
self.graph.add_edge(
NodeIndex::new(*this_index),
Expand Down Expand Up @@ -1627,3 +1584,17 @@ fn is_cycle_check_required(
&& children_b.next().is_some()
&& dag.graph.find_edge(a, b).is_none()
}

fn weight_transform_callable(
py: Python,
map_fn: &Option<PyObject>,
value: &PyObject,
) -> PyResult<PyObject> {
match map_fn {
Some(map_fn) => {
let res = map_fn.call1(py, (value,))?;
Ok(res.to_object(py))
}
None => Ok(value.clone_ref(py)),
}
}
71 changes: 22 additions & 49 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ impl PyGraph {
/// edge with the specified node.
/// :rtype: dict
#[text_signature = "(node, /)"]
pub fn adj(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
pub fn adj(&mut self, node: usize) -> PyResult<HashMap<usize, &PyObject>> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
Expand All @@ -744,11 +744,7 @@ impl PyGraph {
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
let out_dict = PyDict::new(py);
for (index, value) in out_map {
out_dict.set_item(index, value)?;
}
Ok(out_dict.into())
Ok(out_map)
}

/// Get the degree for a node
Expand Down Expand Up @@ -1046,67 +1042,33 @@ impl PyGraph {
&mut self,
py: Python,
other: &PyGraph,
node_map: PyObject,
node_map: HashMap<usize, (usize, PyObject)>,
node_map_func: Option<PyObject>,
edge_map_func: Option<PyObject>,
) -> PyResult<PyObject> {
let mut new_node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let node_map_dict = node_map.cast_as::<PyDict>(py)?;
let mut node_map_hashmap: HashMap<usize, (usize, PyObject)> =
HashMap::default();
for (k, v) in node_map_dict.iter() {
node_map_hashmap.insert(k.extract()?, v.extract()?);
}

fn node_weight_callable(
py: Python,
node_map: &Option<PyObject>,
node: &PyObject,
) -> PyResult<PyObject> {
match node_map {
Some(node_map) => {
let res = node_map.call1(py, (node,))?;
Ok(res.to_object(py))
}
None => Ok(node.clone_ref(py)),
}
}

// TODO: Reimplement this without looping over the graphs
// Loop over other nodes add add to self graph
for node in other.graph.node_indices() {
let new_index = self.graph.add_node(node_weight_callable(
let new_index = self.graph.add_node(weight_transform_callable(
py,
&node_map_func,
&other.graph[node],
)?);
new_node_map.insert(node, new_index);
}

fn edge_weight_callable(
py: Python,
edge_map: &Option<PyObject>,
edge: &PyObject,
) -> PyResult<PyObject> {
match edge_map {
Some(edge_map) => {
let res = edge_map.call1(py, (edge,))?;
Ok(res.to_object(py))
}
None => Ok(edge.clone_ref(py)),
}
}

// loop over other edges and add to self graph
for edge in other.graph.edge_references() {
let new_p_index = new_node_map.get(&edge.source()).unwrap();
let new_c_index = new_node_map.get(&edge.target()).unwrap();
let weight =
edge_weight_callable(py, &edge_map_func, edge.weight())?;
weight_transform_callable(py, &edge_map_func, edge.weight())?;
self.graph.add_edge(*new_p_index, *new_c_index, weight);
}
// Add edges from map
for (this_index, (index, weight)) in node_map_hashmap.iter() {
for (this_index, (index, weight)) in node_map.iter() {
let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
self.graph.add_edge(
NodeIndex::new(*this_index),
Expand All @@ -1129,17 +1091,14 @@ impl PyMappingProtocol for PyGraph {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {
match self.graph.node_weight(NodeIndex::new(idx as usize)) {
match self.graph.node_weight(NodeIndex::new(idx)) {
Some(data) => Ok(data),
None => Err(PyIndexError::new_err("No node found for index")),
}
}

fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self
.graph
.node_weight_mut(NodeIndex::new(idx as usize))
{
let data = match self.graph.node_weight_mut(NodeIndex::new(idx)) {
Some(node_data) => node_data,
None => {
return Err(PyIndexError::new_err("No node found for index"))
Expand All @@ -1156,3 +1115,17 @@ impl PyMappingProtocol for PyGraph {
}
}
}

fn weight_transform_callable(
py: Python,
map_fn: &Option<PyObject>,
value: &PyObject,
) -> PyResult<PyObject> {
match map_fn {
Some(map_fn) => {
let res = map_fn.call1(py, (value,))?;
Ok(res.to_object(py))
}
None => Ok(value.clone_ref(py)),
}
}
26 changes: 5 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use hashbrown::{HashMap, HashSet};
use pyo3::create_exception;
use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PySet};
use pyo3::types::{PyDict, PyList};
use pyo3::wrap_pyfunction;
use pyo3::wrap_pymodule;
use pyo3::Python;
Expand Down Expand Up @@ -305,7 +305,7 @@ fn bfs_successors(
/// :rtype: list
#[pyfunction]
#[text_signature = "(graph, node, /)"]
fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
let index = NodeIndex::new(node);
let mut out_set: HashSet<usize> = HashSet::new();
let reverse_graph = Reversed(graph);
Expand All @@ -315,13 +315,7 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
out_set.insert(n_int);
}
out_set.remove(&node);
let set = PySet::empty(py).expect("Failed to construct empty set");
{
for val in out_set {
set.add(val).expect("Failed to add to set");
}
}
set.into()
out_set
}

/// Return the descendants of a node in a graph.
Expand All @@ -338,11 +332,7 @@ fn ancestors(py: Python, graph: &digraph::PyDiGraph, node: usize) -> PyObject {
/// :rtype: list
#[pyfunction]
#[text_signature = "(graph, node, /)"]
fn descendants(
py: Python,
graph: &digraph::PyDiGraph,
node: usize,
) -> PyObject {
fn descendants(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
let index = NodeIndex::new(node);
let mut out_set: HashSet<usize> = HashSet::new();
let res = algo::dijkstra(graph, index, None, |_| 1);
Expand All @@ -351,13 +341,7 @@ fn descendants(
out_set.insert(n_int);
}
out_set.remove(&node);
let set = PySet::empty(py).expect("Failed to construct empty set");
{
for val in out_set {
set.add(val).expect("Failed to add to set");
}
}
set.into()
out_set
}

/// Get the lexicographical topological sorted nodes from the provided DAG
Expand Down

0 comments on commit b92db69

Please sign in to comment.