-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nx-cugraph: add ego_graph
#4395
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
5c6dc05
Add `ego_graph`
eriknw 3030944
Merge branch 'branch-24.06' into add_ego_graph
eriknw ac11cb1
oops forgot `__all__`
eriknw 2e77df5
better docstring formatting (if a little awkward)
eriknw 0de940c
Update copyright years
eriknw 6ce039b
Merge branch 'branch-24.06' into add_ego_graph
eriknw dacc926
oops fix for Python 3.9 and nx 3.2
eriknw 71c98ad
Merge branch 'add_ego_graph' of github.com:eriknw/cugraph into add_eg…
eriknw 17baf34
Merge branch 'branch-24.06' into add_ego_graph
eriknw 3e3b7cb
Merge branch 'branch-24.06' into add_ego_graph
eriknw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import math | ||
|
||
import cupy as cp | ||
import networkx as nx | ||
import numpy as np | ||
import pylibcugraph as plc | ||
|
||
import nx_cugraph as nxcg | ||
|
||
from ..utils import _dtype_param, _get_float_dtype, index_dtype, networkx_algorithm | ||
|
||
__all__ = ["ego_graph"] | ||
|
||
|
||
@networkx_algorithm( | ||
extra_params=_dtype_param, version_added="24.06", _plc={"bfs", "ego_graph", "sssp"} | ||
) | ||
def ego_graph( | ||
G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None | ||
): | ||
"""Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.""" # noqa: E501 | ||
if isinstance(G, nx.Graph): | ||
G = nxcg.from_networkx(G, preserve_all_attrs=True) | ||
if n not in G: | ||
if distance is None: | ||
raise nx.NodeNotFound(f"Source {n} is not in G") | ||
raise nx.NodeNotFound(f"Node {n} not found in graph") | ||
src_index = n if G.key_to_id is None else G.key_to_id[n] | ||
symmetrize = "union" if undirected and G.is_directed() else None | ||
if distance is None or distance not in G.edge_values: | ||
# Simple BFS to determine nodes | ||
if radius is not None and radius <= 0: | ||
if center: | ||
node_ids = cp.array([src_index], dtype=index_dtype) | ||
else: | ||
node_ids = cp.empty(0, dtype=index_dtype) | ||
node_mask = None | ||
else: | ||
if radius is None or np.isinf(radius): | ||
radius = -1 | ||
else: | ||
radius = math.ceil(radius) | ||
distances, unused_predecessors, node_ids = plc.bfs( | ||
handle=plc.ResourceHandle(), | ||
graph=G._get_plc_graph(symmetrize=symmetrize), | ||
sources=cp.array([src_index], index_dtype), | ||
direction_optimizing=False, # True for undirected only; what's best? | ||
depth_limit=radius, | ||
compute_predecessors=False, | ||
do_expensive_check=False, | ||
) | ||
node_mask = distances != np.iinfo(distances.dtype).max | ||
else: | ||
# SSSP to determine nodes | ||
if callable(distance): | ||
raise NotImplementedError("callable `distance` argument is not supported") | ||
if symmetrize and G.is_multigraph(): | ||
# G._get_plc_graph does not implement `symmetrize=True` w/ edge array | ||
raise NotImplementedError( | ||
"Weighted ego_graph with undirected=True not implemented" | ||
) | ||
# Check for negative values since we don't support negative cycles | ||
edge_vals = G.edge_values[distance] | ||
if distance in G.edge_masks: | ||
edge_vals = edge_vals[G.edge_masks[distance]] | ||
if (edge_vals < 0).any(): | ||
raise NotImplementedError( | ||
"Negative edge weights not yet supported by ego_graph" | ||
) | ||
# PERF: we could use BFS if all edges are equal | ||
if radius is None: | ||
radius = np.inf | ||
dtype = _get_float_dtype(dtype, graph=G, weight=distance) | ||
node_ids, distances, unused_predecessors = plc.sssp( | ||
resource_handle=plc.ResourceHandle(), | ||
graph=(G.to_undirected() if symmetrize else G)._get_plc_graph( | ||
distance, 1, dtype | ||
), | ||
source=src_index, | ||
cutoff=np.nextafter(radius, np.inf, dtype=np.float64), | ||
compute_predecessors=True, # TODO: False is not yet supported | ||
do_expensive_check=False, | ||
) | ||
node_mask = distances != np.finfo(distances.dtype).max | ||
|
||
if node_mask is not None: | ||
if not center: | ||
node_mask &= node_ids != src_index | ||
node_ids = node_ids[node_mask] | ||
if node_ids.size == G._N: | ||
return G.copy() | ||
# TODO: create renumbering helper function(s) | ||
node_ids.sort() # TODO: is this ever necessary? Keep for safety | ||
node_values = {key: val[node_ids] for key, val in G.node_values.items()} | ||
node_masks = {key: val[node_ids] for key, val in G.node_masks.items()} | ||
|
||
G._sort_edge_indices() # TODO: is this ever necessary? Keep for safety | ||
edge_mask = cp.isin(G.src_indices, node_ids) & cp.isin(G.dst_indices, node_ids) | ||
src_indices = cp.searchsorted(node_ids, G.src_indices[edge_mask]).astype( | ||
index_dtype | ||
) | ||
dst_indices = cp.searchsorted(node_ids, G.dst_indices[edge_mask]).astype( | ||
index_dtype | ||
) | ||
edge_values = {key: val[edge_mask] for key, val in G.edge_values.items()} | ||
edge_masks = {key: val[edge_mask] for key, val in G.edge_masks.items()} | ||
|
||
# Renumber nodes | ||
if (id_to_key := G.id_to_key) is not None: | ||
key_to_id = { | ||
id_to_key[old_index]: new_index | ||
for new_index, old_index in enumerate(node_ids.tolist()) | ||
} | ||
else: | ||
key_to_id = { | ||
old_index: new_index | ||
for new_index, old_index in enumerate(node_ids.tolist()) | ||
} | ||
kwargs = { | ||
"N": node_ids.size, | ||
"src_indices": src_indices, | ||
"dst_indices": dst_indices, | ||
"edge_values": edge_values, | ||
"edge_masks": edge_masks, | ||
"node_values": node_values, | ||
"node_masks": node_masks, | ||
"key_to_id": key_to_id, | ||
} | ||
if G.is_multigraph(): | ||
if G.edge_keys is not None: | ||
kwargs["edge_keys"] = [ | ||
x for x, m in zip(G.edge_keys, edge_mask.tolist()) if m | ||
] | ||
if G.edge_indices is not None: | ||
kwargs["edge_indices"] = G.edge_indices[edge_mask] | ||
rv = G.__class__.from_coo(**kwargs) | ||
rv.graph.update(G.graph) | ||
return rv | ||
|
||
|
||
@ego_graph._can_run | ||
def _(G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None): | ||
if distance is not None and undirected and G.is_directed() and G.is_multigraph(): | ||
return "Weighted ego_graph with undirected=True not implemented" | ||
if distance is not None and nx.is_negatively_weighted(G, weight=distance): | ||
return "Weighted ego_graph with negative cycles not yet supported" | ||
if callable(distance): | ||
return "callable `distance` argument is not supported" | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import networkx as nx | ||
import pytest | ||
from packaging.version import parse | ||
|
||
import nx_cugraph as nxcg | ||
|
||
from .testing_utils import assert_graphs_equal | ||
|
||
nxver = parse(nx.__version__) | ||
|
||
|
||
if nxver.major == 3 and nxver.minor < 2: | ||
pytest.skip("Need NetworkX >=3.2 to test ego_graph", allow_module_level=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"create_using", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] | ||
) | ||
@pytest.mark.parametrize("radius", [-1, 0, 1, 1.5, 2, float("inf"), None]) | ||
@pytest.mark.parametrize("center", [True, False]) | ||
@pytest.mark.parametrize("undirected", [False, True]) | ||
@pytest.mark.parametrize("multiple_edges", [False, True]) | ||
@pytest.mark.parametrize("n", [0, 3]) | ||
def test_ego_graph_cycle_graph( | ||
create_using, radius, center, undirected, multiple_edges, n | ||
): | ||
Gnx = nx.cycle_graph(7, create_using=create_using) | ||
if multiple_edges: | ||
# Test multigraph with multiple edges | ||
if not Gnx.is_multigraph(): | ||
return | ||
Gnx.add_edges_from(nx.cycle_graph(7, create_using=nx.DiGraph).edges) | ||
Gnx.add_edge(0, 1, 10) | ||
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True) | ||
assert_graphs_equal(Gnx, Gcg) # Sanity check | ||
|
||
kwargs = {"radius": radius, "center": center, "undirected": undirected} | ||
Hnx = nx.ego_graph(Gnx, n, **kwargs) | ||
Hcg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph") | ||
assert_graphs_equal(Hnx, Hcg) | ||
with pytest.raises(nx.NodeNotFound, match="not in G"): | ||
nx.ego_graph(Gnx, -1, **kwargs) | ||
with pytest.raises(nx.NodeNotFound, match="not in G"): | ||
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph") | ||
# Using sssp with default weight of 1 should give same answer as bfs | ||
nx.set_edge_attributes(Gnx, 1, name="weight") | ||
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True) | ||
assert_graphs_equal(Gnx, Gcg) # Sanity check | ||
|
||
kwargs["distance"] = "weight" | ||
H2nx = nx.ego_graph(Gnx, n, **kwargs) | ||
is_nx32 = nxver.major == 3 and nxver.minor == 2 | ||
if undirected and Gnx.is_directed() and Gnx.is_multigraph(): | ||
if is_nx32: | ||
# `should_run` was added in nx 3.3 | ||
match = "Weighted ego_graph with undirected=True not implemented" | ||
else: | ||
match = "not implemented by cugraph" | ||
with pytest.raises(RuntimeError, match=match): | ||
nx.ego_graph(Gnx, n, **kwargs, backend="cugraph") | ||
with pytest.raises(NotImplementedError, match="ego_graph"): | ||
nx.ego_graph(Gcg, n, **kwargs) | ||
else: | ||
H2cg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph") | ||
assert_graphs_equal(H2nx, H2cg) | ||
with pytest.raises(nx.NodeNotFound, match="not found in graph"): | ||
nx.ego_graph(Gnx, -1, **kwargs) | ||
with pytest.raises(nx.NodeNotFound, match="not found in graph"): | ||
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent, thanks!