-
Notifications
You must be signed in to change notification settings - Fork 10
/
local_structure.py
76 lines (64 loc) · 2.46 KB
/
local_structure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv
from dgl.nn.pytorch.softmax import edge_softmax
def _compute_local_feats(middle_feats, g, mode='mean'):
'''get local feature,
one can also compute it through a gcn layer
given:
middle_feats and graph
'''
graph = g.local_var()
graph.ndata['h'] = middle_feats
if mode == 'max':
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
elif mode == 'mean':
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
local_feats = graph.ndata['neigh']
return local_feats
class customized_GAT(nn.Module):
def __init__(self, in_dim, out_dim, retatt=False):
super(customized_GAT, self).__init__()
self.GATConv1 = GATConv(in_feats=in_dim, out_feats=out_dim, num_heads=1)
self.nonlinear = nn.LeakyReLU(negative_slope=0.2)
self.retatt = retatt
def forward(self, graph, feats):
feats = self.nonlinear(feats)
if self.retatt:
rst, att = self.GATConv1(graph, feats, self.retatt)
return rst.flatten(1), att
else:
rst = self.GATConv1(graph, feats, self.retatt)
return rst.flatten(1)
class distanceNet(nn.Module):
def __init__(self):
super(distanceNet, self).__init__()
def forward(self, graph, feats):
graph = graph.local_var()
feats = feats.view(-1, 1, feats.shape[1])
graph.ndata.update({'ftl': feats, 'ftr': feats})
# compute edge distance
# gaussion
graph.apply_edges(fn.u_sub_v('ftl', 'ftr', 'diff'))
e = graph.edata.pop('diff')
e = torch.exp( (-1.0/100) * torch.sum(torch.abs(e), dim=-1) )
# compute softmax
e = edge_softmax(graph, e)
return e
def old_get_local_model(feat_info, upsampling=False):
'''model to compute a local feature given a graph and features
retatt: return attention coefficients and donot apply linear transformation
'''
if upsampling:
return customized_GAT(feat_info['s_feat'][1], feat_info['t_feat'][1], retatt=True)
return customized_GAT(feat_info['t_feat'][1], feat_info['t_feat'][1], retatt=True)
def get_local_model(feat_info, upsampling=False):
'''
'''
return distanceNet()
def get_upsampling_model(feat_info):
'''upsampling the features of a graph
'''
return customized_GAT(feat_info['s_feat'][1],feat_info['t_feat'][1])