Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WholeGraph Feature Store for cuGraph-PyG and cuGraph-DGL #3874

Merged
merged 31 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e192beb
WholeGraph Feature Store for cuGraph-PyG and cuGraph-DGL
alexbarghi-nv Sep 23, 2023
b618940
wholegraph
alexbarghi-nv Sep 25, 2023
5778176
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Sep 25, 2023
224f133
style
alexbarghi-nv Sep 25, 2023
e17b587
remove c code
alexbarghi-nv Sep 25, 2023
b64326d
wholegraph test
alexbarghi-nv Sep 25, 2023
57662d8
testing
alexbarghi-nv Sep 26, 2023
a485a07
commit
alexbarghi-nv Sep 26, 2023
0324099
change function name
alexbarghi-nv Sep 26, 2023
1546071
Merge branch 'branch-23.10' into wholegraph-fs
alexbarghi-nv Sep 26, 2023
905aaca
correct test
alexbarghi-nv Sep 26, 2023
18dbfee
Merge branch 'wholegraph-fs' of https://github.com/alexbarghi-nv/cugr…
alexbarghi-nv Sep 26, 2023
e46b5d8
style
alexbarghi-nv Sep 26, 2023
404405d
update test to use import_optional
alexbarghi-nv Sep 26, 2023
5ae7fec
add pylibwholegraph to update-version.sh
alexbarghi-nv Sep 26, 2023
4cc8730
style
alexbarghi-nv Sep 26, 2023
a30d829
fix type in dependencies.yaml
alexbarghi-nv Sep 26, 2023
a99b6f5
update update-version.sh
alexbarghi-nv Sep 26, 2023
9967da4
change to set wg properties at init
alexbarghi-nv Sep 26, 2023
3dee752
Merge branch 'branch-23.10' into wholegraph-fs
alexbarghi-nv Sep 26, 2023
78254c1
fix bad kwargs
alexbarghi-nv Sep 27, 2023
f882153
Merge branch 'wholegraph-fs' of https://github.com/alexbarghi-nv/cugr…
alexbarghi-nv Sep 27, 2023
70e0e16
Merge branch 'branch-23.10' into wholegraph-fs
alexbarghi-nv Sep 27, 2023
a65baa0
add missing module check in feature store
alexbarghi-nv Sep 27, 2023
a4f4aa5
Merge branch 'wholegraph-fs' of https://github.com/alexbarghi-nv/cugr…
alexbarghi-nv Sep 27, 2023
2b1c0d3
style
alexbarghi-nv Sep 27, 2023
4caa02e
update dependencies.yaml
alexbarghi-nv Sep 28, 2023
63b0bf8
rerun generator
alexbarghi-nv Sep 28, 2023
6e2a4a3
change pyproject to conda (wrong tag)
alexbarghi-nv Sep 29, 2023
61f8f61
run generator
alexbarghi-nv Sep 29, 2023
ce11c0c
Merge branch 'branch-23.10' into wholegraph-fs
alexbarghi-nv Sep 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; prin
DEPENDENCIES=(
cudf
cugraph
cugraph-dgl
cugraph-pyg
cugraph-service-server
cugraph-service-client
cuxfilter
dask-cuda
Expand All @@ -92,6 +95,7 @@ DEPENDENCIES=(
librmm
pylibcugraph
pylibcugraphops
pylibwholegraph
pylibraft
pyraft
raft-dask
Expand Down
1 change: 1 addition & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies:
- pydata-sphinx-theme
- pylibcugraphops==23.10.*
- pylibraft==23.10.*
- pylibwholegraph==23.10.*
- pytest
- pytest-benchmark
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies:
- pydata-sphinx-theme
- pylibcugraphops==23.10.*
- pylibraft==23.10.*
- pylibwholegraph==23.10.*
- pytest
- pytest-benchmark
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ dependencies:
- *numpy
- python-louvain
- scikit-learn>=0.23.1
- pylibwholegraph==23.10.*
test_python_pylibcugraph:
common:
- output_types: [conda, pyproject]
Expand Down
1 change: 1 addition & 0 deletions python/cugraph-service/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"networkx>=2.5.1",
"numpy>=1.21",
"pandas",
"pylibwholegraph==23.10.*",
"pytest",
"pytest-benchmark",
"pytest-cov",
Expand Down
64 changes: 59 additions & 5 deletions python/cugraph/cugraph/gnn/feature_storage/feat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@
from cugraph.utilities.utils import import_optional

torch = import_optional("torch")
wgth = import_optional("pylibwholegraph.torch")


class FeatureStore:
"""The feature-store class used to store feature data for GNNS"""

def __init__(self, backend="numpy"):
self.fd = defaultdict(dict)
if backend not in ["numpy", "torch"]:
if backend not in ["numpy", "torch", "wholegraph"]:
raise ValueError(
f"backend {backend} not supported. Supported backends are numpy, torch"
)
self.backend = backend

def add_data(self, feat_obj: Sequence, type_name: str, feat_name: str) -> None:
def add_data(
self, feat_obj: Sequence, type_name: str, feat_name: str, **kwargs
) -> None:
"""
Add the feature data to the feature_storage class
Parameters:
Expand All @@ -49,9 +52,26 @@ def add_data(self, feat_obj: Sequence, type_name: str, feat_name: str) -> None:
None
"""
self.fd[feat_name][type_name] = self._cast_feat_obj_to_backend(
feat_obj, self.backend
feat_obj, self.backend, **kwargs
)

def add_data_no_cast(self, feat_obj, type_name: str, feat_name: str) -> None:
"""
Direct add the feature data to the feature_storage class with no cast
Parameters:
----------
feat_obj : array_like object
The feature object to save in feature store
type_name : str
The node-type/edge-type of the feature
feat_name: str
The name of the feature being stored
Returns:
-------
None
"""
self.fd[feat_name][type_name] = feat_obj

def get_data(
self,
indices: Union[np.ndarray, torch.Tensor],
Expand Down Expand Up @@ -87,13 +107,22 @@ def get_data(
f" feature: {list(self.fd[feat_name].keys())}"
)

return self.fd[feat_name][type_name][indices]
feat = self.fd[feat_name][type_name]
if isinstance(feat, wgth.WholeMemoryEmbedding):
indices_tensor = (
indices
if isinstance(indices, torch.Tensor)
else torch.as_tensor(indices, device="cuda")
)
return feat.gather(indices_tensor)
else:
return feat[indices]

def get_feature_list(self) -> list[str]:
return {feat_name: feats.keys() for feat_name, feats in self.fd.items()}

@staticmethod
def _cast_feat_obj_to_backend(feat_obj, backend: str):
def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs):
if backend == "numpy":
if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
return _cast_to_numpy_ar(feat_obj.values)
Expand All @@ -104,6 +133,31 @@ def _cast_feat_obj_to_backend(feat_obj, backend: str):
return _cast_to_torch_tensor(feat_obj.values)
else:
return _cast_to_torch_tensor(feat_obj)
elif backend == "wholegraph":
wg_comm_obj = kwargs.get("wg_comm", wgth.get_local_node_communicator())
wg_type_str = kwargs.get("wg_type", "distributed")
wg_location_str = kwargs.get("wg_location", "cuda")
alexbarghi-nv marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
th_tensor = _cast_to_torch_tensor(feat_obj.values)
else:
th_tensor = _cast_to_torch_tensor(feat_obj)
wg_embedding = wgth.create_embedding(
wg_comm_obj,
wg_type_str,
wg_location_str,
th_tensor.dtype,
th_tensor.shape,
)
(
local_wg_tensor,
local_ld_offset,
) = wg_embedding.get_embedding_tensor().get_local_tensor()
local_th_tensor = th_tensor[
local_ld_offset : local_ld_offset + local_wg_tensor.shape[0]
]
local_wg_tensor.copy_(local_th_tensor)
wg_comm_obj.barrier()
return wg_embedding


def _cast_to_torch_tensor(ar):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Import FeatureStore class

import pytest
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import numpy as np

from cugraph.gnn import FeatureStore

from cugraph.utilities.utils import import_optional, MissingModule

pylibwholegraph = import_optional("pylibwholegraph")
wmb = import_optional("pylibwholegraph.binding.wholememory_binding")
torch = import_optional("torch")


def runtest(world_rank: int, world_size: int):
from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm

wm_comm, _ = init_torch_env_and_create_wm_comm(
world_rank,
world_size,
world_rank,
world_size,
)
wm_comm = wm_comm.wmb_comm

generator = np.random.default_rng(62)
arr = (
generator.integers(low=0, high=100, size=100_000)
.reshape(10_000, -1)
.astype("float64")
)

fs = FeatureStore(backend="wholegraph")
fs.add_data(arr, "type2", "feat1")
wm_comm.barrier()

indices_to_fetch = np.random.randint(low=0, high=len(arr), size=1024)
output_fs = fs.get_data(indices_to_fetch, type_name="type2", feat_name="feat1")
assert isinstance(output_fs, torch.Tensor)
assert output_fs.is_cuda
expected = arr[indices_to_fetch]
np.testing.assert_array_equal(output_fs.cpu().numpy(), expected)

wmb.finalize()


@pytest.mark.sg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
def test_feature_storage_wholegraph_backend():
from pylibwholegraph.utils.multiprocess import multiprocess_run

gpu_count = wmb.fork_get_gpu_count()
print("gpu count:", gpu_count)
assert gpu_count > 0

multiprocess_run(1, runtest)


@pytest.mark.mg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
def test_feature_storage_wholegraph_backend_mg():
from pylibwholegraph.utils.multiprocess import multiprocess_run

gpu_count = wmb.fork_get_gpu_count()
print("gpu count:", gpu_count)
assert gpu_count > 0

multiprocess_run(gpu_count, runtest)
1 change: 1 addition & 0 deletions python/cugraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test = [
"networkx>=2.5.1",
"numpy>=1.21",
"pandas",
"pylibwholegraph==23.10.*",
"pytest",
"pytest-benchmark",
"pytest-cov",
Expand Down