Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nx-cugraph: add ego_graph #4395

Merged
merged 10 commits into from
May 21, 2024
2 changes: 2 additions & 0 deletions python/nx-cugraph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ Below is the list of algorithms that are currently supported in nx-cugraph.
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.classic.wheel_graph.html#networkx.generators.classic.wheel_graph">wheel_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.community">community</a>
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.community.caveman_graph.html#networkx.generators.community.caveman_graph">caveman_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.ego">ego</a>
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.ego.ego_graph.html#networkx.generators.ego.ego_graph">ego_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.small">small</a>
├─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.small.bull_graph.html#networkx.generators.small.bull_graph">bull_graph</a>
├─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.small.chvatal_graph.html#networkx.generators.small.chvatal_graph">chvatal_graph</a>
Expand Down
5 changes: 5 additions & 0 deletions python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"diamond_graph",
"dodecahedral_graph",
"edge_betweenness_centrality",
"ego_graph",
"eigenvector_centrality",
"empty_graph",
"florentine_families_graph",
Expand Down Expand Up @@ -163,6 +164,7 @@
"clustering": "Directed graphs and `weight` parameter are not yet supported.",
"core_number": "Directed graphs are not yet supported.",
"edge_betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.",
"ego_graph": "Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.",
"eigenvector_centrality": "`nstart` parameter is not used, but it is checked for validity.",
"from_pandas_edgelist": "cudf.DataFrame inputs also supported; value columns with str is unsuppported.",
"generic_bfs_edges": "`neighbors` and `sort_neighbors` parameters are not yet supported.",
Expand Down Expand Up @@ -191,6 +193,9 @@
"bellman_ford_path_length": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
"ego_graph": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
"eigenvector_centrality": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
Expand Down
8 changes: 4 additions & 4 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
rev: v0.17
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -50,7 +50,7 @@ repos:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
Expand All @@ -62,7 +62,7 @@ repos:
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==7.0.0
- flake8-bugbear==24.4.21
- flake8-bugbear==24.4.26
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -77,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
9 changes: 7 additions & 2 deletions python/nx-cugraph/nx_cugraph/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -562,7 +562,12 @@ def to_networkx(G: nxcg.Graph, *, sort_edges: bool = False) -> nx.Graph:
dst_iter = map(id_to_key.__getitem__, dst_indices)
if G.is_multigraph() and (G.edge_keys is not None or G.edge_indices is not None):
if G.edge_keys is not None:
edge_keys = G.edge_keys
if not G.is_directed():
edge_keys = [k for k, m in zip(G.edge_keys, mask.tolist()) if m]
else:
edge_keys = G.edge_keys
elif not G.is_directed():
edge_keys = G.edge_indices[mask].tolist()
else:
edge_keys = G.edge_indices.tolist()
if edge_values:
Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -12,5 +12,6 @@
# limitations under the License.
from .classic import *
from .community import *
from .ego import *
from .small import *
from .social import *
161 changes: 161 additions & 0 deletions python/nx-cugraph/nx_cugraph/generators/ego.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import cupy as cp
import networkx as nx
import numpy as np
import pylibcugraph as plc

import nx_cugraph as nxcg

from ..utils import _dtype_param, _get_float_dtype, index_dtype, networkx_algorithm

__all__ = ["ego_graph"]


@networkx_algorithm(
extra_params=_dtype_param, version_added="24.06", _plc={"bfs", "ego_graph", "sssp"}
)
def ego_graph(
G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None
):
"""Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.""" # noqa: E501
if isinstance(G, nx.Graph):
G = nxcg.from_networkx(G, preserve_all_attrs=True)
if n not in G:
if distance is None:
raise nx.NodeNotFound(f"Source {n} is not in G")
raise nx.NodeNotFound(f"Node {n} not found in graph")
src_index = n if G.key_to_id is None else G.key_to_id[n]
symmetrize = "union" if undirected and G.is_directed() else None
if distance is None or distance not in G.edge_values:
# Simple BFS to determine nodes
if radius is not None and radius <= 0:
if center:
node_ids = cp.array([src_index], dtype=index_dtype)
else:
node_ids = cp.empty(0, dtype=index_dtype)
node_mask = None
else:
if radius is None or np.isinf(radius):
radius = -1
else:
radius = math.ceil(radius)
distances, unused_predecessors, node_ids = plc.bfs(
handle=plc.ResourceHandle(),
graph=G._get_plc_graph(symmetrize=symmetrize),
sources=cp.array([src_index], index_dtype),
direction_optimizing=False, # True for undirected only; what's best?
depth_limit=radius,
compute_predecessors=False,
do_expensive_check=False,
)
node_mask = distances != np.iinfo(distances.dtype).max
else:
# SSSP to determine nodes
if callable(distance):
raise NotImplementedError("callable `distance` argument is not supported")
if symmetrize and G.is_multigraph():
# G._get_plc_graph does not implement `symmetrize=True` w/ edge array
raise NotImplementedError(
"Weighted ego_graph with undirected=True not implemented"
)
# Check for negative values since we don't support negative cycles
edge_vals = G.edge_values[distance]
if distance in G.edge_masks:
edge_vals = edge_vals[G.edge_masks[distance]]
if (edge_vals < 0).any():
raise NotImplementedError(
"Negative edge weights not yet supported by ego_graph"
)
# PERF: we could use BFS if all edges are equal
if radius is None:
radius = np.inf
dtype = _get_float_dtype(dtype, graph=G, weight=distance)
node_ids, distances, unused_predecessors = plc.sssp(
resource_handle=plc.ResourceHandle(),
graph=(G.to_undirected() if symmetrize else G)._get_plc_graph(
distance, 1, dtype
),
source=src_index,
cutoff=np.nextafter(radius, np.inf, dtype=np.float64),
compute_predecessors=True, # TODO: False is not yet supported
do_expensive_check=False,
)
node_mask = distances != np.finfo(distances.dtype).max

if node_mask is not None:
if not center:
node_mask &= node_ids != src_index
node_ids = node_ids[node_mask]
if node_ids.size == G._N:
return G.copy()
# TODO: create renumbering helper function(s)
node_ids.sort() # TODO: is this ever necessary? Keep for safety
node_values = {key: val[node_ids] for key, val in G.node_values.items()}
node_masks = {key: val[node_ids] for key, val in G.node_masks.items()}

G._sort_edge_indices() # TODO: is this ever necessary? Keep for safety
edge_mask = cp.isin(G.src_indices, node_ids) & cp.isin(G.dst_indices, node_ids)
src_indices = cp.searchsorted(node_ids, G.src_indices[edge_mask]).astype(
index_dtype
)
dst_indices = cp.searchsorted(node_ids, G.dst_indices[edge_mask]).astype(
index_dtype
)
edge_values = {key: val[edge_mask] for key, val in G.edge_values.items()}
edge_masks = {key: val[edge_mask] for key, val in G.edge_masks.items()}

# Renumber nodes
if (id_to_key := G.id_to_key) is not None:
key_to_id = {
id_to_key[old_index]: new_index
for new_index, old_index in enumerate(node_ids.tolist())
}
else:
key_to_id = {
old_index: new_index
for new_index, old_index in enumerate(node_ids.tolist())
}
kwargs = {
"N": node_ids.size,
"src_indices": src_indices,
"dst_indices": dst_indices,
"edge_values": edge_values,
"edge_masks": edge_masks,
"node_values": node_values,
"node_masks": node_masks,
"key_to_id": key_to_id,
}
if G.is_multigraph():
if G.edge_keys is not None:
kwargs["edge_keys"] = [
x for x, m in zip(G.edge_keys, edge_mask.tolist()) if m
]
if G.edge_indices is not None:
kwargs["edge_indices"] = G.edge_indices[edge_mask]
rv = G.__class__.from_coo(**kwargs)
rv.graph.update(G.graph)
return rv


@ego_graph._can_run
def _(G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None):
if distance is not None and undirected and G.is_directed() and G.is_multigraph():
return "Weighted ego_graph with undirected=True not implemented"
if distance is not None and nx.is_negatively_weighted(G, weight=distance):
return "Weighted ego_graph with negative cycles not yet supported"
if callable(distance):
return "callable `distance` argument is not supported"
return True
81 changes: 81 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_ego_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import pytest
from packaging.version import parse

import nx_cugraph as nxcg

from .testing_utils import assert_graphs_equal

nxver = parse(nx.__version__)


if nxver.major == 3 and nxver.minor < 2:
pytest.skip("Need NetworkX >=3.2 to test ego_graph", allow_module_level=True)


@pytest.mark.parametrize(
"create_using", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph]
)
@pytest.mark.parametrize("radius", [-1, 0, 1, 1.5, 2, float("inf"), None])
@pytest.mark.parametrize("center", [True, False])
@pytest.mark.parametrize("undirected", [False, True])
@pytest.mark.parametrize("multiple_edges", [False, True])
@pytest.mark.parametrize("n", [0, 3])
def test_ego_graph_cycle_graph(
create_using, radius, center, undirected, multiple_edges, n
):
Gnx = nx.cycle_graph(7, create_using=create_using)
if multiple_edges:
# Test multigraph with multiple edges
if not Gnx.is_multigraph():
return
Gnx.add_edges_from(nx.cycle_graph(7, create_using=nx.DiGraph).edges)
Gnx.add_edge(0, 1, 10)
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True)
assert_graphs_equal(Gnx, Gcg) # Sanity check

kwargs = {"radius": radius, "center": center, "undirected": undirected}
Hnx = nx.ego_graph(Gnx, n, **kwargs)
Hcg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
assert_graphs_equal(Hnx, Hcg)
with pytest.raises(nx.NodeNotFound, match="not in G"):
nx.ego_graph(Gnx, -1, **kwargs)
with pytest.raises(nx.NodeNotFound, match="not in G"):
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph")
# Using sssp with default weight of 1 should give same answer as bfs
nx.set_edge_attributes(Gnx, 1, name="weight")
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True)
assert_graphs_equal(Gnx, Gcg) # Sanity check

kwargs["distance"] = "weight"
H2nx = nx.ego_graph(Gnx, n, **kwargs)
is_nx32 = nxver.major == 3 and nxver.minor == 2
if undirected and Gnx.is_directed() and Gnx.is_multigraph():
if is_nx32:
# `should_run` was added in nx 3.3
match = "Weighted ego_graph with undirected=True not implemented"
else:
match = "not implemented by cugraph"
with pytest.raises(RuntimeError, match=match):
nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
with pytest.raises(NotImplementedError, match="ego_graph"):
nx.ego_graph(Gcg, n, **kwargs)
else:
H2cg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
assert_graphs_equal(H2nx, H2cg)
with pytest.raises(nx.NodeNotFound, match="not found in graph"):
nx.ego_graph(Gnx, -1, **kwargs)
with pytest.raises(nx.NodeNotFound, match="not found in graph"):
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph")
5 changes: 3 additions & 2 deletions python/nx-cugraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

[build-system]

Expand All @@ -19,7 +19,7 @@ authors = [
license = { text = "Apache 2.0" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, thanks!

"Programming Language :: Python",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -233,6 +233,7 @@ ignore = [
"nx_cugraph/**/tests/*py" = ["S101", "S311", "T201", "D103", "D100"]
"_nx_cugraph/__init__.py" = ["E501"]
"nx_cugraph/algorithms/**/*py" = ["D205", "D401"] # Allow flexible docstrings for algorithms
"nx_cugraph/generators/**/*py" = ["D205", "D401"] # Allow flexible docstrings for generators
"nx_cugraph/interface.py" = ["D401"] # Flexible docstrings
"scripts/update_readme.py" = ["INP001"] # Not part of a package

Expand Down
Loading