Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] DGL Examples #4583

Merged
merged 74 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
b5915ff
merge
alexbarghi-nv Jun 3, 2024
4e3045f
c
alexbarghi-nv Jun 6, 2024
f243351
pull in dependency fixes
alexbarghi-nv Jun 6, 2024
b1adcd3
merge
alexbarghi-nv Jun 6, 2024
4c29329
w
alexbarghi-nv Jun 7, 2024
265f546
basic graph/fs
alexbarghi-nv Jun 7, 2024
b51eda4
dist sampling
alexbarghi-nv Jun 10, 2024
9943260
graph data views
alexbarghi-nv Jun 12, 2024
055db0a
remove unwanted file
alexbarghi-nv Jun 12, 2024
020bf67
Merge branch 'branch-24.08' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Jun 13, 2024
1f76898
revert devcontainer change
alexbarghi-nv Jun 13, 2024
927ee09
tests, bugfixes, resolve indexing problem (sort of)
alexbarghi-nv Jun 14, 2024
ffdb2fa
Merge branch 'branch-24.08' into dgl-dist-sampling
nv-rliu Jun 24, 2024
68129d9
add heteogeneous tests
alexbarghi-nv Jun 25, 2024
20450a3
testing, fixing graph API
alexbarghi-nv Jun 26, 2024
58ec793
Merge branch 'dgl-dist-sampling' of https://github.com/alexbarghi-nv/…
alexbarghi-nv Jun 26, 2024
557d9aa
Loaders
alexbarghi-nv Jun 27, 2024
a8c0848
add todo
alexbarghi-nv Jun 27, 2024
913b8cd
fix block issue, typing
alexbarghi-nv Jun 28, 2024
79c8f78
reorganize tests
alexbarghi-nv Jul 1, 2024
b25128b
Merge branch 'branch-24.08' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Jul 1, 2024
a56b56d
sampling
alexbarghi-nv Jul 2, 2024
8f14f88
revert dependencies.yaml
alexbarghi-nv Jul 2, 2024
5f74252
update tensordict dependency
alexbarghi-nv Jul 2, 2024
ad120ac
Merge branch 'branch-24.08' into dgl-dist-sampling
alexbarghi-nv Jul 2, 2024
b2fdef8
update dependencies
alexbarghi-nv Jul 2, 2024
5ce714d
Merge branch 'dgl-dist-sampling' of https://github.com/alexbarghi-nv/…
alexbarghi-nv Jul 2, 2024
92fd866
update meta files
alexbarghi-nv Jul 2, 2024
6107d82
fix csr/csc issue, wrap up tests
alexbarghi-nv Jul 3, 2024
f04700d
Merge branch 'branch-24.08' into dgl-dist-sampling
alexbarghi-nv Jul 8, 2024
6bc4b4a
m
alexbarghi-nv Jul 8, 2024
faeb4a5
style
alexbarghi-nv Jul 8, 2024
afb9452
revert ci script
alexbarghi-nv Jul 8, 2024
48ba6d4
fix meta.yaml issue
alexbarghi-nv Jul 9, 2024
786c1af
Merge branch 'branch-24.08' into dgl-dist-sampling
alexbarghi-nv Jul 9, 2024
801de87
add type hint
alexbarghi-nv Jul 10, 2024
5e511cc
add missing type hint
alexbarghi-nv Jul 10, 2024
035b69a
remove comment, add issue reference
alexbarghi-nv Jul 10, 2024
ebbc1db
Merge branch 'dgl-dist-sampling' of https://github.com/alexbarghi-nv/…
alexbarghi-nv Jul 10, 2024
b412776
Add type hint
alexbarghi-nv Jul 10, 2024
1c72bd6
add convert function, fix bugs
alexbarghi-nv Jul 10, 2024
18f6ac2
Merge branch 'dgl-dist-sampling' of https://github.com/alexbarghi-nv/…
alexbarghi-nv Jul 10, 2024
9bd0440
Merge branch 'branch-24.08' into dgl-dist-sampling
alexbarghi-nv Jul 10, 2024
2d522b1
move worker init to utility
alexbarghi-nv Jul 10, 2024
4b60d8d
Merge branch 'dgl-dist-sampling' of https://github.com/alexbarghi-nv/…
alexbarghi-nv Jul 10, 2024
e1fa6e0
revert none return, add check
alexbarghi-nv Jul 10, 2024
8529987
style
alexbarghi-nv Jul 10, 2024
89f4ef4
use global communicator
alexbarghi-nv Jul 22, 2024
4d82ee0
global
alexbarghi-nv Jul 22, 2024
b4ed827
Merge branch 'branch-24.08' into use-correct-communicator
alexbarghi-nv Jul 23, 2024
e144ad1
Merge branch 'branch-24.08' into use-correct-communicator
alexbarghi-nv Jul 24, 2024
2b160bf
use int64 to store # edges
alexbarghi-nv Jul 24, 2024
22b85d2
Merge branch 'use-correct-communicator' of https://github.com/alexbar…
alexbarghi-nv Jul 24, 2024
ae9133f
resolve merge conflict
alexbarghi-nv Jul 25, 2024
1cf01c7
Merge branch 'use-correct-communicator' into dgl-examples
alexbarghi-nv Jul 25, 2024
6db236c
example
alexbarghi-nv Jul 25, 2024
7a3d38f
reverse mfgs
alexbarghi-nv Jul 25, 2024
710741c
node classification
alexbarghi-nv Jul 30, 2024
ddb95d6
resolve merge conflict
alexbarghi-nv Jul 30, 2024
f943d91
mnmg
alexbarghi-nv Jul 30, 2024
7ba4d89
use global communicator
alexbarghi-nv Jul 30, 2024
d9e9b50
Merge branch 'dgl-dist-sampling' into dgl-examples
alexbarghi-nv Jul 30, 2024
e0f1a90
examples
alexbarghi-nv Jul 31, 2024
2d3a640
fix partition function
alexbarghi-nv Aug 1, 2024
994aca8
fix minor issues
alexbarghi-nv Aug 1, 2024
d1c8494
remove dask example
alexbarghi-nv Aug 1, 2024
e92aa28
Merge branch 'branch-24.08' into dgl-examples
alexbarghi-nv Aug 1, 2024
05e1da4
use float64
alexbarghi-nv Aug 2, 2024
c5842b7
Merge branch 'dgl-examples' of https://github.com/alexbarghi-nv/cugra…
alexbarghi-nv Aug 2, 2024
5b46f43
set dtype
alexbarghi-nv Aug 2, 2024
2083b26
Merge branch 'branch-24.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Aug 5, 2024
139b3d6
allow setting directories
alexbarghi-nv Aug 5, 2024
6c9a5bf
Merge branch 'branch-24.10' into dgl-examples
alexbarghi-nv Aug 5, 2024
f233ee4
Merge branch 'branch-24.10' into dgl-examples
alexbarghi-nv Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies:
- numpy>=1.23,<2.0a0
- numpydoc
- nvcc_linux-64=11.8
- ogb
- openmpi
- packaging>=21
- pandas
Expand Down Expand Up @@ -74,6 +75,7 @@ dependencies:
- sphinxcontrib-websupport
- thriftpy2!=0.5.0,!=0.5.1
- torchdata
- torchmetrics
- ucx-proc=*=gpu
- ucx-py==0.40.*,>=0.0.0a0
- wget
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies:
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpydoc
- ogb
- openmpi
- packaging>=21
- pandas
Expand Down Expand Up @@ -79,6 +80,7 @@ dependencies:
- sphinxcontrib-websupport
- thriftpy2!=0.5.0,!=0.5.1
- torchdata
- torchmetrics
- ucx-proc=*=gpu
- ucx-py==0.40.*,>=0.0.0a0
- wget
Expand Down
2 changes: 2 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ dependencies:
- &pytorch_unsuffixed pytorch>=2.0,<2.2.0a0
- torchdata
- pydantic
- ogb
- torchmetrics

specific:
- output_types: [requirements]
Expand Down
4 changes: 4 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
self.__graph = graph
self.__device = device

@property
def _batch_size(self):
return self.__batch_size

@property
def dataset(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def sample(

if g.is_homogeneous:
indices = torch.concat(list(indices))
ds.sample_from_nodes(indices, batch_size=batch_size)
ds.sample_from_nodes(indices.long(), batch_size=batch_size)
return HomogeneousSampleReader(
ds.get_reader(), self.output_format, self.edge_dir
)
Expand Down
20 changes: 14 additions & 6 deletions python/cugraph-dgl/cugraph_dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
HeteroNodeDataView,
HeteroEdgeView,
HeteroEdgeDataView,
EmbeddingView,
)


Expand Down Expand Up @@ -567,8 +568,8 @@ def _has_n_emb(self, ntype: str, emb_name: str) -> bool:
return (ntype, emb_name) in self.__ndata_storage

def _get_n_emb(
self, ntype: str, emb_name: str, u: Union[str, TensorType]
) -> "torch.Tensor":
self, ntype: Union[str, None], emb_name: str, u: Union[str, TensorType]
) -> Union["torch.Tensor", "EmbeddingView"]:
"""
Gets the embedding of a single node type.
Unlike DGL, this function takes the string node
Expand All @@ -583,11 +584,11 @@ def _get_n_emb(
u: Union[str, TensorType]
Nodes to get the representation of, or ALL
to get the representation of all nodes of
the given type.
the given type (returns embedding view).

Returns
-------
torch.Tensor
Union[torch.Tensor, cugraph_dgl.view.EmbeddingView]
The embedding of the given edge type with the given embedding name.
"""

Expand All @@ -598,9 +599,14 @@ def _get_n_emb(
raise ValueError("Must provide the node type for a heterogeneous graph")

if dgl.base.is_all(u):
u = torch.arange(self.num_nodes(ntype), dtype=self.idtype, device="cpu")
return EmbeddingView(
self.__ndata_storage[ntype, emb_name], self.num_nodes(ntype)
)

try:
print(
u,
)
return self.__ndata_storage[ntype, emb_name].fetch(
_cast_to_torch_tensor(u), "cuda"
)
Expand Down Expand Up @@ -644,7 +650,9 @@ def _get_e_emb(
etype = self.to_canonical_etype(etype)

if dgl.base.is_all(u):
u = torch.arange(self.num_edges(etype), dtype=self.idtype, device="cpu")
return EmbeddingView(
self.__edata_storage[etype, emb_name], self.num_edges(etype)
)

try:
return self.__edata_storage[etype, emb_name].fetch(
Expand Down
36 changes: 36 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# limitations under the License.


import warnings

from collections import defaultdict
from collections.abc import MutableMapping
from typing import Union, Dict, List, Tuple
Expand All @@ -20,11 +22,45 @@

import cugraph_dgl
from cugraph_dgl.typing import TensorType
from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor

torch = import_optional("torch")
dgl = import_optional("dgl")


class EmbeddingView:
def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int):
self.__ld = ld
self.__storage = storage

def __getitem__(self, u: TensorType) -> "torch.Tensor":
u = _cast_to_torch_tensor(u)
try:
return self.__storage.fetch(
u,
"cuda",
)
except RuntimeError as ex:
warnings.warn(
"Got error accessing data, trying again with index on device: "
+ str(ex)
)
return self.__storage.fetch(
u.cuda(),
"cuda",
)

@property
def shape(self) -> "torch.Size":
try:
f = self.__storage.fetch(torch.tensor([0]), "cpu")
except RuntimeError:
f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda")
sz = [s for s in f.shape]
sz[0] = self.__ld
return torch.Size(tuple(sz))


class HeteroEdgeDataView(MutableMapping):
"""
Duck-typed version of DGL's HeteroEdgeDataView.
Expand Down
Loading
Loading