diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index dbddf02..40819cd 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -153,6 +153,9 @@ def __next__(self): data.set_value_dict("num_sampled_edges", next_sample.num_sampled_edges) # TODO figure out how to set input_id for heterogeneous output + input_type, input_id = next_sample.metadata[0] + data[input_type].input_id = input_id + data[input_type].batch_size = input_id.size(0) else: raise ValueError("Invalid output type") @@ -287,6 +290,8 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): col = {} edge = {} + input_type = None + for etype in range(num_edge_types): pyg_can_etype = self.__edge_types[etype] @@ -349,11 +354,15 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): ux = col[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]] if ux.numel() > 0: + input_type = pyg_can_etype[0] # can only ever be 1 num_sampled_nodes[self.__src_types[etype]][0] = torch.max( num_sampled_nodes[self.__src_types[etype]][0], (ux.max() + 1).reshape((1,)), ) + if input_type is None: + raise ValueError("No input type found!") + num_sampled_nodes = { self.__vertex_types[i]: z.diff( prepend=torch.zeros((1,), dtype=torch.int64, device="cuda") @@ -362,11 +371,14 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): } num_sampled_edges = {k: v.cpu() for k, v in num_sampled_edges.items()} - input_index = raw_sample_data["input_index"][ - raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ - index + 1 - ] - ] + input_index = ( + input_type, + raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data[ + "input_offsets" + ][index + 1] + ], + ) edge_inverse = ( (