Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MFG creation on sampling gpus for cugraph dgl #3742

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
6263b9f
define sampling output renumbering function API
seunghwak Jul 13, 2023
6c88e86
update the API (remove multi_gpu flag)
seunghwak Jul 17, 2023
055496f
initial draft implementation
seunghwak Jul 17, 2023
27fe4f0
Merge branch 'branch-23.08' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Jul 18, 2023
2211d24
add test code
seunghwak Jul 19, 2023
9fe5e13
bug fixes
seunghwak Jul 19, 2023
9ad3c8f
minor tweaks
seunghwak Jul 19, 2023
dfbc196
define API for MFG renumbering in the C API
ChuckHastings Jul 20, 2023
bc5d3e1
define API for MFG renumbering in the C API
ChuckHastings Jul 20, 2023
012d392
memory footprint cut
seunghwak Jul 20, 2023
63759b8
code improvemnt
seunghwak Jul 20, 2023
e21bcd0
bug fix and memory footprint optimization
seunghwak Jul 20, 2023
7010081
bug fix
seunghwak Jul 20, 2023
d4d5407
clang-format
seunghwak Jul 20, 2023
a746701
Merge branch 'branch-23.08' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Jul 20, 2023
cf7dff8
clang-format
seunghwak Jul 20, 2023
948ad11
remove unnecessary template parameter
seunghwak Jul 20, 2023
b26523b
clang-format
seunghwak Jul 20, 2023
141bab9
Merge branch 'mfg_capi' of https://github.com/chuckhastings/cugraph i…
alexbarghi-nv Jul 21, 2023
ff813dd
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv Jul 21, 2023
37c08e8
Merge branch 'branch-23.08' into mfg_capi
alexbarghi-nv Jul 21, 2023
f0e5779
copyright year
seunghwak Jul 21, 2023
8222c18
Merge remote-tracking branch 'seunghwa/fea_mfg' into mfg_capi
ChuckHastings Jul 21, 2023
9a8abee
move expand_sparse_offsets from src/detail to include/cugraph/utiliti…
seunghwak Jul 21, 2023
d4d4c78
update renumber_sampled_edgelist to use the existing expand_sparse_of…
seunghwak Jul 21, 2023
c6d1d10
Merge branch 'branch-23.08' of github.com:rapidsai/cugraph into fea_mfg
seunghwak Jul 21, 2023
51a05bf
Merge branch 'mfg_capi' of github.com:chuckhastings/cugraph into mfg_…
ChuckHastings Jul 21, 2023
1e4174e
Testing with Seunghwa's branch merged in
ChuckHastings Jul 21, 2023
b1dac25
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv Jul 24, 2023
fa734fb
c
alexbarghi-nv Jul 24, 2023
899c0e3
merge in changes
alexbarghi-nv Jul 24, 2023
4b810fc
revert
alexbarghi-nv Jul 24, 2023
932e1ea
minor
alexbarghi-nv Jul 24, 2023
7609bac
Remove debug notebook
alexbarghi-nv Jul 24, 2023
a212d3b
refactor mg, fix test
alexbarghi-nv Jul 24, 2023
8a92353
basic functionality
alexbarghi-nv Jul 24, 2023
db77c16
Merge branch 'branch-23.08' into mfg_capi
ChuckHastings Jul 24, 2023
7e1d1d1
disable broken cugraph-pyg tests
alexbarghi-nv Jul 24, 2023
3bcdf7e
style
alexbarghi-nv Jul 24, 2023
8a16b95
pull in latest update
alexbarghi-nv Jul 24, 2023
6ffc699
Merge branch 'mfg_capi' of https://github.com/chuckhastings/cugraph i…
alexbarghi-nv Jul 24, 2023
2ceeaf7
fix pyg tests, make renumbering optional
alexbarghi-nv Jul 25, 2023
aab0b56
fix merge
alexbarghi-nv Jul 25, 2023
7979a0a
style
alexbarghi-nv Jul 25, 2023
68ef5ef
update pyg samplers
alexbarghi-nv Jul 25, 2023
665621e
sg loader tests
alexbarghi-nv Jul 25, 2023
6484b12
style
alexbarghi-nv Jul 25, 2023
2ee1ce9
remove prints
alexbarghi-nv Jul 25, 2023
bd3dbc5
add renumbering for homogenous graphs by default
alexbarghi-nv Jul 25, 2023
33c6353
style
alexbarghi-nv Jul 25, 2023
1b6dc5b
reformat
alexbarghi-nv Jul 25, 2023
bc06101
Support renumbering
VibhuJawa Jul 25, 2023
0af5c9e
Merge branch 'branch-23.08' into cugraph-dgl-sample-side-mfg
VibhuJawa Jul 25, 2023
0d0b3c8
Remove clone command
VibhuJawa Jul 25, 2023
e6e2fcd
wqMerge branch 'cugraph-dgl-sample-side-mfg' of https://github.com/Vi…
VibhuJawa Jul 25, 2023
0b3494d
wqMerge branch 'cugraph-sample-side-mfg-updated' into cugraph-dgl-sam…
VibhuJawa Jul 25, 2023
0063362
Update based on reviews
VibhuJawa Jul 26, 2023
927c7b5
Merge in updates
VibhuJawa Jul 26, 2023
2995c62
Make tests same as upstream
VibhuJawa Jul 26, 2023
f12d473
Add test based on reviews
VibhuJawa Jul 26, 2023
8d6d441
cugraph-dgl-sample-side-mfg
VibhuJawa Jul 26, 2023
4c2a42d
Add test for renumbering
VibhuJawa Jul 26, 2023
8ad21db
Add test_get_tensor_d_from_sampled_df
VibhuJawa Jul 26, 2023
cf3eb41
create_homogeneous_sampled_graphs_from_dataframe
VibhuJawa Jul 26, 2023
f62d773
Added nids to block
VibhuJawa Jul 26, 2023
42ebfda
Confirmed that pytests pass
VibhuJawa Jul 27, 2023
ff02d29
Remove print
VibhuJawa Jul 27, 2023
e4fb847
Merge branch 'branch-23.08' into cugraph-dgl-sample-side-mfg
VibhuJawa Jul 27, 2023
54d9d5f
Merge branch 'branch-23.08' into cugraph-dgl-sample-side-mfg
rlratzel Jul 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def __iter__(self):
output_dir = os.path.join(
self._sampling_output_dir, "epoch_" + str(self.epoch_number)
)
if isinstance(self.cugraph_dgl_dataset, HomogenousBulkSamplerDataset):
deduplicate_sources = True
prior_sources_behavior = "carryover"
renumber = True
else:
deduplicate_sources = False
prior_sources_behavior = None
renumber = False

bs = BulkSampler(
output_path=output_dir,
batch_size=self._batch_size,
Expand All @@ -218,6 +227,9 @@ def __iter__(self):
seeds_per_call=self._seeds_per_call,
fanout_vals=self.graph_sampler._reversed_fanout_vals,
with_replacement=self.graph_sampler.replace,
deduplicate_sources=deduplicate_sources,
prior_sources_behavior=prior_sources_behavior,
renumber=renumber,
)
if self.shuffle:
self.tensorized_indices_ds.shuffle()
Expand Down
5 changes: 4 additions & 1 deletion python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(
total_number_of_nodes: int,
edge_dir: str,
):
# TODO: Deprecate `total_number_of_nodes`
# as it is no longer needed
# in the next release
self.total_number_of_nodes = total_number_of_nodes
self.edge_dir = edge_dir
self._current_batch_fn = None
Expand All @@ -52,7 +55,7 @@ def __getitem__(self, idx: int):
if fn != self._current_batch_fn:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = create_homogeneous_sampled_graphs_from_dataframe(
df, self.total_number_of_nodes, self.edge_dir
df, self.edge_dir
)
current_offset = idx - batch_offset
return self._current_batches[current_offset]
Expand Down
221 changes: 140 additions & 81 deletions python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,130 +29,189 @@ def cast_to_tensor(ser: cudf.Series):
return torch.as_tensor(ser.values, device="cuda")


def _get_tensor_ls_from_sampled_df(df):
def _split_tensor(t, split_indices):
"""
Split a tensor into a list of tensors based on split_indices.
"""
# TODO: Switch to something below
# return [t[i:j] for i, j in zip(split_indices[:-1], split_indices[1:])]
if split_indices.device.type != "cpu":
split_indices = split_indices.to("cpu")
return torch.tensor_split(t, split_indices)


def _get_renumber_map(df):
map = df["map"]
df.drop(columns=["map"], inplace=True)

map_starting_offset = map.iloc[0]
renumber_map = map[map_starting_offset:].dropna().reset_index(drop=True)
renumber_map_batch_indices = map[1 : map_starting_offset - 1].reset_index(drop=True)
renumber_map_batch_indices = renumber_map_batch_indices - map_starting_offset

# Drop all rows with NaN values
df.dropna(axis=0, how="all", inplace=True)
df.reset_index(drop=True, inplace=True)

return df, cast_to_tensor(renumber_map), cast_to_tensor(renumber_map_batch_indices)


def _get_tensor_d_from_sampled_df(df):
"""
Converts a sampled cuDF DataFrame into a list of tensors.

Args:
df (cudf.DataFrame): The sampled cuDF DataFrame containing columns
'batch_id', 'sources', 'destinations', 'edge_id', and 'hop_id'.

Returns:
list: A list of tuples, where each tuple contains three tensors:
'sources', 'destinations', and 'edge_id'.
The tensors are split based on 'batch_id' and 'hop_id'.

dict: A dictionary of tensors, keyed by batch_id and hop_id.
"""
df, renumber_map, renumber_map_batch_indices = _get_renumber_map(df)
batch_id_tensor = cast_to_tensor(df["batch_id"])
batch_id_min = batch_id_tensor.min()
batch_id_max = batch_id_tensor.max()
batch_indices = torch.arange(
start=batch_id_tensor.min() + 1,
end=batch_id_tensor.max() + 1,
start=batch_id_min + 1,
end=batch_id_max + 1,
device=batch_id_tensor.device,
)
batch_indices = torch.searchsorted(batch_id_tensor, batch_indices)

split_d = {}

for column in ["sources", "destinations", "edge_id", "hop_id"]:
if column in df.columns:
tensor = cast_to_tensor(df[column])
split_d[column] = torch.tensor_split(tensor, batch_indices.cpu())
# TODO: Fix below
# batch_indices = _get_id_tensor_boundaries(batch_id_tensor)
batch_indices = torch.searchsorted(batch_id_tensor, batch_indices).to("cpu")
split_d = {i: {} for i in range(batch_id_min, batch_id_max + 1)}

for column in df.columns:
if column != "batch_id":
t = cast_to_tensor(df[column])
split_t = _split_tensor(t, batch_indices)
for bid, batch_t in zip(split_d.keys(), split_t):
split_d[bid][column] = batch_t

split_t = _split_tensor(renumber_map, renumber_map_batch_indices)
for bid, batch_t in zip(split_d.keys(), split_t):
split_d[bid]["map"] = batch_t
del df
result_tensor_d = {}
for batch_id, batch_d in split_d.items():
hop_id_tensor = batch_d["hop_id"]
hop_id_min = hop_id_tensor.min()
hop_id_max = hop_id_tensor.max()

result_tensor_ls = []
for i, hop_id_tensor in enumerate(split_d["hop_id"]):
hop_indices = torch.arange(
start=hop_id_tensor.min() + 1,
end=hop_id_tensor.max() + 1,
start=hop_id_min + 1,
end=hop_id_max + 1,
device=hop_id_tensor.device,
)
hop_indices = torch.searchsorted(hop_id_tensor, hop_indices)
s = torch.tensor_split(split_d["sources"][i], hop_indices.cpu())
d = torch.tensor_split(split_d["destinations"][i], hop_indices.cpu())
if "edge_id" in split_d:
eid = torch.tensor_split(split_d["edge_id"][i], hop_indices.cpu())
else:
eid = [None] * len(s)

result_tensor_ls.append((x, y, z) for x, y, z in zip(s, d, eid))

return result_tensor_ls
# TODO: Fix below
# hop_indices = _get_id_tensor_boundaries(hop_id_tensor)
hop_indices = torch.searchsorted(hop_id_tensor, hop_indices).to("cpu")
hop_split_d = {i: {} for i in range(hop_id_min, hop_id_max + 1)}
for column, t in batch_d.items():
if column not in ["hop_id", "map"]:
split_t = _split_tensor(t, hop_indices)
for hid, ht in zip(hop_split_d.keys(), split_t):
hop_split_d[hid][column] = ht

result_tensor_d[batch_id] = hop_split_d
if "map" in batch_d:
result_tensor_d[batch_id]["map"] = batch_d["map"]
return result_tensor_d


def create_homogeneous_sampled_graphs_from_dataframe(
sampled_df: cudf.DataFrame,
total_number_of_nodes: int,
edge_dir: str = "in",
):
"""
This helper function creates DGL MFGS for
homogeneous graphs from cugraph sampled dataframe

Args:
sampled_df (cudf.DataFrame): The sampled cuDF DataFrame containing
columns `sources`, `destinations`, `edge_id`, `batch_id` and
`hop_id`.
edge_dir (str): Direction of edges from samples
Returns:
list: A list containing three elements:
- input_nodes: The input nodes for the batch.
- output_nodes: The output nodes for the batch.
- graph_per_hop_ls: A list of DGL MFGS for each hop.
"""
result_tensor_ls = _get_tensor_ls_from_sampled_df(sampled_df)
result_tensor_d = _get_tensor_d_from_sampled_df(sampled_df)
del sampled_df
result_mfgs = [
_create_homogeneous_sampled_graphs_from_tensors_perhop(
tensors_perhop_ls, total_number_of_nodes, edge_dir
tensors_batch_d, edge_dir
)
for tensors_perhop_ls in result_tensor_ls
for tensors_batch_d in result_tensor_d.values()
]
del result_tensor_ls
del result_tensor_d
return result_mfgs


def _create_homogeneous_sampled_graphs_from_tensors_perhop(
tensors_perhop_ls, total_number_of_nodes, edge_dir
):
def _create_homogeneous_sampled_graphs_from_tensors_perhop(tensors_batch_d, edge_dir):
"""
This helper function creates sampled DGL MFGS for
homogeneous graphs from tensors per hop for a single
batch

Args:
tensors_batch_d (dict): A dictionary of tensors, keyed by hop_id.
edge_dir (str): Direction of edges from samples
Returns:
tuple: A tuple of three elements:
- input_nodes: The input nodes for the batch.
- output_nodes: The output nodes for the batch.
- graph_per_hop_ls: A list of DGL MFGS for each hop.
"""
if edge_dir not in ["in", "out"]:
raise ValueError(f"Invalid edge_dir {edge_dir} provided")
if edge_dir == "out":
raise ValueError("Outwards edges not supported yet")
graph_per_hop_ls = []
output_nodes = None
seed_nodes = None
for src_ids, dst_ids, edge_ids in tensors_perhop_ls:
# print("Creating block", flush=True)
block = create_homogeneous_dgl_block_from_tensors_ls(
src_ids=src_ids,
dst_ids=dst_ids,
edge_ids=edge_ids,
seed_nodes=seed_nodes,
total_number_of_nodes=total_number_of_nodes,
)
seed_nodes = block.srcdata[dgl.NID]
if output_nodes is None:
output_nodes = block.dstdata[dgl.NID]
graph_per_hop_ls.append(block)
seednodes = None
for hop_id, tensor_per_hop_d in tensors_batch_d.items():
if hop_id != "map":
block = _create_homogeneous_dgl_block_from_tensor_d(
tensor_per_hop_d, tensors_batch_d["map"], seednodes
)
seednodes = torch.concat(
[tensor_per_hop_d["sources"], tensor_per_hop_d["destinations"]]
)
graph_per_hop_ls.append(block)

# default DGL behavior
if edge_dir == "in":
graph_per_hop_ls.reverse()
return seed_nodes, output_nodes, graph_per_hop_ls


def create_homogeneous_dgl_block_from_tensors_ls(
src_ids: torch.Tensor,
dst_ids: torch.Tensor,
edge_ids: Optional[torch.Tensor],
seed_nodes: Optional[torch.Tensor],
total_number_of_nodes: int,
):
sampled_graph = dgl.graph(
(src_ids, dst_ids),
num_nodes=total_number_of_nodes,
)
if edge_ids is not None:
sampled_graph.edata[dgl.EID] = edge_ids
# TODO: Check if unique is needed
if seed_nodes is None:
seed_nodes = dst_ids.unique()

block = dgl.to_block(
sampled_graph,
dst_nodes=seed_nodes,
src_nodes=src_ids.unique(),
include_dst_in_src=True,
input_nodes = graph_per_hop_ls[0].srcdata[dgl.NID]
output_nodes = graph_per_hop_ls[-1].dstdata[dgl.NID]
return input_nodes, output_nodes, graph_per_hop_ls


def _create_homogeneous_dgl_block_from_tensor_d(tensor_d, renumber_map, seednodes=None):
rs = tensor_d["sources"]
rd = tensor_d["destinations"]

max_src_nodes = rs.max()
max_dst_nodes = rd.max()
if seednodes is not None:
# If we have isolated vertices
# sources can be missing from seednodes
# so we add them
# to ensure all the blocks are
# linedup correctly
max_dst_nodes = max(max_dst_nodes, seednodes.max())

data_dict = {("_N", "_E", "_N"): (rs, rd)}
num_src_nodes = {"_N": max_src_nodes.item() + 1}
num_dst_nodes = {"_N": max_dst_nodes.item() + 1}
block = dgl.create_block(
data_dict=data_dict, num_src_nodes=num_src_nodes, num_dst_nodes=num_dst_nodes
)
if edge_ids is not None:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
if "edge_id" in tensor_d:
block.edata[dgl.EID] = tensor_d["edge_id"]
block.srcdata[dgl.NID] = renumber_map[block.srcnodes()]
block.dstdata[dgl.NID] = renumber_map[block.dstnodes()]
return block


Expand Down
14 changes: 7 additions & 7 deletions python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"import torch\n",
"from rmm.allocators.torch import rmm_torch_allocator\n",
"rmm.reinitialize(initial_pool_size=15e9)\n",
"#Switch to async pool in case of memory issues due to fragmentation of the pool\n",
"#Switch to async pool in case of memory issues due to fragmentation of the pool\n",
"#rmm.mr.set_current_device_resource(rmm.mr.CudaAsyncMemoryResource(initial_pool_size=15e9))\n",
"torch.cuda.memory.change_current_allocator(rmm_torch_allocator)"
]
Expand Down Expand Up @@ -106,7 +106,7 @@
"g, train_idx = load_dgl_dataset()\n",
"g = cugraph_dgl.cugraph_storage_from_heterograph(g, single_gpu=single_gpu)\n",
"\n",
"batch_size = 1024\n",
"batch_size = 1024*2\n",
"fanout_vals=[25, 25]\n",
"sampler = cugraph_dgl.dataloading.NeighborSampler(fanout_vals)\n",
"dataloader = cugraph_dgl.dataloading.DataLoader(\n",
Expand Down Expand Up @@ -135,7 +135,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"7.25 s ± 916 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"7.08 s ± 596 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -190,8 +190,8 @@
"outputs": [],
"source": [
"g, train_idx = load_dgl_dataset()\n",
"batch_size = 1024\n",
"fanout_vals=[25, 25]\n",
"batch_size = 1024*2\n",
"fanout_vals = [25, 25]\n",
"sampler = dgl.dataloading.MultiLayerNeighborSampler(fanout_vals)\n",
"dataloader = dgl.dataloading.DataLoader(\n",
" g, \n",
Expand All @@ -217,7 +217,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"4.22 s ± 345 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"7.34 s ± 353 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -256,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
Loading