-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
sag_pool.py
150 lines (120 loc) · 5.7 KB
/
sag_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Callable, Optional, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.nn import GraphConv
from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK
from torch_geometric.typing import OptTensor
class SAGPooling(torch.nn.Module):
r"""The self-attention pooling operator from the `"Self-Attention Graph
Pooling" <https://arxiv.org/abs/1904.08082>`_ and `"Understanding
Attention and Generalization in Graph Neural Networks"
<https://arxiv.org/abs/1905.02850>`_ papers.
If :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`, computes:
.. math::
\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})
\mathbf{i} &= \mathrm{top}_k(\mathbf{y})
\mathbf{X}^{\prime} &= (\mathbf{X} \odot
\mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}
If :obj:`min_score` :math:`\tilde{\alpha}` is a value in :obj:`[0, 1]`,
computes:
.. math::
\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))
\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}.
Projections scores are learned based on a graph neural network layer.
Args:
in_channels (int): Size of each input sample.
ratio (float or int): Graph pooling ratio, which is used to compute
:math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value
of :math:`k` itself, depending on whether the type of :obj:`ratio`
is :obj:`float` or :obj:`int`.
This value is ignored if :obj:`min_score` is not :obj:`None`.
(default: :obj:`0.5`)
GNN (torch.nn.Module, optional): A graph neural network layer for
calculating projection scores (one of
:class:`torch_geometric.nn.conv.GraphConv`,
:class:`torch_geometric.nn.conv.GCNConv`,
:class:`torch_geometric.nn.conv.GATConv` or
:class:`torch_geometric.nn.conv.SAGEConv`). (default:
:class:`torch_geometric.nn.conv.GraphConv`)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :obj:`min_score` is used. (default: :obj:`1`)
nonlinearity (str or callable, optional): The non-linearity to use.
(default: :obj:`"tanh"`)
**kwargs (optional): Additional parameters for initializing the graph
neural network layer.
"""
def __init__(
self,
in_channels: int,
ratio: Union[float, int] = 0.5,
GNN: torch.nn.Module = GraphConv,
min_score: Optional[float] = None,
multiplier: float = 1.0,
nonlinearity: Union[str, Callable] = 'tanh',
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.multiplier = multiplier
self.gnn = GNN(in_channels, 1, **kwargs)
self.select = SelectTopK(1, ratio, min_score, nonlinearity)
self.connect = FilterEdges()
self.reset_parameters()
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.gnn.reset_parameters()
self.select.reset_parameters()
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: OptTensor = None,
batch: OptTensor = None,
attn: OptTensor = None,
) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:
r"""Forward pass.
Args:
x (torch.Tensor): The node feature matrix.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific example. (default: :obj:`None`)
attn (torch.Tensor, optional): Optional node-level matrix to use
for computing attention scores instead of using the node
feature matrix :obj:`x`. (default: :obj:`None`)
"""
if batch is None:
batch = edge_index.new_zeros(x.size(0))
attn = x if attn is None else attn
attn = attn.view(-1, 1) if attn.dim() == 1 else attn
attn = self.gnn(attn, edge_index)
select_out = self.select(attn, batch)
perm = select_out.node_index
score = select_out.weight
assert score is not None
x = x[perm] * score.view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x
connect_out = self.connect(select_out, edge_index, edge_attr, batch)
return (x, connect_out.edge_index, connect_out.edge_attr,
connect_out.batch, perm, score)
def __repr__(self) -> str:
if self.min_score is None:
ratio = f'ratio={self.ratio}'
else:
ratio = f'min_score={self.min_score}'
return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, '
f'{self.in_channels}, {ratio}, multiplier={self.multiplier})')