Skip to content

Commit

Permalink
cuGraph-DGL and WholeGraph Performance Testing with Feature Store Per…
Browse files Browse the repository at this point in the history
…formance Improvements (#4081)

Large-scale cuGraph-DGL performance testing scripts.  Also changes the DGL and PyG scripts to evaluate on all ranks and reuse the test samples, and adds support for benchmarking cuGraph-DGL/cuGraph-PyG with WholeGraph.

Updates `cuGraph.gnn.FeatureStore` and `cuGraph-PyG` for increased performance:
* Supporting passing in a WG embedding directly to cugraph.gnn.FeatureStore
* Simplifying how cuGraph-PyG handles filtering and using a cache to prevent repeatedly copying data between the device and host
* Fix bug in cugraph.gnn.FeatureStore where indexing with a gpu tensor would raise an exception, especially with WG
* Add a function to cugraph.gnn.FeatureStore to check where data is stored, which is used by cuGraph-PyG to prevent unnecessary d2h and h2d copies

Merge after #3584

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Brad Rees (https://github.com/BradReesWork)

Approvers:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Don Acosta (https://github.com/acostadon)
  - Brad Rees (https://github.com/BradReesWork)
  - Naim (https://github.com/naimnv)
  - Joseph Nke (https://github.com/jnke2016)

URL: #4081
  • Loading branch information
alexbarghi-nv authored Mar 11, 2024
1 parent 6c4f881 commit 4f4be6e
Show file tree
Hide file tree
Showing 19 changed files with 1,514 additions and 265 deletions.
2 changes: 1 addition & 1 deletion benchmarks/cugraph/standalone/bulk_sampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. Y
the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which
can be used to create replications of the dataset for scale testing purposes.

The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE`
The final two arguments are `FRAMEWORK` which can be "cugraph_dgl_csr", "cugraph_pyg" or "pyg", and `GPUS_PER_NODE`
which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE`
is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched
GPUs per node is currently unsupported by this script but should be possible in practice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> Non

rmm.reinitialize(
devices=[rank],
pool_allocator=True,
initial_pool_size=pool_size,
pool_allocator=False,
# pool_allocator=True,
# initial_pool_size=pool_size,
)

if use_rmm_torch_allocator:
Expand Down Expand Up @@ -119,10 +120,17 @@ def parse_args():
parser.add_argument(
"--framework",
type=str,
help="The framework to test (PyG, cuGraphPyG)",
help="The framework to test (PyG, cugraph_pyg, cugraph_dgl_csr)",
required=True,
)

parser.add_argument(
"--use_wholegraph",
action="store_true",
help="Whether to use WholeGraph feature storage",
required=False,
)

parser.add_argument(
"--model",
type=str,
Expand Down Expand Up @@ -162,6 +170,13 @@ def parse_args():
required=False,
)

parser.add_argument(
"--skip_download",
action="store_true",
help="Whether to skip downloading",
required=False,
)

return parser.parse_args()


Expand All @@ -186,21 +201,43 @@ def main(args):

world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node

if args.use_wholegraph:
# TODO support WG without cuGraph
if args.framework.lower() not in ["cugraph_pyg", "cugraph_dgl_csr"]:
raise ValueError("WG feature store only supported with cuGraph backends")
from pylibwholegraph.torch.initialize import (
get_global_communicator,
get_local_node_communicator,
init,
)

logger.info("initializing WG comms...")
init(global_rank, world_size, local_rank, args.gpus_per_node)
wm_comm = get_global_communicator()
get_local_node_communicator()

wm_comm = wm_comm.wmb_comm
logger.info(f"rank {global_rank} successfully initialized WG comms")
wm_comm.barrier()

dataset = OGBNPapers100MDataset(
replication_factor=args.replication_factor,
dataset_dir=args.dataset_dir,
train_split=args.train_split,
val_split=args.val_split,
load_edge_index=(args.framework == "PyG"),
load_edge_index=(args.framework.lower() == "pyg"),
backend="wholegraph" if args.use_wholegraph else "torch",
)

if global_rank == 0:
# Note: this does not generate WG files
if global_rank == 0 and not args.skip_download:
dataset.download()

dist.barrier()

fanout = [int(f) for f in args.fanout.split("_")]

if args.framework == "PyG":
if args.framework.lower() == "pyg":
from trainers.pyg import PyGNativeTrainer

trainer = PyGNativeTrainer(
Expand All @@ -215,7 +252,7 @@ def main(args):
num_neighbors=fanout,
batch_size=args.batch_size,
)
elif args.framework == "cuGraphPyG":
elif args.framework.lower() == "cugraph_pyg":
sample_dir = os.path.join(
args.sample_dir,
f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}",
Expand All @@ -229,11 +266,35 @@ def main(args):
device=local_rank,
rank=global_rank,
world_size=world_size,
gpus_per_node=args.gpus_per_node,
num_epochs=args.num_epochs,
shuffle=True,
replace=False,
num_neighbors=fanout,
batch_size=args.batch_size,
backend="wholegraph" if args.use_wholegraph else "torch",
)
elif args.framework.lower() == "cugraph_dgl_csr":
sample_dir = os.path.join(
args.sample_dir,
f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}",
)
from trainers.dgl import DGLCuGraphTrainer

trainer = DGLCuGraphTrainer(
model=args.model,
dataset=dataset,
sample_dir=sample_dir,
device=local_rank,
rank=global_rank,
world_size=world_size,
gpus_per_node=args.gpus_per_node,
num_epochs=args.num_epochs,
shuffle=True,
replace=False,
num_neighbors=[int(f) for f in args.fanout.split("_")],
batch_size=args.batch_size,
backend="wholegraph" if args.use_wholegraph else "torch",
)
else:
raise ValueError("unsupported framework")
Expand Down
Loading

0 comments on commit 4f4be6e

Please sign in to comment.