Skip to content

Commit

Permalink
[BUG] Output Edge Labels in the Distributed Sampler (#4898)
Browse files Browse the repository at this point in the history
We currently do not output edge labels in the distributed sampler, which breaks some link prediction workflows where the graph contains pre-labeled edges.  This PR adds support for that so these workflows can be enabled.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4898
  • Loading branch information
alexbarghi-nv authored Jan 29, 2025
1 parent b64b04f commit 9e3a457
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __get_call_groups(
input_id: TensorType,
seeds_per_call: int,
assume_equal_input_size: bool = False,
label: Optional[TensorType] = None,
):
torch = import_optional("torch")

Expand All @@ -231,6 +232,8 @@ def __get_call_groups(
# many batches.
seeds_call_groups = torch.split(seeds, seeds_per_call, dim=-1)
index_call_groups = torch.split(input_id, seeds_per_call, dim=-1)
if label is not None:
label_call_groups = torch.split(label, seeds_per_call, dim=-1)

# Need to add empties to the list of call groups to handle the case
# where not all ranks have the same number of call groups. This
Expand All @@ -251,8 +254,16 @@ def __get_call_groups(
[torch.tensor([], dtype=torch.int64, device=input_id.device)]
* (int(num_call_groups) - len(index_call_groups))
)
if label is not None:
label_call_groups = list(label_call_groups) + (
[torch.tensor([], dtype=label.dtype, device=label.device)]
* (int(num_call_groups) - len(label_call_groups))
)

return seeds_call_groups, index_call_groups
if label is not None:
return seeds_call_groups, index_call_groups, label_call_groups
else:
return seeds_call_groups, index_call_groups

def sample_from_nodes(
self,
Expand Down Expand Up @@ -344,7 +355,7 @@ def sample_from_nodes(
def __sample_from_edges_func(
self,
call_id: int,
current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor"],
current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"],
batch_id_start: int,
batch_size: int,
batches_per_call: int,
Expand All @@ -353,7 +364,7 @@ def __sample_from_edges_func(
) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]:
torch = import_optional("torch")

current_seeds, current_ix = current_seeds_and_ix
current_seeds, current_ix, current_label = current_seeds_and_ix
num_seed_edges = current_ix.numel()

# The index gets stored as-is regardless of what makes it into
Expand Down Expand Up @@ -468,6 +479,7 @@ def __sample_from_edges_func(
random_state=random_state,
)
minibatch_dict["input_index"] = current_ix.cuda()
minibatch_dict["input_label"] = current_label.cuda()
minibatch_dict["input_offsets"] = input_offsets
minibatch_dict[
"edge_inverse"
Expand Down Expand Up @@ -505,6 +517,7 @@ def sample_from_edges(
random_state: int = 62,
assume_equal_input_size: bool = False,
input_id: Optional[TensorType] = None,
input_label: Optional[TensorType] = None,
) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]:
"""
Performs sampling starting from seed edges.
Expand All @@ -527,6 +540,10 @@ def sample_from_edges(
Input ids corresponding to the original batch tensor, if it
was permuted prior to calling this function. If present,
will be saved with the samples.
input_label: Optional[TensorType]
Input labels corresponding to the input seeds. Typically used
for link prediction sampling. If present, will be saved with
the samples. Generally not compatible with negative sampling.
"""

torch = import_optional("torch")
Expand All @@ -545,12 +562,20 @@ def sample_from_edges(
local_num_batches, assume_equal_input_size=assume_equal_input_size
)

edges_call_groups, index_call_groups = self.__get_call_groups(
groups = self.__get_call_groups(
edges,
input_id,
actual_seed_edges_per_call,
assume_equal_input_size=input_size_is_equal,
label=input_label,
)
if len(groups) == 2:
edges_call_groups, index_call_groups = groups
label_call_groups = [torch.tensor([], dtype=torch.int32)] * len(
edges_call_groups
)
else:
edges_call_groups, index_call_groups, label_call_groups = groups

sample_args = [
batch_id_start,
Expand All @@ -563,14 +588,14 @@ def sample_from_edges(
if self.__writer is None:
# Buffered sampling
return BufferedSampleReader(
zip(edges_call_groups, index_call_groups),
zip(edges_call_groups, index_call_groups, label_call_groups),
self.__sample_from_edges_func,
*sample_args,
)
else:
# Unbuffered sampling
for i, current_seeds_and_ix in enumerate(
zip(edges_call_groups, index_call_groups)
zip(edges_call_groups, index_call_groups, label_call_groups)
):
sample_args[0] = self.__sample_from_edges_func(
i,
Expand Down

0 comments on commit 9e3a457

Please sign in to comment.