Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW]Optimize cugraph-DGL csc codepath #3977

Merged
merged 12 commits into from
Nov 8, 2023
152 changes: 152 additions & 0 deletions benchmarks/cugraph-dgl/scale-benchmarks/cugraph_dgl_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO"
os.environ["KVIKIO_NTHREADS"] = "64"
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
import json
import pandas as pd
import os
import time
from rmm.allocators.torch import rmm_torch_allocator
import rmm
import torch
from cugraph_dgl.dataloading import HomogenousBulkSamplerDataset
from model import run_1_epoch
from argparse import ArgumentParser
from load_graph_feats import load_node_labels, load_node_features


def create_dataloader(sampled_dir, total_num_nodes, sparse_format, return_type):
print("Creating dataloader", flush=True)
st = time.time()
dataset = HomogenousBulkSamplerDataset(
total_num_nodes,
edge_dir="in",
sparse_format=sparse_format,
return_type=return_type,
)

dataset.set_input_files(sampled_dir)
dataloader = torch.utils.data.DataLoader(
dataset, collate_fn=lambda x: x, shuffle=False, num_workers=0, batch_size=None
)
et = time.time()
print(f"Time to create dataloader = {et - st:.2f} seconds", flush=True)
return dataloader


def setup_common_pool():
rmm.reinitialize(initial_pool_size=5e9, pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def main(args):
print(
f"Running cugraph-dgl dataloading benchmark with the following parameters:\n"
f"Dataset path = {args.dataset_path}\n"
f"Sampling path = {args.sampling_path}\n"
)
with open(os.path.join(args.dataset_path, "meta.json"), "r") as f:
input_meta = json.load(f)

sampled_dirs = [
os.path.join(args.sampling_path, f) for f in os.listdir(args.sampling_path)
]

time_ls = []
for sampled_dir in sampled_dirs:
with open(os.path.join(sampled_dir, "output_meta.json"), "r") as f:
sampled_meta_d = json.load(f)

replication_factor = sampled_meta_d["replication_factor"]
feat_load_st = time.time()
label_data = load_node_labels(
args.dataset_path, replication_factor, input_meta
)["paper"]["y"]
feat_data = feat_data = load_node_features(
args.dataset_path, replication_factor, node_type="paper"
)
print(
f"Feature and label data loading took = {time.time()-feat_load_st}",
flush=True,
)

r_time_ls = e2e_benchmark(sampled_dir, feat_data, label_data, sampled_meta_d)
[x.update({"replication_factor": replication_factor}) for x in r_time_ls]
[x.update({"num_edges": sampled_meta_d["total_num_edges"]}) for x in r_time_ls]
time_ls.extend(r_time_ls)

print(
f"Benchmark completed for replication factor = {replication_factor}\n{'=' * 30}",
flush=True,
)

df = pd.DataFrame(time_ls)
df.to_csv("cugraph_dgl_e2e_benchmark.csv", index=False)
print(f"Benchmark completed for all replication factors\n{'=' * 30}", flush=True)


def e2e_benchmark(
sampled_dir: str, feat: torch.Tensor, y: torch.Tensor, sampled_meta_d: dict
):
"""
Run the e2e_benchmark
Args:
sampled_dir: directory containing the sampled graph
feat: node features
y: node labels
sampled_meta_d: dictionary containing the sampled graph metadata
"""
time_ls = []

# TODO: Make this a parameter in bulk sampling script
sampled_meta_d["sparse_format"] = "csc"
sampled_dir = os.path.join(sampled_dir, "samples")
dataloader = create_dataloader(
sampled_dir,
sampled_meta_d["total_num_nodes"],
sampled_meta_d["sparse_format"],
return_type="cugraph_dgl.nn.SparseGraph",
)
time_d = run_1_epoch(
dataloader,
feat,
y,
fanout=sampled_meta_d["fanout"],
batch_size=sampled_meta_d["batch_size"],
model_backend="cugraph_dgl",
)
time_ls.append(time_d)
print("=" * 30)
return time_ls


def parse_arguments():
parser = ArgumentParser()
parser.add_argument(
"--dataset_path", type=str, default="/raid/vjawa/ogbn_papers100M/"
)
parser.add_argument(
"--sampling_path",
type=str,
default="/raid/vjawa/nov_1_bulksampling_benchmarks/",
)
return parser.parse_args()


if __name__ == "__main__":
setup_common_pool()
arguments = parse_arguments()
main(arguments)
17 changes: 11 additions & 6 deletions benchmarks/cugraph-dgl/scale-benchmarks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def create_model(feat_size, num_classes, num_layers, model_backend="dgl"):


def train_model(model, dataloader, opt, feat, y):
times = {key: 0 for key in ["mfg_creation", "feature", "m_fwd", "m_bkwd"]}
times_d = {key: 0 for key in ["mfg_creation", "feature", "m_fwd", "m_bkwd"]}
epoch_st = time.time()
mfg_st = time.time()
for input_nodes, output_nodes, blocks in dataloader:
times["mfg_creation"] += time.time() - mfg_st
times_d["mfg_creation"] += time.time() - mfg_st
if feat is not None:
fst = time.time()
input_nodes = input_nodes.to("cpu")
Expand All @@ -71,23 +71,24 @@ def train_model(model, dataloader, opt, feat, y):
output_nodes = output_nodes["paper"]
output_nodes = output_nodes.to(y.device)
y_batch = y[output_nodes].to("cuda")
times["feature"] += time.time() - fst
times_d["feature"] += time.time() - fst

m_fwd_st = time.time()
y_hat = model(blocks, input_feat)
times["m_fwd"] += time.time() - m_fwd_st
times_d["m_fwd"] += time.time() - m_fwd_st

m_bkwd_st = time.time()
loss = F.cross_entropy(y_hat, y_batch)
opt.zero_grad()
loss.backward()
opt.step()
times["m_bkwd"] += time.time() - m_bkwd_st
times_d["m_bkwd"] += time.time() - m_bkwd_st
mfg_st = time.time()

print(f"Epoch time = {time.time() - epoch_st:.2f} seconds")
print(f"Time to create MFG = {times_d['mfg_creation']:.2f} seconds")

return times
return times_d


def analyze_time(dataloader, times, epoch_time, fanout, batch_size):
Expand Down Expand Up @@ -119,6 +120,10 @@ def run_1_epoch(dataloader, feat, y, fanout, batch_size, model_backend):
else:
model = None
opt = None

# Warmup RUN
times = train_model(model, dataloader, opt, feat, y)

epoch_st = time.time()
times = train_model(model, dataloader, opt, feat, y)
epoch_time = time.time() - epoch_st
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import cudf
import cupy
from cugraph.utilities.utils import import_optional
from cugraph_dgl.nn import SparseGraph

Expand Down Expand Up @@ -444,53 +443,58 @@ def _process_sampled_df_csc(
destinations, respectively.
"""
# dropna
major_offsets = df.major_offsets.dropna().values
label_hop_offsets = df.label_hop_offsets.dropna().values
renumber_map_offsets = df.renumber_map_offsets.dropna().values
renumber_map = df.map.dropna().values
minors = df.minors.dropna().values
major_offsets = cast_to_tensor(df.major_offsets.dropna())
label_hop_offsets = cast_to_tensor(df.label_hop_offsets.dropna())
renumber_map_offsets = cast_to_tensor(df.renumber_map_offsets.dropna())
renumber_map = cast_to_tensor(df.map.dropna())
minors = cast_to_tensor(df.minors.dropna())

n_batches = renumber_map_offsets.size - 1
n_hops = int((label_hop_offsets.size - 1) / n_batches)
n_batches = len(renumber_map_offsets) - 1
n_hops = int((len(label_hop_offsets) - 1) / n_batches)

# make global offsets local
major_offsets -= major_offsets[0]
label_hop_offsets -= label_hop_offsets[0]
renumber_map_offsets -= renumber_map_offsets[0]
# Have to make a clone as pytorch does not allow
# in-place operations on tensors
major_offsets -= major_offsets[0].clone()
label_hop_offsets -= label_hop_offsets[0].clone()
renumber_map_offsets -= renumber_map_offsets[0].clone()

# get the sizes of each adjacency matrix (for MFGs)
mfg_sizes = (label_hop_offsets[1:] - label_hop_offsets[:-1]).reshape(
(n_batches, n_hops)
)
n_nodes = renumber_map_offsets[1:] - renumber_map_offsets[:-1]
mfg_sizes = cupy.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1)))
mfg_sizes = torch.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1)))
if reverse_hop_id:
mfg_sizes = mfg_sizes[:, ::-1]
mfg_sizes = mfg_sizes.flip(1)

tensors_dict = {}
renumber_map_list = []
# Note: minors and major_offsets from BulkSampler are of type int32
# and int64 respectively. Since pylibcugraphops binding code doesn't
# support distinct node and edge index type, we simply casting both
# to int32 for now.
minors = minors.int()
major_offsets = major_offsets.int()
# Note: We transfer tensors to CPU here to avoid the overhead of
# transferring them in each iteration of the for loop below.
major_offsets_cpu = major_offsets.to("cpu").numpy()
label_hop_offsets_cpu = label_hop_offsets.to("cpu").numpy()
Comment on lines +479 to +482
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tingyu66 , This is the main optimization because transferring b/w CPU ->GPU 1 tensor at a time was slow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for resolving the .item() overhead. 👍


for batch_id in range(n_batches):
batch_dict = {}

for hop_id in range(n_hops):
hop_dict = {}
idx = batch_id * n_hops + hop_id # idx in label_hop_offsets
major_offsets_start = label_hop_offsets[idx].item()
major_offsets_end = label_hop_offsets[idx + 1].item()
minors_start = major_offsets[major_offsets_start].item()
minors_end = major_offsets[major_offsets_end].item()
# Note: minors and major_offsets from BulkSampler are of type int32
# and int64 respectively. Since pylibcugraphops binding code doesn't
# support distinct node and edge index type, we simply casting both
# to int32 for now.
hop_dict["minors"] = torch.as_tensor(
minors[minors_start:minors_end], device="cuda"
).int()
hop_dict["major_offsets"] = torch.as_tensor(
major_offsets_start = label_hop_offsets_cpu[idx]
major_offsets_end = label_hop_offsets_cpu[idx + 1]
minors_start = major_offsets_cpu[major_offsets_start]
minors_end = major_offsets_cpu[major_offsets_end]
hop_dict["minors"] = minors[minors_start:minors_end]
hop_dict["major_offsets"] = (
major_offsets[major_offsets_start : major_offsets_end + 1]
- major_offsets[major_offsets_start],
device="cuda",
).int()
- major_offsets[major_offsets_start]
)
if reverse_hop_id:
batch_dict[n_hops - 1 - hop_id] = hop_dict
else:
Expand All @@ -499,12 +503,9 @@ def _process_sampled_df_csc(
tensors_dict[batch_id] = batch_dict

renumber_map_list.append(
torch.as_tensor(
renumber_map[
renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1]
],
device="cuda",
)
renumber_map[
renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1]
],
)

return tensors_dict, renumber_map_list, mfg_sizes.tolist()
Expand Down