diff --git a/ci/run_cugraph_pyg_pytests.sh b/ci/run_cugraph_pyg_pytests.sh index 88642e6ceb6..fb27f16d79e 100755 --- a/ci/run_cugraph_pyg_pytests.sh +++ b/ci/run_cugraph_pyg_pytests.sh @@ -6,7 +6,10 @@ set -euo pipefail # Support invoking run_cugraph_pyg_pytests.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cugraph-pyg/cugraph_pyg -pytest --cache-clear --ignore=tests/mg "$@" . +pytest --cache-clear --benchmark-disable "$@" . + +# Used to skip certain examples in CI due to memory limitations +export CI_RUN=1 # Test examples for e in "$(pwd)"/examples/*.py; do diff --git a/ci/test.sh b/ci/test.sh index f20fc40f85a..884ed7ac881 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -103,7 +103,7 @@ if hasArg "--run-python-tests"; then conda list cd ${CUGRAPH_ROOT}/python/cugraph-pyg/cugraph_pyg # rmat is not tested because of MG testing - pytest --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --ignore=tests/mg --ignore=tests/int --ignore=tests/generators --benchmark-disable + pytest -sv -m sg --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --benchmark-disable echo "Ran Python pytest for cugraph_pyg : return code was: $?, test script exit code is now: $EXITCODE" echo "Python pytest for cugraph-service (single-GPU only)..." diff --git a/ci/test_python.sh b/ci/test_python.sh index 9537f66e825..c215e25c526 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -215,13 +215,14 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then # Install pyg dependencies (which requires pip) - pip install ogb + pip install \ + ogb \ + tensordict + pip install \ pyg_lib \ torch_scatter \ torch_sparse \ - torch_cluster \ - torch_spline_conv \ -f ${PYG_URL} rapids-print-env diff --git a/ci/test_wheel_cugraph-pyg.sh b/ci/test_wheel_cugraph-pyg.sh index f45112dd80b..1004063cc38 100755 --- a/ci/test_wheel_cugraph-pyg.sh +++ b/ci/test_wheel_cugraph-pyg.sh @@ -24,6 +24,9 @@ python -m pip install $(ls ./dist/${python_package_name}*.whl)[test] # RAPIDS_DATASET_ROOT_DIR is used by test scripts export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)" +# Used to skip certain examples in CI due to memory limitations +export CI_RUN=1 + if [[ "${CUDA_VERSION}" == "11.8.0" ]]; then PYTORCH_URL="https://download.pytorch.org/whl/cu118" PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu118.html" @@ -39,15 +42,14 @@ rapids-retry python -m pip install \ pyg_lib \ torch_scatter \ torch_sparse \ - torch_cluster \ - torch_spline_conv \ + tensordict \ -f ${PYG_URL} rapids-logger "pytest cugraph-pyg (single GPU)" pushd python/cugraph-pyg/cugraph_pyg python -m pytest \ --cache-clear \ - --ignore=tests/mg \ + --benchmark-disable \ tests # Test examples for e in "$(pwd)"/examples/*.py; do diff --git a/conda/recipes/cugraph-pyg/meta.yaml b/conda/recipes/cugraph-pyg/meta.yaml index c02e8391eb2..64091ff4782 100644 --- a/conda/recipes/cugraph-pyg/meta.yaml +++ b/conda/recipes/cugraph-pyg/meta.yaml @@ -34,6 +34,7 @@ requirements: - cupy >=12.0.0 - cugraph ={{ version }} - pylibcugraphops ={{ minor_version }} + - tensordict >=0.1.2 - pyg >=2.5,<2.6 tests: diff --git a/dependencies.yaml b/dependencies.yaml index c0699fdb1c5..3c2622fde9f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -565,6 +565,7 @@ dependencies: - cugraph==24.6.* - pytorch>=2.0 - pytorch-cuda==11.8 + - tensordict>=0.1.2 - pyg>=2.5,<2.6 depends_on_rmm: diff --git a/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst b/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst index a150d4db9fe..5475fd6c581 100644 --- a/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst +++ b/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst @@ -6,8 +6,37 @@ cugraph-pyg .. currentmodule:: cugraph_pyg +Graph Storage +------------- .. autosummary:: :toctree: ../api/cugraph-pyg/ -.. cugraph_pyg.data.cugraph_store.EXPERIMENTAL__CuGraphStore -.. cugraph_pyg.sampler.cugraph_sampler.EXPERIMENTAL__CuGraphSampler + cugraph_pyg.data.dask_graph_store.DaskGraphStore + cugraph_pyg.data.graph_store.GraphStore + +Feature Storage +--------------- +.. autosummary:: + :toctree: ../api/cugraph-pyg/ + + cugraph_pyg.data.feature_store.TensorDictFeatureStore + +Data Loaders +------------ +.. autosummary:: + :toctree: ../api/cugraph-pyg/ + + cugraph_pyg.loader.dask_node_loader.DaskNeighborLoader + cugraph_pyg.loader.dask_node_loader.BulkSampleLoader + cugraph_pyg.loader.node_loader.NodeLoader + cugraph_pyg.loader.neighbor_loader.NeighborLoader + +Samplers +-------- +.. autosummary:: + :toctree: ../api/cugraph-pyg/ + + cugraph_pyg.sampler.sampler.BaseSampler + cugraph_pyg.sampler.sampler.SampleReader + cugraph_pyg.sampler.sampler.HomogeneousSampleReader + cugraph_pyg.sampler.sampler.SampleIterator diff --git a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml index ebef0094cfa..922d92f069a 100644 --- a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml +++ b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml @@ -21,4 +21,5 @@ dependencies: - pytorch-cuda==11.8 - pytorch>=2.0 - scipy +- tensordict>=0.1.2 name: cugraph_pyg_dev_cuda-118 diff --git a/python/cugraph-pyg/cugraph_pyg/data/__init__.py b/python/cugraph-pyg/cugraph_pyg/data/__init__.py index 66a9843c047..4c6f267410d 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cugraph_pyg.data.cugraph_store import CuGraphStore +import warnings + +from cugraph_pyg.data.dask_graph_store import DaskGraphStore +from cugraph_pyg.data.graph_store import GraphStore +from cugraph_pyg.data.feature_store import TensorDictFeatureStore + + +def CuGraphStore(*args, **kwargs): + warnings.warn("CuGraphStore has been renamed to DaskGraphStore", FutureWarning) + return DaskGraphStore(*args, **kwargs) diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/dask_graph_store.py similarity index 99% rename from python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py rename to python/cugraph-pyg/cugraph_pyg/data/dask_graph_store.py index 354eea8ee6b..ef22982c4da 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/dask_graph_store.py @@ -199,7 +199,7 @@ def cast(cls, *args, **kwargs): return cls(*args, **kwargs) -class CuGraphStore: +class DaskGraphStore: """ Duck-typed version of PyG's GraphStore and FeatureStore. """ @@ -221,7 +221,7 @@ def __init__( order: str = "CSR", ): """ - Constructs a new CuGraphStore from the provided + Constructs a new DaskGraphStore from the provided arguments. Parameters diff --git a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py new file mode 100644 index 00000000000..42dda42a9e1 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from typing import Optional, Tuple, List + +from cugraph.utilities.utils import import_optional, MissingModule + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") +tensordict = import_optional("tensordict") + + +class TensorDictFeatureStore( + object + if isinstance(torch_geometric, MissingModule) + else torch_geometric.data.FeatureStore +): + """ + A basic implementation of the PyG FeatureStore interface that stores + feature data in a single TensorDict. This type of feature store is + not distributed, so each node will have to load the entire graph's + features into memory. + """ + + def __init__(self): + super().__init__() + + self.__features = {} + + def _put_tensor( + self, + tensor: "torch_geometric.typing.FeatureTensorType", + attr: "torch_geometric.data.feature_store.TensorAttr", + ) -> bool: + if attr.group_name in self.__features: + td = self.__features[attr.group_name] + batch_size = td.batch_size[0] + + if attr.is_set("index"): + if attr.attr_name in td.keys(): + if attr.index.shape[0] != batch_size: + raise ValueError( + "Leading size of index tensor " + "does not match existing tensors for group name " + f"{attr.group_name}; Expected {batch_size}, " + f"got {attr.index.shape[0]}" + ) + td[attr.attr_name][attr.index] = tensor + return True + else: + warnings.warn( + "Ignoring index parameter " + f"(attribute does not exist for group {attr.group_name})" + ) + + if tensor.shape[0] != batch_size: + raise ValueError( + "Leading size of input tensor does not match " + f"existing tensors for group name {attr.group_name};" + f" Expected {batch_size}, got {tensor.shape[0]}" + ) + else: + batch_size = tensor.shape[0] + self.__features[attr.group_name] = tensordict.TensorDict( + {}, batch_size=batch_size + ) + + self.__features[attr.group_name][attr.attr_name] = tensor + return True + + def _get_tensor( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> Optional["torch_geometric.typing.FeatureTensorType"]: + if attr.group_name not in self.__features: + return None + + if attr.attr_name not in self.__features[attr.group_name].keys(): + return None + + tensor = self.__features[attr.group_name][attr.attr_name] + return ( + tensor + if (attr.index is None or (not attr.is_set("index"))) + else tensor[attr.index] + ) + + def _remove_tensor( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> bool: + if attr.group_name not in self.__features: + return False + + if attr.attr_name not in self.__features[attr.group_name].keys(): + return False + + del self.__features[attr.group_name][attr.attr_name] + return True + + def _get_tensor_size( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> Tuple: + return self._get_tensor(attr).size() + + def get_all_tensor_attrs( + self, + ) -> List["torch_geometric.data.feature_store.TensorAttr"]: + attrs = [] + for group_name, td in self.__features.items(): + for attr_name in td.keys(): + attrs.append( + torch_geometric.data.feature_store.TensorAttr( + group_name, + attr_name, + ) + ) + + return attrs diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py new file mode 100644 index 00000000000..01af7fd6ed0 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -0,0 +1,322 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cupy +import cudf +import pandas + +import pylibcugraph + +from cugraph.utilities.utils import import_optional, MissingModule +from cugraph.gnn.comms import cugraph_comms_get_raft_handle + +from typing import Union, Optional, List, Dict + + +# Have to use import_optional even though these are required +# dependencies in order to build properly. +torch_geometric = import_optional("torch_geometric") +torch = import_optional("torch") +tensordict = import_optional("tensordict") + +TensorType = Union["torch.Tensor", cupy.ndarray, np.ndarray, cudf.Series, pandas.Series] + + +class GraphStore( + object + if isinstance(torch_geometric, MissingModule) + else torch_geometric.data.GraphStore +): + """ + This object uses lazy graph creation. Users can repeatedly call + put_edge_index, and the tensors won't be converted into a cuGraph + graph until one is needed (i.e. when creating a loader). + """ + + def __init__(self, is_multi_gpu: bool = False): + self.__edge_indices = tensordict.TensorDict({}, batch_size=(2,)) + self.__sizes = {} + self.__graph = None + self.__vertex_offsets = None + self.__handle = None + self.__is_multi_gpu = is_multi_gpu + + super().__init__() + + def _put_edge_index( + self, + edge_index: "torch_geometric.typing.EdgeTensorType", + edge_attr: "torch_geometric.data.EdgeAttr", + ) -> bool: + if edge_attr.layout != torch_geometric.data.graph_store.EdgeLayout.COO: + raise ValueError("Only COO format supported") + + if isinstance(edge_index, (cupy.ndarray, cudf.Series)): + edge_index = torch.as_tensor(edge_index, device="cuda") + elif isinstance(edge_index, (np.ndarray)): + edge_index = torch.as_tensor(edge_index, device="cpu") + elif isinstance(edge_index, pandas.Series): + edge_index = torch.as_tensor(edge_index.values, device="cpu") + elif isinstance(edge_index, cudf.Series): + edge_index = torch.as_tensor(edge_index.values, device="cuda") + + self.__edge_indices[edge_attr.edge_type] = torch.stack( + [edge_index[0], edge_index[1]] + ) + self.__sizes[edge_attr.edge_type] = edge_attr.size + + # invalidate the graph + self.__graph = None + self.__vertex_offsets = None + return True + + def _get_edge_index( + self, edge_attr: "torch_geometric.data.EdgeAttr" + ) -> Optional["torch_geometric.typing.EdgeTensorType"]: + ei = torch_geometric.EdgeIndex(self.__edge_indices[edge_attr.edge_type]) + + if edge_attr.layout == "csr": + return ei.sort_by("row").values.get_csr() + elif edge_attr.layout == "csc": + return ei.sort_by("col").values.get_csc() + + return ei + + def _remove_edge_index(self, edge_attr: "torch_geometric.data.EdgeAttr") -> bool: + del self.__edge_indices[edge_attr.edge_type] + + # invalidate the graph + self.__graph = None + return True + + def get_all_edge_attrs(self) -> List["torch_geometric.data.EdgeAttr"]: + attrs = [] + for et in self.__edge_indices.keys(leaves_only=True, include_nested=True): + attrs.append( + torch_geometric.data.EdgeAttr( + edge_type=et, layout="coo", is_sorted=False, size=self.__sizes[et] + ) + ) + + return attrs + + @property + def is_multi_gpu(self): + return self.__is_multi_gpu + + @property + def _resource_handle(self): + if self.__handle is None: + if self.is_multi_gpu: + self.__handle = pylibcugraph.ResourceHandle( + cugraph_comms_get_raft_handle().getHandle() + ) + else: + self.__handle = pylibcugraph.ResourceHandle() + return self.__handle + + @property + def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]: + graph_properties = pylibcugraph.GraphProperties( + is_multigraph=True, is_symmetric=False + ) + + if self.__graph is None: + edgelist_dict = self.__get_edgelist() + + if self.is_multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + vertices_array = cupy.arange( + sum(self._num_vertices().values()), dtype="int64" + ) + vertices_array = cupy.array_split(vertices_array, world_size)[rank] + + self.__graph = pylibcugraph.MGGraph( + self._resource_handle, + graph_properties, + [cupy.asarray(edgelist_dict["src"]).astype("int64")], + [cupy.asarray(edgelist_dict["dst"]).astype("int64")], + vertices_array=[vertices_array], + edge_id_array=[cupy.asarray(edgelist_dict["eid"])], + edge_type_array=[cupy.asarray(edgelist_dict["etp"])], + ) + else: + self.__graph = pylibcugraph.SGGraph( + self._resource_handle, + graph_properties, + cupy.asarray(edgelist_dict["src"]).astype("int64"), + cupy.asarray(edgelist_dict["dst"]).astype("int64"), + vertices_array=cupy.arange( + sum(self._num_vertices().values()), dtype="int64" + ), + edge_id_array=cupy.asarray(edgelist_dict["eid"]), + edge_type_array=cupy.asarray(edgelist_dict["etp"]), + ) + + return self.__graph + + def _num_vertices(self) -> Dict[str, int]: + num_vertices = {} + for edge_attr in self.get_all_edge_attrs(): + if edge_attr.size is not None: + num_vertices[edge_attr.edge_type[0]] = ( + max(num_vertices[edge_attr.edge_type[0]], edge_attr.size[0]) + if edge_attr.edge_type[0] in num_vertices + else edge_attr.size[0] + ) + num_vertices[edge_attr.edge_type[2]] = ( + max(num_vertices[edge_attr.edge_type[2]], edge_attr.size[1]) + if edge_attr.edge_type[2] in num_vertices + else edge_attr.size[1] + ) + else: + if edge_attr.edge_type[0] not in num_vertices: + num_vertices[edge_attr.edge_type[0]] = int( + self.__edge_indices[edge_attr.edge_type][0].max() + 1 + ) + if edge_attr.edge_type[2] not in num_vertices: + num_vertices[edge_attr.edge_type[1]] = int( + self.__edge_indices[edge_attr.edge_type][1].max() + 1 + ) + + if self.is_multi_gpu: + vtypes = num_vertices.keys() + for vtype in vtypes: + sz = torch.tensor(num_vertices[vtype], device="cuda") + torch.distributed.all_reduce(sz, op=torch.distributed.ReduceOp.MAX) + num_vertices[vtype] = int(sz) + return num_vertices + + @property + def _vertex_offsets(self) -> Dict[str, int]: + if self.__vertex_offsets is None: + num_vertices = self._num_vertices() + ordered_keys = sorted(list(num_vertices.keys())) + self.__vertex_offsets = {} + offset = 0 + for vtype in ordered_keys: + self.__vertex_offsets[vtype] = offset + offset += num_vertices[vtype] + + return dict(self.__vertex_offsets) + + @property + def is_homogeneous(self) -> bool: + return len(self._vertex_offsets) == 1 + + def __get_edgelist(self): + """ + Returns + ------- + Dict[str, torch.Tensor] with the following keys: + src: source vertices (int64) + Note that src is the 2nd element of the PyG edge index. + dst: destination vertices (int64) + Note that dst is the 1st element of the PyG edge index. + eid: edge ids for each edge (int64) + Note that these start from 0 for each edge type. + etp: edge types for each edge (int32) + Note that these are in lexicographic order. + """ + sorted_keys = sorted( + list(self.__edge_indices.keys(leaves_only=True, include_nested=True)) + ) + + # note that this still follows the PyG convention of (dst, rel, src) + # i.e. (author, writes, paper): [[0,1,2],[2,0,1]] is referring to a + # cuGraph graph where (paper 2) -> (author 0), (paper 0) -> (author 1), + # and (paper 1) -> (author 0) + edge_index = torch.concat( + [ + torch.stack( + [ + self.__edge_indices[dst_type, rel_type, src_type][0] + + self._vertex_offsets[dst_type], + self.__edge_indices[dst_type, rel_type, src_type][1] + + self._vertex_offsets[src_type], + ] + ) + for (dst_type, rel_type, src_type) in sorted_keys + ], + axis=1, + ).cuda() + + edge_type_array = torch.arange( + len(sorted_keys), dtype=torch.int32, device="cuda" + ).repeat_interleave( + torch.tensor( + [self.__edge_indices[et].shape[1] for et in sorted_keys], + device="cuda", + dtype=torch.int32, + ) + ) + + if self.is_multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + num_edges_t = torch.tensor( + [self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda" + ) + num_edges_all_t = torch.empty( + world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda" + ) + torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t) + + if rank > 0: + start_offsets = num_edges_all_t[:rank].T.sum(axis=1) + edge_id_array = torch.concat( + [ + torch.arange( + start_offsets[i], + start_offsets[i] + num_edges_all_t[rank][i], + dtype=torch.int64, + device="cuda", + ) + for i in range(len(sorted_keys)) + ] + ) + else: + edge_id_array = torch.concat( + [ + torch.arange( + self.__edge_indices[et].shape[1], + dtype=torch.int64, + device="cuda", + ) + for et in sorted_keys + ] + ) + + else: + # single GPU + edge_id_array = torch.concat( + [ + torch.arange( + self.__edge_indices[et].shape[1], + dtype=torch.int64, + device="cuda", + ) + for et in sorted_keys + ] + ) + + return { + "dst": edge_index[0], + "src": edge_index[1], + "etp": edge_type_array, + "eid": edge_id_array, + } diff --git a/python/cugraph-pyg/cugraph_pyg/examples/README.md b/python/cugraph-pyg/cugraph_pyg/examples/README.md deleted file mode 100644 index 572111ac26a..00000000000 --- a/python/cugraph-pyg/cugraph_pyg/examples/README.md +++ /dev/null @@ -1,11 +0,0 @@ -This directory contains examples for running cugraph-pyg training. - -For single-GPU (SG) scripts, no special configuration is required. - -For multi-GPU (MG) scripts, dask must be started first in a separate process. -To do this, the `start_dask.sh` script has been provided. This scripts starts -a dask scheduler and dask workers. To select the GPUs and amount of memory -allocated to dask per GPU, the `CUDA_VISIBLE_DEVICES` and `WORKER_RMM_POOL_SIZE` -arguments in that script can be modified. -To connect to dask, the scheduler JSON file must be provided. This can be done -using the `--dask_scheduler_file` argument in the mg python script being run. diff --git a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py index 29a6cc2b464..31cbaf69ca5 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py @@ -95,7 +95,7 @@ def main(): with tempfile.TemporaryDirectory() as directory: tmp.spawn( sample, - args=(world_size, uid, el, "."), + args=(world_size, uid, el, directory), nprocs=world_size, ) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py index 8366ff44233..de45acc7456 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py @@ -55,6 +55,8 @@ def sample(edgelist, directory): G, sample_writer, fanout=[5, 5], + compression="CSR", + retain_original_seeds=True, ) sampler.sample_from_nodes(seeds, batch_size=16, random_state=62) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py new file mode 100644 index 00000000000..71b0e4bb2fb --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py @@ -0,0 +1,178 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import argparse +import tempfile +import os + +from typing import Optional + +import torch +import cupy + +import rmm +from rmm.allocators.cupy import rmm_cupy_allocator +from rmm.allocators.torch import rmm_torch_allocator + +# Must change allocators immediately upon import +# or else other imports will cause memory to be +# allocated and prevent changing the allocator +rmm.reinitialize(devices=[0], pool_allocator=True, managed_memory=True) +cupy.cuda.set_allocator(rmm_cupy_allocator) +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + +import torch.nn.functional as F # noqa: E402 +import torch_geometric # noqa: E402 +import cugraph_pyg # noqa: E402 +from cugraph_pyg.loader import NeighborLoader # noqa: E402 + +# Enable cudf spilling to save gpu memory +from cugraph.testing.mg_utils import enable_spilling # noqa: E402 + +enable_spilling() + +parser = argparse.ArgumentParser() +parser.add_argument("--hidden_channels", type=int, default=256) +parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=4) +parser.add_argument("--batch_size", type=int, default=1024) +parser.add_argument("--fan_out", type=int, default=30) +parser.add_argument("--tempdir_root", type=str, default=None) +parser.add_argument("--dataset_root", type=str, default="dataset") +parser.add_argument("--dataset", type=str, default="ogbn-products") + +args = parser.parse_args() + +wall_clock_start = time.perf_counter() +device = torch.device("cuda") + +from ogb.nodeproppred import PygNodePropPredDataset # noqa: E402 + +dataset = PygNodePropPredDataset(name=args.dataset, root=args.dataset_root) +split_idx = dataset.get_idx_split() +data = dataset[0] + +graph_store = cugraph_pyg.data.GraphStore() +graph_store[ + ("node", "rel", "node"), "coo", False, (data.num_nodes, data.num_nodes) +] = data.edge_index + +feature_store = cugraph_pyg.data.TensorDictFeatureStore() +feature_store["node", "x"] = data.x +feature_store["node", "y"] = data.y + +with tempfile.TemporaryDirectory(dir=args.tempdir_root) as samples_dir: + train_dir = os.path.join(samples_dir, "train") + os.mkdir(train_dir) + train_loader = NeighborLoader( + data=(feature_store, graph_store), + num_neighbors=[args.fan_out] * args.num_layers, + input_nodes=split_idx["train"], + replace=False, + batch_size=args.batch_size, + directory=train_dir, + ) + + val_dir = os.path.join(samples_dir, "val") + os.mkdir(val_dir) + val_loader = NeighborLoader( + data=(feature_store, graph_store), + num_neighbors=[args.fan_out] * args.num_layers, + input_nodes=split_idx["valid"], + replace=False, + batch_size=args.batch_size, + directory=val_dir, + ) + + test_dir = os.path.join(samples_dir, "test") + os.mkdir(test_dir) + test_loader = NeighborLoader( + data=(feature_store, graph_store), + num_neighbors=[args.fan_out] * args.num_layers, + input_nodes=split_idx["test"], + replace=False, + batch_size=args.batch_size, + directory=test_dir, + ) + + model = torch_geometric.nn.models.GCN( + dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) + + warmup_steps = 20 + + def train(epoch: int): + model.train() + for i, batch in enumerate(train_loader): + if i == warmup_steps: + torch.cuda.synchronize() + start_avg_time = time.perf_counter() + batch = batch.to(device) + + optimizer.zero_grad() + batch_size = batch.batch_size + out = model(batch.x, batch.edge_index)[:batch_size] + y = batch.y[:batch_size].view(-1).to(torch.long) + + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(f"Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}") + torch.cuda.synchronize() + print( + f"Average Training Iteration Time (s/iter): \ + {(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}" + ) + + @torch.no_grad() + def test(loader: NeighborLoader, val_steps: Optional[int] = None): + model.eval() + + total_correct = total_examples = 0 + for i, batch in enumerate(loader): + if val_steps is not None and i >= val_steps: + break + batch = batch.to(device) + batch_size = batch.batch_size + out = model(batch.x, batch.edge_index)[:batch_size] + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + return total_correct / total_examples + + torch.cuda.synchronize() + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, "seconds") + print("Beginning training...") + for epoch in range(1, 1 + args.epochs): + train(epoch) + val_acc = test(val_loader, val_steps=100) + print(f"Val Acc: ~{val_acc:.4f}") + + test_acc = test(test_loader) + print(f"Test Acc: {test_acc:.4f}") + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py new file mode 100644 index 00000000000..b1bb0240e71 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py @@ -0,0 +1,328 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Single-node, multi-GPU example. + +import argparse +import os +import tempfile +import time +import warnings + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +from ogb.nodeproppred import PygNodePropPredDataset +from torch.nn.parallel import DistributedDataParallel + +import torch_geometric + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ["CUDF_SPILL"] = "1" + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + + +def init_pytorch_worker(rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=rank, + managed_memory=True, + pool_allocator=True, + ) + + import cupy + + cupy.cuda.Device(rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(rank) + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank) + + +def run_train( + rank, + data, + world_size, + cugraph_id, + model, + epochs, + batch_size, + fan_out, + split_idx, + num_classes, + wall_clock_start, + tempdir=None, + num_layers=3, +): + + init_pytorch_worker( + rank, + world_size, + cugraph_id, + ) + + model = model.to(rank) + model = DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) + + kwargs = dict( + num_neighbors=[fan_out] * num_layers, + batch_size=batch_size, + ) + # Set Up Neighbor Loading + from cugraph_pyg.data import GraphStore, TensorDictFeatureStore + from cugraph_pyg.loader import NeighborLoader + + graph_store = GraphStore(is_multi_gpu=True) + ixr = torch.tensor_split(data.edge_index, world_size, dim=1)[rank] + graph_store[ + ("node", "rel", "node"), "coo", False, (data.num_nodes, data.num_nodes) + ] = ixr + + feature_store = TensorDictFeatureStore() + feature_store["node", "x"] = data.x + feature_store["node", "y"] = data.y + + dist.barrier() + + ix_train = torch.tensor_split(split_idx["train"], world_size)[rank].cuda() + train_path = os.path.join(tempdir, f"train_{rank}") + os.mkdir(train_path) + train_loader = NeighborLoader( + (feature_store, graph_store), + input_nodes=ix_train, + directory=train_path, + shuffle=True, + drop_last=True, + **kwargs, + ) + + ix_test = torch.tensor_split(split_idx["test"], world_size)[rank].cuda() + test_path = os.path.join(tempdir, f"test_{rank}") + os.mkdir(test_path) + test_loader = NeighborLoader( + (feature_store, graph_store), + input_nodes=ix_test, + directory=test_path, + shuffle=True, + drop_last=True, + local_seeds_per_call=80000, + **kwargs, + ) + + ix_valid = torch.tensor_split(split_idx["valid"], world_size)[rank].cuda() + valid_path = os.path.join(tempdir, f"valid_{rank}") + os.mkdir(valid_path) + valid_loader = NeighborLoader( + (feature_store, graph_store), + input_nodes=ix_valid, + directory=valid_path, + shuffle=True, + drop_last=True, + **kwargs, + ) + + dist.barrier() + + eval_steps = 1000 + warmup_steps = 20 + dist.barrier() + torch.cuda.synchronize() + + if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time) =", prep_time, "seconds") + print("Beginning training...") + for epoch in range(epochs): + for i, batch in enumerate(train_loader): + if i == warmup_steps: + torch.cuda.synchronize() + start = time.time() + + batch = batch.to(rank) + batch_size = batch.batch_size + + batch.y = batch.y.to(torch.long) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index) + loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) + loss.backward() + optimizer.step() + if rank == 0 and i % 10 == 0: + print( + "Epoch: " + + str(epoch) + + ", Iteration: " + + str(i) + + ", Loss: " + + str(loss) + ) + nb = i + 1.0 + + if rank == 0: + print( + "Average Training Iteration Time:", + (time.time() - start) / (nb - warmup_steps), + "s/iter", + ) + + with torch.no_grad(): + total_correct = total_examples = 0 + for i, batch in enumerate(valid_loader): + if i >= eval_steps: + break + + batch = batch.to(rank) + batch_size = batch.batch_size + + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index)[:batch_size] + + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + acc_val = total_correct / total_examples + if rank == 0: + print( + f"Validation Accuracy: {acc_val * 100.0:.4f}%", + ) + + torch.cuda.synchronize() + + with torch.no_grad(): + total_correct = total_examples = 0 + for i, batch in enumerate(test_loader): + batch = batch.to(rank) + batch_size = batch.batch_size + + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index)[:batch_size] + + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + acc_test = total_correct / total_examples + if rank == 0: + print( + f"Test Accuracy: {acc_test * 100.0:.4f}%", + ) + + if rank == 0: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") + + cugraph_comms_shutdown() + dist.destroy_process_group() + + +if __name__ == "__main__": + if "CI_RUN" in os.environ and os.environ["CI_RUN"] == "1": + warnings.warn("Skipping SMNG example in CI due to memory limit") + else: + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=256) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=1024) + parser.add_argument("--fan_out", type=int, default=30) + parser.add_argument("--tempdir_root", type=str, default=None) + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--dataset", type=str, default="ogbn-products") + + parser.add_argument( + "--n_devices", + type=int, + default=-1, + help="1-8 to use that many GPUs. Defaults to all available GPUs", + ) + + args = parser.parse_args() + wall_clock_start = time.perf_counter() + + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + dataset = PygNodePropPredDataset(name=args.dataset, root=args.dataset_root) + split_idx = dataset.get_idx_split() + data = dataset[0] + data.y = data.y.reshape(-1) + + model = torch_geometric.nn.models.GCN( + dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + ) + + print("Data =", data) + if args.n_devices == -1: + world_size = torch.cuda.device_count() + else: + world_size = args.n_devices + print("Using", world_size, "GPUs...") + + # Create the uid needed for cuGraph comms + cugraph_id = cugraph_comms_create_unique_id() + + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir: + mp.spawn( + run_train, + args=( + data, + world_size, + cugraph_id, + model, + args.epochs, + args.batch_size, + args.fan_out, + split_idx, + dataset.num_classes, + wall_clock_start, + tempdir, + args.num_layers, + ), + nprocs=world_size, + join=True, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py index 80d683e6c79..145675c8a06 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_mg.py @@ -11,6 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# For this script, dask must be started first in a separate process. +# To do this, the `start_dask.sh` script has been provided. This scripts starts +# a dask scheduler and dask workers. To select the GPUs and amount of memory +# allocated to dask per GPU, the `CUDA_VISIBLE_DEVICES` and `WORKER_RMM_POOL_SIZE` +# arguments in that script can be modified. +# To connect to dask, the scheduler JSON file must be provided. This can be done +# using the `--dask_scheduler_file` argument in the mg python script being run. from ogb.nodeproppred import NodePropPredDataset @@ -159,8 +166,8 @@ def train( td.barrier() import cugraph - from cugraph_pyg.data import CuGraphStore - from cugraph_pyg.loader import CuGraphNeighborLoader + from cugraph_pyg.data import DaskGraphStore + from cugraph_pyg.loader import DaskNeighborLoader if rank == 0: print("Rank 0 downloading dataset") @@ -212,7 +219,7 @@ def train( # Rank 0 will initialize the distributed cugraph graph. cugraph_store_create_start = time.perf_counter_ns() print("G:", G[("paper", "cites", "paper")].shape) - cugraph_store = CuGraphStore(fs, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(fs, G, N, multi_gpu=True) cugraph_store_create_end = time.perf_counter_ns() print( "cuGraph Store created on rank 0 in " @@ -237,7 +244,7 @@ def train( # Will automatically use the stored distributed cugraph graph on rank 0. cugraph_store_create_start = time.perf_counter_ns() - cugraph_store = CuGraphStore(fs, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(fs, G, N, multi_gpu=True) cugraph_store_create_end = time.perf_counter_ns() print( f"Rank {rank} created cugraph store in " @@ -269,7 +276,7 @@ def train( model.train() start_time_loader = time.perf_counter_ns() - cugraph_bulk_loader = CuGraphNeighborLoader( + cugraph_bulk_loader = DaskNeighborLoader( cugraph_store, train_nodes, batch_size=250, diff --git a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py index 58a403084df..e0169ee2c25 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/graph_sage_sg.py @@ -74,8 +74,8 @@ def train(device: int, features_device: Union[str, int] = "cpu", num_epochs=2) - init_pytorch_worker(device) import cugraph - from cugraph_pyg.data import CuGraphStore - from cugraph_pyg.loader import CuGraphNeighborLoader + from cugraph_pyg.data import DaskGraphStore + from cugraph_pyg.loader import DaskNeighborLoader from ogb.nodeproppred import NodePropPredDataset @@ -106,7 +106,7 @@ def train(device: int, features_device: Union[str, int] = "cpu", num_epochs=2) - fs.add_data(train_mask, "paper", "train") - cugraph_store = CuGraphStore(fs, G, N) + cugraph_store = DaskGraphStore(fs, G, N) model = ( CuGraphSAGE(in_channels=128, hidden_channels=64, out_channels=349, num_layers=3) @@ -120,7 +120,7 @@ def train(device: int, features_device: Union[str, int] = "cpu", num_epochs=2) - start_time_train = time.perf_counter_ns() model.train() - cugraph_bulk_loader = CuGraphNeighborLoader( + cugraph_bulk_loader = DaskNeighborLoader( cugraph_store, train_nodes, batch_size=500, num_neighbors=[10, 25] ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py index 2c3d7eff89e..cad66aaa183 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cugraph_pyg.loader.cugraph_node_loader import CuGraphNeighborLoader +import warnings -from cugraph_pyg.loader.cugraph_node_loader import BulkSampleLoader +from cugraph_pyg.loader.node_loader import NodeLoader +from cugraph_pyg.loader.neighbor_loader import NeighborLoader + +from cugraph_pyg.loader.dask_node_loader import DaskNeighborLoader + +from cugraph_pyg.loader.dask_node_loader import BulkSampleLoader + + +def CuGraphNeighborLoader(*args, **kwargs): + warnings.warn( + "CuGraphNeighborLoader has been renamed to DaskNeighborLoader", FutureWarning + ) + return DaskNeighborLoader(*args, **kwargs) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/dask_node_loader.py similarity index 97% rename from python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py rename to python/cugraph-pyg/cugraph_pyg/loader/dask_node_loader.py index 55c9e9b3329..aaf82dd46bb 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/dask_node_loader.py @@ -23,8 +23,8 @@ from cugraph.gnn import BulkSampler from cugraph.utilities.utils import import_optional, MissingModule -from cugraph_pyg.data import CuGraphStore -from cugraph_pyg.sampler.cugraph_sampler import ( +from cugraph_pyg.data import DaskGraphStore +from cugraph_pyg.sampler.sampler_utils import ( _sampler_output_from_sampling_results_heterogeneous, _sampler_output_from_sampling_results_homogeneous_csr, _sampler_output_from_sampling_results_homogeneous_coo, @@ -47,8 +47,8 @@ class BulkSampleLoader: def __init__( self, - feature_store: CuGraphStore, - graph_store: CuGraphStore, + feature_store: DaskGraphStore, + graph_store: DaskGraphStore, input_nodes: InputNodes = None, batch_size: int = 0, *, @@ -72,10 +72,10 @@ def __init__( Parameters ---------- - feature_store: CuGraphStore + feature_store: DaskGraphStore The feature store containing features for the graph. - graph_store: CuGraphStore + graph_store: DaskGraphStore The graph store containing the graph structure. input_nodes: InputNodes @@ -487,10 +487,10 @@ def __iter__(self): return self -class CuGraphNeighborLoader: +class DaskNeighborLoader: def __init__( self, - data: Union[CuGraphStore, Tuple[CuGraphStore, CuGraphStore]], + data: Union[DaskGraphStore, Tuple[DaskGraphStore, DaskGraphStore]], input_nodes: Union[InputNodes, int] = None, batch_size: int = None, **kwargs, @@ -498,8 +498,8 @@ def __init__( """ Parameters ---------- - data: CuGraphStore or (CuGraphStore, CuGraphStore) - The CuGraphStore or stores where the graph/feature data is held. + data: DaskGraphStore or (DaskGraphStore, DaskGraphStore) + The DaskGraphStore or stores where the graph/feature data is held. batch_size: int (required) The number of input nodes in each batch. diff --git a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py new file mode 100644 index 00000000000..3d29ee3aca3 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import tempfile + +from typing import Union, Tuple, Optional, Callable, List, Dict + +import cugraph_pyg +from cugraph_pyg.loader import NodeLoader +from cugraph_pyg.sampler import BaseSampler + +from cugraph.gnn import UniformNeighborSampler, DistSampleWriter +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") + + +class NeighborLoader(NodeLoader): + """ + Node loader that implements the neighbor sampling + algorithm used in GraphSAGE. + + Duck-typed version of torch_geometric.loader.NeighborLoader + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + num_neighbors: Union[ + List[int], Dict["torch_geometric.typing.EdgeType", List[int]] + ], + input_nodes: "torch_geometric.typing.InputNodes" = None, + input_time: "torch_geometric.typing.OptTensor" = None, + replace: bool = False, + subgraph_type: Union[ + "torch_geometric.typing.SubgraphType", str + ] = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + is_sorted: bool = False, + filter_per_worker: Optional[bool] = None, + neighbor_sampler: Optional["torch_geometric.sampler.NeighborSampler"] = None, + directed: bool = True, # Deprecated. + batch_size: int = 16, + directory: str = None, + batches_per_partition=256, + format: str = "parquet", + compression: Optional[str] = None, + local_seeds_per_call: Optional[int] = None, + **kwargs, + ): + """ + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.NeighborLoader. + num_neighbors: List[int] or Dict[EdgeType, List[int]] + Fanout values. + See torch_geometric.loader.NeighborLoader. + input_nodes: InputNodes + Input nodes for sampling. + See torch_geometric.loader.NeighborLoader. + input_time: OptTensor (optional) + See torch_geometric.loader.NeighborLoader. + replace: bool (optional, default=False) + Whether to sample with replacement. + See torch_geometric.loader.NeighborLoader. + subgraph_type: Union[SubgraphType, str] (optional, default='directional') + The type of subgraph to return. + Currently only 'directional' is supported. + See torch_geometric.loader.NeighborLoader. + disjoint: bool (optional, default=False) + Whether to perform disjoint sampling. + Currently unsupported. + See torch_geometric.loader.NeighborLoader. + temporal_strategy: str (optional, default='uniform') + Currently only 'uniform' is suppported. + See torch_geometric.loader.NeighborLoader. + time_attr: str (optional, default=None) + Used for temporal sampling. + See torch_geometric.loader.NeighborLoader. + weight_attr: str (optional, default=None) + Used for biased sampling. + See torch_geometric.loader.NeighborLoader. + transform: Callable (optional, default=None) + See torch_geometric.loader.NeighborLoader. + transform_sampler_output: Callable (optional, default=None) + See torch_geometric.loader.NeighborLoader. + is_sorted: bool (optional, default=False) + Ignored by cuGraph. + See torch_geometric.loader.NeighborLoader. + filter_per_worker: bool (optional, default=False) + Currently ignored by cuGraph, but this may + change once in-memory sampling is implemented. + See torch_geometric.loader.NeighborLoader. + neighbor_sampler: torch_geometric.sampler.NeighborSampler + (optional, default=None) + Not supported by cuGraph. + See torch_geometric.loader.NeighborLoader. + directed: bool (optional, default=True) + Deprecated. + See torch_geometric.loader.NeighborLoader. + batch_size: int (optional, default=16) + The number of input nodes per output minibatch. + See torch.utils.dataloader. + directory: str (optional, default=None) + The directory where samples will be temporarily stored. + It is recommend that this be set by the user, usually + setting it to a tempfile.TemporaryDirectory with a context + manager is a good option but depending on the filesystem, + you may want to choose an alternative location with fast I/O + intead. + If not set, this will create a TemporaryDirectory that will + persist until this object is garbage collected. + See cugraph.gnn.DistSampleWriter. + batches_per_partition: int (optional, default=256) + The number of batches per partition if writing samples to + disk. Manually tuning this parameter is not recommended + but reducing it may help conserve GPU memory. + See cugraph.gnn.DistSampleWriter. + format: str (optional, default='parquet') + If writing samples to disk, they will be written in this + file format. + See cugraph.gnn.DistSampleWriter. + compression: str (optional, default=None) + The compression type to use if writing samples to disk. + If not provided, it is automatically chosen. + local_seeds_per_call: int (optional, default=None) + The number of seeds to process within a single sampling call. + Manually tuning this parameter is not recommended but reducing + it may conserve GPU memory. The total number of seeds processed + per sampling call is equal to the sum of this parameter across + all workers. If not provided, it will be automatically + calculated. + See cugraph.gnn.DistSampler. + **kwargs + Other keyword arguments passed to the superclass. + """ + + subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type) + + if not directed: + subgraph_type = torch_geometric.sampler.base.SubgraphType.induced + warnings.warn( + "The 'directed' argument is deprecated. " + "Use subgraph_type='induced' instead." + ) + if subgraph_type != torch_geometric.sampler.base.SubgraphType.directional: + raise ValueError("Only directional subgraphs are currently supported") + if disjoint: + raise ValueError("Disjoint sampling is currently unsupported") + if temporal_strategy != "uniform": + warnings.warn("Only the uniform temporal strategy is currently supported") + if neighbor_sampler is not None: + raise ValueError("Passing a neighbor sampler is currently unsupported") + if time_attr is not None: + raise ValueError("Temporal sampling is currently unsupported") + if weight_attr is not None: + raise ValueError("Biased sampling is currently unsupported") + if is_sorted: + warnings.warn("The 'is_sorted' argument is ignored by cuGraph.") + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if directory is None: + warnings.warn("Setting a directory to store samples is recommended.") + self._tempdir = tempfile.TemporaryDirectory() + directory = self._tempdir.name + + if compression is None: + compression = "CSR" + elif compression not in ["CSR", "COO"]: + raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") + + writer = DistSampleWriter( + directory=directory, + batches_per_partition=batches_per_partition, + format=format, + ) + + feature_store, graph_store = data + sampler = BaseSampler( + UniformNeighborSampler( + graph_store._graph, + writer, + retain_original_seeds=True, + fanout=num_neighbors, + prior_sources_behavior="exclude", + deduplicate_sources=True, + compression=compression, + compress_per_hop=False, + with_replacement=replace, + local_seeds_per_call=local_seeds_per_call, + ), + (feature_store, graph_store), + batch_size=batch_size, + ) + # TODO add heterogeneous support and pass graph_store._vertex_offsets + + super().__init__( + (feature_store, graph_store), + sampler, + input_nodes=input_nodes, + input_time=input_time, + transform=transform, + transform_sampler_output=transform_sampler_output, + filter_per_worker=filter_per_worker, + batch_size=batch_size, + **kwargs, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py new file mode 100644 index 00000000000..56b58352a7c --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import cugraph_pyg +from typing import Union, Tuple, Callable, Optional + +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") +torch = import_optional("torch") + + +class NodeLoader: + """ + Duck-typed version of torch_geometric.loader.NodeLoader + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + node_sampler: "cugraph_pyg.sampler.BaseSampler", + input_nodes: "torch_geometric.typing.InputNodes" = None, + input_time: "torch_geometric.typing.OptTensor" = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + filter_per_worker: Optional[bool] = None, + custom_cls: Optional["torch_geometric.data.HeteroData"] = None, + input_id: "torch_geometric.typing.OptTensor" = None, + batch_size: int = 1, + shuffle: bool = False, + drop_last: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.NodeLoader. + node_sampler: BaseSampler + See torch_geometric.loader.NodeLoader. + input_nodes: InputNodes + See torch_geometric.loader.NodeLoader. + input_time: OptTensor + See torch_geometric.loader.NodeLoader. + transform: Callable (optional, default=None) + This argument currently has no effect. + transform_sampler_output: Callable (optional, default=None) + This argument currently has no effect. + filter_per_worker: bool (optional, default=False) + This argument currently has no effect. + custom_cls: HeteroData + This argument currently has no effect. This loader will + always return a Data or HeteroData object. + input_id: OptTensor + See torch_geometric.loader.NodeLoader. + + """ + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if not isinstance(node_sampler, cugraph_pyg.sampler.BaseSampler): + raise NotImplementedError("Must provide a cuGraph sampler") + + if input_time is not None: + raise ValueError("Temporal sampling is currently unsupported") + + if filter_per_worker: + warnings.warn("filter_per_worker is currently ignored") + + if custom_cls is not None: + warnings.warn("custom_cls is currently ignored") + + if transform is not None: + warnings.warn("transform is currently ignored.") + + if transform_sampler_output is not None: + warnings.warn("transform_sampler_output is currently ignored.") + + ( + input_type, + input_nodes, + input_id, + ) = torch_geometric.loader.utils.get_input_nodes( + data, + input_nodes, + input_id, + ) + + self.__input_data = torch_geometric.loader.node_loader.NodeSamplerInput( + input_id=input_id, + node=input_nodes, + time=None, + input_type=input_type, + ) + + self.__data = data + + self.__node_sampler = node_sampler + + self.__batch_size = batch_size + self.__shuffle = shuffle + self.__drop_last = drop_last + + def __iter__(self): + if self.__shuffle: + perm = torch.randperm(self.__input_data.node.numel()) + else: + perm = torch.arange(self.__input_data.node.numel()) + + if self.__drop_last: + d = perm.numel() % self.__batch_size + perm = perm[:-d] + + input_data = torch_geometric.loader.node_loader.NodeSamplerInput( + input_id=None + if self.__input_data.input_id is None + else self.__input_data.input_id[perm], + node=self.__input_data.node[perm], + time=None + if self.__input_data.time is None + else self.__input_data.time[perm], + input_type=self.__input_data.input_type, + ) + + return cugraph_pyg.sampler.SampleIterator( + self.__data, self.__node_sampler.sample_from_nodes(input_data) + ) diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/__init__.py b/python/cugraph-pyg/cugraph_pyg/sampler/__init__.py index 2ec68a8b4ac..34fe9c4463e 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,3 +10,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from cugraph_pyg.sampler.sampler import BaseSampler, SampleIterator diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py new file mode 100644 index 00000000000..101f7b042be --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -0,0 +1,323 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Iterator, Union, Dict, Tuple + +from cugraph.utilities.utils import import_optional +from cugraph.gnn import DistSampler, DistSampleReader + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + + +class SampleIterator: + def __init__( + self, + data: Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + output_iter: Iterator[ + Union[ + "torch_geometric.sampler.HeteroSamplerOutput", + "torch_geometric.sampler.SamplerOutput", + ] + ], + ): + self.__feature_store, self.__graph_store = data + self.__output_iter = output_iter + + def __next__(self): + next_sample = next(self.__output_iter) + if isinstance(next_sample, torch_geometric.sampler.SamplerOutput): + sz = next_sample.edge.numel() + if sz == next_sample.col.numel(): + col = next_sample.col + else: + col = torch_geometric.edge_index.ptr2index( + next_sample.col, next_sample.edge.numel() + ) + + data = torch_geometric.loader.utils.filter_custom_store( + self.__feature_store, + self.__graph_store, + next_sample.node, + next_sample.row, + col, + next_sample.edge, + None, + ) + + if "n_id" not in data: + data.n_id = next_sample.node + if next_sample.edge is not None and "e_id" not in data: + edge = next_sample.edge.to(torch.long) + data.e_id = edge + + data.batch = next_sample.batch + data.num_sampled_nodes = next_sample.num_sampled_nodes + data.num_sampled_edges = next_sample.num_sampled_edges + + data.input_id = data.batch + data.seed_time = None + data.batch_size = data.input_id.size(0) + + elif isinstance(next_sample, torch_geometric.sampler.HeteroSamplerOutput): + col = {} + for edge_type, col_idx in next_sample.col: + sz = next_sample.edge[edge_type].numel() + if sz == col_idx.numel(): + col[edge_type] = col_idx + else: + col[edge_type] = torch_geometric.edge_index.ptr2index(col_idx, sz) + + data = torch_geometric.loader.utils.filter_custom_hetero_store( + self.__feature_store, + self.__graph_store, + next_sample.node, + next_sample.row, + col, + next_sample.edge, + None, + ) + + for key, node in next_sample.node.items(): + if "n_id" not in data[key]: + data[key].n_id = node + + for key, edge in (next_sample.edge or {}).items(): + if edge is not None and "e_id" not in data[key]: + edge = edge.to(torch.long) + data[key].e_id = edge + + data.set_value_dict("batch", next_sample.batch) + data.set_value_dict("num_sampled_nodes", next_sample.num_sampled_nodes) + data.set_value_dict("num_sampled_edges", next_sample.num_sampled_edges) + + # TODO figure out how to set input_id for heterogeneous output + else: + raise ValueError("Invalid output type") + + return data + + def __iter__(self): + return self + + +class SampleReader: + def __init__(self, base_reader: DistSampleReader): + self.__base_reader = base_reader + self.__num_samples_remaining = 0 + self.__index = 0 + + def __next__(self): + if self.__num_samples_remaining == 0: + # raw_sample_data is already a dict of tensors + self.__raw_sample_data, start_inclusive, end_inclusive = next( + self.__base_reader + ) + + self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[ + "label_hop_offsets" + ][0].clone() + self.__raw_sample_data["renumber_map_offsets"] -= self.__raw_sample_data[ + "renumber_map_offsets" + ][0].clone() + if "major_offsets" in self.__raw_sample_data: + self.__raw_sample_data["major_offsets"] -= self.__raw_sample_data[ + "major_offsets" + ][0].clone() + + self.__num_samples_remaining = end_inclusive - start_inclusive + 1 + self.__index = 0 + + out = self._decode(self.__raw_sample_data, self.__index) + self.__index += 1 + self.__num_samples_remaining -= 1 + return out + + def __iter__(self): + return self + + +class HomogeneousSampleReader(SampleReader): + def __init__(self, base_reader: DistSampleReader): + super().__init__(base_reader) + + def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): + fanout_length = (raw_sample_data["label_hop_offsets"].numel() - 1) // ( + raw_sample_data["renumber_map_offsets"].numel() - 1 + ) + + major_offsets_start_incl = raw_sample_data["label_hop_offsets"][ + index * fanout_length + ] + major_offsets_end_incl = raw_sample_data["label_hop_offsets"][ + (index + 1) * fanout_length + ] + + major_offsets = raw_sample_data["major_offsets"][ + major_offsets_start_incl : major_offsets_end_incl + 1 + ].clone() + minors = raw_sample_data["minors"][major_offsets[0] : major_offsets[-1]] + edge_id = raw_sample_data["edge_id"][major_offsets[0] : major_offsets[-1]] + # don't retrieve edge type for a homogeneous graph + + major_offsets -= major_offsets[0].clone() + + renumber_map_start = raw_sample_data["renumber_map_offsets"][index] + renumber_map_end = raw_sample_data["renumber_map_offsets"][index + 1] + + renumber_map = raw_sample_data["map"][renumber_map_start:renumber_map_end] + + current_label_hop_offsets = raw_sample_data["label_hop_offsets"][ + index * fanout_length : (index + 1) * fanout_length + 1 + ].clone() + current_label_hop_offsets -= current_label_hop_offsets[0].clone() + + num_sampled_edges = major_offsets[current_label_hop_offsets].diff() + + num_sampled_nodes_hops = torch.tensor( + [ + minors[: num_sampled_edges[:i].sum()].max() + 1 + for i in range(1, fanout_length + 1) + ], + device="cpu", + ) + + num_seeds = ( + torch.searchsorted(major_offsets, num_sampled_edges[0]).reshape((1,)).cpu() + ) + num_sampled_nodes = torch.concat( + [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] + ) + + return torch_geometric.sampler.SamplerOutput( + node=renumber_map.cpu(), + row=minors, + col=major_offsets, + edge=edge_id, + batch=renumber_map[:num_seeds], + num_sampled_nodes=num_sampled_nodes.cpu(), + num_sampled_edges=num_sampled_edges.cpu(), + ) + + def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): + fanout_length = (raw_sample_data["label_hop_offsets"].numel() - 1) // ( + raw_sample_data["renumber_map_offsets"].numel() - 1 + ) + + major_minor_start = raw_sample_data["label_hop_offsets"][index * fanout_length] + ix_end = (index + 1) * fanout_length + if ix_end == raw_sample_data["label_hop_offsets"].numel(): + major_minor_end = raw_sample_data["majors"].numel() + else: + major_minor_end = raw_sample_data["label_hop_offsets"][ix_end] + + majors = raw_sample_data["majors"][major_minor_start:major_minor_end] + minors = raw_sample_data["minors"][major_minor_start:major_minor_end] + edge_id = raw_sample_data["edge_id"][major_minor_start:major_minor_end] + # don't retrieve edge type for a homogeneous graph + + renumber_map_start = raw_sample_data["renumber_map_offsets"][index] + renumber_map_end = raw_sample_data["renumber_map_offsets"][index + 1] + + renumber_map = raw_sample_data["map"][renumber_map_start:renumber_map_end] + + num_sampled_edges = ( + raw_sample_data["label_hop_offsets"][ + index * fanout_length : (index + 1) * fanout_length + 1 + ] + .diff() + .cpu() + ) + + num_seeds = (majors[: num_sampled_edges[0]].max() + 1).reshape((1,)).cpu() + num_sampled_nodes_hops = torch.tensor( + [ + minors[: num_sampled_edges[:i].sum()].max() + 1 + for i in range(1, fanout_length + 1) + ], + device="cpu", + ) + + num_sampled_nodes = torch.concat( + [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] + ) + + return torch_geometric.sampler.SamplerOutput( + node=renumber_map.cpu(), + row=minors, + col=majors, + edge=edge_id, + batch=renumber_map[:num_seeds], + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, + ) + + def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): + if "major_offsets" in raw_sample_data: + return self.__decode_csc(raw_sample_data, index) + else: + return self.__decode_coo(raw_sample_data, index) + + +class BaseSampler: + def __init__( + self, + sampler: DistSampler, + data: Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + batch_size: int = 16, + ): + self.__sampler = sampler + self.__feature_store, self.__graph_store = data + self.__batch_size = batch_size + + def sample_from_nodes( + self, index: "torch_geometric.sampler.NodeSamplerInput", **kwargs + ) -> Iterator[ + Union[ + "torch_geometric.sampler.HeteroSamplerOutput", + "torch_geometric.sampler.SamplerOutput", + ] + ]: + self.__sampler.sample_from_nodes( + index.node, batch_size=self.__batch_size, **kwargs + ) + + edge_attrs = self.__graph_store.get_all_edge_attrs() + if ( + len(edge_attrs) == 1 + and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2] + ): + return HomogeneousSampleReader(self.__sampler.get_reader()) + else: + # TODO implement heterogeneous sampling + raise NotImplementedError( + "Sampling heterogeneous graphs is currently" + " unsupported in the non-dask API" + ) + + def sample_from_edges( + self, + index: "torch_geometric.sampler.EdgeSamplerInput", + neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"], + **kwargs, + ) -> Iterator[ + Union[ + "torch_geometric.sampler.HeteroSamplerOutput", + "torch_geometric.sampler.SamplerOutput", + ] + ]: + raise NotImplementedError("Edge sampling is currently unimplemented.") diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py similarity index 89% rename from python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py rename to python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py index 8bcfb783ae1..c3e19393970 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py @@ -14,7 +14,7 @@ from typing import Sequence, Dict, Tuple -from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.data import DaskGraphStore from cugraph.utilities.utils import import_optional import cudf @@ -28,7 +28,7 @@ def _get_unique_nodes( sampling_results: cudf.DataFrame, - graph_store: CuGraphStore, + graph_store: DaskGraphStore, node_type: str, node_position: str, ) -> int: @@ -40,7 +40,7 @@ def _get_unique_nodes( sampling_results: cudf.DataFrame The dataframe containing sampling results or filtered sampling results (i.e. sampling results for hop 2) - graph_store: CuGraphStore + graph_store: DaskGraphStore The graph store containing the structure of the sampled graph. node_type: str The node type to count the number of unique nodes of. @@ -81,7 +81,7 @@ def _get_unique_nodes( def _sampler_output_from_sampling_results_homogeneous_coo( sampling_results: cudf.DataFrame, renumber_map: torch.Tensor, - graph_store: CuGraphStore, + graph_store: DaskGraphStore, data_index: Dict[Tuple[int, int], Dict[str, int]], batch_id: int, metadata: Sequence = None, @@ -94,7 +94,7 @@ def _sampler_output_from_sampling_results_homogeneous_coo( renumber_map: torch.Tensor The tensor containing the renumber map, or None if there is no renumber map. - graph_store: CuGraphStore + graph_store: DaskGraphStore The graph store containing the structure of the sampled graph. data_index: Dict[Tuple[int, int], Dict[str, int]] Dictionary where keys are the batch id and hop id, @@ -181,7 +181,7 @@ def _sampler_output_from_sampling_results_homogeneous_csr( major_offsets: torch.Tensor, minors: torch.Tensor, renumber_map: torch.Tensor, - graph_store: CuGraphStore, + graph_store: DaskGraphStore, label_hop_offsets: torch.Tensor, batch_id: int, metadata: Sequence = None, @@ -196,7 +196,7 @@ def _sampler_output_from_sampling_results_homogeneous_csr( renumber_map: torch.Tensor The tensor containing the renumber map. Required. - graph_store: CuGraphStore + graph_store: DaskGraphStore The graph store containing the structure of the sampled graph. label_hop_offsets: torch.Tensor The tensor containing the label-hop offsets. @@ -263,7 +263,7 @@ def _sampler_output_from_sampling_results_homogeneous_csr( def _sampler_output_from_sampling_results_heterogeneous( sampling_results: cudf.DataFrame, renumber_map: cudf.Series, - graph_store: CuGraphStore, + graph_store: DaskGraphStore, metadata: Sequence = None, ) -> HeteroSamplerOutput: """ @@ -274,7 +274,7 @@ def _sampler_output_from_sampling_results_heterogeneous( renumber_map: cudf.Series The series containing the renumber map, or None if there is no renumber map. - graph_store: CuGraphStore + graph_store: DaskGraphStore The graph store containing the structure of the sampled graph. metadata: Tensor The metadata for the sampled batch. @@ -403,41 +403,3 @@ def _sampler_output_from_sampling_results_heterogeneous( num_sampled_edges={k: t.tolist() for k, t in num_edges_per_hop_dict.items()}, metadata=metadata, ) - - -def filter_cugraph_store_csc( - feature_store: torch_geometric.data.FeatureStore, - graph_store: torch_geometric.data.GraphStore, - node_dict: Dict[str, torch.Tensor], - row_dict: Dict[str, torch.Tensor], - col_dict: Dict[str, torch.Tensor], - edge_dict: Dict[str, Tuple[torch.Tensor]], -) -> torch_geometric.data.HeteroData: - """ - Deprecated - """ - - data = torch_geometric.data.HeteroData() - - for attr in graph_store.get_all_edge_attrs(): - key = attr.edge_type - if key in row_dict and key in col_dict: - data.put_edge_index( - (row_dict[key], col_dict[key]), - edge_type=key, - layout="csc", - is_sorted=True, - ) - - required_attrs = [] - for attr in feature_store.get_all_tensor_attrs(): - if attr.group_name in node_dict: - attr.index = node_dict[attr.group_name] - required_attrs.append(attr) - data[attr.group_name].num_nodes = attr.index.size(0) - - tensors = feature_store.multi_get_tensor(required_attrs) - for i, attr in enumerate(required_attrs): - data[attr.group_name][attr.attr_name] = tensors[i] - - return data diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store.py similarity index 92% rename from python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py rename to python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store.py index c99fd447aa0..0a997a960b8 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store.py @@ -12,12 +12,12 @@ # limitations under the License. import cugraph -from cugraph_pyg.data.cugraph_store import ( +from cugraph_pyg.data.dask_graph_store import ( CuGraphTensorAttr, CuGraphEdgeAttr, EdgeLayout, ) -from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.data import DaskGraphStore import cudf import cupy @@ -33,6 +33,7 @@ @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_tensor_attr(): ta = CuGraphTensorAttr("group0", "property1") assert not ta.is_fully_specified() @@ -63,6 +64,7 @@ def test_tensor_attr(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_edge_attr(): ea = CuGraphEdgeAttr("type0", EdgeLayout.COO, False, 10) assert ea.edge_type == "type0" @@ -98,6 +100,7 @@ def single_vertex_graph(request): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.parametrize("edge_index_type", ["numpy", "torch-cpu", "torch-gpu", "cudf"]) +@pytest.mark.sg def test_get_edge_index(graph, edge_index_type): F, G, N = graph if "torch" in edge_index_type: @@ -113,7 +116,7 @@ def test_get_edge_index(graph, edge_index_type): G[et][0] = cudf.Series(G[et][0]) G[et][1] = cudf.Series(G[et][1]) - cugraph_store = CuGraphStore(F, G, N, order="CSC") + cugraph_store = DaskGraphStore(F, G, N, order="CSC") for pyg_can_edge_type in G: src, dst = cugraph_store.get_edge_index( @@ -129,9 +132,10 @@ def test_get_edge_index(graph, edge_index_type): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_edge_types(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) eta = cugraph_store._edge_types_to_attrs assert eta.keys() == G.keys() @@ -145,9 +149,10 @@ def test_edge_types(graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_subgraph(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) if len(G.keys()) > 1: for edge_type in G.keys(): @@ -163,9 +168,10 @@ def test_get_subgraph(graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_renumber_vertices_basic(single_vertex_graph): F, G, N = single_vertex_graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) nodes_of_interest = torch.as_tensor( cupy.random.randint(0, sum(N.values()), 3), device="cuda" @@ -176,9 +182,10 @@ def test_renumber_vertices_basic(single_vertex_graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_renumber_vertices_multi_edge_multi_vertex(multi_edge_multi_vertex_graph_1): F, G, N = multi_edge_multi_vertex_graph_1 - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) nodes_of_interest = torch.as_tensor( cupy.random.randint(0, sum(N.values()), 3), device="cuda" @@ -196,10 +203,11 @@ def test_renumber_vertices_multi_edge_multi_vertex(multi_edge_multi_vertex_graph @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_renumber_edges(abc_graph): F, G, N = abc_graph - graph_store = CuGraphStore(F, G, N, order="CSR") + graph_store = DaskGraphStore(F, G, N, order="CSR") # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( @@ -232,9 +240,10 @@ def test_renumber_edges(abc_graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_tensor(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) for feature_name, feature_on_types in F.get_feature_list().items(): for type_name in feature_on_types: @@ -253,9 +262,10 @@ def test_get_tensor(graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_tensor_empty_idx(karate_gnn): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) t = cugraph_store.get_tensor( CuGraphTensorAttr(group_name="type0", attr_name="prop0", index=None) @@ -264,9 +274,10 @@ def test_get_tensor_empty_idx(karate_gnn): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_multi_get_tensor(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) for vertex_type in sorted(N.keys()): v_ids = np.arange(N[vertex_type]) @@ -291,9 +302,10 @@ def test_multi_get_tensor(graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_all_tensor_attrs(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) tensor_attrs = [] for vertex_type in sorted(N.keys()): @@ -320,20 +332,11 @@ def test_get_all_tensor_attrs(graph): ) -@pytest.mark.skip("not implemented") -def test_get_tensor_spec_props(graph): - raise NotImplementedError("not implemented") - - -@pytest.mark.skip("not implemented") -def test_multi_get_tensor_spec_props(multi_edge_multi_vertex_graph_1): - raise NotImplementedError("not implemented") - - @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_tensor_from_tensor_attrs(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) tensor_attrs = cugraph_store.get_all_tensor_attrs() for tensor_attr in tensor_attrs: @@ -345,9 +348,10 @@ def test_get_tensor_from_tensor_attrs(graph): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_get_tensor_size(graph): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) tensor_attrs = cugraph_store.get_all_tensor_attrs() for tensor_attr in tensor_attrs: @@ -361,9 +365,10 @@ def test_get_tensor_size(graph): @pytest.mark.skipif( isinstance(torch_geometric, MissingModule), reason="pyg not available" ) +@pytest.mark.sg def test_get_input_nodes(karate_gnn): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) input_node_info = torch_geometric.loader.utils.get_input_nodes( (cugraph_store, cugraph_store), "type0" @@ -383,11 +388,12 @@ def test_get_input_nodes(karate_gnn): assert input_nodes.tolist() == torch.arange(17, dtype=torch.int32).tolist() +@pytest.mark.sg def test_serialize(multi_edge_multi_vertex_no_graph_1): import pickle F, G, N = multi_edge_multi_vertex_no_graph_1 - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) cugraph_store_copy = pickle.loads(pickle.dumps(cugraph_store)) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store_mg.py similarity index 90% rename from python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py rename to python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store_mg.py index 85acbebc3ec..65cb8984586 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_dask_graph_store_mg.py @@ -12,12 +12,12 @@ # limitations under the License. import cugraph -from cugraph_pyg.data.cugraph_store import ( +from cugraph_pyg.data.dask_graph_store import ( CuGraphTensorAttr, CuGraphEdgeAttr, EdgeLayout, ) -from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.data import DaskGraphStore import cudf import dask_cudf @@ -101,6 +101,7 @@ def single_vertex_graph(request): @pytest.mark.parametrize( "edge_index_type", ["numpy", "torch-cpu", "torch-gpu", "cudf", "dask-cudf"] ) +@pytest.mark.mg def test_get_edge_index(graph, edge_index_type, dask_client): F, G, N = graph if "torch" in edge_index_type: @@ -120,7 +121,7 @@ def test_get_edge_index(graph, edge_index_type, dask_client): G[et][0] = dask_cudf.from_cudf(cudf.Series(G[et][0]), npartitions=1) G[et][1] = dask_cudf.from_cudf(cudf.Series(G[et][1]), npartitions=1) - cugraph_store = CuGraphStore(F, G, N, order="CSC", multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, order="CSC", multi_gpu=True) for pyg_can_edge_type in G: src, dst = cugraph_store.get_edge_index( @@ -143,9 +144,10 @@ def test_get_edge_index(graph, edge_index_type, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_edge_types(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) eta = cugraph_store._edge_types_to_attrs assert eta.keys() == G.keys() @@ -159,9 +161,10 @@ def test_edge_types(graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_subgraph(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) if len(G.keys()) > 1: for edge_type in G.keys(): @@ -177,9 +180,10 @@ def test_get_subgraph(graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_renumber_vertices_basic(single_vertex_graph, dask_client): F, G, N = single_vertex_graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) nodes_of_interest = torch.as_tensor( cupy.random.randint(0, sum(N.values()), 3), device="cuda" @@ -190,11 +194,12 @@ def test_renumber_vertices_basic(single_vertex_graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_renumber_vertices_multi_edge_multi_vertex( multi_edge_multi_vertex_graph_1, dask_client ): F, G, N = multi_edge_multi_vertex_graph_1 - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) nodes_of_interest = torch.as_tensor( cupy.random.randint(0, sum(N.values()), 3), device="cuda" @@ -212,10 +217,11 @@ def test_renumber_vertices_multi_edge_multi_vertex( @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_renumber_edges(abc_graph, dask_client): F, G, N = abc_graph - graph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") + graph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( @@ -248,9 +254,10 @@ def test_renumber_edges(abc_graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_tensor(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) for feature_name, feature_on_types in F.get_feature_list().items(): for type_name in feature_on_types: @@ -269,9 +276,10 @@ def test_get_tensor(graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_tensor_empty_idx(karate_gnn, dask_client): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) t = cugraph_store.get_tensor( CuGraphTensorAttr(group_name="type0", attr_name="prop0", index=None) @@ -280,9 +288,10 @@ def test_get_tensor_empty_idx(karate_gnn, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_multi_get_tensor(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) for vertex_type in sorted(N.keys()): v_ids = np.arange(N[vertex_type]) @@ -307,9 +316,10 @@ def test_multi_get_tensor(graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_all_tensor_attrs(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) tensor_attrs = [] for vertex_type in sorted(N.keys()): @@ -328,20 +338,11 @@ def test_get_all_tensor_attrs(graph, dask_client): ) -@pytest.mark.skip("not implemented") -def test_get_tensor_spec_props(graph, dask_client): - raise NotImplementedError("not implemented") - - -@pytest.mark.skip("not implemented") -def test_multi_get_tensor_spec_props(multi_edge_multi_vertex_graph_1, dask_client): - raise NotImplementedError("not implemented") - - @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_tensor_from_tensor_attrs(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) tensor_attrs = cugraph_store.get_all_tensor_attrs() for tensor_attr in tensor_attrs: @@ -353,9 +354,10 @@ def test_get_tensor_from_tensor_attrs(graph, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_get_tensor_size(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) tensor_attrs = cugraph_store.get_all_tensor_attrs() for tensor_attr in tensor_attrs: @@ -369,9 +371,10 @@ def test_get_tensor_size(graph, dask_client): @pytest.mark.skipif( isinstance(torch_geometric, MissingModule), reason="pyg not available" ) +@pytest.mark.mg def test_get_input_nodes(karate_gnn, dask_client): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) nodes = torch_geometric.loader.utils.get_input_nodes( (cugraph_store, cugraph_store), "type0" @@ -387,13 +390,15 @@ def test_get_input_nodes(karate_gnn, dask_client): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_mg_frame_handle(graph, dask_client): F, G, N = graph - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) - assert isinstance(cugraph_store._CuGraphStore__graph._plc_graph, dict) + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True) + assert isinstance(cugraph_store._DaskGraphStore__graph._plc_graph, dict) @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_cugraph_loader_large_index(dask_client): large_index = ( np.random.randint(0, 1_000_000, (100_000_000,)), @@ -404,7 +409,7 @@ def test_cugraph_loader_large_index(dask_client): F = cugraph.gnn.FeatureStore(backend="torch") F.add_data(large_features, "N", "f") - store = CuGraphStore( + store = DaskGraphStore( F, {("N", "e", "N"): large_index}, {"N": 1_000_000}, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py new file mode 100644 index 00000000000..ab5f1e217bb --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import TensorDictFeatureStore + +torch = import_optional("torch") + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +def test_tensordict_feature_store_basic_api(): + feature_store = TensorDictFeatureStore() + + node_features_0 = torch.randint(128, (100, 1000)) + node_features_1 = torch.randint(256, (100, 10)) + + other_features = torch.randint(1024, (10, 5)) + + feature_store["node", "feat0"] = node_features_0 + feature_store["node", "feat1"] = node_features_1 + feature_store["other", "feat"] = other_features + + assert (feature_store["node"]["feat0"][:] == node_features_0).all() + assert (feature_store["node"]["feat1"][:] == node_features_1).all() + assert (feature_store["other"]["feat"][:] == other_features).all() + + assert len(feature_store.get_all_tensor_attrs()) == 3 + + del feature_store["node", "feat0"] + assert len(feature_store.get_all_tensor_attrs()) == 2 diff --git a/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store.py new file mode 100644 index 00000000000..a8b93665aad --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph.datasets import karate +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import GraphStore + +torch = import_optional("torch") + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +def test_graph_store_basic_api(): + df = karate.get_edgelist() + src = torch.as_tensor(df["src"], device="cuda") + dst = torch.as_tensor(df["dst"], device="cuda") + + ei = torch.stack([dst, src]) + + graph_store = GraphStore() + graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo") + + rei = graph_store.get_edge_index(("person", "knows", "person"), "coo") + + assert (ei == rei).all() + + edge_attrs = graph_store.get_all_edge_attrs() + assert len(edge_attrs) == 1 + + graph_store.remove_edge_index(("person", "knows", "person"), "coo") + edge_attrs = graph_store.get_all_edge_attrs() + assert len(edge_attrs) == 0 diff --git a/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store_mg.py new file mode 100644 index 00000000000..14540b7e17d --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store_mg.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph.datasets import karate +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import GraphStore + +torch = import_optional("torch") + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +def test_graph_store_basic_api_mg(): + df = karate.get_edgelist() + src = torch.as_tensor(df["src"], device="cuda") + dst = torch.as_tensor(df["dst"], device="cuda") + + ei = torch.stack([dst, src]) + + graph_store = GraphStore(is_multi_gpu=True) + graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo") + + rei = graph_store.get_edge_index(("person", "knows", "person"), "coo") + + assert (ei == rei).all() + + edge_attrs = graph_store.get_all_edge_attrs() + assert len(edge_attrs) == 1 + + graph_store.remove_edge_index(("person", "knows", "person"), "coo") + edge_attrs = graph_store.get_all_edge_attrs() + assert len(edge_attrs) == 0 diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader.py similarity index 95% rename from python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py rename to python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader.py index ab20ef01fd3..34ef6a59511 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader.py @@ -20,9 +20,9 @@ import cupy import numpy as np -from cugraph_pyg.loader import CuGraphNeighborLoader +from cugraph_pyg.loader import DaskNeighborLoader from cugraph_pyg.loader import BulkSampleLoader -from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.data import DaskGraphStore from cugraph_pyg.nn import SAGEConv as CuGraphSAGEConv from cugraph.gnn import FeatureStore @@ -47,14 +47,15 @@ @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_cugraph_loader_basic( karate_gnn: Tuple[ FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] ] ): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, order="CSR") - loader = CuGraphNeighborLoader( + cugraph_store = DaskGraphStore(F, G, N, order="CSR") + loader = DaskNeighborLoader( (cugraph_store, cugraph_store), torch.arange(N["type0"] + N["type1"], dtype=torch.int64), 10, @@ -77,14 +78,15 @@ def test_cugraph_loader_basic( @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_cugraph_loader_hetero( karate_gnn: Tuple[ FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] ] ): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, order="CSR") - loader = CuGraphNeighborLoader( + cugraph_store = DaskGraphStore(F, G, N, order="CSR") + loader = DaskNeighborLoader( (cugraph_store, cugraph_store), input_nodes=("type1", torch.tensor([0, 1, 2, 5], device="cuda")), batch_size=2, @@ -107,6 +109,7 @@ def test_cugraph_loader_hetero( @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_cugraph_loader_from_disk(): m = [2, 9, 99, 82, 9, 3, 18, 1, 12] n = torch.arange(1, 1 + len(m), dtype=torch.int32) @@ -118,7 +121,7 @@ def test_cugraph_loader_from_disk(): G = {("t0", "knows", "t0"): 9080} N = {"t0": 256} - cugraph_store = CuGraphStore(F, G, N, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, order="CSR") bogus_samples = cudf.DataFrame( { @@ -164,6 +167,7 @@ def test_cugraph_loader_from_disk(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_cugraph_loader_from_disk_subset(): m = [2, 9, 99, 82, 9, 3, 18, 1, 12] n = torch.arange(1, 1 + len(m), dtype=torch.int32) @@ -175,7 +179,7 @@ def test_cugraph_loader_from_disk_subset(): G = {("t0", "knows", "t0"): 9080} N = {"t0": 256} - cugraph_store = CuGraphStore(F, G, N, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, order="CSR") bogus_samples = cudf.DataFrame( { @@ -223,6 +227,7 @@ def test_cugraph_loader_from_disk_subset(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available") +@pytest.mark.sg def test_cugraph_loader_from_disk_subset_csr(): m = [2, 9, 99, 82, 11, 13] n = torch.arange(1, 1 + len(m), dtype=torch.int32) @@ -234,7 +239,7 @@ def test_cugraph_loader_from_disk_subset_csr(): G = {("t0", "knows", "t0"): 9080} N = {"t0": 256} - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) bogus_samples = cudf.DataFrame( { @@ -289,6 +294,7 @@ def test_cugraph_loader_from_disk_subset_csr(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_cugraph_loader_e2e_coo(): m = [2, 9, 99, 82, 9, 3, 18, 1, 12] x = torch.randint(3000, (256, 256)).to(torch.float32) @@ -298,7 +304,7 @@ def test_cugraph_loader_e2e_coo(): G = {("t0", "knows", "t0"): 9999} N = {"t0": 256} - cugraph_store = CuGraphStore(F, G, N, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, order="CSR") bogus_samples = cudf.DataFrame( { @@ -357,6 +363,7 @@ def test_cugraph_loader_e2e_coo(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available") @pytest.mark.parametrize("framework", ["pyg", "cugraph-ops"]) +@pytest.mark.sg def test_cugraph_loader_e2e_csc(framework: str): m = [2, 9, 99, 82, 9, 3, 18, 1, 12] x = torch.randint(3000, (256, 256)).to(torch.float32) @@ -366,7 +373,7 @@ def test_cugraph_loader_e2e_csc(framework: str): G = {("t0", "knows", "t0"): 9999} N = {"t0": 256} - cugraph_store = CuGraphStore(F, G, N) + cugraph_store = DaskGraphStore(F, G, N) bogus_samples = cudf.DataFrame( { @@ -461,6 +468,7 @@ def test_cugraph_loader_e2e_csc(framework: str): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.parametrize("drop_last", [True, False]) +@pytest.mark.sg def test_drop_last(drop_last): N = {"N": 10} G = { @@ -471,9 +479,9 @@ def test_drop_last(drop_last): F = FeatureStore(backend="torch") F.add_data(torch.arange(10), "N", "z") - store = CuGraphStore(F, G, N) + store = DaskGraphStore(F, G, N) with tempfile.TemporaryDirectory() as dir: - loader = CuGraphNeighborLoader( + loader = DaskNeighborLoader( (store, store), input_nodes=torch.tensor([0, 1, 2, 3, 4]), num_neighbors=[1], @@ -499,6 +507,7 @@ def test_drop_last(drop_last): @pytest.mark.parametrize("directory", ["local", "temp"]) +@pytest.mark.sg def test_load_directory( karate_gnn: Tuple[ FeatureStore, Dict[Tuple[str, str, str], np.ndarray], Dict[str, int] @@ -508,8 +517,8 @@ def test_load_directory( if directory == "local": local_dir = tempfile.TemporaryDirectory(dir=".") - cugraph_store = CuGraphStore(*karate_gnn) - cugraph_loader = CuGraphNeighborLoader( + cugraph_store = DaskGraphStore(*karate_gnn) + cugraph_loader = DaskNeighborLoader( (cugraph_store, cugraph_store), torch.arange(8, dtype=torch.int64), 2, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader_mg.py similarity index 85% rename from python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py rename to python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader_mg.py index f5035a38621..9e8a85a5b67 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_dask_neighbor_loader_mg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,18 +13,19 @@ import pytest -from cugraph_pyg.loader import CuGraphNeighborLoader -from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.loader import DaskNeighborLoader +from cugraph_pyg.data import DaskGraphStore from cugraph.utilities.utils import import_optional, MissingModule torch = import_optional("torch") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_cugraph_loader_basic(dask_client, karate_gnn): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") - loader = CuGraphNeighborLoader( + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") + loader = DaskNeighborLoader( (cugraph_store, cugraph_store), torch.arange(N["type0"] + N["type1"], dtype=torch.int64), 10, @@ -49,10 +50,11 @@ def test_cugraph_loader_basic(dask_client, karate_gnn): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_cugraph_loader_hetero(dask_client, karate_gnn): F, G, N = karate_gnn - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") - loader = CuGraphNeighborLoader( + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") + loader = DaskNeighborLoader( (cugraph_store, cugraph_store), input_nodes=("type1", torch.tensor([0, 1, 2, 5], device="cuda")), batch_size=2, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py new file mode 100644 index 00000000000..8edb5276953 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph.datasets import karate +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import TensorDictFeatureStore, GraphStore +from cugraph_pyg.loader import NeighborLoader + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +def test_neighbor_loader(): + """ + Basic e2e test that covers loading and sampling. + """ + + df = karate.get_edgelist() + src = torch.as_tensor(df["src"], device="cuda") + dst = torch.as_tensor(df["dst"], device="cuda") + + ei = torch.stack([dst, src]) + + graph_store = GraphStore() + graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo") + + feature_store = TensorDictFeatureStore() + feature_store["person", "feat"] = torch.randint(128, (34, 16)) + + loader = NeighborLoader( + (feature_store, graph_store), + [5, 5], + input_nodes=torch.arange(34), + directory=".", + ) + + for batch in loader: + assert isinstance(batch, torch_geometric.data.Data) + assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all() diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py new file mode 100644 index 00000000000..6a5f46b0940 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py @@ -0,0 +1,111 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import os + +from cugraph.datasets import karate +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import TensorDictFeatureStore, GraphStore +from cugraph_pyg.loader import NeighborLoader + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + + +def init_pytorch_worker(rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=rank, + ) + + import cupy + + cupy.cuda.Device(rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(rank) + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank) + + +def run_test_neighbor_loader_mg(rank, uid, world_size, specify_size): + """ + Basic e2e test that covers loading and sampling. + """ + init_pytorch_worker(rank, world_size, uid) + + df = karate.get_edgelist() + src = torch.as_tensor(df["src"], device="cuda") + dst = torch.as_tensor(df["dst"], device="cuda") + + ei = torch.stack([dst, src]) + ei = torch.tensor_split(ei.clone(), world_size, axis=1)[rank] + + sz = (34, 34) if specify_size else None + graph_store = GraphStore(is_multi_gpu=True) + graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo", False, sz) + + feature_store = TensorDictFeatureStore() + feature_store["person", "feat"] = torch.randint(128, (34, 16)) + + ix_train = torch.tensor_split(torch.arange(34), world_size, axis=0)[rank] + + loader = NeighborLoader( + (feature_store, graph_store), + [5, 5], + input_nodes=ix_train, + ) + + for batch in loader: + assert isinstance(batch, torch_geometric.data.Data) + assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all() + + cugraph_comms_shutdown() + + +@pytest.mark.parametrize("specify_size", [True, False]) +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +def test_neighbor_loader_mg(specify_size): + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_neighbor_loader_mg, + args=( + uid, + world_size, + specify_size, + ), + nprocs=world_size, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py index a26063f62fa..92d216fefa3 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py @@ -30,6 +30,7 @@ @pytest.mark.parametrize("max_num_neighbors", [8, None]) @pytest.mark.parametrize("use_edge_attr", [True, False]) @pytest.mark.parametrize("graph", ["basic_pyg_graph_1", "basic_pyg_graph_2"]) +@pytest.mark.sg def test_gat_conv_equality( use_edge_index, bias, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py index a62f2fed2f7..2e221922add 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py @@ -24,6 +24,7 @@ @pytest.mark.parametrize("heads", [1, 2, 3, 5, 10, 16]) @pytest.mark.parametrize("use_edge_attr", [True, False]) @pytest.mark.parametrize("graph", ["basic_pyg_graph_1", "basic_pyg_graph_2"]) +@pytest.mark.sg def test_gatv2_conv_equality( use_edge_index, bipartite, concat, heads, use_edge_attr, graph, request ): diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py index d8190ea345f..f182869002a 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py @@ -25,6 +25,7 @@ ) @pytest.mark.parametrize("heads", [1, 3, 10]) @pytest.mark.parametrize("aggr", ["sum", "mean"]) +@pytest.mark.sg def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): import torch from torch_geometric.data import HeteroData diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py index fc0aaf25b7b..8b06cb2e180 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py @@ -25,6 +25,7 @@ @pytest.mark.parametrize("num_bases", [1, 2, None]) @pytest.mark.parametrize("root_weight", [True, False]) @pytest.mark.parametrize("graph", ["basic_pyg_graph_1", "basic_pyg_graph_2"]) +@pytest.mark.sg def test_rgcn_conv_equality( use_edge_index, aggr, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py index 9d8d413c590..878ceff632a 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py @@ -26,6 +26,7 @@ @pytest.mark.parametrize("normalize", [True, False]) @pytest.mark.parametrize("root_weight", [True, False]) @pytest.mark.parametrize("graph", ["basic_pyg_graph_1", "basic_pyg_graph_2"]) +@pytest.mark.sg def test_sage_conv_equality( use_edge_index, aggr, diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py index 1776b691c87..d207a4d7947 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -24,6 +24,7 @@ @pytest.mark.parametrize("concat", [True, False]) @pytest.mark.parametrize("heads", [1, 2, 3, 5, 10, 16]) @pytest.mark.parametrize("graph", ["basic_pyg_graph_1", "basic_pyg_graph_2"]) +@pytest.mark.sg def test_transformer_conv_equality( use_edge_index, use_edge_attr, bipartite, concat, heads, graph, request ): diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils.py similarity index 93% rename from python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py rename to python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils.py index ed011a658a9..7659fdc386f 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils.py @@ -16,8 +16,8 @@ import pytest -from cugraph_pyg.data import CuGraphStore -from cugraph_pyg.sampler.cugraph_sampler import ( +from cugraph_pyg.data import DaskGraphStore +from cugraph_pyg.sampler.sampler_utils import ( _sampler_output_from_sampling_results_heterogeneous, ) @@ -29,9 +29,10 @@ @pytest.mark.cugraph_ops @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_neighbor_sample(basic_graph_1): F, G, N = basic_graph_1 - cugraph_store = CuGraphStore(F, G, N, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, order="CSR") batches = cudf.DataFrame( { @@ -88,9 +89,10 @@ def test_neighbor_sample(basic_graph_1): @pytest.mark.cugraph_ops @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1): F, G, N = multi_edge_multi_vertex_graph_1 - cugraph_store = CuGraphStore(F, G, N, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, order="CSR") batches = cudf.DataFrame( { @@ -148,10 +150,11 @@ def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg def test_neighbor_sample_mock_sampling_results(abc_graph): F, G, N = abc_graph - graph_store = CuGraphStore(F, G, N, order="CSR") + graph_store = DaskGraphStore(F, G, N, order="CSR") # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( @@ -191,9 +194,3 @@ def test_neighbor_sample_mock_sampling_results(abc_graph): assert out.num_sampled_edges[("A", "ab", "B")] == [3, 0, 1, 0] assert out.num_sampled_edges[("B", "ba", "A")] == [0, 1, 0, 1] assert out.num_sampled_edges[("B", "bc", "C")] == [0, 2, 0, 2] - - -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skip("needs to be written") -def test_neighbor_sample_renumbered(): - pass diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils_mg.py similarity index 86% rename from python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py rename to python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils_mg.py index 80a2d0a6c79..91e0668b3c1 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/sampler/test_sampler_utils_mg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,8 +16,8 @@ import pytest -from cugraph_pyg.data import CuGraphStore -from cugraph_pyg.sampler.cugraph_sampler import ( +from cugraph_pyg.data import DaskGraphStore +from cugraph_pyg.sampler.sampler_utils import ( _sampler_output_from_sampling_results_heterogeneous, ) @@ -31,9 +31,10 @@ @pytest.mark.cugraph_ops @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_neighbor_sample(dask_client, basic_graph_1): F, G, N = basic_graph_1 - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") batches = cudf.DataFrame( { @@ -87,18 +88,19 @@ def test_neighbor_sample(dask_client, basic_graph_1): # check the hop dictionaries assert len(out.num_sampled_nodes) == 1 - assert out.num_sampled_nodes["vt1"].tolist() == [4, 1] + assert out.num_sampled_nodes["vt1"] == [4, 1] assert len(out.num_sampled_edges) == 1 - assert out.num_sampled_edges[("vt1", "pig", "vt1")].tolist() == [6] + assert out.num_sampled_edges[("vt1", "pig", "vt1")] == [6] @pytest.mark.cugraph_ops @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skip(reason="broken") +@pytest.mark.mg def test_neighbor_sample_multi_vertex(dask_client, multi_edge_multi_vertex_graph_1): F, G, N = multi_edge_multi_vertex_graph_1 - cugraph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") + cugraph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") batches = cudf.DataFrame( { @@ -160,6 +162,7 @@ def test_neighbor_sample_multi_vertex(dask_client, multi_edge_multi_vertex_graph @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg def test_neighbor_sample_mock_sampling_results(dask_client): N = { "A": 2, # 0, 1 @@ -190,7 +193,7 @@ def test_neighbor_sample_mock_sampling_results(dask_client): torch.tensor([3.2, 2.1], dtype=torch.float32), type_name="A", feat_name="prop1" ) - graph_store = CuGraphStore(F, G, N, multi_gpu=True, order="CSR") + graph_store = DaskGraphStore(F, G, N, multi_gpu=True, order="CSR") # let 0, 1 be the start vertices, fanout = [2, 1, 2, 3] mock_sampling_results = cudf.DataFrame( @@ -222,17 +225,11 @@ def test_neighbor_sample_mock_sampling_results(dask_client): assert out.col[("B", "ba", "A")].tolist() == [1, 1] assert len(out.num_sampled_nodes) == 3 - assert out.num_sampled_nodes["A"].tolist() == [2, 0, 0, 0, 0] - assert out.num_sampled_nodes["B"].tolist() == [0, 2, 0, 0, 0] - assert out.num_sampled_nodes["C"].tolist() == [0, 0, 2, 0, 1] + assert out.num_sampled_nodes["A"] == [2, 0, 0, 0, 0] + assert out.num_sampled_nodes["B"] == [0, 2, 0, 0, 0] + assert out.num_sampled_nodes["C"] == [0, 0, 2, 0, 1] assert len(out.num_sampled_edges) == 3 - assert out.num_sampled_edges[("A", "ab", "B")].tolist() == [3, 0, 1, 0] - assert out.num_sampled_edges[("B", "ba", "A")].tolist() == [0, 1, 0, 1] - assert out.num_sampled_edges[("B", "bc", "C")].tolist() == [0, 2, 0, 2] - - -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skip("needs to be written") -def test_neighbor_sample_renumbered(dask_client): - pass + assert out.num_sampled_edges[("A", "ab", "B")] == [3, 0, 1, 0] + assert out.num_sampled_edges[("B", "ba", "A")] == [0, 1, 0, 1] + assert out.num_sampled_edges[("B", "bc", "C")] == [0, 2, 0, 2] diff --git a/python/cugraph-pyg/pytest.ini b/python/cugraph-pyg/pytest.ini index 579b2245308..db99a54ae49 100644 --- a/python/cugraph-pyg/pytest.ini +++ b/python/cugraph-pyg/pytest.ini @@ -23,6 +23,8 @@ addopts = markers = slow: slow-running tests/benchmarks cugraph_ops: Tests requiring cugraph-ops + mg: Test MG code paths - number of gpu > 1 + sg: Test SG code paths and dask sg tests - number of gpu == 1 python_classes = Bench* diff --git a/python/cugraph/cugraph/gnn/__init__.py b/python/cugraph/cugraph/gnn/__init__.py index 1f4d98f0230..b6c8e1981d0 100644 --- a/python/cugraph/cugraph/gnn/__init__.py +++ b/python/cugraph/cugraph/gnn/__init__.py @@ -16,6 +16,7 @@ from .data_loading.dist_sampler import ( DistSampler, DistSampleWriter, + DistSampleReader, UniformNeighborSampler, ) from .comms.cugraph_nccl_comms import ( diff --git a/python/cugraph/cugraph/gnn/data_loading/__init__.py b/python/cugraph/cugraph/gnn/data_loading/__init__.py index a50f6085e9a..98c547a0083 100644 --- a/python/cugraph/cugraph/gnn/data_loading/__init__.py +++ b/python/cugraph/cugraph/gnn/data_loading/__init__.py @@ -15,5 +15,6 @@ from cugraph.gnn.data_loading.dist_sampler import ( DistSampler, DistSampleWriter, + DistSampleReader, UniformNeighborSampler, ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index e57e195a4b8..52638230b9b 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -12,15 +12,18 @@ # limitations under the License. import os +import re import warnings from math import ceil +from functools import reduce import pylibcugraph import numpy as np import cupy import cudf -from typing import Union, List, Dict, Tuple +from typing import Union, List, Dict, Tuple, Iterator, Optional + from cugraph.utilities import import_optional from cugraph.gnn.comms import cugraph_comms_get_raft_handle @@ -32,6 +35,73 @@ TensorType = Union["torch.Tensor", cupy.ndarray, cudf.Series] +class DistSampleReader: + def __init__( + self, + directory: str, + *, + format: str = "parquet", + rank: Optional[int] = None, + filelist=None, + ): + self.__format = format + self.__directory = directory + + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + if filelist is None: + files = os.listdir(directory) + ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet") + filematch = [ex.match(f) for f in files] + filematch = [f for f in filematch if f] + + if rank is not None: + filematch = [f for f in filematch if int(f[1]) == rank] + + batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch]) + filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True) + + self.__files = filematch + else: + self.__files = list(filelist) + + if rank is None: + self.__batch_count = batch_count + else: + batch_count = torch.tensor([batch_count], device="cuda") + torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN) + self.__batch_count = int(batch_count) + + def __iter__(self): + return self + + def __next__(self): + if len(self.__files) > 0: + f = self.__files.pop() + fname = f[0] + start_inclusive = int(f[2]) + end_inclusive = int(f[4]) + + if (end_inclusive - start_inclusive + 1) > self.__batch_count: + end_inclusive = start_inclusive + self.__batch_count - 1 + self.__batch_count = 0 + else: + self.__batch_count -= end_inclusive - start_inclusive + 1 + + df = cudf.read_parquet(os.path.join(self.__directory, fname)) + tensors = {} + for col in list(df.columns): + s = df[col].dropna() + if len(s) > 0: + tensors[col] = torch.as_tensor(s, device="cuda") + df.drop(col, axis=1, inplace=True) + + return tensors, start_inclusive, end_inclusive + + raise StopIteration + + class DistSampleWriter: def __init__( self, @@ -72,6 +142,16 @@ def _directory(self): def _batches_per_partition(self): return self.__batches_per_partition + def get_reader( + self, rank: int + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Returns an iterator over sampled data. + """ + + # currently only disk reading is supported + return DistSampleReader(self._directory, format=self._format, rank=rank) + def __write_minibatches_coo(self, minibatch_dict): has_edge_ids = minibatch_dict["edge_id"] is not None has_edge_types = minibatch_dict["edge_type"] is not None @@ -166,10 +246,109 @@ def __write_minibatches_coo(self, minibatch_dict): ) def __write_minibatches_csr(self, minibatch_dict): - raise NotImplementedError( - "CSR format currently not supported for distributed sampling" + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] ) + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] + + # major offsets and minors + ( + major_offsets_start_incl, + major_offsets_end_incl, + ) = label_hop_offsets_array_p[[0, -1]] + + start_ix, end_ix = minibatch_dict["major_offsets"][ + [major_offsets_start_incl, major_offsets_end_incl] + ] + + major_offsets_array_p = minibatch_dict["major_offsets"][ + major_offsets_start_incl : major_offsets_end_incl + 1 + ] + + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "major_offsets": major_offsets_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + def write_minibatches(self, minibatch_dict): if (minibatch_dict["majors"] is not None) and ( minibatch_dict["minors"] is not None @@ -188,8 +367,8 @@ def __init__( self, graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], writer: DistSampleWriter, - local_seeds_per_call: int = 32768, - retain_original_seeds: bool = False, # TODO See #4329, needs C API + local_seeds_per_call: int, + retain_original_seeds: bool = False, ): """ Parameters @@ -199,14 +378,16 @@ def __init__( writer: DistSampleWriter (required) The writer responsible for writing samples to disk or, in the future, device or host memory. - local_seeds_per_call: int (optional, default=32768) + local_seeds_per_call: int The number of seeds on this rank this sampler will process in a single sampling call. Batches will get split into multiple sampling calls based on this parameter. This parameter must be the same across all ranks. The total number of seeds processed per sampling call is this - parameter times the world size. + parameter times the world size. Subclasses should + generally calculate the appropriate number of + seeds. retain_original_seeds: bool (optional, default=False) Whether to retain the original seeds even if they do not appear in the output minibatch. This will @@ -219,6 +400,13 @@ def __init__( self.__handle = None self.__retain_original_seeds = retain_original_seeds + def get_reader(self) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Returns an iterator over sampled data. + """ + rank = torch.distributed.get_rank() if self.is_multi_gpu else None + return self.__writer.get_reader(rank) + def sample_batches( self, seeds: TensorType, @@ -438,13 +626,6 @@ def sample_from_nodes( : len(current_seeds) ] - # Handle the case where not all ranks have the same number of call groups, - # in which case there will be some empty groups that get submitted on the - # ranks with fewer call groups. - label_start, label_end = ( - current_batches[[0, -1]] if len(current_batches) > 0 else (0, -1) - ) - minibatch_dict = self.sample_batches( seeds=current_seeds, batch_ids=current_batches, @@ -482,12 +663,20 @@ def _retain_original_seeds(self): class UniformNeighborSampler(DistSampler): + # Number of vertices in the output minibatch, based + # on benchmarking. + BASE_VERTICES_PER_BYTE = 0.1107662486009992 + + # Default number of seeds if the output minibatch + # size can't be estimated. + UNKNOWN_VERTICES_DEFAULT = 32768 + def __init__( self, graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], writer: DistSampleWriter, *, - local_seeds_per_call: int = 32768, + local_seeds_per_call: Optional[int] = None, retain_original_seeds: bool = False, fanout: List[int] = [-1], prior_sources_behavior: str = "exclude", @@ -496,12 +685,6 @@ def __init__( compress_per_hop: bool = False, with_replacement: bool = False, ): - super().__init__( - graph, - writer, - local_seeds_per_call=local_seeds_per_call, - retain_original_seeds=retain_original_seeds, - ) self.__fanout = fanout self.__prior_sources_behavior = prior_sources_behavior self.__deduplicate_sources = deduplicate_sources @@ -509,6 +692,28 @@ def __init__( self.__compression = compression self.__with_replacement = with_replacement + super().__init__( + graph, + writer, + local_seeds_per_call=self.__calc_local_seeds_per_call(local_seeds_per_call), + retain_original_seeds=retain_original_seeds, + ) + + def __calc_local_seeds_per_call(self, local_seeds_per_call: Optional[int] = None): + if local_seeds_per_call is None: + if len([x for x in self.__fanout if x <= 0]) > 0: + return UniformNeighborSampler.UNKNOWN_VERTICES_DEFAULT + + total_memory = torch.cuda.get_device_properties(0).total_memory + fanout_prod = reduce(lambda x, y: x * y, self.__fanout) + return int( + UniformNeighborSampler.BASE_VERTICES_PER_BYTE + * total_memory + / fanout_prod + ) + + return local_seeds_per_call + def sample_batches( self, seeds: TensorType, @@ -526,12 +731,17 @@ def sample_batches( local_label_list, assume_equal_input_size=assume_equal_input_size ) - # TODO add calculation of seed vertex label offsets if self._retain_original_seeds: - warnings.warn( - "The 'retain_original_seeds` parameter is currently ignored " - "since seed retention is not implemented yet." + label_offsets = torch.concat( + [ + torch.searchsorted(batch_ids, local_label_list), + torch.tensor( + [batch_ids.shape[0]], device="cuda", dtype=torch.int64 + ), + ] ) + else: + label_offsets = None sampling_results_dict = pylibcugraph.uniform_neighbor_sample( self._resource_handle, @@ -542,7 +752,7 @@ def sample_batches( label_to_output_comm_rank=cupy.asarray(label_to_output_comm_rank), h_fan_out=np.array(self.__fanout, dtype="int32"), with_replacement=self.__with_replacement, - do_expensive_check=False, + do_expensive_check=True, with_edge_properties=True, random_state=random_state + rank, prior_sources_behavior=self.__prior_sources_behavior, @@ -551,10 +761,28 @@ def sample_batches( renumber=True, compression=self.__compression, compress_per_hop=self.__compress_per_hop, + retain_seeds=self._retain_original_seeds, + label_offsets=None + if label_offsets is None + else cupy.asarray(label_offsets), return_dict=True, ) sampling_results_dict["rank"] = rank else: + if self._retain_original_seeds: + batch_ids = batch_ids.to(device="cuda", dtype=torch.int32) + local_label_list = torch.unique(batch_ids) + label_offsets = torch.concat( + [ + torch.searchsorted(batch_ids, local_label_list), + torch.tensor( + [batch_ids.shape[0]], device="cuda", dtype=torch.int64 + ), + ] + ) + else: + label_offsets = None + sampling_results_dict = pylibcugraph.uniform_neighbor_sample( self._resource_handle, self._graph, @@ -571,6 +799,8 @@ def sample_batches( renumber=True, compression=self.__compression, compress_per_hop=self.__compress_per_hop, + retain_seeds=self._retain_original_seeds, + label_offsets=cupy.asarray(label_offsets), return_dict=True, )