Skip to content

Commit

Permalink
Merge pull request #31 from alexbarghi-nv/pyg-biased
Browse files Browse the repository at this point in the history
Add PyG Biased Sampling (rapidsai/cugraph#4586)
  • Loading branch information
alexbarghi-nv authored Aug 19, 2024
2 parents 755c2e3 + 3e5df7c commit ef8d1e4
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 52 deletions.
117 changes: 72 additions & 45 deletions python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cugraph.utilities.utils import import_optional, MissingModule
from cugraph.gnn.comms import cugraph_comms_get_raft_handle

from typing import Union, Optional, List, Dict
from typing import Union, Optional, List, Dict, Tuple


# Have to use import_optional even though these are required
Expand Down Expand Up @@ -58,13 +58,19 @@ def __init__(self, is_multi_gpu: bool = False):
"""
self.__edge_indices = tensordict.TensorDict({}, batch_size=(2,))
self.__sizes = {}
self.__graph = None
self.__vertex_offsets = None

self.__handle = None
self.__is_multi_gpu = is_multi_gpu

self.__clear_graph()

super().__init__()

def __clear_graph(self):
self.__graph = None
self.__vertex_offsets = None
self.__weight_attr = None

def _put_edge_index(
self,
edge_index: "torch_geometric.typing.EdgeTensorType",
Expand All @@ -88,8 +94,7 @@ def _put_edge_index(
self.__sizes[edge_attr.edge_type] = edge_attr.size

# invalidate the graph
self.__graph = None
self.__vertex_offsets = None
self.__clear_graph()
return True

def _get_edge_index(
Expand All @@ -108,7 +113,7 @@ def _remove_edge_index(self, edge_attr: "torch_geometric.data.EdgeAttr") -> bool
del self.__edge_indices[edge_attr.edge_type]

# invalidate the graph
self.__graph = None
self.__clear_graph()
return True

def get_all_edge_attrs(self) -> List["torch_geometric.data.EdgeAttr"]:
Expand Down Expand Up @@ -163,6 +168,9 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
vertices_array=[vertices_array],
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
weight_array=[cupy.asarray(edgelist_dict["wgt"])]
if "wgt" in edgelist_dict
else None,
)
else:
self.__graph = pylibcugraph.SGGraph(
Expand All @@ -175,6 +183,9 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
),
edge_id_array=cupy.asarray(edgelist_dict["eid"]),
edge_type_array=cupy.asarray(edgelist_dict["etp"]),
weight_array=cupy.asarray(edgelist_dict["wgt"])
if "wgt" in edgelist_dict
else None,
)

return self.__graph
Expand Down Expand Up @@ -228,6 +239,32 @@ def _vertex_offsets(self) -> Dict[str, int]:
def is_homogeneous(self) -> bool:
return len(self._vertex_offsets) == 1

def _set_weight_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
if attr != self.__weight_attr:
self.__clear_graph()
self.__weight_attr = attr

def __get_weight_tensor(
self,
sorted_keys: List[Tuple[str, str, str]],
start_offsets: "torch.Tensor",
num_edges_t: "torch.Tensor",
):
feature_store, attr_name = self.__weight_attr

weights = []
for i, et in enumerate(sorted_keys):
ix = torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_t[i],
dtype=torch.int64,
device="cpu",
)

weights.append(feature_store[et, attr_name][ix])

return torch.concat(weights)

def __get_edgelist(self):
"""
Returns
Expand Down Expand Up @@ -275,59 +312,49 @@ def __get_edgelist(self):
)
)

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)
num_edges_all_t = torch.empty(
world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda"
)
torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t)

if rank > 0:
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
edge_id_array = torch.concat(
[
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cuda",
)
for i in range(len(sorted_keys))
]
)
else:
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
)

start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
else:
# single GPU
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
rank = 0
start_offsets = torch.zeros(
(len(sorted_keys),), dtype=torch.int64, device="cuda"
)
num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel()))

edge_id_array = torch.concat(
[
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cuda",
)
for i in range(len(sorted_keys))
]
)

return {
d = {
"dst": edge_index[0],
"src": edge_index[1],
"etp": edge_type_array,
"eid": edge_id_array,
}

if self.__weight_attr is not None:
d["wgt"] = self.__get_weight_tensor(
sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
).cuda()

return d
13 changes: 8 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cugraph_pyg.loader import NodeLoader
from cugraph_pyg.sampler import BaseSampler

from cugraph.gnn import UniformNeighborSampler, DistSampleWriter
from cugraph.gnn import NeighborSampler, DistSampleWriter
from cugraph.utilities.utils import import_optional

torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
neighbor_sampler: Optional["torch_geometric.sampler.NeighborSampler"] = None,
directed: bool = True, # Deprecated.
batch_size: int = 16,
directory: str = None,
directory: Optional[str] = None,
batches_per_partition=256,
format: str = "parquet",
compression: Optional[str] = None,
Expand Down Expand Up @@ -174,8 +174,6 @@ def __init__(
raise ValueError("Passing a neighbor sampler is currently unsupported")
if time_attr is not None:
raise ValueError("Temporal sampling is currently unsupported")
if weight_attr is not None:
raise ValueError("Biased sampling is currently unsupported")
if is_sorted:
warnings.warn("The 'is_sorted' argument is ignored by cuGraph.")
if not isinstance(data, (list, tuple)) or not isinstance(
Expand All @@ -201,8 +199,12 @@ def __init__(
)

feature_store, graph_store = data

if weight_attr is not None:
graph_store._set_weight_attr((feature_store, weight_attr))

sampler = BaseSampler(
UniformNeighborSampler(
NeighborSampler(
graph_store._graph,
writer,
retain_original_seeds=True,
Expand All @@ -213,6 +215,7 @@ def __init__(
compress_per_hop=False,
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
18 changes: 17 additions & 1 deletion python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from cugraph.utilities.utils import import_optional
from cugraph.gnn import DistSampler, DistSampleReader

from .sampler_utils import filter_cugraph_pyg_store

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")

Expand Down Expand Up @@ -65,6 +67,19 @@ def __next__(self):
next_sample.col, next_sample.edge.numel()
)

data = filter_cugraph_pyg_store(
self.__feature_store,
self.__graph_store,
next_sample.node,
next_sample.row,
col,
next_sample.edge,
None,
)

"""
# TODO Re-enable this once PyG resolves
# the issue with edge features (9566)
data = torch_geometric.loader.utils.filter_custom_store(
self.__feature_store,
self.__graph_store,
Expand All @@ -74,6 +89,7 @@ def __next__(self):
next_sample.edge,
None,
)
"""

if "n_id" not in data:
data.n_id = next_sample.node
Expand Down Expand Up @@ -250,7 +266,7 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
node=renumber_map.cpu(),
row=minors,
col=major_offsets,
edge=edge_id,
edge=edge_id.cpu(),
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes.cpu(),
num_sampled_edges=num_sampled_edges.cpu(),
Expand Down
26 changes: 26 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,29 @@ def _sampler_output_from_sampling_results_heterogeneous(
num_sampled_edges={k: t.tolist() for k, t in num_edges_per_hop_dict.items()},
metadata=metadata,
)


def filter_cugraph_pyg_store(
feature_store,
graph_store,
node,
row,
col,
edge,
clx,
) -> "torch_geometric.data.Data":
data = torch_geometric.data.Data()

data.edge_index = torch.stack([row, col], dim=0)

required_attrs = []
for attr in feature_store.get_all_tensor_attrs():
attr.index = edge if isinstance(attr.group_name, tuple) else node
required_attrs.append(attr)
data.num_nodes = attr.index.size(0)

tensors = feature_store.multi_get_tensor(required_attrs)
for i, attr in enumerate(required_attrs):
data[attr.attr_name] = tensors[i]

return data
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,43 @@ def test_neighbor_loader():
(feature_store, graph_store),
[5, 5],
input_nodes=torch.arange(34),
directory=".",
)

for batch in loader:
assert isinstance(batch, torch_geometric.data.Data)
assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all()


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.sg
def test_neighbor_loader_biased():
eix = torch.tensor(
[
[3, 4, 5],
[0, 1, 2],
]
)

graph_store = GraphStore()
graph_store.put_edge_index(eix, ("person", "knows", "person"), "coo")

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (6, 12))
feature_store[("person", "knows", "person"), "bias"] = torch.tensor(
[0, 12, 14], dtype=torch.float32
)

loader = NeighborLoader(
(feature_store, graph_store),
[1],
input_nodes=torch.tensor([0, 1, 2], dtype=torch.int64),
batch_size=3,
weight_attr="bias",
)

out = list(iter(loader))
assert len(out) == 1
out = out[0]

assert out.edge_index.shape[1] == 2
assert (out.edge_index.cpu() == torch.tensor([[3, 4], [1, 2]])).all()
Loading

0 comments on commit ef8d1e4

Please sign in to comment.