From 1ae6133bf1d2542847129ae6f29c4e1b489d4d52 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 20 Jul 2022 07:03:33 -0700 Subject: [PATCH] Add get_node_storage and get_edge_storage to CuGraphStorage (#2381) This PR enables https://github.com/rapidsai/dgl/pull/14 by adding the following functions/classes to CuGraphStore: 1. add_node_data 2. add_edge_data 3. get_node_storage 4. get_edge_storage 5. CuFeatureStorage Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Xiaoyun Wang (https://github.com/wangxiaoyunNV) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/2381 --- python/cugraph/cugraph/gnn/graph_store.py | 326 +++++++++++++---- .../cugraph/cugraph/tests/test_graph_store.py | 334 ++++++++++++++---- 2 files changed, 516 insertions(+), 144 deletions(-) diff --git a/python/cugraph/cugraph/gnn/graph_store.py b/python/cugraph/cugraph/gnn/graph_store.py index 6bd46e78f90..7e77ffcf594 100644 --- a/python/cugraph/cugraph/gnn/graph_store.py +++ b/python/cugraph/cugraph/gnn/graph_store.py @@ -11,13 +11,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import cudf import cugraph from cugraph.experimental import PropertyGraph from cugraph.community.egonet import batched_ego_graphs from cugraph.utilities.utils import sample_groups +import cupy as cp -import numpy as np + +src_n = PropertyGraph.src_col_name +dst_n = PropertyGraph.dst_col_name +type_n = PropertyGraph.type_col_name +eid_n = PropertyGraph.edge_id_col_name +vid_n = PropertyGraph.vertex_col_name class CuGraphStore: @@ -34,54 +41,148 @@ class CuGraphStore: hetrogeneous graphs - use PropertyGraph """ + def __init__(self, graph, backend_lib="torch"): + if isinstance(graph, PropertyGraph): + self.__G = graph + else: + raise ValueError("graph must be a PropertyGraph") + # dict to map column names corresponding to edge features + # of each type + self.edata_key_col_d = defaultdict(list) + # dict to map column names corresponding to node features + # of each type + self.ndata_key_col_d = defaultdict(list) + self.backend_lib = backend_lib + + def add_node_data(self, df, node_col_name, node_key, ntype=None): + self.gdata.add_vertex_data( + df, vertex_col_name=node_col_name, type_name=ntype + ) + col_names = list(df.columns) + col_names.remove(node_col_name) + self.ndata_key_col_d[node_key] += col_names + + def add_edge_data(self, df, vertex_col_names, edge_key, etype=None): + self.gdata.add_edge_data( + df, vertex_col_names=vertex_col_names, type_name=etype + ) + col_names = [ + col for col in list(df.columns) if col not in vertex_col_names + ] + self.edata_key_col_d[edge_key] += col_names + + def get_node_storage(self, key, ntype=None): + + if ntype is None: + ntypes = self.ntypes + if len(self.ntypes) > 1: + raise ValueError( + ( + "Node type name must be specified if there " + "are more than one node types." + ) + ) + ntype = ntypes[0] + + df = self.gdata._vertex_prop_dataframe + col_names = self.ndata_key_col_d[key] + return CuFeatureStorage( + df=df, + id_col=vid_n, + _type_=ntype, + col_names=col_names, + backend_lib=self.backend_lib, + ) + + def get_edge_storage(self, key, etype=None): + if etype is None: + etypes = self.etypes + if len(self.etypes) > 1: + raise ValueError( + ( + "Edge type name must be specified if there" + "are more than one edge types." + ) + ) + + etype = etypes[0] + col_names = self.edata_key_col_d[key] + df = self.gdata._edge_prop_dataframe + return CuFeatureStorage( + df=df, + id_col=eid_n, + _type_=etype, + col_names=col_names, + backend_lib=self.backend_lib, + ) + + def num_nodes(self, ntype=None): + if ntype is not None: + s = self.gdata._vertex_prop_dataframe[type_n] == ntype + return s.sum() + else: + return self.gdata.num_vertices + + def num_edges(self, etype=None): + if etype is not None: + s = self.gdata._edge_prop_dataframe[type_n] == etype + return s.sum() + else: + return self.gdata.num_edges + + @property + def ntypes(self): + s = self.gdata._vertex_prop_dataframe[type_n] + ntypes = s.drop_duplicates().to_arrow().to_pylist() + return ntypes + + @property + def etypes(self): + s = self.gdata._edge_prop_dataframe[type_n] + ntypes = s.drop_duplicates().to_arrow().to_pylist() + return ntypes + @property def ndata(self): - return self.__G._vertex_prop_dataframe + return { + k: self.gdata._vertex_prop_dataframe[col_names].dropna(how="all") + for k, col_names in self.ndata_key_col_d.items() + } @property def edata(self): - return self.__G._edge_prop_dataframe + return { + k: self.gdata._edge_prop_dataframe[col_names].dropna(how="all") + for k, col_names in self.edata_key_col_d.items() + } @property def gdata(self): return self.__G - def __init__(self, graph): - if isinstance(graph, PropertyGraph): - self.__G = graph - else: - raise ValueError("graph must be a PropertyGraph") - ###################################### # Utilities ###################################### @property def num_vertices(self): - return self.__G.num_vertices - - @property - def num_edges(self): - return self.__G.num_edges + return self.gdata.num_vertices def get_vertex_ids(self): - return self.__G.vertices_ids() + return self.gdata.vertices_ids() ###################################### # Sampling APIs ###################################### - def sample_neighbors(self, - nodes, - fanout=-1, - edge_dir='in', - prob=None, - replace=False): + def sample_neighbors( + self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False + ): """ Sample neighboring edges of the given nodes and return the subgraph. Parameters ---------- - nodes : array (single dimension) + nodes_cap : Dlpack of Node IDs (single dimension) Node IDs to sample neighbors from. fanout : int The number of edges to be sampled for each node on each edge type. @@ -102,16 +203,24 @@ def sample_neighbors(self, Returns ------- - CuPy array - The sampled arrays for bipartite graph. + DLPack capsule + The src nodes for the sampled bipartite graph. + DLPack capsule + The sampled dst nodes for the sampledbipartite graph. + DLPack capsule + The corresponding eids for the sampled bipartite graph """ + nodes = cudf.from_dlpack(nodes) num_nodes = len(nodes) - current_seeds = nodes.reindex(index=np.arange(0, num_nodes)) - _g = self.__G.extract_subgraph(create_using=cugraph.Graph, - allow_multi_edges=True) - ego_edge_list, seeds_offsets = batched_ego_graphs(_g, - current_seeds, - radius=1) + current_seeds = nodes.reindex(index=cp.arange(0, num_nodes)) + _g = self.__G.extract_subgraph( + create_using=cugraph.Graph, allow_multi_edges=True + ) + ego_edge_list, seeds_offsets = batched_ego_graphs( + _g, current_seeds, radius=1 + ) + + del _g # filter and get a certain size neighborhood # Step 1 @@ -125,20 +234,55 @@ def sample_neighbors(self, dst_seeds.index = ego_edge_list.index filtered_list = ego_edge_list[ego_edge_list["dst"] == dst_seeds] + del dst_seeds, offset_lens, seeds_offsets_s + del ego_edge_list, seeds_offsets + # Step 2 # Sample Fan Out # for each dst take maximum of fanout samples - filtered_list = sample_groups(filtered_list, - by="dst", - n_samples=fanout) - - return filtered_list['dst'].values, filtered_list['src'].values - - def node_subgraph(self, - nodes=None, - create_using=cugraph.Graph, - directed=False, - multigraph=True): + filtered_list = sample_groups( + filtered_list, by="dst", n_samples=fanout + ) + + # TODO: Verify order of execution + sample_df = cudf.DataFrame( + {src_n: filtered_list["src"], dst_n: filtered_list["dst"]} + ) + del filtered_list + + # del parents_nodes, children_nodes + edge_df = sample_df.merge( + self.gdata._edge_prop_dataframe[[src_n, dst_n, eid_n]], + on=[src_n, dst_n], + ) + + return ( + edge_df[src_n].to_dlpack(), + edge_df[dst_n].to_dlpack(), + edge_df[eid_n].to_dlpack(), + ) + + def find_edges(self, edge_ids, etype): + """Return the source and destination node IDs given the edge IDs within + the given edge type. + Return type is + cudf.Series, cudf.Series + """ + edge_df = self.gdata._edge_prop_dataframe[ + [src_n, dst_n, eid_n, type_n] + ] + subset_df = get_subset_df( + edge_df, PropertyGraph.edge_id_col_name, edge_ids, etype + ) + return subset_df[src_n].to_dlpack(), subset_df[dst_n].to_dlpack() + + def node_subgraph( + self, + nodes=None, + create_using=cugraph.Graph, + directed=False, + multigraph=True, + ): """ Return a subgraph induced on the given nodes. @@ -156,12 +300,12 @@ def node_subgraph(self, The sampled subgraph with the same node ID space with the original graph. """ - + # Values vary b/w cugraph and DGL investigate # expr="(_SRC in nodes) | (_DST_ in nodes)" - _g = self.__G.extract_subgraph( - create_using=cugraph.Graph(directed=directed), - allow_multi_edges=multigraph) + create_using=cugraph.Graph(directed=directed), + allow_multi_edges=multigraph, + ) if nodes is None: return _g @@ -192,18 +336,15 @@ def egonet(self, nodes, k): for each seed. """ - _g = self.__G.extract_subgraph(create_using=cugraph.Graph, - allow_multi_edges=True) + _g = self.__G.extract_subgraph( + create_using=cugraph.Graph, allow_multi_edges=True + ) ego_edge_list, seeds_offsets = batched_ego_graphs(_g, nodes, radius=k) return ego_edge_list, seeds_offsets - def randomwalk(self, - nodes, - length, - prob=None, - restart_prob=None): + def randomwalk(self, nodes, length, prob=None, restart_prob=None): """ Perform randomwalks starting from the given nodes and return the traces. @@ -235,11 +376,13 @@ def randomwalk(self, the node IDs reached by the randomwalk starting from nodes[i]. -1 means the walk has stopped. """ - _g = self.__G.extract_subgraph(create_using=cugraph.Graph, - allow_multi_edges=True) + _g = self.__G.extract_subgraph( + create_using=cugraph.Graph, allow_multi_edges=True + ) - p, w, s = cugraph.random_walks(_g, nodes, - max_depth=length, use_padding=True) + p, w, s = cugraph.random_walks( + _g, nodes, max_depth=length, use_padding=True + ) return p, w, s @@ -251,31 +394,35 @@ class CuFeatureStorage: is fine. DGL simply uses duck-typing to implement its sampling pipeline. """ - def __getitem__(self, ids): - """Fetch the features of the given node/edge IDs. + def __init__(self, df, id_col, _type_, col_names, backend_lib="torch"): + self.df = df + self.id_col = id_col + self.type = _type_ + self.col_names = col_names + if backend_lib == "torch": + from torch.utils.dlpack import from_dlpack + elif backend_lib == "tf": + from tensorflow.experimental.dlpack import from_dlpack + elif backend_lib == "cupy": + from cupy import from_dlpack + else: + raise NotImplementedError( + "Only pytorch and tensorflow backends are currently supported" + ) - Parameters - ---------- - ids : Tensor - Node or edge IDs. + self.from_dlpack = from_dlpack - Returns - ------- - Tensor - Feature data stored in PyTorch Tensor. - """ - pass - - async def async_fetch(self, ids, device): - """Asynchronously fetch the features of the given node/edge IDs to the + def fetch(self, indices, device, pin_memory=False, **kwargs): + """Fetch the features of the given node/edge IDs to the given device. Parameters ---------- - ids : Tensor + indices : Tensor Node or edge IDs. device : Device Device context. + pin_memory : Returns ------- @@ -283,4 +430,37 @@ async def async_fetch(self, ids, device): Feature data stored in PyTorch Tensor. """ # Default implementation uses synchronous fetch. - return self.__getitem__(ids).to(device) + + subset_cols = self.col_names + [type_n, self.id_col] + subset_df = get_subset_df( + self.df[subset_cols], self.id_col, indices, self.type + )[self.col_names] + tensor = self.from_dlpack(subset_df.to_dlpack()) + + if isinstance(tensor, cp.ndarray): + # can not transfer to + # a different device for cupy + return tensor + else: + return tensor.to(device) + + +def get_subset_df(df, id_col, indices, _type_): + """ + Util to get the subset dataframe to the indices of the requested type + """ + # We can avoid all of this if we set index to id_col like + # edge_id_col_name and vertex_id_col_name and make it much faster + # by using loc + indices_df = cudf.Series(cp.asarray(indices), name=id_col).to_frame() + id_col_name = id_col + "_index_" + indices_df = indices_df.reset_index(drop=False).rename( + columns={"index": id_col_name} + ) + subset_df = indices_df.merge(df, how="left") + if _type_ is None: + subset_df = subset_df[subset_df[type_n].isnull()] + else: + subset_df = subset_df[subset_df[type_n] == _type_] + subset_df = subset_df.sort_values(by=id_col_name) + return subset_df diff --git a/python/cugraph/cugraph/tests/test_graph_store.py b/python/cugraph/cugraph/tests/test_graph_store.py index 66843812e16..7cb535da5da 100644 --- a/python/cugraph/cugraph/tests/test_graph_store.py +++ b/python/cugraph/cugraph/tests/test_graph_store.py @@ -11,13 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import pytest - import cugraph from cugraph.testing import utils from cugraph.experimental import PropertyGraph import numpy as np import cudf +import cupy as cp +from cugraph.gnn import CuGraphStore # Test @@ -35,8 +37,9 @@ def test_using_graph(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph() - g.from_cudf_edgelist(cu_M, source='0', - destination='1', edge_attr='2', renumber=True) + g.from_cudf_edgelist( + cu_M, source="0", destination="1", edge_attr="2", renumber=True + ) cugraph.gnn.CuGraphStore(graph=g) @@ -46,19 +49,19 @@ def test_using_pgraph(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', - edge_attr='2', renumber=True) + g.from_cudf_edgelist( + cu_M, source="0", destination="1", edge_attr="2", renumber=True + ) pG = PropertyGraph() - pG.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=None) + pG.add_edge_data( + cu_M, vertex_col_names=("0", "1"), property_columns=None + ) gstore = cugraph.gnn.CuGraphStore(graph=pG) assert g.number_of_edges() == pG.num_edges - assert g.number_of_edges() == gstore.num_edges + assert g.number_of_edges() == gstore.num_edges() assert g.number_of_vertices() == pG.num_vertices assert g.number_of_vertices() == gstore.num_vertices @@ -69,14 +72,11 @@ def test_node_data_pg(graph_file): cu_M = utils.read_csv_file(graph_file) pG = PropertyGraph() - pG.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=None) - gstore = cugraph.gnn.CuGraphStore(graph=pG) - - edata = gstore.edata + gstore.add_edge_data( + cu_M, vertex_col_names=("0", "1"), edge_key="feat" + ) + edata = gstore.edata["feat"] assert edata.shape[0] > 0 @@ -89,15 +89,13 @@ def test_egonet(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pG = PropertyGraph() - pG.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=None) - - gstore = cugraph.gnn.CuGraphStore(graph=pG) + gstore = cugraph.gnn.CuGraphStore(graph=pG, backend_lib='cupy') + gstore.add_edge_data( + cu_M, vertex_col_names=("0", "1"), edge_key="edge_feat" + ) nodes = [1, 2] @@ -115,16 +113,13 @@ def test_workflow(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pg = PropertyGraph() - pg.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=["2"]) - gstore = cugraph.gnn.CuGraphStore(graph=pg) - + gstore.add_edge_data( + cu_M, vertex_col_names=("0", "1"), edge_key="feat" + ) nodes = gstore.get_vertex_ids() num_nodes = len(nodes) @@ -142,25 +137,26 @@ def test_sample_neighbors(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pg = PropertyGraph() - pg.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=["2"]) - gstore = cugraph.gnn.CuGraphStore(graph=pg) + gstore.add_edge_data( + cu_M, edge_key="feat", vertex_col_names=("0", "1") + ) nodes = gstore.get_vertex_ids() num_nodes = len(nodes) assert num_nodes > 0 - sampled_nodes = nodes[:5] + sampled_nodes = nodes[:5].to_dlpack() - parents_list, children_list = gstore.sample_neighbors(sampled_nodes, 2) + parents_cap, children_cap, edge_id_cap = gstore.sample_neighbors( + sampled_nodes, 2 + ) + parents_list = cudf.from_dlpack(parents_cap) assert len(parents_list) > 0 @@ -169,21 +165,21 @@ def test_sample_neighbor_neg_one_fanout(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pg = PropertyGraph() - pg.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=["2"]) - gstore = cugraph.gnn.CuGraphStore(graph=pg) + gstore.add_edge_data( + cu_M, edge_key="edge_k", vertex_col_names=("0", "1") + ) nodes = gstore.get_vertex_ids() - sampled_nodes = nodes[:5] + sampled_nodes = nodes[:5].to_dlpack() # -1, default fan_out - parents_list, children_list = gstore.sample_neighbors(sampled_nodes, -1) - + parents_cap, children_cap, edge_id_cap = gstore.sample_neighbors( + sampled_nodes, -1 + ) + parents_list = cudf.from_dlpack(parents_cap) assert len(parents_list) > 0 @@ -192,26 +188,29 @@ def test_n_data(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pg = PropertyGraph() - pg.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=["2"]) + gstore = cugraph.gnn.CuGraphStore(graph=pg) + + gstore.add_edge_data( + cu_M, + edge_key="feat", + vertex_col_names=("0", "1"), + ) - num_nodes = g.number_of_nodes() + num_nodes = gstore.num_nodes() df_feat = cudf.DataFrame() - df_feat['node_id'] = np.arange(num_nodes) - df_feat['val0'] = [float(i+1) for i in range(num_nodes)] - df_feat['val1'] = [float(i+2) for i in range(num_nodes)] - pg.add_vertex_data(df_feat, - type_name="test_feat", - vertex_col_name="node_id", - property_columns=None) - gstore = cugraph.gnn.CuGraphStore(graph=pg) + df_feat["node_id"] = np.arange(num_nodes) + df_feat["val0"] = [float(i + 1) for i in range(num_nodes)] + df_feat["val1"] = [float(i + 2) for i in range(num_nodes)] + gstore.add_node_data( + df_feat, + node_key="node_feat", + node_col_name="node_id", + ) - ndata = gstore.ndata + ndata = gstore.ndata["node_feat"] assert ndata.shape[0] > 0 @@ -221,16 +220,209 @@ def test_e_data(graph_file): cu_M = utils.read_csv_file(graph_file) g = cugraph.Graph(directed=True) - g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True) + g.from_cudf_edgelist(cu_M, source="0", destination="1", renumber=True) pg = PropertyGraph() - pg.add_edge_data(cu_M, - type_name="edge", - vertex_col_names=("0", "1"), - property_columns=["2"]) - gstore = cugraph.gnn.CuGraphStore(graph=pg) + gstore.add_edge_data( + cu_M, vertex_col_names=("0", "1"), edge_key="edge_k" + ) - edata = gstore.edata - + edata = gstore.edata["edge_k"] assert edata.shape[0] > 0 + + +dataset1 = { + "merchants": [ + [ + "merchant_id", + "merchant_locaton", + "merchant_size", + "merchant_sales", + "merchant_num_employees", + ], + [ + (11, 78750, 44, 123.2, 12), + (4, 78757, 112, 234.99, 18), + (21, 44145, 83, 992.1, 27), + (16, 47906, 92, 32.43, 5), + (86, 47906, 192, 2.43, 51), + ], + ], + "users": [ + ["user_id", "user_location", "vertical"], + [ + (89021, 78757, 0), + (32431, 78750, 1), + (89216, 78757, 1), + (78634, 47906, 0), + ], + ], + "taxpayers": [ + ["payer_id", "amount"], + [ + (11, 1123.98), + (4, 3243.7), + (21, 8932.3), + (16, 3241.77), + (86, 789.2), + (89021, 23.98), + (78634, 41.77), + ], + ], + "transactions": [ + ["user_id", "merchant_id", "volume", "time", "card_num"], + [ + (89021, 11, 33.2, 1639084966.5513437, 123456), + (89216, 4, None, 1639085163.481217, 8832), + (78634, 16, 72.0, 1639084912.567394, 4321), + (32431, 4, 103.2, 1639084721.354346, 98124), + ], + ], + "relationships": [ + ["user_id_1", "user_id_2", "relationship_type"], + [ + (89216, 89021, 9), + (89216, 32431, 9), + (32431, 78634, 8), + (78634, 89216, 8), + ], + ], + "referrals": [ + ["user_id_1", "user_id_2", "merchant_id", "stars"], + [ + (89216, 78634, 11, 5), + (89021, 89216, 4, 4), + (89021, 89216, 21, 3), + (89021, 89216, 11, 3), + (89021, 78634, 21, 4), + (78634, 32431, 11, 4), + ], + ], +} + + +# util to create dataframe +def create_df_from_dataset(col_n, rows): + data_d = defaultdict(list) + for row in rows: + for col_id, col_v in enumerate(row): + data_d[col_n[col_id]].append(col_v) + return cudf.DataFrame(data_d) + + +@pytest.fixture() +def dataset1_CuGraphStore(): + """ + Fixture which returns an instance of a CuGraphStore with vertex and edge + data added from dataset1, parameterized for different DataFrame types. + """ + merchant_df = create_df_from_dataset( + dataset1["merchants"][0], dataset1["merchants"][1] + ) + user_df = create_df_from_dataset( + dataset1["users"][0], dataset1["users"][1] + ) + taxpayers_df = create_df_from_dataset( + dataset1["taxpayers"][0], dataset1["taxpayers"][1] + ) + transactions_df = create_df_from_dataset( + dataset1["transactions"][0], dataset1["transactions"][1] + ) + relationships_df = create_df_from_dataset( + dataset1["relationships"][0], dataset1["relationships"][1] + ) + referrals_df = create_df_from_dataset( + dataset1["referrals"][0], dataset1["referrals"][1] + ) + + pG = PropertyGraph() + graph = CuGraphStore(pG, backend_lib='cupy') + # Vertex and edge data is added as one or more DataFrames; either a Pandas + # DataFrame to keep data on the CPU, a cuDF DataFrame to keep data on GPU, + # or a dask_cudf DataFrame to keep data on distributed GPUs. + + # For dataset1: vertices are merchants and users, edges are transactions, + # relationships, and referrals. + + # property_columns=None (the default) means all columns except + # vertex_col_name will be used as properties for the vertices/edges. + + graph.add_node_data( + merchant_df, "merchant_id", "merchant_k", "merchant" + ) + graph.add_node_data(user_df, "user_id", "user_k", "user") + graph.add_node_data( + taxpayers_df, "payer_id", "taxpayers_k", "taxpayers" + ) + + graph.add_edge_data( + referrals_df, + ("user_id_1", "user_id_2"), + "referrals_k", + "referrals", + ) + graph.add_edge_data( + relationships_df, + ("user_id_1", "user_id_2"), + "relationships_k", + "relationships", + ) + graph.add_edge_data( + transactions_df, + ("user_id", "merchant_id"), + "transactions_k", + "transactions", + ) + + return graph + + +def test_num_nodes_gs(dataset1_CuGraphStore): + assert dataset1_CuGraphStore.num_nodes() == 9 + + +def test_num_edges(dataset1_CuGraphStore): + assert dataset1_CuGraphStore.num_edges() == 14 + + +def test_get_node_storage_gs(dataset1_CuGraphStore): + fs = dataset1_CuGraphStore.get_node_storage( + key="merchant_k", ntype="merchant" + ) + merchent_gs = fs.fetch([11, 4, 21, 316, 11], device="cuda") + merchant_df = create_df_from_dataset( + dataset1["merchants"][0], dataset1["merchants"][1] + ) + cudf_ar = ( + merchant_df.set_index("merchant_id") + .loc[[11, 4, 21, 316, 11]] + .values + ) + assert cp.allclose(cudf_ar, merchent_gs) + + +def test_get_edge_storage_gs(dataset1_CuGraphStore): + fs = dataset1_CuGraphStore.get_edge_storage( + "relationships_k", "relationships" + ) + relationship_t = fs.fetch([6, 7, 8], device="cuda") + + relationships_df = create_df_from_dataset( + dataset1["relationships"][0], dataset1["relationships"][1] + ) + cudf_ar = relationships_df["relationship_type"].iloc[[0, 1, 2]].values + + assert cp.allclose(cudf_ar, relationship_t) + + +def test_sampling_gs(dataset1_CuGraphStore): + node_pack = cp.asarray([4]).toDlpack() + ( + parents_cap, + children_cap, + edge_id_cap, + ) = dataset1_CuGraphStore.sample_neighbors(node_pack, fanout=1) + x = cudf.from_dlpack(parents_cap) + + assert x is not None