diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 587c5fb38e7..c980ed320dc 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -40,7 +40,7 @@ if ! rapids-is-release-build; then alpha_spec=',>=0.0.0a0' fi -for dep in rmm cudf cugraph raft-dask pylibcugraph pylibcugraphops pylibraft ucx-py; do +for dep in rmm cudf cugraph raft-dask pylibcugraph pylibcugraphops pylibwholegraph pylibraft ucx-py; do sed -r -i "s/${dep}==(.*)\"/${dep}${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} done diff --git a/dependencies.yaml b/dependencies.yaml index 3c2622fde9f..19634420520 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -20,6 +20,7 @@ files: - depends_on_pylibraft - depends_on_raft_dask - depends_on_pylibcugraphops + - depends_on_pylibwholegraph - depends_on_cupy - python_run_cugraph - python_run_nx_cugraph @@ -60,6 +61,7 @@ files: includes: - cuda_version - depends_on_cudf + - depends_on_pylibwholegraph - py_version - test_python_common - test_python_cugraph @@ -98,6 +100,7 @@ files: includes: - test_python_common - test_python_cugraph + - depends_on_pylibwholegraph py_build_pylibcugraph: output: pyproject pyproject_dir: python/pylibcugraph @@ -175,6 +178,7 @@ files: key: test includes: - test_python_common + - depends_on_pylibwholegraph py_build_cugraph_pyg: output: pyproject pyproject_dir: python/cugraph-pyg @@ -198,6 +202,7 @@ files: key: test includes: - test_python_common + - depends_on_pylibwholegraph py_build_cugraph_equivariant: output: pyproject pyproject_dir: python/cugraph-equivariant @@ -535,9 +540,7 @@ dependencies: - *numpy - python-louvain - scikit-learn>=0.23.1 - - output_types: [conda] - packages: - - pylibwholegraph==24.6.* + test_python_pylibcugraph: common: - output_types: [conda, pyproject] @@ -568,6 +571,27 @@ dependencies: - tensordict>=0.1.2 - pyg>=2.5,<2.6 + depends_on_pylibwholegraph: + common: + - output_types: conda + packages: + - &pylibwholegraph_conda pylibwholegraph==24.6.* + - output_types: requirements + packages: + # pip recognizes the index as a global option for the requirements.txt file + - --extra-index-url=https://pypi.nvidia.com + - --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple + specific: + - output_types: [requirements, pyproject] + matrices: + - matrix: {cuda: "12.*"} + packages: + - pylibwholegraph-cu12==24.6.* + - matrix: {cuda: "11.*"} + packages: + - pylibwholegraph-cu11==24.6.* + - {matrix: null, packages: [*pylibwholegraph_conda]} + depends_on_rmm: common: - output_types: conda diff --git a/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst b/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst index 5475fd6c581..d2b1d124ccb 100644 --- a/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst +++ b/docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst @@ -20,6 +20,7 @@ Feature Storage :toctree: ../api/cugraph-pyg/ cugraph_pyg.data.feature_store.TensorDictFeatureStore + cugraph_pyg.data.feature_store.WholeFeatureStore Data Loaders ------------ diff --git a/python/cugraph-dgl/pyproject.toml b/python/cugraph-dgl/pyproject.toml index 37ea8b850bd..534106eb87f 100644 --- a/python/cugraph-dgl/pyproject.toml +++ b/python/cugraph-dgl/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ [project.optional-dependencies] test = [ "pandas", + "pylibwholegraph==24.6.*", "pytest", "pytest-benchmark", "pytest-cov", diff --git a/python/cugraph-pyg/cugraph_pyg/data/__init__.py b/python/cugraph-pyg/cugraph_pyg/data/__init__.py index 4c6f267410d..6d51fd5ea01 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/data/__init__.py @@ -15,7 +15,10 @@ from cugraph_pyg.data.dask_graph_store import DaskGraphStore from cugraph_pyg.data.graph_store import GraphStore -from cugraph_pyg.data.feature_store import TensorDictFeatureStore +from cugraph_pyg.data.feature_store import ( + TensorDictFeatureStore, + WholeFeatureStore, +) def CuGraphStore(*args, **kwargs): diff --git a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py index 42dda42a9e1..0adef9f9135 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py @@ -20,6 +20,7 @@ torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") tensordict = import_optional("tensordict") +wgth = import_optional("pylibwholegraph.torch") class TensorDictFeatureStore( @@ -127,3 +128,149 @@ def get_all_tensor_attrs( ) return attrs + + +class WholeFeatureStore( + object + if isinstance(torch_geometric, MissingModule) + else torch_geometric.data.FeatureStore +): + """ + A basic implementation of the PyG FeatureStore interface that stores + feature data in WholeGraph WholeMemory. This type of feature store is + distributed, and avoids data replication across workers. + + Data should be sliced before being passed into this feature store. + That means each worker should have its own partition. + """ + + def __init__(self, memory_type="distributed", location="cpu"): + """ + Parameters + ---------- + memory_type: str (optional, default='distributed') + The memory type of this store. + location: str(optional, default='cpu') + The location ('cpu' or 'cuda') where data is stored. + """ + super().__init__() + + self.__features = {} + + self.__wg_comm = wgth.get_local_node_communicator() + self.__wg_type = memory_type + self.__wg_location = location + + def _put_tensor( + self, + tensor: "torch_geometric.typing.FeatureTensorType", + attr: "torch_geometric.data.feature_store.TensorAttr", + ) -> bool: + wg_comm_obj = self.__wg_comm + + if attr.is_set("index"): + if (attr.group_name, attr.attr_name) in self.__features: + raise NotImplementedError( + "Updating an embedding from an index" + " is not supported by WholeGraph." + ) + else: + warnings.warn( + "Ignoring index parameter " + f"(attribute does not exist for group {attr.group_name})" + ) + + if len(tensor.shape) > 2: + raise ValueError("Only 1-D or 2-D tensors are supported by WholeGraph.") + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + ld = torch.tensor(tensor.shape[0], device="cuda", dtype=torch.int64) + sizes = torch.empty((world_size,), device="cuda", dtype=torch.int64) + torch.distributed.all_gather_into_tensor(sizes, ld) + + sizes = sizes.cpu() + ld = sizes.sum() + + td = -1 if len(tensor.shape) == 1 else tensor.shape[1] + global_shape = [ + int(ld), + td if td > 0 else 1, + ] + + if td < 0: + tensor = tensor.reshape((tensor.shape[0], 1)) + + wg_embedding = wgth.create_wholememory_tensor( + wg_comm_obj, + self.__wg_type, + self.__wg_location, + global_shape, + tensor.dtype, + [global_shape[1], 1], + ) + + offset = sizes[:rank].sum() if rank > 0 else 0 + + wg_embedding.scatter( + tensor.clone(memory_format=torch.contiguous_format).cuda(), + torch.arange( + offset, offset + tensor.shape[0], dtype=torch.int64, device="cuda" + ).contiguous(), + ) + + wg_comm_obj.barrier() + + self.__features[attr.group_name, attr.attr_name] = (wg_embedding, td) + return True + + def _get_tensor( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> Optional["torch_geometric.typing.FeatureTensorType"]: + if (attr.group_name, attr.attr_name) not in self.__features: + return None + + emb, td = self.__features[attr.group_name, attr.attr_name] + + if attr.index is None or (not attr.is_set("index")): + attr.index = torch.arange(emb.shape[0], dtype=torch.int64) + + attr.index = attr.index.cuda() + t = emb.gather( + attr.index, + force_dtype=emb.dtype, + ) + + if td < 0: + t = t.reshape((t.shape[0],)) + + return t + + def _remove_tensor( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> bool: + if (attr.group_name, attr.attr_name) not in self.__features: + return False + + del self.__features[attr.group_name, attr.attr_name] + return True + + def _get_tensor_size( + self, attr: "torch_geometric.data.feature_store.TensorAttr" + ) -> Tuple: + return self.__features[attr.group_name, attr.attr_name].shape + + def get_all_tensor_attrs( + self, + ) -> List["torch_geometric.data.feature_store.TensorAttr"]: + attrs = [] + for (group_name, attr_name) in self.__features.keys(): + attrs.append( + torch_geometric.data.feature_store.TensorAttr( + group_name=group_name, + attr_name=attr_name, + ) + ) + + return attrs diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py new file mode 100644 index 00000000000..be6447208ce --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py @@ -0,0 +1,434 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Multi-node, multi-GPU example with WholeGraph feature storage. +# Can be run with torchrun. + +import argparse +import os +import warnings +import tempfile +import time +import json + + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ogb.nodeproppred import PygNodePropPredDataset +from torch.nn.parallel import DistributedDataParallel + +import torch_geometric + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ["CUDF_SPILL"] = "1" + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + + +def init_pytorch_worker(global_rank, local_rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=local_rank, + managed_memory=True, + pool_allocator=True, + ) + + import cupy + + cupy.cuda.Device(local_rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(local_rank) + + cugraph_comms_init( + rank=global_rank, world_size=world_size, uid=cugraph_id, device=local_rank + ) + + wm_init(global_rank, world_size, local_rank, torch.cuda.device_count()) + + +def partition_data(dataset, split_idx, edge_path, feature_path, label_path, meta_path): + data = dataset[0] + + # Split and save edge index + os.makedirs( + edge_path, + exist_ok=True, + ) + for (r, e) in enumerate(torch.tensor_split(data.edge_index, world_size, dim=1)): + rank_path = os.path.join(edge_path, f"rank={r}.pt") + torch.save( + e.clone(), + rank_path, + ) + + # Split and save features + os.makedirs( + feature_path, + exist_ok=True, + ) + + for (r, f) in enumerate(torch.tensor_split(data.x, world_size)): + rank_path = os.path.join(feature_path, f"rank={r}_x.pt") + torch.save( + f.clone(), + rank_path, + ) + for (r, f) in enumerate(torch.tensor_split(data.y, world_size)): + rank_path = os.path.join(feature_path, f"rank={r}_y.pt") + torch.save( + f.clone(), + rank_path, + ) + + # Split and save labels + os.makedirs( + label_path, + exist_ok=True, + ) + for (d, i) in split_idx.items(): + i_parts = torch.tensor_split(i, world_size) + for r, i_part in enumerate(i_parts): + rank_path = os.path.join(label_path, f"rank={r}") + os.makedirs(rank_path, exist_ok=True) + torch.save(i_part, os.path.join(rank_path, f"{d}.pt")) + + # Save metadata + meta = { + "num_classes": int(dataset.num_classes), + "num_features": int(dataset.num_features), + "num_nodes": int(data.num_nodes), + } + with open(meta_path, "w") as f: + json.dump(meta, f) + + +def load_partitioned_data( + rank, edge_path, feature_path, label_path, meta_path, wg_mem_type +): + from cugraph_pyg.data import GraphStore, WholeFeatureStore + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = WholeFeatureStore(memory_type=wg_mem_type) + + # Load metadata + with open(meta_path, "r") as f: + meta = json.load(f) + + # Load labels + split_idx = {} + for split in ["train", "test", "valid"]: + split_idx[split] = torch.load( + os.path.join(label_path, f"rank={rank}", f"{split}.pt") + ) + + # Load features + feature_store["node", "x"] = torch.load( + os.path.join(feature_path, f"rank={rank}_x.pt") + ) + feature_store["node", "y"] = torch.load( + os.path.join(feature_path, f"rank={rank}_y.pt") + ) + + # Load edge index + eix = torch.load(os.path.join(edge_path, f"rank={rank}.pt")) + graph_store[ + ("node", "rel", "node"), "coo", False, (meta["num_nodes"], meta["num_nodes"]) + ] = eix + + return (feature_store, graph_store), split_idx, meta + + +def run_train( + global_rank, + data, + split_idx, + world_size, + device, + model, + epochs, + batch_size, + fan_out, + num_classes, + wall_clock_start, + tempdir=None, + num_layers=3, +): + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) + + kwargs = dict( + num_neighbors=[fan_out] * num_layers, + batch_size=batch_size, + ) + # Set Up Neighbor Loading + from cugraph_pyg.loader import NeighborLoader + + ix_train = split_idx["train"].cuda() + train_path = os.path.join(tempdir, f"train_{global_rank}") + os.mkdir(train_path) + train_loader = NeighborLoader( + data, + input_nodes=ix_train, + directory=train_path, + shuffle=True, + drop_last=True, + **kwargs, + ) + + ix_test = split_idx["test"].cuda() + test_path = os.path.join(tempdir, f"test_{global_rank}") + os.mkdir(test_path) + test_loader = NeighborLoader( + data, + input_nodes=ix_test, + directory=test_path, + shuffle=True, + drop_last=True, + local_seeds_per_call=80000, + **kwargs, + ) + + ix_valid = split_idx["valid"].cuda() + valid_path = os.path.join(tempdir, f"valid_{global_rank}") + os.mkdir(valid_path) + valid_loader = NeighborLoader( + data, + input_nodes=ix_valid, + directory=valid_path, + shuffle=True, + drop_last=True, + **kwargs, + ) + + dist.barrier() + + eval_steps = 1000 + warmup_steps = 20 + dist.barrier() + torch.cuda.synchronize() + + if global_rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time) =", prep_time, "seconds") + print("Beginning training...") + + for epoch in range(epochs): + for i, batch in enumerate(train_loader): + if i == warmup_steps: + torch.cuda.synchronize() + start = time.time() + + batch = batch.to(device) + batch_size = batch.batch_size + + batch.y = batch.y.view(-1).to(torch.long) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index) + loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) + loss.backward() + optimizer.step() + if global_rank == 0 and i % 10 == 0: + print( + "Epoch: " + + str(epoch) + + ", Iteration: " + + str(i) + + ", Loss: " + + str(loss) + ) + nb = i + 1.0 + + if global_rank == 0: + print( + "Average Training Iteration Time:", + (time.time() - start) / (nb - warmup_steps), + "s/iter", + ) + + with torch.no_grad(): + total_correct = total_examples = 0 + for i, batch in enumerate(valid_loader): + if i >= eval_steps: + break + + batch = batch.to(device) + batch_size = batch.batch_size + + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index)[:batch_size] + + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + acc_val = total_correct / total_examples + if global_rank == 0: + print( + f"Validation Accuracy: {acc_val * 100.0:.4f}%", + ) + + torch.cuda.synchronize() + + with torch.no_grad(): + total_correct = total_examples = 0 + for i, batch in enumerate(test_loader): + batch = batch.to(device) + batch_size = batch.batch_size + + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index)[:batch_size] + + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + acc_test = total_correct / total_examples + if global_rank == 0: + print( + f"Test Accuracy: {acc_test * 100.0:.4f}%", + ) + + if global_rank == 0: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") + + wm_finalize() + cugraph_comms_shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=256) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=1024) + parser.add_argument("--fan_out", type=int, default=30) + parser.add_argument("--tempdir_root", type=str, default=None) + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--skip_partition", action="store_true") + parser.add_argument("--wg_mem_type", type=str, default="chunked") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + wall_clock_start = time.perf_counter() + + if "LOCAL_RANK" in os.environ: + dist.init_process_group("nccl") + world_size = dist.get_world_size() + global_rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(local_rank) + + # Create the uid needed for cuGraph comms + if global_rank == 0: + cugraph_id = [cugraph_comms_create_unique_id()] + else: + cugraph_id = [None] + dist.broadcast_object_list(cugraph_id, src=0, device=device) + cugraph_id = cugraph_id[0] + + init_pytorch_worker(global_rank, local_rank, world_size, cugraph_id) + + # Split the data + edge_path = os.path.join(args.dataset_root, args.dataset + "_eix_part") + feature_path = os.path.join(args.dataset_root, args.dataset + "_fea_part") + label_path = os.path.join(args.dataset_root, args.dataset + "_label_part") + meta_path = os.path.join(args.dataset_root, args.dataset + "_meta.json") + + # We partition the data to avoid loading it in every worker, which will + # waste memory and can lead to an out of memory exception. + # cugraph_pyg.GraphStore and cugraph_pyg.WholeFeatureStore are always + # constructed from partitions of the edge index and features, respectively, + # so this works well. + if not args.skip_partition and global_rank == 0: + dataset = PygNodePropPredDataset(name=args.dataset, root=args.dataset_root) + split_idx = dataset.get_idx_split() + + partition_data( + dataset, + split_idx, + meta_path=meta_path, + label_path=label_path, + feature_path=feature_path, + edge_path=edge_path, + ) + + dist.barrier() + data, split_idx, meta = load_partitioned_data( + rank=global_rank, + edge_path=edge_path, + feature_path=feature_path, + label_path=label_path, + meta_path=meta_path, + wg_mem_type=args.wg_mem_type, + ) + dist.barrier() + + model = torch_geometric.nn.models.GCN( + meta["num_features"], + args.hidden_channels, + args.num_layers, + meta["num_classes"], + ).to(device) + model = DistributedDataParallel(model, device_ids=[local_rank]) + + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir: + run_train( + global_rank, + data, + split_idx, + world_size, + device, + model, + args.epochs, + args.batch_size, + args.fan_out, + meta["num_classes"], + wall_clock_start, + tempdir, + args.num_layers, + ) + else: + warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py index 71b0e4bb2fb..82a612622a1 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py @@ -16,7 +16,7 @@ import tempfile import os -from typing import Optional +from typing import Optional, Tuple, Dict import torch import cupy @@ -42,137 +42,174 @@ enable_spilling() -parser = argparse.ArgumentParser() -parser.add_argument("--hidden_channels", type=int, default=256) -parser.add_argument("--num_layers", type=int, default=2) -parser.add_argument("--lr", type=float, default=0.001) -parser.add_argument("--epochs", type=int, default=4) -parser.add_argument("--batch_size", type=int, default=1024) -parser.add_argument("--fan_out", type=int, default=30) -parser.add_argument("--tempdir_root", type=str, default=None) -parser.add_argument("--dataset_root", type=str, default="dataset") -parser.add_argument("--dataset", type=str, default="ogbn-products") - -args = parser.parse_args() - -wall_clock_start = time.perf_counter() -device = torch.device("cuda") - -from ogb.nodeproppred import PygNodePropPredDataset # noqa: E402 - -dataset = PygNodePropPredDataset(name=args.dataset, root=args.dataset_root) -split_idx = dataset.get_idx_split() -data = dataset[0] - -graph_store = cugraph_pyg.data.GraphStore() -graph_store[ - ("node", "rel", "node"), "coo", False, (data.num_nodes, data.num_nodes) -] = data.edge_index - -feature_store = cugraph_pyg.data.TensorDictFeatureStore() -feature_store["node", "x"] = data.x -feature_store["node", "y"] = data.y - -with tempfile.TemporaryDirectory(dir=args.tempdir_root) as samples_dir: - train_dir = os.path.join(samples_dir, "train") - os.mkdir(train_dir) - train_loader = NeighborLoader( - data=(feature_store, graph_store), - num_neighbors=[args.fan_out] * args.num_layers, - input_nodes=split_idx["train"], - replace=False, - batch_size=args.batch_size, - directory=train_dir, - ) - val_dir = os.path.join(samples_dir, "val") - os.mkdir(val_dir) - val_loader = NeighborLoader( - data=(feature_store, graph_store), - num_neighbors=[args.fan_out] * args.num_layers, - input_nodes=split_idx["valid"], - replace=False, - batch_size=args.batch_size, - directory=val_dir, +def train(epoch: int): + model.train() + for i, batch in enumerate(train_loader): + if i == warmup_steps: + torch.cuda.synchronize() + start_avg_time = time.perf_counter() + batch = batch.to(device) + + optimizer.zero_grad() + batch_size = batch.batch_size + out = model(batch.x, batch.edge_index)[:batch_size] + y = batch.y[:batch_size].view(-1).to(torch.long) + + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(f"Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}") + torch.cuda.synchronize() + print( + f"Average Training Iteration Time (s/iter): \ + {(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}" ) - test_dir = os.path.join(samples_dir, "test") - os.mkdir(test_dir) - test_loader = NeighborLoader( - data=(feature_store, graph_store), - num_neighbors=[args.fan_out] * args.num_layers, - input_nodes=split_idx["test"], - replace=False, - batch_size=args.batch_size, - directory=test_dir, + +@torch.no_grad() +def test(loader: NeighborLoader, val_steps: Optional[int] = None): + model.eval() + + total_correct = total_examples = 0 + for i, batch in enumerate(loader): + if val_steps is not None and i >= val_steps: + break + batch = batch.to(device) + batch_size = batch.batch_size + out = model(batch.x, batch.edge_index)[:batch_size] + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + return total_correct / total_examples + + +def create_loader( + data, num_neighbors, input_nodes, replace, batch_size, samples_dir, stage_name +): + directory = os.path.join(samples_dir, stage_name) + os.mkdir(directory) + return NeighborLoader( + data, + num_neighbors=num_neighbors, + input_nodes=input_nodes, + replace=replace, + batch_size=batch_size, + directory=directory, ) - model = torch_geometric.nn.models.GCN( + +def load_data( + dataset, dataset_root +) -> Tuple[ + Tuple[torch_geometric.data.FeatureStore, torch_geometric.data.GraphStore], + Dict[str, torch.Tensor], + int, + int, +]: + from ogb.nodeproppred import PygNodePropPredDataset + + dataset = PygNodePropPredDataset(dataset, root=dataset_root) + split_idx = dataset.get_idx_split() + data = dataset[0] + + graph_store = cugraph_pyg.data.GraphStore() + graph_store[ + ("node", "rel", "node"), "coo", False, (data.num_nodes, data.num_nodes) + ] = data.edge_index + + feature_store = cugraph_pyg.data.TensorDictFeatureStore() + feature_store["node", "x"] = data.x + feature_store["node", "y"] = data.y + + return ( + (feature_store, graph_store), + split_idx, dataset.num_features, - args.hidden_channels, - args.num_layers, dataset.num_classes, - ).to(device) + ) - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) - warmup_steps = 20 +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_channels", type=int, default=256) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=1024) + parser.add_argument("--fan_out", type=int, default=30) + parser.add_argument("--tempdir_root", type=str, default=None) + parser.add_argument("--dataset_root", type=str, default="dataset") + parser.add_argument("--dataset", type=str, default="ogbn-products") - def train(epoch: int): - model.train() - for i, batch in enumerate(train_loader): - if i == warmup_steps: - torch.cuda.synchronize() - start_avg_time = time.perf_counter() - batch = batch.to(device) + return parser.parse_args() - optimizer.zero_grad() - batch_size = batch.batch_size - out = model(batch.x, batch.edge_index)[:batch_size] - y = batch.y[:batch_size].view(-1).to(torch.long) - loss = F.cross_entropy(out, y) - loss.backward() - optimizer.step() +if __name__ == "__main__": + args = parse_args() + wall_clock_start = time.perf_counter() + device = torch.device("cuda") - if i % 10 == 0: - print(f"Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}") - torch.cuda.synchronize() - print( - f"Average Training Iteration Time (s/iter): \ - {(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}" + data, split_idx, num_features, num_classes = load_data( + args.dataset, args.dataset_root + ) + + with tempfile.TemporaryDirectory(dir=args.tempdir_root) as samples_dir: + loader_kwargs = { + "data": data, + "num_neighbors": [args.fan_out] * args.num_layers, + "replace": False, + "batch_size": args.batch_size, + "samples_dir": samples_dir, + } + + train_loader = create_loader( + input_nodes=split_idx["train"], + stage_name="train", + **loader_kwargs, + ) + + val_loader = create_loader( + input_nodes=split_idx["valid"], + stage_name="val", + **loader_kwargs, ) - @torch.no_grad() - def test(loader: NeighborLoader, val_steps: Optional[int] = None): - model.eval() + test_loader = create_loader( + input_nodes=split_idx["test"], + stage_name="test", + **loader_kwargs, + ) - total_correct = total_examples = 0 - for i, batch in enumerate(loader): - if val_steps is not None and i >= val_steps: - break - batch = batch.to(device) - batch_size = batch.batch_size - out = model(batch.x, batch.edge_index)[:batch_size] - pred = out.argmax(dim=-1) - y = batch.y[:batch_size].view(-1).to(torch.long) + model = torch_geometric.nn.models.GCN( + num_features, + args.hidden_channels, + args.num_layers, + num_classes, + ).to(device) - total_correct += int((pred == y).sum()) - total_examples += y.size(0) + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=0.0005 + ) - return total_correct / total_examples + warmup_steps = 20 - torch.cuda.synchronize() - prep_time = round(time.perf_counter() - wall_clock_start, 2) - print("Total time before training begins (prep_time)=", prep_time, "seconds") - print("Beginning training...") - for epoch in range(1, 1 + args.epochs): - train(epoch) - val_acc = test(val_loader, val_steps=100) - print(f"Val Acc: ~{val_acc:.4f}") - - test_acc = test(test_loader) - print(f"Test Acc: {test_acc:.4f}") - total_time = round(time.perf_counter() - wall_clock_start, 2) - print("Total Program Runtime (total_time) =", total_time, "seconds") - print("total_time - prep_time =", total_time - prep_time, "seconds") + torch.cuda.synchronize() + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, "seconds") + print("Beginning training...") + for epoch in range(1, 1 + args.epochs): + train(epoch) + val_acc = test(val_loader, val_steps=100) + print(f"Val Acc: ~{val_acc:.4f}") + + test_acc = test(test_loader) + print(f"Test Acc: {test_acc:.4f}") + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") diff --git a/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store_mg.py new file mode 100644 index 00000000000..f1f514560c8 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store_mg.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from cugraph.utilities.utils import import_optional, MissingModule + +from cugraph_pyg.data import TensorDictFeatureStore, WholeFeatureStore + +torch = import_optional("torch") +pylibwholegraph = import_optional("pylibwholegraph") + + +def run_test_wholegraph_feature_store_basic_api(rank, world_size, dtype): + if dtype == "float32": + torch_dtype = torch.float32 + elif dtype == "int64": + torch_dtype = torch.int64 + + torch.cuda.set_device(rank) + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + pylibwholegraph.torch.initialize.init( + rank, + world_size, + rank, + world_size, + ) + + features = torch.arange(0, world_size * 2000) + features = features.reshape((features.numel() // 100, 100)).to(torch_dtype) + + tensordict_store = TensorDictFeatureStore() + tensordict_store["node", "fea"] = features + + whole_store = WholeFeatureStore() + whole_store["node", "fea"] = torch.tensor_split(features, world_size)[rank] + + ix = torch.arange(features.shape[0]) + assert ( + whole_store["node", "fea"][ix].cpu() == tensordict_store["node", "fea"][ix] + ).all() + + label = torch.arange(0, features.shape[0]).reshape((features.shape[0], 1)) + tensordict_store["node", "label"] = label + whole_store["node", "label"] = torch.tensor_split(label, world_size)[rank] + + assert ( + whole_store["node", "fea"][ix].cpu() == tensordict_store["node", "fea"][ix] + ).all() + + pylibwholegraph.torch.initialize.finalize() + + +@pytest.mark.skipif( + isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" +) +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("dtype", ["float32", "int64"]) +@pytest.mark.mg +def test_wholegraph_feature_store_basic_api(dtype): + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn( + run_test_wholegraph_feature_store_basic_api, + args=( + world_size, + dtype, + ), + nprocs=world_size, + ) diff --git a/python/cugraph-pyg/pyproject.toml b/python/cugraph-pyg/pyproject.toml index dfa522e6047..b41911b5f80 100644 --- a/python/cugraph-pyg/pyproject.toml +++ b/python/cugraph-pyg/pyproject.toml @@ -40,6 +40,7 @@ Documentation = "https://docs.rapids.ai/api/cugraph/stable/" [project.optional-dependencies] test = [ "pandas", + "pylibwholegraph==24.6.*", "pytest", "pytest-benchmark", "pytest-cov", diff --git a/python/cugraph/pyproject.toml b/python/cugraph/pyproject.toml index b29d6f80ff0..8f9a6214ace 100644 --- a/python/cugraph/pyproject.toml +++ b/python/cugraph/pyproject.toml @@ -55,6 +55,7 @@ test = [ "networkx>=2.5.1", "numpy>=1.23,<2.0a0", "pandas", + "pylibwholegraph==24.6.*", "pytest", "pytest-benchmark", "pytest-cov",