Skip to content

Commit

Permalink
fixes, add/improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Oct 18, 2024
1 parent 2cf4252 commit b4cd8de
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 100 deletions.
91 changes: 63 additions & 28 deletions python/cugraph-dgl/cugraph_dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,27 +983,57 @@ def global_uniform_negative_sampling(

if len(self.ntypes) == 1:
vertices = torch.arange(self.num_nodes())
src_vertex_offset = 0
dst_vertex_offset = 0
src_bias = cupy.ones(len(vertices), dtype="float32")
dst_bias = src_bias
else:
can_edge_type = self.to_canonical_etype(etype)
src_vertex_offset = self._vertex_offsets[can_edge_type[0]]
dst_vertex_offset = self._vertex_offsets[can_edge_type[2]]

# Limit sampled vertices to those of the given edge type.
vertices = torch.concat(
[
torch.arange(
self._vertex_offsets[can_edge_type[0]],
self._vertex_offsets[can_edge_type[0]]
+ self.num_nodes(can_edge_type[0]),
dtype=torch.int64,
device="cuda",
),
torch.arange(
self._vertex_offsets[can_edge_type[2]],
self._vertex_offsets[can_edge_type[2]]
+ self.num_nodes(can_edge_type[2]),
dtype=torch.int64,
device="cuda",
),
]
)
if can_edge_type[0] == can_edge_type[2]:
vertices = torch.arange(
src_vertex_offset,
src_vertex_offset + self.num_nodes(can_edge_type[0]),
dtype=torch.int64,
device="cuda",
)
src_bias = cupy.ones(self.num_nodes(can_edge_type[0]), dtype="float32")
dst_bias = src_bias

else:
vertices = torch.concat(
[
torch.arange(
src_vertex_offset,
src_vertex_offset + self.num_nodes(can_edge_type[0]),
dtype=torch.int64,
device="cuda",
),
torch.arange(
dst_vertex_offset,
dst_vertex_offset + self.num_nodes(can_edge_type[2]),
dtype=torch.int64,
device="cuda",
),
]
)

src_bias = cupy.concatenate(
[
cupy.ones(self.num_nodes(can_edge_type[0]), dtype="float32"),
cupy.zeros(self.num_nodes(can_edge_type[2]), dtype="float32"),
]
)

dst_bias = cupy.concatenate(
[
cupy.zeros(self.num_nodes(can_edge_type[0]), dtype="float32"),
cupy.ones(self.num_nodes(can_edge_type[2]), dtype="float32"),
]
)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
Expand All @@ -1020,19 +1050,20 @@ def global_uniform_negative_sampling(
num_samples_global = num_samples

graph = (
self.__graph
if self.__graph["direction"] == "out"
else self._graph("out", self.__graph["prob_attr"])
self.__graph["graph"]
if self.__graph is not None and self.__graph["direction"] == "out"
else self._graph(
"out", None if self.__graph is None else self.__graph["prob_attr"]
)
)
bias = cupy.ones(len(vertices), dtype="float32")

result_dict = pylibcugraph.negative_sampling(
self._resource_handle,
graph,
num_samples_global,
vertices=cupy.asarray(vertices),
src_bias=bias,
dst_bias=bias,
src_bias=src_bias,
dst_bias=dst_bias,
remove_duplicates=True,
remove_false_negatives=True,
exact_number_of_samples=True,
Expand All @@ -1041,10 +1072,14 @@ def global_uniform_negative_sampling(

# TODO remove this workaround once the C API is updated to take a local number
# of negatives (rapidsai/cugraph#4672)
src_neg = torch.as_tensor(result_dict["sources"], device="cuda")[:num_samples]
dst_neg = torch.as_tensor(result_dict["destinations"], device="cuda")[
:num_samples
]
src_neg = (
torch.as_tensor(result_dict["sources"], device="cuda")[:num_samples]
- src_vertex_offset
)
dst_neg = (
torch.as_tensor(result_dict["destinations"], device="cuda")[:num_samples]
- dst_vertex_offset
)

if exclude_self_loops:
f = src_neg != dst_neg
Expand Down
68 changes: 68 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@

import dgl
import torch
import numpy as np

import cugraph_dgl

from cugraph.testing.mg_utils import (
start_dask_client,
stop_dask_client,
)

from cugraph.datasets import karate


@pytest.fixture(scope="module")
def dask_client():
Expand Down Expand Up @@ -66,3 +71,66 @@ def dgl_graph_1():
src = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
dst = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
return dgl.graph((src, dst))


def create_karate_bipartite(multi_gpu: bool = False):
df = karate.get_edgelist()
df.src = df.src.astype("int64")
df.dst = df.dst.astype("int64")

graph = cugraph_dgl.Graph(is_multi_gpu=multi_gpu)
total_num_nodes = max(df.src.max(), df.dst.max()) + 1

num_nodes_group_1 = total_num_nodes // 2
num_nodes_group_2 = total_num_nodes - num_nodes_group_1

node_x_1 = np.random.random((num_nodes_group_1,))
node_x_2 = np.random.random((num_nodes_group_2,))

graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1")
graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2")

edges = {}
edges["type1", "e1", "type1"] = df[
(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)
]
edges["type1", "e2", "type2"] = df[
(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)
]
edges["type2", "e3", "type1"] = df[
(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)
]
edges["type2", "e4", "type2"] = df[
(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)
]

edges["type1", "e2", "type2"].dst -= num_nodes_group_1
edges["type2", "e3", "type1"].src -= num_nodes_group_1
edges["type2", "e4", "type2"].dst -= num_nodes_group_1
edges["type2", "e4", "type2"].src -= num_nodes_group_1

if multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

edges_local = {
etype: edf.iloc[np.array_split(np.arange(edf), world_size)[rank]]
for etype, edf in edges
}
else:
edges_local = edges

for etype, edf in edges_local.items():
graph.add_edges(edf.src, edf.dst, etype=etype)

return graph, edges, (num_nodes_group_1, num_nodes_group_2)


@pytest.fixture
def karate_bipartite():
return create_karate_bipartite(False)


@pytest.fixture
def karate_bipartite_mg():
return create_karate_bipartite(True)
Loading

0 comments on commit b4cd8de

Please sign in to comment.