-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlayers.py
119 lines (101 loc) · 3.62 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import MLP, MessageNorm
from ogb.graphproppred.mol_encoder import BondEncoder
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class GENConv(nn.Module):
r"""
Description
-----------
Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
in_dim: int
Input size.
out_dim: int
Output size.
aggregator: str
Type of aggregation. Default is 'softmax'.
beta: float
A continuous variable called an inverse temperature. Default is 1.0.
learn_beta: bool
Whether beta is a learnable variable or not. Default is False.
p: float
Initial power for power mean aggregation. Default is 1.0.
learn_p: bool
Whether p is a learnable variable or not. Default is False.
msg_norm: bool
Whether message normalization is used. Default is False.
learn_msg_scale: bool
Whether s is a learnable scaling factor or not in message normalization. Default is False.
mlp_layers: int
The number of MLP layers. Default is 1.
eps: float
A small positive constant in message construction function. Default is 1e-7.
"""
def __init__(
self,
in_dim,
out_dim,
aggregator="softmax",
beta=1.0,
learn_beta=False,
p=1.0,
learn_p=False,
msg_norm=False,
learn_msg_scale=False,
mlp_layers=1,
eps=1e-7,
):
super(GENConv, self).__init__()
self.aggr = aggregator
self.eps = eps
channels = [in_dim]
for _ in range(mlp_layers - 1):
channels.append(in_dim * 2)
channels.append(out_dim)
self.mlp = MLP(channels)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
self.beta = (
nn.Parameter(torch.Tensor([beta]), requires_grad=True)
if learn_beta and self.aggr == "softmax"
else beta
)
self.p = (
nn.Parameter(torch.Tensor([p]), requires_grad=True)
if learn_p
else p
)
self.edge_encoder = BondEncoder(in_dim)
def forward(self, g, node_feats, edge_feats):
with g.local_scope():
# Node and edge feature size need to match.
g.ndata["h"] = node_feats
g.edata["h"] = self.edge_encoder(edge_feats)
g.apply_edges(fn.u_add_e("h", "h", "m"))
if self.aggr == "softmax":
g.edata["m"] = F.relu(g.edata["m"]) + self.eps
g.edata["a"] = edge_softmax(g, g.edata["m"] * self.beta)
g.update_all(
lambda edge: {"x": edge.data["m"] * edge.data["a"]},
fn.sum("x", "m"),
)
elif self.aggr == "power":
minv, maxv = 1e-7, 1e1
torch.clamp_(g.edata["m"], minv, maxv)
g.update_all(
lambda edge: {"x": torch.pow(edge.data["m"], self.p)},
fn.mean("x", "m"),
)
torch.clamp_(g.ndata["m"], minv, maxv)
g.ndata["m"] = torch.pow(g.ndata["m"], self.p)
else:
raise NotImplementedError(
f"Aggregator {self.aggr} is not supported."
)
if self.msg_norm is not None:
g.ndata["m"] = self.msg_norm(node_feats, g.ndata["m"])
feats = node_feats + g.ndata["m"]
return self.mlp(feats)