forked from rapidsai/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuilder.py
118 lines (100 loc) · 4.73 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Graph builder from pandas dataframes"""
from collections import namedtuple
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_categorical
import dgl
__all__ = ['PandasGraphBuilder']
def _series_to_tensor(series):
if is_categorical(series):
return torch.LongTensor(series.cat.codes.values.astype('int64'))
else: # numeric
return torch.FloatTensor(series.values)
class PandasGraphBuilder(object):
"""Creates a heterogeneous graph from multiple pandas dataframes.
Examples
--------
Let's say we have the following three pandas dataframes:
User table ``users``:
=========== =========== =======
``user_id`` ``country`` ``age``
=========== =========== =======
XYZZY U.S. 25
FOO China 24
BAR China 23
=========== =========== =======
Game table ``games``:
=========== ========= ============== ==================
``game_id`` ``title`` ``is_sandbox`` ``is_multiplayer``
=========== ========= ============== ==================
1 Minecraft True True
2 Tetris 99 False True
=========== ========= ============== ==================
Play relationship table ``plays``:
=========== =========== =========
``user_id`` ``game_id`` ``hours``
=========== =========== =========
XYZZY 1 24
FOO 1 20
FOO 2 16
BAR 2 28
=========== =========== =========
One could then create a bidirectional bipartite graph as follows:
>>> builder = PandasGraphBuilder()
>>> builder.add_entities(users, 'user_id', 'user')
>>> builder.add_entities(games, 'game_id', 'game')
>>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
>>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
>>> g = builder.build()
>>> g.num_nodes('user')
3
>>> g.num_edges('plays')
4
"""
def __init__(self):
self.entity_tables = {}
self.relation_tables = {}
self.entity_pk_to_name = {} # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = {} # mapping from entity names to primary key values
self.num_nodes_per_type = {}
self.edges_per_relation = {}
self.relation_name_to_etype = {}
self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = {} # mapping from relation name to destination key
def add_entities(self, entity_table, primary_key, name):
entities = entity_table[primary_key].astype('category')
if not (entities.value_counts() == 1).all():
raise ValueError('Different entity with the same primary key detected.')
# preserve the category order in the original entity table
entities = entities.cat.reorder_categories(entity_table[primary_key].values)
self.entity_pk_to_name[primary_key] = name
self.entity_pk[name] = primary_key
self.num_nodes_per_type[name] = entity_table.shape[0]
self.entity_key_map[name] = entities
self.entity_tables[name] = entity_table
def add_binary_relations(self, relation_table, source_key, destination_key, name):
src = relation_table[source_key].astype('category')
src = src.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[source_key]].cat.categories)
dst = relation_table[destination_key].astype('category')
dst = dst.cat.set_categories(
self.entity_key_map[self.entity_pk_to_name[destination_key]].cat.categories)
if src.isnull().any():
raise ValueError(
'Some source entities in relation %s do not exist in entity %s.' %
(name, source_key))
if dst.isnull().any():
raise ValueError(
'Some destination entities in relation %s do not exist in entity %s.' %
(name, destination_key))
srctype = self.entity_pk_to_name[source_key]
dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (src.cat.codes.values.astype('int64'), dst.cat.codes.values.astype('int64'))
self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key
def build(self):
# Create heterograph
graph = dgl.heterograph(self.edges_per_relation, self.num_nodes_per_type)
return graph