-
Notifications
You must be signed in to change notification settings - Fork 47
/
conv.jl
1690 lines (1386 loc) · 51.4 KB
/
conv.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
@doc raw"""
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
Performs the operation
```math
\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j
```
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as
```math
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
```
The input to the layer is a node feature array `X` of size `(num_features, num_nodes)`
and optionally an edge weight vector.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `σ`: Activation function. Default `identity`.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
If `add_self_loops=true` the new weights will be set to 1.
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
Default `false`.
# Forward
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing) -> AbstractMatrix
Takes as input a graph `g`,ca node feature matrix `x` of size `[in, num_nodes]`,
and optionally an edge weight vector. Returns a node feature matrix of size
`[out, num_nodes]`.
# Examples
```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
g = GNNGraph(s, t)
x = randn(3, g.num_nodes)
# create layer
l = GCNConv(3 => 5)
# forward pass
y = l(g, x) # size: 5 × num_nodes
# convolution with edge weights
w = [1.1, 0.1, 2.3, 0.5]
y = l(g, x, w)
# Edge weights can also be embedded in the graph.
g = GNNGraph(s, t, w)
l = GCNConv(3 => 5, use_edge_weight=true)
y = l(g, x) # same as l(g, x, w)
```
"""
struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer
weight::W
bias::B
σ::F
add_self_loops::Bool
use_edge_weight::Bool
end
@functor GCNConv
function GCNConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform,
bias::Bool = true,
add_self_loops = true,
use_edge_weight = false)
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
end
check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs"))
function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector)
if length(edge_weight) !== g.num_edges
throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"))
end
end
check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing
function (l::GCNConv)(g::GNNGraph,
x::AbstractMatrix{T},
edge_weight::EW = nothing
) where {T, EW <: Union{Nothing, AbstractVector}}
check_gcnconv_input(g, edge_weight)
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
# Pad weights with ones
# TODO for ADJMAT_T the new edges are not generally at the end
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
@assert length(edge_weight) == g.num_edges
end
end
Dout, Din = size(l.weight)
if Dout < Din
# multiply before convolution if it is more convenient, otherwise multiply after
x = l.weight * x
end
if edge_weight !== nothing
d = degree(g, T; dir = :in, edge_weight)
else
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
end
c = 1 ./ sqrt.(d)
x = x .* c'
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj = x)
else
x = propagate(copy_xj, g, +, xj = x)
end
x = x .* c'
if Dout >= Din
x = l.weight * x
end
return l.σ.(x .+ l.bias)
end
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
return l(g, x, edge_weight)
end
function Base.show(io::IO, l::GCNConv)
out, in = size(l.weight)
print(io, "GCNConv($in => $out")
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
@doc raw"""
ChebConv(in => out, k; bias=true, init=glorot_uniform)
Chebyshev spectral graph convolutional layer from
paper [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375).
Implements
```math
X' = \sum^{K-1}_{k=0} W^{(k)} Z^{(k)}
```
where ``Z^{(k)}`` is the ``k``-th term of Chebyshev polynomials, and can be calculated by the following recursive form:
```math
\begin{aligned}
Z^{(0)} &= X \\
Z^{(1)} &= \hat{L} X \\
Z^{(k)} &= 2 \hat{L} Z^{(k-1)} - Z^{(k-2)}
\end{aligned}
```
with ``\hat{L}`` the [`scaled_laplacian`](@ref).
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `k`: The order of Chebyshev polynomial.
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
"""
struct ChebConv{W <: AbstractArray{<:Number, 3}, B} <: GNNLayer
weight::W
bias::B
k::Int
end
function ChebConv(ch::Pair{Int, Int}, k::Int;
init = glorot_uniform, bias::Bool = true)
in, out = ch
W = init(out, in, k)
b = bias ? Flux.create_bias(W, true, out) : false
ChebConv(W, b, k)
end
@functor ChebConv
function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where {T}
check_num_nodes(g, X)
@assert size(X, 1)==size(c.weight, 2) "Input feature size must match input channel size."
L̃ = scaled_laplacian(g, eltype(X))
Z_prev = X
Z = X * L̃
Y = view(c.weight, :, :, 1) * Z_prev
Y += view(c.weight, :, :, 2) * Z
for k in 3:(c.k)
Z, Z_prev = 2 * Z * L̃ - Z_prev, Z
Y += view(c.weight, :, :, k) * Z
end
return Y .+ c.bias
end
function Base.show(io::IO, l::ChebConv)
out, in, k = size(l.weight)
print(io, "ChebConv(", in, " => ", out)
print(io, ", k=", k)
print(io, ")")
end
@doc raw"""
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
Performs:
```math
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
```
where the aggregation type is selected by `aggr`.
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
"""
struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
weight1::W
weight2::W
bias::B
σ::F
aggr::A
end
@functor GraphConv
function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +,
init = glorot_uniform, bias::Bool = true)
in, out = ch
W1 = init(out, in)
W2 = init(out, in)
b = bias ? Flux.create_bias(W1, true, out) : false
GraphConv(W1, W2, b, σ, aggr)
end
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copy_xj, g, l.aggr, xj = x)
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
return x
end
function Base.show(io::IO, l::GraphConv)
in_channel = size(l.weight1, ndims(l.weight1))
out_channel = size(l.weight1, ndims(l.weight1) - 1)
print(io, "GraphConv(", in_channel, " => ", out_channel)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", aggr=", l.aggr)
print(io, ")")
end
@doc raw"""
GATConv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops])
GATConv((in, ein) => out, ...)
Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.org/abs/1710.10903).
Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i; W \mathbf{x}_j]))
```
with ``z_i`` a normalization factor.
In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass
and the attention coefficients will be calculated as
```math
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W_e \mathbf{e}_{j\to i}; W \mathbf{x}_i; W \mathbf{x}_j]))
```
# Arguments
- `in`: The dimension of input node features.
- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward).
- `out`: The dimension of output node features.
- `σ`: Activation function. Default `identity`.
- `bias`: Learn the additive bias if true. Default `true`.
- `heads`: Number attention heads. Default `1`.
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
"""
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <:
GNNLayer
dense_x::DX
dense_e::DE
bias::B
a::A
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
end
@functor GATConv
Flux.trainable(l::GATConv) = (l.dense_x, l.dense_e, l.bias, l.a)
GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)
function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end
dense_x = Dense(in, out * heads, bias = false)
dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
end
(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
function (l::GATConv)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"
if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end
_, chout = l.channel
heads = l.heads
Wx = l.dense_x(x)
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
# a hand-written message passing
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
α = softmax_edge_neighbors(g, m.logα)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)
if !l.concat
x = mean(x, dims = 2)
end
x = reshape(x, :, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)
return x
end
function message(l::GATConv, Wxi, Wxj, e)
_, chout = l.channel
heads = l.heads
if e === nothing
Wxx = vcat(Wxi, Wxj)
else
We = l.dense_e(e)
We = reshape(We, chout, heads, :) # chout × nheads × nnodes
Wxx = vcat(Wxi, Wxj, We)
end
aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges
logα = leakyrelu.(aWW, l.negative_slope)
return (; logα, Wxj)
end
function Base.show(io::IO, l::GATConv)
(in, ein), out = l.channel
print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end
@doc raw"""
GATv2Conv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops])
GATv2Conv((in, ein) => out, ...)
GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_2 \mathbf{x}_i + W_1 \mathbf{x}_j))
```
with ``z_i`` a normalization factor.
In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass
and the attention coefficients will be calculated as
```math
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_3 \mathbf{e}_{j\to i} + W_2 \mathbf{x}_i + W_1 \mathbf{x}_j)).
```
# Arguments
- `in`: The dimension of input node features.
- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward).
- `out`: The dimension of output node features.
- `σ`: Activation function. Default `identity`.
- `bias`: Learn the additive bias if true. Default `true`.
- `heads`: Number attention heads. Default `1`.
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
"""
struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
dense_i::A1
dense_j::A2
dense_e::A3
bias::B
a::C
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
end
@functor GATv2Conv
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_e, l.bias, l.a)
function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
end
function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
heads::Int = 1,
concat::Bool = true,
negative_slope = 0.2,
init = glorot_uniform,
bias::Bool = true,
add_self_loops = true)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end
dense_i = Dense(in, out * heads; bias = bias, init = init)
dense_j = Dense(in, out * heads; bias = false, init = init)
if ein > 0
dense_e = Dense(ein, out * heads; bias = false, init = init)
else
dense_e = nothing
end
b = bias ? Flux.create_bias(dense_i.weight, true, concat ? out * heads : out) : false
a = init(out, heads)
negative_slope = convert(eltype(dense_i.weight), negative_slope)
GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat,
add_self_loops)
end
(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"
if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end
_, out = l.channel
heads = l.heads
Wxi = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
Wxj = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
α = softmax_edge_neighbors(g, m.logα)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)
if !l.concat
x = mean(x, dims = 2)
end
x = reshape(x, :, size(x, 3))
x = l.σ.(x .+ l.bias)
return x
end
function message(l::GATv2Conv, Wxi, Wxj, e)
_, out = l.channel
heads = l.heads
Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
if e !== nothing
Wx += reshape(l.dense_e(e), out, heads, :)
end
logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims = 1) # 1 × heads × nedges
return (; logα, Wxj)
end
function Base.show(io::IO, l::GATv2Conv)
(in, ein), out = l.channel
print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end
@doc raw"""
GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform)
Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493).
Implements the recursion
```math
\begin{aligned}
\mathbf{h}^{(0)}_i &= [\mathbf{x}_i; \mathbf{0}] \\
\mathbf{h}^{(l)}_i &= GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
\end{aligned}
```
where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing through GRU. The dimension of input ``\mathbf{x}_i`` needs to be less or equal to `out`.
# Arguments
- `out`: The dimension of output features.
- `num_layers`: The number of gated recurrent unit.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `init`: Weight initialization function.
"""
struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
weight::W
gru::R
out_ch::Int
num_layers::Int
aggr::A
end
@functor GatedGraphConv
function GatedGraphConv(out_ch::Int, num_layers::Int;
aggr = +, init = glorot_uniform)
w = init(out_ch, out_ch, num_layers)
gru = GRUCell(out_ch, out_ch)
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
end
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real}
check_num_nodes(g, H)
m, n = size(H)
@assert (m<=l.out_ch) "number of input features must less or equals to output features."
if m < l.out_ch
Hpad = similar(H, S, l.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
end
for i in 1:(l.num_layers)
M = view(l.weight, :, :, i) * H
M = propagate(copy_xj, g, l.aggr; xj = M)
H, _ = l.gru(H, M)
end
H
end
function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end
@doc raw"""
EdgeConv(nn; aggr=max)
Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
Performs the operation
```math
\mathbf{x}_i' = \square_{j \in N(i)}\, nn([\mathbf{x}_i; \mathbf{x}_j - \mathbf{x}_i])
```
where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
# Arguments
- `nn`: A (possibly learnable) function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
"""
struct EdgeConv{NN, A} <: GNNLayer
nn::NN
aggr::A
end
@functor EdgeConv
EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)
function (l::EdgeConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi))
x = propagate(message, g, l.aggr, l, xi = x, xj = x)
return x
end
function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
print(io, ", aggr=", l.aggr)
print(io, ")")
end
@doc raw"""
GINConv(f, ϵ; aggr=+)
Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf).
Implements the graph convolution
```math
\mathbf{x}_i' = f_\Theta\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right)
```
where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
# Arguments
- `f`: A (possibly learnable) function acting on node features.
- `ϵ`: Weighting factor.
"""
struct GINConv{R <: Real, NN, A} <: GNNLayer
nn::NN
ϵ::R
aggr::A
end
@functor GINConv
Flux.trainable(l::GINConv) = (l.nn,)
GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copy_xj, g, l.aggr, xj = x)
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
end
function Base.show(io::IO, l::GINConv)
print(io, "GINConv($(l.nn)")
print(io, ", $(l.ϵ)")
print(io, ")")
end
@doc raw"""
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
The continuous kernel-based convolutional operator from the
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
This convolution is also known as the edge-conditioned convolution from the
[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
Performs the operation
```math
\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
```
where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `f`: A (possibly learnable) function acting on edge features.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `σ`: Activation function.
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
"""
struct NNConv{W, B, NN, F, A} <: GNNLayer
weight::W
bias::B
nn::NN
σ::F
aggr::A
end
@functor NNConv
function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true,
init = glorot_uniform)
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
return NNConv(W, b, nn, σ, aggr)
end
function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
check_num_nodes(g, x)
m = propagate(message, g, l.aggr, l, xj = x, e = e)
return l.σ.(l.weight * x .+ m .+ l.bias)
end
function message(l::NNConv, xi, xj, e)
nin, nedges = size(xj)
W = reshape(l.nn(e), (:, nin, nedges))
xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul
m = NNlib.batched_mul(W, xj)
return reshape(m, :, nedges)
end
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
function Base.show(io::IO, l::NNConv)
out, in = size(l.weight)
print(io, "NNConv($in => $out")
print(io, ", aggr=", l.aggr)
print(io, ")")
end
@doc raw"""
SAGEConv(in => out, σ=identity; aggr=mean, bias=true, init=glorot_uniform)
GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf).
Performs:
```math
\mathbf{x}_i' = W \cdot [\mathbf{x}_i; \square_{j \in \mathcal{N}(i)} \mathbf{x}_j]
```
where the aggregation type is selected by `aggr`.
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
"""
struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
weight::W
bias::B
σ::F
aggr::A
end
@functor SAGEConv
function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean,
init = glorot_uniform, bias::Bool = true)
in, out = ch
W = init(out, 2 * in)
b = bias ? Flux.create_bias(W, true, out) : false
SAGEConv(W, b, σ, aggr)
end
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copy_xj, g, l.aggr, xj = x)
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
return x
end
function Base.show(io::IO, l::SAGEConv)
out_channel, in_channel = size(l.weight)
print(io, "SAGEConv(", in_channel ÷ 2, " => ", out_channel)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", aggr=", l.aggr)
print(io, ")")
end
@doc raw"""
ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true)
The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) paper.
The layer's forward pass is given by
```math
\mathbf{x}_i' = act\big(U\mathbf{x}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big),
```
where the edge gates ``\eta_{ij}`` are given by
```math
\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j).
```
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `act`: Activation function.
- `init`: Weight matrices' initializing function.
- `bias`: Learn an additive bias if true.
"""
struct ResGatedGraphConv{W, B, F} <: GNNLayer
A::W
B::W
U::W
V::W
bias::B
σ::F
end
@functor ResGatedGraphConv
function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform, bias::Bool = true)
in, out = ch
A = init(out, in)
B = init(out, in)
U = init(out, in)
V = init(out, in)
b = bias ? Flux.create_bias(A, true, out) : false
return ResGatedGraphConv(A, B, U, V, b, σ)
end
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx
Ax = l.A * x
Bx = l.B * x
Vx = l.V * x
m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx))
return l.σ.(l.U * x .+ m .+ l.bias)
end
function Base.show(io::IO, l::ResGatedGraphConv)
out_channel, in_channel = size(l.A)
print(io, "ResGatedGraphConv(", in_channel, " => ", out_channel)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
@doc raw"""
CGConv((in, ein) => out, f, act=identity; bias=true, init=glorot_uniform, residual=false)
CGConv(in => out, ...)
The crystal graph convolutional layer from the paper
[Crystal Graph Convolutional Neural Networks for an Accurate and
Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf).
Performs the operation
```math
\mathbf{x}_i' = \mathbf{x}_i + \sum_{j\in N(i)}\sigma(W_f \mathbf{z}_{ij} + \mathbf{b}_f)\, act(W_s \mathbf{z}_{ij} + \mathbf{b}_s)
```
where ``\mathbf{z}_{ij}`` is the node and edge features concatenation
``[\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{j\to i}]``
and ``\sigma`` is the sigmoid function.
The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
as the input size.
# Arguments
- `in`: The dimension of input node features.
- `ein`: The dimension of input edge features.
If `ein` is not given, assumes that no edge features are passed as input in the forward pass.
- `out`: The dimension of output node features.
- `act`: Activation function.
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
- `residual`: Add a residual connection.
# Examples
```julia
g = rand_graph(5, 6)
x = rand(Float32, 2, g.num_nodes)
e = rand(Float32, 3, g.num_edges)
l = CGConv((2, 3) => 4, tanh)
y = l(g, x, e) # size: (4, num_nodes)
# No edge features
l = CGConv(2 => 4, tanh)
y = l(g, x) # size: (4, num_nodes)
```
"""
struct CGConv{D1, D2} <: GNNLayer
ch::Pair{NTuple{2, Int}, Int}
dense_f::D1
dense_s::D2
residual::Bool
end
@functor CGConv
CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)
function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
bias = true, init = glorot_uniform)
(nin, ein), out = ch
dense_f = Dense(2nin + ein, out, sigmoid; bias, init)
dense_s = Dense(2nin + ein, out, act; bias, init)
return CGConv(ch, dense_f, dense_s, residual)
end
function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
if e !== nothing
check_num_edges(g, e)
end
m = propagate(message, g, +, l, xi = x, xj = x, e = e)
if l.residual
if size(x, 1) == size(m, 1)
m += x
else
@warn "number of output features different from number of input features, residual not applied."
end
end
return m
end
function message(l::CGConv, xi, xj, e)
if e !== nothing
z = vcat(xi, xj, e)
else
z = vcat(xi, xj)
end
return l.dense_f(z) .* l.dense_s(z)
end
(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
function Base.show(io::IO, l::CGConv)
print(io, "CGConv($(l.ch)")
l.dense_s.σ == identity || print(io, ", ", l.dense_s.σ)
print(io, ", residual=$(l.residual)")
print(io, ")")
end
@doc raw"""
AGNNConv(init_beta=1f0)
Attention-based Graph Neural Network layer from paper [Attention-based
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
The forward pass is given by
```math
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}}
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_{j'})}}
```
with the cosine distance defined by
```math
\cos(\mathbf{x}_i, \mathbf{x}_j) =
\frac{\mathbf{x}_i \cdot \mathbf{x}_j}{\lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert}
```
and ``\beta`` a trainable parameter.