Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Biased Sampling in cuGraph-PyG #4586

Merged
merged 16 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 72 additions & 45 deletions python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cugraph.utilities.utils import import_optional, MissingModule
from cugraph.gnn.comms import cugraph_comms_get_raft_handle

from typing import Union, Optional, List, Dict
from typing import Union, Optional, List, Dict, Tuple


# Have to use import_optional even though these are required
Expand Down Expand Up @@ -58,13 +58,19 @@ def __init__(self, is_multi_gpu: bool = False):
"""
self.__edge_indices = tensordict.TensorDict({}, batch_size=(2,))
self.__sizes = {}
self.__graph = None
self.__vertex_offsets = None

self.__handle = None
self.__is_multi_gpu = is_multi_gpu

self.__clear_graph()

super().__init__()

def __clear_graph(self):
self.__graph = None
self.__vertex_offsets = None
self.__weight_attr = None

def _put_edge_index(
self,
edge_index: "torch_geometric.typing.EdgeTensorType",
Expand All @@ -88,8 +94,7 @@ def _put_edge_index(
self.__sizes[edge_attr.edge_type] = edge_attr.size

# invalidate the graph
self.__graph = None
self.__vertex_offsets = None
self.__clear_graph()
return True

def _get_edge_index(
Expand All @@ -108,7 +113,7 @@ def _remove_edge_index(self, edge_attr: "torch_geometric.data.EdgeAttr") -> bool
del self.__edge_indices[edge_attr.edge_type]

# invalidate the graph
self.__graph = None
self.__clear_graph()
return True

def get_all_edge_attrs(self) -> List["torch_geometric.data.EdgeAttr"]:
Expand Down Expand Up @@ -163,6 +168,9 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
vertices_array=[vertices_array],
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
weight_array=[cupy.asarray(edgelist_dict["wgt"])]
if "wgt" in edgelist_dict
else None,
)
else:
self.__graph = pylibcugraph.SGGraph(
Expand All @@ -175,6 +183,9 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
),
edge_id_array=cupy.asarray(edgelist_dict["eid"]),
edge_type_array=cupy.asarray(edgelist_dict["etp"]),
weight_array=cupy.asarray(edgelist_dict["wgt"])
if "wgt" in edgelist_dict
else None,
)

return self.__graph
Expand Down Expand Up @@ -228,6 +239,32 @@ def _vertex_offsets(self) -> Dict[str, int]:
def is_homogeneous(self) -> bool:
return len(self._vertex_offsets) == 1

def _set_weight_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
if attr != self.__weight_attr:
self.__clear_graph()
self.__weight_attr = attr

def __get_weight_tensor(
self,
sorted_keys: List[Tuple[str, str, str]],
start_offsets: "torch.Tensor",
num_edges_t: "torch.Tensor",
):
feature_store, attr_name = self.__weight_attr

weights = []
for i, et in enumerate(sorted_keys):
ix = torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_t[i],
dtype=torch.int64,
device="cpu",
)

weights.append(feature_store[et, attr_name][ix])

return torch.concat(weights)

def __get_edgelist(self):
"""
Returns
Expand Down Expand Up @@ -275,59 +312,49 @@ def __get_edgelist(self):
)
)

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)
num_edges_all_t = torch.empty(
world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda"
)
torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t)

if rank > 0:
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
edge_id_array = torch.concat(
[
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cuda",
)
for i in range(len(sorted_keys))
]
)
else:
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
)

start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
else:
# single GPU
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
rank = 0
start_offsets = torch.zeros(
(len(sorted_keys),), dtype=torch.int64, device="cuda"
)
num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel()))

edge_id_array = torch.concat(
[
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cuda",
)
for i in range(len(sorted_keys))
]
)

return {
d = {
"dst": edge_index[0],
"src": edge_index[1],
"etp": edge_type_array,
"eid": edge_id_array,
}

if self.__weight_attr is not None:
d["wgt"] = self.__get_weight_tensor(
sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
).cuda()

return d
13 changes: 8 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cugraph_pyg.loader import NodeLoader
from cugraph_pyg.sampler import BaseSampler

from cugraph.gnn import UniformNeighborSampler, DistSampleWriter
from cugraph.gnn import NeighborSampler, DistSampleWriter
from cugraph.utilities.utils import import_optional

torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
neighbor_sampler: Optional["torch_geometric.sampler.NeighborSampler"] = None,
directed: bool = True, # Deprecated.
batch_size: int = 16,
directory: str = None,
directory: Optional[str] = None,
batches_per_partition=256,
format: str = "parquet",
compression: Optional[str] = None,
Expand Down Expand Up @@ -174,8 +174,6 @@ def __init__(
raise ValueError("Passing a neighbor sampler is currently unsupported")
if time_attr is not None:
raise ValueError("Temporal sampling is currently unsupported")
if weight_attr is not None:
raise ValueError("Biased sampling is currently unsupported")
if is_sorted:
warnings.warn("The 'is_sorted' argument is ignored by cuGraph.")
if not isinstance(data, (list, tuple)) or not isinstance(
Expand All @@ -201,8 +199,12 @@ def __init__(
)

feature_store, graph_store = data

if weight_attr is not None:
graph_store._set_weight_attr((feature_store, weight_attr))

sampler = BaseSampler(
UniformNeighborSampler(
NeighborSampler(
graph_store._graph,
writer,
retain_original_seeds=True,
Expand All @@ -213,6 +215,7 @@ def __init__(
compress_per_hop=False,
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
18 changes: 17 additions & 1 deletion python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from cugraph.utilities.utils import import_optional
from cugraph.gnn import DistSampler, DistSampleReader

from .sampler_utils import filter_cugraph_pyg_store

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")

Expand Down Expand Up @@ -65,6 +67,19 @@ def __next__(self):
next_sample.col, next_sample.edge.numel()
)

data = filter_cugraph_pyg_store(
self.__feature_store,
self.__graph_store,
next_sample.node,
next_sample.row,
col,
next_sample.edge,
None,
)

"""
# TODO Re-enable this once PyG resolves
# the issue with edge features (9566)
data = torch_geometric.loader.utils.filter_custom_store(
self.__feature_store,
self.__graph_store,
Expand All @@ -74,6 +89,7 @@ def __next__(self):
next_sample.edge,
None,
)
"""

if "n_id" not in data:
data.n_id = next_sample.node
Expand Down Expand Up @@ -250,7 +266,7 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
node=renumber_map.cpu(),
row=minors,
col=major_offsets,
edge=edge_id,
edge=edge_id.cpu(),
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes.cpu(),
num_sampled_edges=num_sampled_edges.cpu(),
Expand Down
26 changes: 26 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,29 @@ def _sampler_output_from_sampling_results_heterogeneous(
num_sampled_edges={k: t.tolist() for k, t in num_edges_per_hop_dict.items()},
metadata=metadata,
)


def filter_cugraph_pyg_store(
feature_store,
graph_store,
node,
row,
col,
edge,
clx,
) -> "torch_geometric.data.Data":
data = torch_geometric.data.Data()

data.edge_index = torch.stack([row, col], dim=0)

required_attrs = []
for attr in feature_store.get_all_tensor_attrs():
attr.index = edge if isinstance(attr.group_name, tuple) else node
required_attrs.append(attr)
data.num_nodes = attr.index.size(0)

tensors = feature_store.multi_get_tensor(required_attrs)
for i, attr in enumerate(required_attrs):
data[attr.attr_name] = tensors[i]

return data
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,43 @@ def test_neighbor_loader():
(feature_store, graph_store),
[5, 5],
input_nodes=torch.arange(34),
directory=".",
)

for batch in loader:
assert isinstance(batch, torch_geometric.data.Data)
assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all()


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.sg
def test_neighbor_loader_biased():
eix = torch.tensor(
[
[3, 4, 5],
[0, 1, 2],
]
)

graph_store = GraphStore()
graph_store.put_edge_index(eix, ("person", "knows", "person"), "coo")

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (6, 12))
feature_store[("person", "knows", "person"), "bias"] = torch.tensor(
[0, 12, 14], dtype=torch.float32
)

loader = NeighborLoader(
(feature_store, graph_store),
[1],
input_nodes=torch.tensor([0, 1, 2], dtype=torch.int64),
batch_size=3,
weight_attr="bias",
)

out = list(iter(loader))
assert len(out) == 1
out = out[0]

assert out.edge_index.shape[1] == 2
assert (out.edge_index.cpu() == torch.tensor([[3, 4], [1, 2]])).all()
Loading
Loading