diff --git a/ci/test_python.sh b/ci/test_python.sh index f62caca92a6..500bc2f3467 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -111,6 +111,11 @@ popd rapids-logger "pytest networkx using nx-cugraph backend" pushd python/nx-cugraph ./run_nx_tests.sh +# Individually run tests that are skipped above b/c they may run out of memory +PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAG and test_antichains" +PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestMultiDiGraph_DAGLCA and test_all_pairs_lca_pairs_without_lca" +PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAGLCA and test_all_pairs_lca_pairs_without_lca" +PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestEfficiency and test_using_ego_graph" # run_nx_tests.sh outputs coverage data, so check that total coverage is >0.0% # in case nx-cugraph failed to load but fallback mode allowed the run to pass. _coverage=$(coverage report|grep "^TOTAL") diff --git a/python/nx-cugraph/_nx_cugraph/__init__.py b/python/nx-cugraph/_nx_cugraph/__init__.py index 1fd436bb845..d02c9c3e940 100644 --- a/python/nx-cugraph/_nx_cugraph/__init__.py +++ b/python/nx-cugraph/_nx_cugraph/__init__.py @@ -29,8 +29,14 @@ # "description": "TODO", "functions": { # BEGIN: functions + "ancestors", "barbell_graph", "betweenness_centrality", + "bfs_edges", + "bfs_layers", + "bfs_predecessors", + "bfs_successors", + "bfs_tree", "bull_graph", "caveman_graph", "chvatal_graph", @@ -44,6 +50,8 @@ "davis_southern_women_graph", "degree_centrality", "desargues_graph", + "descendants", + "descendants_at_distance", "diamond_graph", "dodecahedral_graph", "edge_betweenness_centrality", @@ -53,6 +61,7 @@ "from_pandas_edgelist", "from_scipy_sparse_array", "frucht_graph", + "generic_bfs_edges", "heawood_graph", "hits", "house_graph", @@ -99,9 +108,14 @@ "extra_docstrings": { # BEGIN: extra_docstrings "betweenness_centrality": "`weight` parameter is not yet supported.", + "bfs_edges": "`sort_neighbors` parameter is not yet supported.", + "bfs_predecessors": "`sort_neighbors` parameter is not yet supported.", + "bfs_successors": "`sort_neighbors` parameter is not yet supported.", + "bfs_tree": "`sort_neighbors` parameter is not yet supported.", "edge_betweenness_centrality": "`weight` parameter is not yet supported.", "eigenvector_centrality": "`nstart` parameter is not used, but it is checked for validity.", "from_pandas_edgelist": "cudf.DataFrame inputs also supported.", + "generic_bfs_edges": "`neighbors` and `sort_neighbors` parameters are not yet supported.", "k_truss": ( "Currently raises `NotImplementedError` for graphs with more than one connected\n" "component when k >= 3. We expect to fix this soon." diff --git a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py index e4947491555..d28a629fe63 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py @@ -17,11 +17,14 @@ components, link_analysis, shortest_paths, + traversal, ) from .bipartite import complete_bipartite_graph from .centrality import * from .components import * from .core import * +from .dag import * from .isolate import * from .link_analysis import * from .shortest_paths import * +from .traversal import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/core.py b/python/nx-cugraph/nx_cugraph/algorithms/core.py index 390598d070e..c00df2d832f 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/core.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/core.py @@ -15,7 +15,12 @@ import pylibcugraph as plc import nx_cugraph as nxcg -from nx_cugraph.utils import _get_int_dtype, networkx_algorithm, not_implemented_for +from nx_cugraph.utils import ( + _get_int_dtype, + index_dtype, + networkx_algorithm, + not_implemented_for, +) __all__ = ["k_truss"] @@ -81,10 +86,8 @@ def k_truss(G, k): edge_values = {key: val[edge_indices] for key, val in G.edge_values.items()} edge_masks = {key: val[edge_indices] for key, val in G.edge_masks.items()} # Renumber step 2: edge indices - mapper = cp.zeros(len(G), src_indices.dtype) - mapper[node_indices] = cp.arange(node_indices.size, dtype=mapper.dtype) - src_indices = mapper[src_indices] - dst_indices = mapper[dst_indices] + src_indices = cp.searchsorted(node_indices, src_indices).astype(index_dtype) + dst_indices = cp.searchsorted(node_indices, dst_indices).astype(index_dtype) # Renumber step 3: node values node_values = {key: val[node_indices] for key, val in G.node_values.items()} node_masks = {key: val[node_indices] for key, val in G.node_masks.items()} diff --git a/python/nx-cugraph/nx_cugraph/algorithms/dag.py b/python/nx-cugraph/nx_cugraph/algorithms/dag.py new file mode 100644 index 00000000000..067cfed9101 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/dag.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp +import networkx as nx +import numpy as np +import pylibcugraph as plc + +from nx_cugraph.convert import _to_graph +from nx_cugraph.utils import index_dtype, networkx_algorithm + +__all__ = [ + "descendants", + "ancestors", +] + + +def _ancestors_and_descendants(G, source, *, is_ancestors): + G = _to_graph(G) + if source not in G: + hash(source) # To raise TypeError if appropriate + raise nx.NetworkXError( + f"The node {source} is not in the {G.__class__.__name__.lower()}." + ) + src_index = source if G.key_to_id is None else G.key_to_id[source] + distances, predecessors, node_ids = plc.bfs( + handle=plc.ResourceHandle(), + graph=G._get_plc_graph(switch_indices=is_ancestors), + sources=cp.array([src_index], dtype=index_dtype), + direction_optimizing=False, + depth_limit=-1, + compute_predecessors=False, + do_expensive_check=False, + ) + mask = (distances != np.iinfo(distances.dtype).max) & (distances != 0) + return G._nodearray_to_set(node_ids[mask]) + + +@networkx_algorithm +def descendants(G, source): + return _ancestors_and_descendants(G, source, is_ancestors=False) + + +@networkx_algorithm +def ancestors(G, source): + return _ancestors_and_descendants(G, source, is_ancestors=True) diff --git a/python/nx-cugraph/nx_cugraph/algorithms/traversal/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/traversal/__init__.py new file mode 100644 index 00000000000..1751cd46919 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/traversal/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .breadth_first_search import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/traversal/breadth_first_search.py b/python/nx-cugraph/nx_cugraph/algorithms/traversal/breadth_first_search.py new file mode 100644 index 00000000000..e2a7d46f462 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/traversal/breadth_first_search.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import repeat + +import cupy as cp +import networkx as nx +import numpy as np +import pylibcugraph as plc + +import nx_cugraph as nxcg +from nx_cugraph.convert import _to_graph +from nx_cugraph.utils import _groupby, index_dtype, networkx_algorithm + +__all__ = [ + "bfs_edges", + "bfs_tree", + "bfs_predecessors", + "bfs_successors", + "descendants_at_distance", + "bfs_layers", + "generic_bfs_edges", +] + + +def _check_G_and_source(G, source): + G = _to_graph(G) + if source not in G: + hash(source) # To raise TypeError if appropriate + raise nx.NetworkXError( + f"The node {source} is not in the {G.__class__.__name__.lower()}." + ) + return G + + +def _bfs(G, source, *, depth_limit=None, reverse=False): + src_index = source if G.key_to_id is None else G.key_to_id[source] + distances, predecessors, node_ids = plc.bfs( + handle=plc.ResourceHandle(), + graph=G._get_plc_graph(switch_indices=reverse), + sources=cp.array([src_index], dtype=index_dtype), + direction_optimizing=False, + depth_limit=-1 if depth_limit is None else depth_limit, + compute_predecessors=True, + do_expensive_check=False, + ) + mask = predecessors >= 0 + return distances[mask], predecessors[mask], node_ids[mask] + + +@networkx_algorithm +def generic_bfs_edges(G, source, neighbors=None, depth_limit=None, sort_neighbors=None): + """`neighbors` and `sort_neighbors` parameters are not yet supported.""" + return bfs_edges(source, depth_limit=depth_limit) + + +@generic_bfs_edges._can_run +def _(G, source, neighbors=None, depth_limit=None, sort_neighbors=None): + return neighbors is None and sort_neighbors is None + + +@networkx_algorithm +def bfs_edges(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """`sort_neighbors` parameter is not yet supported.""" + G = _check_G_and_source(G, source) + if depth_limit is not None and depth_limit < 1: + return + distances, predecessors, node_ids = _bfs( + G, source, depth_limit=depth_limit, reverse=reverse + ) + # Using groupby like this is similar to bfs_predecessors + groups = _groupby([distances, predecessors], node_ids) + id_to_key = G.id_to_key + for key in sorted(groups): + children_ids = groups[key] + parent_id = key[1] + parent = id_to_key[parent_id] if id_to_key is not None else parent_id + yield from zip( + repeat(parent, children_ids.size), + G._nodeiter_to_iter(children_ids.tolist()), + ) + + +@bfs_edges._can_run +def _(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + return sort_neighbors is None + + +@networkx_algorithm +def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """`sort_neighbors` parameter is not yet supported.""" + G = _check_G_and_source(G, source) + if depth_limit is not None and depth_limit < 1: + return nxcg.DiGraph.from_coo( + 1, + cp.array([], dtype=index_dtype), + cp.array([], dtype=index_dtype), + id_to_key=[source], + ) + + distances, predecessors, node_ids = _bfs( + G, + source, + depth_limit=depth_limit, + reverse=reverse, + ) + if predecessors.size == 0: + return nxcg.DiGraph.from_coo( + 1, + cp.array([], dtype=index_dtype), + cp.array([], dtype=index_dtype), + id_to_key=[source], + ) + # TODO: create renumbering helper function(s) + unique_node_ids = cp.unique(cp.hstack((predecessors, node_ids))) + # Renumber edges + src_indices = cp.searchsorted(unique_node_ids, predecessors).astype(index_dtype) + dst_indices = cp.searchsorted(unique_node_ids, node_ids).astype(index_dtype) + # Renumber nodes + if (id_to_key := G.id_to_key) is not None: + key_to_id = { + id_to_key[old_index]: new_index + for new_index, old_index in enumerate(unique_node_ids.tolist()) + } + else: + key_to_id = { + old_index: new_index + for new_index, old_index in enumerate(unique_node_ids.tolist()) + } + return nxcg.DiGraph.from_coo( + unique_node_ids.size, + src_indices, + dst_indices, + key_to_id=key_to_id, + ) + + +@bfs_tree._can_run +def _(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + return sort_neighbors is None + + +@networkx_algorithm +def bfs_successors(G, source, depth_limit=None, sort_neighbors=None): + """`sort_neighbors` parameter is not yet supported.""" + G = _check_G_and_source(G, source) + if depth_limit is not None and depth_limit < 1: + yield (source, []) + return + + distances, predecessors, node_ids = _bfs(G, source, depth_limit=depth_limit) + groups = _groupby([distances, predecessors], node_ids) + id_to_key = G.id_to_key + for key in sorted(groups): + children_ids = groups[key] + parent_id = key[1] + parent = id_to_key[parent_id] if id_to_key is not None else parent_id + children = G._nodearray_to_list(children_ids) + yield (parent, children) + + +@bfs_successors._can_run +def _(G, source, depth_limit=None, sort_neighbors=None): + return sort_neighbors is None + + +@networkx_algorithm +def bfs_layers(G, sources): + G = _to_graph(G) + if sources in G: + sources = [sources] + else: + sources = set(sources) + if not all(source in G for source in sources): + node = next(source for source in sources if source not in G) + raise nx.NetworkXError(f"The node {node} is not in the graph.") + sources = list(sources) + source_ids = G._list_to_nodearray(sources) + distances, predecessors, node_ids = plc.bfs( + handle=plc.ResourceHandle(), + graph=G._get_plc_graph(), + sources=source_ids, + direction_optimizing=False, + depth_limit=-1, + compute_predecessors=False, + do_expensive_check=False, + ) + mask = distances != np.iinfo(distances.dtype).max + distances = distances[mask] + node_ids = node_ids[mask] + groups = _groupby(distances, node_ids) + return (G._nodearray_to_list(groups[key]) for key in range(len(groups))) + + +@networkx_algorithm +def bfs_predecessors(G, source, depth_limit=None, sort_neighbors=None): + """`sort_neighbors` parameter is not yet supported.""" + G = _check_G_and_source(G, source) + if depth_limit is not None and depth_limit < 1: + return + + distances, predecessors, node_ids = _bfs(G, source, depth_limit=depth_limit) + # We include `predecessors` in the groupby for "nicer" iteration order + groups = _groupby([distances, predecessors], node_ids) + id_to_key = G.id_to_key + for key in sorted(groups): + children_ids = groups[key] + parent_id = key[1] + parent = id_to_key[parent_id] if id_to_key is not None else parent_id + yield from zip( + G._nodeiter_to_iter(children_ids.tolist()), + repeat(parent, children_ids.size), + ) + + +@bfs_predecessors._can_run +def _(G, source, depth_limit=None, sort_neighbors=None): + return sort_neighbors is None + + +@networkx_algorithm +def descendants_at_distance(G, source, distance): + G = _check_G_and_source(G, source) + if distance is None or distance < 0: + return set() + if distance == 0: + return {source} + + src_index = source if G.key_to_id is None else G.key_to_id[source] + distances, predecessors, node_ids = plc.bfs( + handle=plc.ResourceHandle(), + graph=G._get_plc_graph(), + sources=cp.array([src_index], dtype=index_dtype), + direction_optimizing=False, + depth_limit=distance, + compute_predecessors=False, + do_expensive_check=False, + ) + mask = distances == distance + node_ids = node_ids[mask] + return G._nodearray_to_set(node_ids) diff --git a/python/nx-cugraph/nx_cugraph/classes/graph.py b/python/nx-cugraph/nx_cugraph/classes/graph.py index e32f93d8bfe..cdd3f744f24 100644 --- a/python/nx-cugraph/nx_cugraph/classes/graph.py +++ b/python/nx-cugraph/nx_cugraph/classes/graph.py @@ -458,6 +458,24 @@ def has_edge(self, u: NodeKey, v: NodeKey) -> bool: return False return bool(((self.src_indices == u) & (self.dst_indices == v)).any()) + def _neighbors(self, n: NodeKey) -> cp.ndarray[NodeValue]: + if n not in self: + hash(n) # To raise TypeError if appropriate + raise nx.NetworkXError( + f"The node {n} is not in the {self.__class__.__name__.lower()}." + ) + if self.key_to_id is not None: + n = self.key_to_id[n] + nbrs = self.dst_indices[self.src_indices == n] + if self.is_multigraph(): + nbrs = cp.unique(nbrs) + return nbrs + + @networkx_api + def neighbors(self, n: NodeKey) -> Iterator[NodeKey]: + nbrs = self._neighbors(n) + return iter(self._nodeiter_to_iter(nbrs.tolist())) + @networkx_api def has_node(self, n: NodeKey) -> bool: return n in self @@ -701,6 +719,11 @@ def _nodearray_to_list(self, node_ids: cp.ndarray[IndexValue]) -> list[NodeKey]: return node_ids.tolist() return list(self._nodeiter_to_iter(node_ids.tolist())) + def _list_to_nodearray(self, nodes: list[NodeKey]) -> cp.ndarray[IndexValue]: + if (key_to_id := self.key_to_id) is not None: + nodes = [key_to_id[node] for node in nodes] + return cp.array(nodes, dtype=index_dtype) + def _nodearray_to_set(self, node_ids: cp.ndarray[IndexValue]) -> set[NodeKey]: if self.key_to_id is None: return set(node_ids.tolist()) diff --git a/python/nx-cugraph/nx_cugraph/convert.py b/python/nx-cugraph/nx_cugraph/convert.py index 3c0814370d3..f265540a161 100644 --- a/python/nx-cugraph/nx_cugraph/convert.py +++ b/python/nx-cugraph/nx_cugraph/convert.py @@ -39,6 +39,24 @@ REQUIRED = ... +def _iterate_values(graph, adj, is_dicts, func): + # Using `dict.values` is faster and is the common case, but it doesn't always work + if is_dicts is not False: + it = concat(map(dict.values, adj.values())) + if graph is not None and graph.is_multigraph(): + it = concat(map(dict.values, it)) + try: + return func(it), True + except TypeError: + if is_dicts is True: + raise + # May not be regular dicts + it = concat(x.values() for x in adj.values()) + if graph is not None and graph.is_multigraph(): + it = concat(x.values() for x in it) + return func(it), False + + def from_networkx( graph: nx.Graph, edge_attrs: AttrKey | dict[AttrKey, EdgeValue | None] | None = None, @@ -152,6 +170,7 @@ def from_networkx( if isinstance(adj, nx.classes.coreviews.FilterAdjacency): adj = {k: dict(v) for k, v in adj.items()} + is_dicts = None N = len(adj) if ( not preserve_edge_attrs @@ -162,12 +181,9 @@ def from_networkx( # Either we weren't asked to preserve edge attributes, or there are no edges edge_attrs = None elif preserve_edge_attrs: - # Using comprehensions should be just as fast starting in Python 3.11 - it = concat(map(dict.values, adj.values())) - if graph.is_multigraph(): - it = concat(map(dict.values, it)) - # PERF: should we add `filter(None, ...)` to remove empty data dicts? - attr_sets = set(map(frozenset, it)) + attr_sets, is_dicts = _iterate_values( + graph, adj, is_dicts, lambda it: set(map(frozenset, it)) + ) attrs = frozenset.union(*attr_sets) edge_attrs = dict.fromkeys(attrs, REQUIRED) if len(attr_sets) > 1: @@ -207,10 +223,9 @@ def from_networkx( del edge_attrs[attr] # Else some edges have attribute (default already None) else: - it = concat(map(dict.values, adj.values())) - if graph.is_multigraph(): - it = concat(map(dict.values, it)) - attr_sets = set(map(required.intersection, it)) + attr_sets, is_dicts = _iterate_values( + graph, adj, is_dicts, lambda it: set(map(required.intersection, it)) + ) for attr in required - frozenset.union(*attr_sets): # No edges have these attributes del edge_attrs[attr] @@ -269,17 +284,19 @@ def from_networkx( dst_iter = map(key_to_id.__getitem__, dst_iter) if graph.is_multigraph(): dst_indices = np.fromiter(dst_iter, index_dtype) - num_multiedges = np.fromiter( - map(len, concat(map(dict.values, adj.values()))), index_dtype + num_multiedges, is_dicts = _iterate_values( + None, adj, is_dicts, lambda it: np.fromiter(map(len, it), index_dtype) ) # cp.repeat is slow to use here, so use numpy instead dst_indices = cp.array(np.repeat(dst_indices, num_multiedges)) # Determine edge keys and edge ids for multigraphs - edge_keys = list(concat(concat(map(dict.values, adj.values())))) - edge_indices = cp.fromiter( - concat(map(range, map(len, concat(map(dict.values, adj.values()))))), - index_dtype, - ) + if is_dicts: + edge_keys = list(concat(concat(map(dict.values, adj.values())))) + it = concat(map(dict.values, adj.values())) + else: + edge_keys = list(concat(concat(x.values() for x in adj.values()))) + it = concat(x.values() for x in adj.values()) + edge_indices = cp.fromiter(concat(map(range, map(len, it))), index_dtype) if edge_keys == edge_indices.tolist(): edge_keys = None # Prefer edge_indices else: @@ -323,19 +340,21 @@ def from_networkx( edge_masks[edge_attr] = cp.fromiter(iter_mask, bool) edge_values[edge_attr] = cp.array(vals, dtype) # if vals.ndim > 1: ... + elif edge_default is REQUIRED: + if dtype is None: + + def func(it, edge_attr=edge_attr): + return cp.array(list(map(op.itemgetter(edge_attr), it))) + + else: + + def func(it, edge_attr=edge_attr, dtype=dtype): + return cp.fromiter(map(op.itemgetter(edge_attr), it), dtype) + + edge_value, is_dicts = _iterate_values(graph, adj, is_dicts, func) + edge_values[edge_attr] = edge_value else: - if edge_default is REQUIRED: - # Using comprehensions should be fast starting in Python 3.11 - # iter_values = ( - # edgedata[edge_attr] - # for rowdata in adj.values() - # for edgedata in rowdata.values() - # ) - it = concat(map(dict.values, adj.values())) - if graph.is_multigraph(): - it = concat(map(dict.values, it)) - iter_values = map(op.itemgetter(edge_attr), it) - elif graph.is_multigraph(): + if graph.is_multigraph(): iter_values = ( edgedata.get(edge_attr, edge_default) for rowdata in adj.values() @@ -352,7 +371,7 @@ def from_networkx( edge_values[edge_attr] = cp.array(list(iter_values)) else: edge_values[edge_attr] = cp.fromiter(iter_values, dtype) - # if vals.ndim > 1: ... + # if vals.ndim > 1: ... # cp.repeat is slow to use here, so use numpy instead src_indices = np.repeat( diff --git a/python/nx-cugraph/nx_cugraph/convert_matrix.py b/python/nx-cugraph/nx_cugraph/convert_matrix.py index 6c8b8fb4a1d..80ca0c2fa4b 100644 --- a/python/nx-cugraph/nx_cugraph/convert_matrix.py +++ b/python/nx-cugraph/nx_cugraph/convert_matrix.py @@ -36,6 +36,7 @@ def from_pandas_edgelist( graph_class, inplace = _create_using_class(create_using) src_array = df[source].to_numpy() dst_array = df[target].to_numpy() + # TODO: create renumbering helper function(s) # Renumber step 0: node keys nodes = np.unique(np.concatenate([src_array, dst_array])) N = nodes.size diff --git a/python/nx-cugraph/nx_cugraph/interface.py b/python/nx-cugraph/nx_cugraph/interface.py index be6b3596030..3f6449f571a 100644 --- a/python/nx-cugraph/nx_cugraph/interface.py +++ b/python/nx-cugraph/nx_cugraph/interface.py @@ -12,6 +12,7 @@ # limitations under the License. from __future__ import annotations +import os import sys import networkx as nx @@ -246,7 +247,51 @@ def key(testpath): key("test_tree_isomorphism.py:test_positive"): too_slow, key("test_tree_isomorphism.py:test_negative"): too_slow, key("test_efficiency.py:TestEfficiency.test_using_ego_graph"): maybe_oom, + key("test_dag.py:TestDAG.test_antichains"): maybe_oom, + key( + "test_lowest_common_ancestors.py:" + "TestDAGLCA.test_all_pairs_lca_pairs_without_lca" + ): maybe_oom, + key( + "test_lowest_common_ancestors.py:" + "TestMultiDiGraph_DAGLCA.test_all_pairs_lca_pairs_without_lca" + ): maybe_oom, + # These repeatedly call `bfs_layers`, which converts the graph every call + key( + "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph2_different_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph3_same_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph3_different_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph4_same_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp." + "test_disconnected_graph_all_same_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp." + "test_disconnected_graph_all_different_labels" + ): too_slow, + key( + "test_vf2pp.py:TestGraphISOVF2pp." + "test_disconnected_graph_some_same_labels" + ): too_slow, + key( + "test_vf2pp.py:TestMultiGraphISOVF2pp." + "test_custom_multigraph3_same_labels" + ): too_slow, + key( + "test_vf2pp_helpers.py:TestNodeOrdering." + "test_matching_order_all_branches" + ): too_slow, } + if os.environ.get("PYTEST_NO_SKIP", False): + skip.clear() for item in items: kset = set(item.keywords) diff --git a/python/nx-cugraph/nx_cugraph/utils/misc.py b/python/nx-cugraph/nx_cugraph/utils/misc.py index e303375918d..aa06d7fd29b 100644 --- a/python/nx-cugraph/nx_cugraph/utils/misc.py +++ b/python/nx-cugraph/nx_cugraph/utils/misc.py @@ -58,16 +58,18 @@ def pairwise(it): def _groupby( - groups: cp.ndarray, values: cp.ndarray, groups_are_canonical: bool = False + groups: cp.ndarray | list[cp.ndarray], + values: cp.ndarray | list[cp.ndarray], + groups_are_canonical: bool = False, ) -> dict[int, cp.ndarray]: """Perform a groupby operation given an array of group IDs and array of values. Parameters ---------- - groups : cp.ndarray - Array that holds the group IDs. - values : cp.ndarray - Array of values to be grouped according to groups. + groups : cp.ndarray or list of cp.ndarray + Array or list of arrays that holds the group IDs. + values : cp.ndarray or list of cp.ndarray + Array or list of arrays of values to be grouped according to groups. Must be the same size as groups array. groups_are_canonical : bool, default False Whether the group IDs are consecutive integers beginning with 0. @@ -76,18 +78,42 @@ def _groupby( ------- dict with group IDs as keys and cp.ndarray as values. """ - if groups.size == 0: - return {} - sort_indices = cp.argsort(groups) - sorted_groups = groups[sort_indices] - sorted_values = values[sort_indices] - prepend = 1 if groups_are_canonical else sorted_groups[0] + 1 - left_bounds = cp.nonzero(cp.diff(sorted_groups, prepend=prepend))[0] - boundaries = pairwise(itertools.chain(left_bounds.tolist(), [groups.size])) + if isinstance(groups, list): + if groups_are_canonical: + raise ValueError( + "`groups_are_canonical=True` is not allowed when `groups` is a list." + ) + if len(groups) == 0 or (size := groups[0].size) == 0: + return {} + sort_indices = cp.lexsort(cp.vstack(groups[::-1])) + sorted_groups = cp.vstack([group[sort_indices] for group in groups]) + prepend = sorted_groups[:, 0].max() + 1 + changed = cp.abs(cp.diff(sorted_groups, prepend=prepend)).sum(axis=0) + changed[0] = 1 + left_bounds = cp.nonzero(changed)[0] + else: + if (size := groups.size) == 0: + return {} + sort_indices = cp.argsort(groups) + sorted_groups = groups[sort_indices] + prepend = 1 if groups_are_canonical else sorted_groups[0] + 1 + left_bounds = cp.nonzero(cp.diff(sorted_groups, prepend=prepend))[0] + if isinstance(values, list): + sorted_values = [vals[sort_indices] for vals in values] + else: + sorted_values = values[sort_indices] + boundaries = pairwise(itertools.chain(left_bounds.tolist(), [size])) if groups_are_canonical: it = enumerate(boundaries) + elif isinstance(groups, list): + it = zip(map(tuple, sorted_groups.T[left_bounds].tolist()), boundaries) else: it = zip(sorted_groups[left_bounds].tolist(), boundaries) + if isinstance(values, list): + return { + group: [sorted_vals[start:end] for sorted_vals in sorted_values] + for group, (start, end) in it + } return {group: sorted_values[start:end] for group, (start, end) in it}