forked from HEPTrkX/heptrkx-gnn-tracking
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gnn.py
91 lines (86 loc) · 3.52 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
This module implements the PyTorch modules that define the
message-passing graph neural networks for hit or segment classification.
"""
import torch
import torch.nn as nn
class EdgeNetwork(nn.Module):
"""
A module which computes weights for edges of the graph.
For each edge, it selects the associated nodes' features
and applies some fully-connected network layers with a final
sigmoid activation.
"""
def __init__(self, input_dim, hidden_dim=8, hidden_activation=nn.Tanh):
super(EdgeNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim*2, hidden_dim),
hidden_activation(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid())
def forward(self, X, Ri, Ro):
# Select the features of the associated nodes
bo = torch.bmm(Ro.transpose(1, 2), X)
bi = torch.bmm(Ri.transpose(1, 2), X)
B = torch.cat([bo, bi], dim=2)
# Apply the network to each edge
return self.network(B).squeeze(-1)
class NodeNetwork(nn.Module):
"""
A module which computes new node features on the graph.
For each node, it aggregates the neighbor node features
(separately on the input and output side), and combines
them with the node's previous features in a fully-connected
network to compute the new features.
"""
def __init__(self, input_dim, output_dim, hidden_activation=nn.Tanh):
super(NodeNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim*3, output_dim),
hidden_activation(),
nn.Linear(output_dim, output_dim),
hidden_activation())
def forward(self, X, e, Ri, Ro):
bo = torch.bmm(Ro.transpose(1, 2), X)
bi = torch.bmm(Ri.transpose(1, 2), X)
Rwo = Ro * e[:,None]
Rwi = Ri * e[:,None]
mi = torch.bmm(Rwi, bo)
mo = torch.bmm(Rwo, bi)
M = torch.cat([mi, mo, X], dim=2)
return self.network(M)
class GNNSegmentClassifier(nn.Module):
"""
Segment classification graph neural network model.
Consists of an input network, an edge network, and a node network.
"""
def __init__(self, input_dim=2, hidden_dim=8, n_iters=3, hidden_activation=nn.Tanh):
super(GNNSegmentClassifier, self).__init__()
self.n_iters = n_iters
# Setup the input network
self.input_network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
hidden_activation())
# Setup the edge network
self.edge_network = EdgeNetwork(input_dim+hidden_dim, hidden_dim,
hidden_activation)
# Setup the node layers
self.node_network = NodeNetwork(input_dim+hidden_dim, hidden_dim,
hidden_activation)
def forward(self, inputs):
"""Apply forward pass of the model"""
X, Ri, Ro = inputs
# Apply input network to get hidden representation
H = self.input_network(X)
# Shortcut connect the inputs onto the hidden representation
H = torch.cat([H, X], dim=-1)
# Loop over iterations of edge and node networks
for i in range(self.n_iters):
# Apply edge network
e = self.edge_network(H, Ri, Ro)
# Apply node network
H = self.node_network(H, e, Ri, Ro)
# Shortcut connect the inputs onto the hidden representation
H = torch.cat([H, X], dim=-1)
# Apply final edge network
return self.edge_network(H, Ri, Ro)