Skip to content

Commit

Permalink
Addresed using ntype and etype
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Nov 7, 2022
1 parent 5b36168 commit d24ab52
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def get_node_storage(self, key, ntype=None, indices_offset=0):
storage_type="node",
indices_offset=indices_offset,
backend_lib=self.backend_lib,
types_to_fetch=[ntype],
)

def get_edge_storage(self, key, etype=None, indices_offset=0):
Expand Down Expand Up @@ -310,6 +311,7 @@ def get_edge_storage(self, key, etype=None, indices_offset=0):
storage_type="edge",
backend_lib=self.backend_lib,
indices_offset=indices_offset,
types_to_fetch=[etype],
)

######################################
Expand Down
4 changes: 3 additions & 1 deletion python/cugraph/cugraph/gnn/dgl_extensions/cugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def get_node_storage(self, key, ntype=None, indices_offset=0):
storage_type="node",
indices_offset=indices_offset,
backend_lib=self.backend_lib,
types_to_fetch=[ntype],
)

def get_edge_storage(self, key, etype=None, indices_offset=0):
Expand Down Expand Up @@ -302,6 +303,7 @@ def get_edge_storage(self, key, etype=None, indices_offset=0):
storage_type="edge",
backend_lib=self.backend_lib,
indices_offset=indices_offset,
types_to_fetch=[etype],
)

######################################
Expand Down Expand Up @@ -377,7 +379,7 @@ def sample_neighbors(
sgs_obj=sgs_obj,
sgs_src_range_obj=sgs_src_range_obj,
sg_node_dtype=self._sg_node_dtype,
nodes_cap=nodes_cap,
nodes_ar=nodes_cap,
replace=replace,
fanout=fanout,
edge_dir=edge_dir,
Expand Down
18 changes: 14 additions & 4 deletions python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class CuFeatureStorage:
"""

def __init__(
self, pg, columns, storage_type, backend_lib="torch", indices_offset=0
self,
pg,
columns,
storage_type,
backend_lib="torch",
indices_offset=0,
types_to_fetch=None,
):
self.pg = pg
self.columns = columns
Expand All @@ -89,6 +95,7 @@ def __init__(

self.from_dlpack = from_dlpack
self.indices_offset = indices_offset
self.types_to_fetch = types_to_fetch

def fetch(self, indices, device=None, pin_memory=False, **kwargs):
"""Fetch the features of the given node/edge IDs to the
Expand Down Expand Up @@ -136,10 +143,13 @@ def fetch(self, indices, device=None, pin_memory=False, **kwargs):
indices = indices + self.indices_offset

if self.storage_type == "node":
result = self.pg.get_vertex_data(vertex_ids=indices, columns=self.columns)
result = self.pg.get_vertex_data(
vertex_ids=indices, columns=self.columns, types=self.types_to_fetch
)
else:
result = self.pg.get_edge_data(edge_ids=indices, columns=self.columns)

result = self.pg.get_edge_data(
edge_ids=indices, columns=self.columns, types=self.types_to_fetch
)
if type(result).__name__ == "DataFrame":
result = result[self.columns]
if hasattr(result, "compute"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,7 @@ def test_ntypes(dataset1_CuGraphStore):

def test_get_node_storage_gs(dataset1_CuGraphStore):
fs = dataset1_CuGraphStore.get_node_storage(key="merchant_k", ntype="merchant")
# indices = [11, 4, 21, 316, 11]
indices = [11, 4, 21, 316]
indices = [11, 4, 21, 316, 11]

merchant_gs = fs.fetch(indices, device="cuda")
merchant_df = create_df_from_dataset(
Expand All @@ -389,8 +388,31 @@ def test_get_node_storage_gs(dataset1_CuGraphStore):
assert cp.allclose(cudf_ar, merchant_gs)


def test_get_node_storage_ntypes():
node_ser = cudf.Series([1, 2, 3])
feat_ser = cudf.Series([1.0, 1.0, 1.0])
df = cudf.DataFrame({"node_ids": node_ser, "feat": feat_ser})
pg = PropertyGraph()
gs = CuGraphStore(pg, backend_lib="cupy")
gs.add_node_data(df, "node_ids", ntype="nt.a")

node_ser = cudf.Series([4, 5, 6])
feat_ser = cudf.Series([2.0, 2.0, 2.0])
df = cudf.DataFrame({"node_ids": node_ser, "feat": feat_ser})
gs.add_node_data(df, "node_ids", ntype="nt.b")

# All indices from a single ntype
output_ar = gs.get_node_storage(key="feat", ntype="nt.a").fetch([1, 2, 3])
cp.testing.assert_array_equal(cp.asarray([1, 1, 1], dtype=cp.float32), output_ar)

# Indices from other ntype are ignored
output_ar = gs.get_node_storage(key="feat", ntype="nt.b").fetch([1, 2, 5])
cp.testing.assert_array_equal(cp.asarray([2.0], dtype=cp.float32), output_ar)


def test_get_edge_storage_gs(dataset1_CuGraphStore):
fs = dataset1_CuGraphStore.get_edge_storage("relationships_k", "relationships")
etype = "('user', 'relationship', 'user')"
fs = dataset1_CuGraphStore.get_edge_storage("relationships_k", etype)
relationship_t = fs.fetch([6, 7, 8], device="cuda")

relationships_df = create_df_from_dataset(
Expand Down

0 comments on commit d24ab52

Please sign in to comment.