Skip to content

Commit

Permalink
get hetero input ids working
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Dec 5, 2024
1 parent 9eb3319 commit ef57559
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def __next__(self):
data.set_value_dict("num_sampled_edges", next_sample.num_sampled_edges)

# TODO figure out how to set input_id for heterogeneous output
input_type, input_id = next_sample.metadata[0]
data[input_type].input_id = input_id
data[input_type].batch_size = input_id.size(0)
else:
raise ValueError("Invalid output type")

Expand Down Expand Up @@ -287,6 +290,8 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
col = {}
edge = {}

input_type = None

for etype in range(num_edge_types):
pyg_can_etype = self.__edge_types[etype]

Expand Down Expand Up @@ -349,11 +354,15 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):

ux = col[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]]
if ux.numel() > 0:
input_type = pyg_can_etype[0] # can only ever be 1
num_sampled_nodes[self.__src_types[etype]][0] = torch.max(
num_sampled_nodes[self.__src_types[etype]][0],
(ux.max() + 1).reshape((1,)),
)

if input_type is None:
raise ValueError("No input type found!")

num_sampled_nodes = {
self.__vertex_types[i]: z.diff(
prepend=torch.zeros((1,), dtype=torch.int64, device="cuda")
Expand All @@ -362,11 +371,14 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
}
num_sampled_edges = {k: v.cpu() for k, v in num_sampled_edges.items()}

input_index = raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][
index + 1
]
]
input_index = (
input_type,
raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data[
"input_offsets"
][index + 1]
],
)

edge_inverse = (
(
Expand Down

0 comments on commit ef57559

Please sign in to comment.