From 786b20009d095128f61bd736a64891820cb9c439 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Fri, 2 Aug 2024 15:34:14 +0530 Subject: [PATCH 01/41] WIP --- GNNLux/src/layers/conv.jl | 71 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 83c3efddc..7960d2721 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -628,3 +628,74 @@ function Base.show(io::IO, l::GINConv) print(io, ", $(l.ϵ)") print(io, ")") end + +@concrete struct NNConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + aggr + in_dims::Int + out_dims::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias + σ +end + +""" +function NNConv(ch::Pair{Int, Int}, σ = identity; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, + allow_fast_activation::Bool = true) +""" +# fix args order +function NNConv(ch::Pair{Int, Int}, nn, σ = identity; + aggr = +, + init_bias = zeros32, + use_bias::Bool = true, + init_weight = glorot_uniform, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) +end + +function (l::GCNConv)(g, x, edge_weight, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + + # what would be the order of args here? + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) + y = GNNlib.nn_conv(m, g, x, edge_weight) + stnew = _getstate(nn) + return y, stnew +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims # nn wont affect this right? +LuxCore.outputsize(d::NNConv) = (d.out_dims,) + + +function Base.show(io::IO, l::GINConv) + print(io, "NNConv($(l.nn)") + print(io, ", $(l.ϵ)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end \ No newline at end of file From 01767eef1723997bb0aafabcc85023c7643f6fd4 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Fri, 2 Aug 2024 23:43:56 +0530 Subject: [PATCH 02/41] WIP --- GNNLux/src/layers/conv.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 7960d2721..1f14c3582 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -642,16 +642,7 @@ end σ end -""" -function NNConv(ch::Pair{Int, Int}, σ = identity; - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true, - add_self_loops::Bool = true, - use_edge_weight::Bool = false, - allow_fast_activation::Bool = true) -""" -# fix args order + function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, init_bias = zeros32, From b8c4db675cb3fd0e190191e212236fab7268745e Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Sat, 3 Aug 2024 20:33:52 +0530 Subject: [PATCH 03/41] Update conv.jl --- GNNLux/src/layers/conv.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 1f14c3582..009c64cc0 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -667,20 +667,8 @@ function (l::GCNConv)(g, x, edge_weight, ps, st) return y, stnew end -function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv) - weight = l.init_weight(rng, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight, bias) - else - return (; weight) - end -end - -LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims # nn wont affect this right? LuxCore.outputsize(d::NNConv) = (d.out_dims,) - function Base.show(io::IO, l::GINConv) print(io, "NNConv($(l.nn)") print(io, ", $(l.ϵ)") @@ -689,4 +677,4 @@ function Base.show(io::IO, l::GINConv) l.add_self_loops || print(io, ", add_self_loops=false") !l.use_edge_weight || print(io, ", use_edge_weight=true") print(io, ")") -end \ No newline at end of file +end From b6c1a27febbbdc944729afdebb1640e8f6dc6e2d Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sun, 4 Aug 2024 17:13:09 +0530 Subject: [PATCH 04/41] fix --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 1f14c3582..e68c8fc7f 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -681,7 +681,7 @@ LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out LuxCore.outputsize(d::NNConv) = (d.out_dims,) -function Base.show(io::IO, l::GINConv) +function Base.show(io::IO, l::NNConv) print(io, "NNConv($(l.nn)") print(io, ", $(l.ϵ)") l.σ == identity || print(io, ", ", l.σ) From fb0bb1da97f220a2de1dfeb2c368c95ccc746807 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:16:11 +0530 Subject: [PATCH 05/41] Update conv.jl --- GNNLux/src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 009c64cc0..cc838050e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -656,7 +656,7 @@ function NNConv(ch::Pair{Int, Int}, nn, σ = identity; return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) end -function (l::GCNConv)(g, x, edge_weight, ps, st) +function (l::NNConv)(g, x, edge_weight, ps, st) nn = StatefulLuxLayer{true}(l.nn, ps, st) # what would be the order of args here? @@ -669,7 +669,7 @@ end LuxCore.outputsize(d::NNConv) = (d.out_dims,) -function Base.show(io::IO, l::GINConv) +function Base.show(io::IO, l::NNConv) print(io, "NNConv($(l.nn)") print(io, ", $(l.ϵ)") l.σ == identity || print(io, ", ", l.σ) From cd28e977446729dcee0dfaf40fc2b0a3975ebd55 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:18:51 +0530 Subject: [PATCH 06/41] Update conv.jl --- GNNLux/src/layers/conv.jl | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index da96dfc68..378dd61d6 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -667,17 +667,6 @@ function (l::NNConv)(g, x, edge_weight, ps, st) return y, stnew end -function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv) - weight = l.init_weight(rng, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight, bias) - else - return (; weight) - end -end - -LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims # nn wont affect this right? LuxCore.outputsize(d::NNConv) = (d.out_dims,) function Base.show(io::IO, l::GINConv) @@ -688,4 +677,4 @@ function Base.show(io::IO, l::GINConv) l.add_self_loops || print(io, ", add_self_loops=false") !l.use_edge_weight || print(io, ", use_edge_weight=true") print(io, ")") -end \ No newline at end of file +end From 7f1a07a5f9b7e8d8af54701e822fe53083b4e8d4 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:19:44 +0530 Subject: [PATCH 07/41] Update conv.jl --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 378dd61d6..cc838050e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -669,7 +669,7 @@ end LuxCore.outputsize(d::NNConv) = (d.out_dims,) -function Base.show(io::IO, l::GINConv) +function Base.show(io::IO, l::NNConv) print(io, "NNConv($(l.nn)") print(io, ", $(l.ϵ)") l.σ == identity || print(io, ", ", l.σ) From 70674a2adfca13c5365a8f49e327161fd38f5a18 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 12:30:48 +0530 Subject: [PATCH 08/41] added tests --- GNNLux/Project.toml | 14 +++++ GNNLux/src/layers/conv.jl | 1 - GNNLux/test/layers/conv_tests.jl | 7 +++ GNNLux/test/layers/temp.jl | 94 ++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 GNNLux/test/layers/temp.jl diff --git a/GNNLux/Project.toml b/GNNLux/Project.toml index 9f27ee3d1..d98416ece 100644 --- a/GNNLux/Project.toml +++ b/GNNLux/Project.toml @@ -4,15 +4,29 @@ authors = ["Carlo Lucibello and contributors"] version = "0.1.0" [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ConcreteStructs = "0.2.3" diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index da96dfc68..bb65fc4b0 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -659,7 +659,6 @@ end function (l::NNConv)(g, x, edge_weight, ps, st) nn = StatefulLuxLayer{true}(l.nn, ps, st) - # what would be the order of args here? m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.add_self_loops, l.use_edge_weight, l.σ) y = GNNlib.nn_conv(m, g, x, edge_weight) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9f010f39e..8db856803 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -93,4 +93,11 @@ l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + + @testset "NNConv" begin + edim = 10 + nn = Dense(edim, out_dims * in_dims) + l = NNConv(in_dims => out_dims, nn, tanh, bias = true, aggr = +) + test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) + end end diff --git a/GNNLux/test/layers/temp.jl b/GNNLux/test/layers/temp.jl new file mode 100644 index 000000000..46e4ba1a2 --- /dev/null +++ b/GNNLux/test/layers/temp.jl @@ -0,0 +1,94 @@ + + +@testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) +end + +@testset "SGConv" begin + l = SGConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) +end + + + +function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; + outputsize=nothing, sizey=nothing, container=false, + atol=1.0f-2, rtol=1.0f-2) + + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(l) == LuxCore.statelength(st) + + y, st′ = l(g, x, ps, st) + @test eltype(y) == eltype(x) + if outputsize !== nothing + @test LuxCore.outputsize(l) == outputsize + end + if sizey !== nothing + @test size(y) == sizey + elseif outputsize !== nothing + @test size(y) == (outputsize..., g.num_nodes) + end + + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) +end + +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme +using StableRNGs + +""" +MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean) +""" + +g = rand_graph(10, 40, seed=1234) + in_dims = 3 + out_dims = 5 + x = randn(Float32, in_dims, 10) + rng = StableRNG(1234) + l = MEGNetConv(in_dims => out_dims) + l + l isa GNNContainerLayer + test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) + + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + edata = rand(T, in_channel, g.num_edges) + + (x_new, e_new), st_new = l(g, x, ps, st) + + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) + + +nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) +l = GINConv(nn, 0.5) +test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) + + + + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + + + l = MEGNetConv(in_dims => out_dims) + l + l isa GNNContainerLayer + test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) + + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) \ No newline at end of file From 8f081cd0de5e75f04c854a85c0dd285b4188a662 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:33:57 +0530 Subject: [PATCH 09/41] Delete GNNLux/test/layers/temp.jl --- GNNLux/test/layers/temp.jl | 94 -------------------------------------- 1 file changed, 94 deletions(-) delete mode 100644 GNNLux/test/layers/temp.jl diff --git a/GNNLux/test/layers/temp.jl b/GNNLux/test/layers/temp.jl deleted file mode 100644 index 46e4ba1a2..000000000 --- a/GNNLux/test/layers/temp.jl +++ /dev/null @@ -1,94 +0,0 @@ - - -@testset "GINConv" begin - nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) - l = GINConv(nn, 0.5) - test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) -end - -@testset "SGConv" begin - l = SGConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) -end - - - -function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; - outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2) - - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) - @test LuxCore.statelength(l) == LuxCore.statelength(st) - - y, st′ = l(g, x, ps, st) - @test eltype(y) == eltype(x) - if outputsize !== nothing - @test LuxCore.outputsize(l) == outputsize - end - if sizey !== nothing - @test size(y) == sizey - elseif outputsize !== nothing - @test size(y) == (outputsize..., g.num_nodes) - end - - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) -end - -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme -using StableRNGs - -""" -MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean) -""" - -g = rand_graph(10, 40, seed=1234) - in_dims = 3 - out_dims = 5 - x = randn(Float32, in_dims, 10) - rng = StableRNG(1234) - l = MEGNetConv(in_dims => out_dims) - l - l isa GNNContainerLayer - test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) - - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - edata = rand(T, in_channel, g.num_edges) - - (x_new, e_new), st_new = l(g, x, ps, st) - - @test size(x_new) == (out_dims, g.num_nodes) - @test size(e_new) == (out_dims, g.num_edges) - - -nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) -l = GINConv(nn, 0.5) -test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) - - - - hin = 6 - hout = 7 - hidden = 8 - l = EGNNConv(hin => hout, hidden) - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - h = randn(rng, Float32, hin, g.num_nodes) - (hnew, xnew), stnew = l(g, h, x, ps, st) - @test size(hnew) == (hout, g.num_nodes) - @test size(xnew) == (in_dims, g.num_nodes) - - - l = MEGNetConv(in_dims => out_dims) - l - l isa GNNContainerLayer - test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) - - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) \ No newline at end of file From 890fcda374d6fd3254abc8a4189930b5f6891648 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 12:47:02 +0530 Subject: [PATCH 10/41] add to lux --- GNNLux/src/GNNLux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index d8970095c..831a96c45 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -31,7 +31,7 @@ export AGNNConv, # GMMConv, GraphConv, # MEGNetConv, - # NNConv, + NNConv, # ResGatedGraphConv, # SAGEConv, SGConv From fc2db99683dcf04d317916a6b054caebbe6c6129 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 13:06:54 +0530 Subject: [PATCH 11/41] fix test --- GNNLux/test/layers/basic_tests.jl | 2 +- GNNLux/test/layers/conv_tests.jl | 4 +- data.txt | 25145 ++++++++++++++++++++++++++++ sccript.py | 16 + 4 files changed, 25164 insertions(+), 3 deletions(-) create mode 100644 data.txt create mode 100644 sccript.py diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index ac937d128..cac2a45fa 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin rng = StableRNG(17) - g = rand_graph(rng, 10, 40) + g = rand_graph(10, 40) x = randn(rng, Float32, 3, 10) @testset "GNNLayer" begin diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 142435074..9f3722b05 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - g = rand_graph(rng, 10, 40) + g = rand_graph(10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) @@ -97,7 +97,7 @@ @testset "NNConv" begin edim = 10 nn = Dense(edim, out_dims * in_dims) - l = NNConv(in_dims => out_dims, nn, tanh, bias = true, aggr = +) + l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) end end diff --git a/data.txt b/data.txt new file mode 100644 index 000000000..5174930c5 --- /dev/null +++ b/data.txt @@ -0,0 +1,25145 @@ +[.\docs\make.jl] +using GraphNeuralNetworks +using GNNGraphs +using Flux +using NNlib +using Graphs +using SparseArrays +using Pluto, PlutoStaticHTML # for tutorials +using Documenter, DemoCards +using DocumenterInterLinks + + +tutorials, tutorials_cb, tutorial_assets = makedemos("tutorials") +assets = [] +isnothing(tutorial_assets) || push!(assets, tutorial_assets) + +interlinks = InterLinks( + "NNlib" => "https://fluxml.ai/NNlib.jl/stable/", + "Graphs" => "https://juliagraphs.org/Graphs.jl/stable/") + + +DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, + :(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux); + recursive = true) + +prettyurls = get(ENV, "CI", nothing) == "true" +mathengine = MathJax3() + +makedocs(; + modules = [GraphNeuralNetworks, GNNGraphs, GNNlib], + doctest = false, + clean = true, + plugins = [interlinks], + format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing), + sitename = "GraphNeuralNetworks.jl", + pages = ["Home" => "index.md", + "Graphs" => ["gnngraph.md", "heterograph.md", "temporalgraph.md"], + "Message Passing" => "messagepassing.md", + "Model Building" => "models.md", + "Datasets" => "datasets.md", + "Tutorials" => tutorials, + "API Reference" => [ + "GNNGraph" => "api/gnngraph.md", + "Basic Layers" => "api/basic.md", + "Convolutional Layers" => "api/conv.md", + "Pooling Layers" => "api/pool.md", + "Message Passing" => "api/messagepassing.md", + "Heterogeneous Graphs" => "api/heterograph.md", + "Temporal Graphs" => "api/temporalgraph.md", + "Utils" => "api/utils.md", + ], + "Developer Notes" => "dev.md", + "Summer Of Code" => "gsoc.md", + ]) + +tutorials_cb() + +deploydocs(repo = "github.com/CarloLucibello/GraphNeuralNetworks.jl.git") + +[.\docs\tutorials\introductory_tutorials\gnn_intro_pluto.jl] +### A Pluto.jl notebook ### +# v0.19.45 + +#> [frontmatter] +#> author = "[Carlo Lucibello](https://github.com/CarloLucibello)" +#> title = "Hands-on introduction to Graph Neural Networks" +#> date = "2022-05-22" +#> description = "A beginner level introduction to graph machine learning using GraphNeuralNetworks.jl" +#> cover = "assets/intro_1.png" + +using Markdown +using InteractiveUtils + +# ╔═╡ 42c84361-222a-46c4-b81f-d33eb41635c9 +begin + using Flux + using Flux: onecold, onehotbatch, logitcrossentropy + using MLDatasets + using LinearAlgebra, Random, Statistics + import GraphMakie + import CairoMakie as Makie + using Graphs + using PlutoUI + using GraphNeuralNetworks +end + +# ╔═╡ 03a9e023-e682-4ea3-a10b-14c4d101b291 +md""" +*This Pluto notebook is a Julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* + +Recently, deep learning on graphs has emerged to one of the hottest research fields in the deep learning community. +Here, **Graph Neural Networks (GNNs)** aim to generalize classical deep learning concepts to irregular structured data (in contrast to images or texts) and to enable neural networks to reason about objects and their relations. + +This is done by following a simple **neural message passing scheme**, where node features ``\mathbf{x}_i^{(\ell)}`` of all nodes ``i \in \mathcal{V}`` in a graph ``\mathcal{G} = (\mathcal{V}, \mathcal{E})`` are iteratively updated by aggregating localized information from their neighbors ``\mathcal{N}(i)``: + +```math +\mathbf{x}_i^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_i^{(\ell)}, \left\{ \mathbf{x}_j^{(\ell)} : j \in \mathcal{N}(i) \right\} \right) +``` + +This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/CarloLucibello/GraphNeuralNetworks.jl)**. +GraphNeuralNetworks.jl is an extension library to the popular deep learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/), and consists of various methods and utilities to ease the implementation of Graph Neural Networks. + +Let's first import the packages we need: +""" + +# ╔═╡ 361e0948-d91a-11ec-2d95-2db77435a0c1 +begin + ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation + Random.seed!(17) # for reproducibility +end; + +# ╔═╡ ef96f5ae-724d-4b8e-b7d7-c116ad1c3279 +md""" +Following [Kipf et al. (2017)](https://arxiv.org/abs/1609.02907), let's dive into the world of GNNs by looking at a simple graph-structured example, the well-known [**Zachary's karate club network**](https://en.wikipedia.org/wiki/Zachary%27s_karate_club). This graph describes a social network of 34 members of a karate club and documents links between members who interacted outside the club. Here, we are interested in detecting communities that arise from the member's interaction. + +GraphNeuralNetworks.jl provides utilities to convert [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl)'s datasets to its own type: +""" + +# ╔═╡ 4ba372d4-7a6a-41e0-92a0-9547a78e2898 +dataset = MLDatasets.KarateClub() + +# ╔═╡ 55aca2f0-4bbb-4d3a-9777-703896cfc548 +md""" +After initializing the `KarateClub` dataset, we first can inspect some of its properties. +For example, we can see that this dataset holds exactly **one graph**. +Furthermore, the graph holds exactly **4 classes**, which represent the community each node belongs to. +""" + +# ╔═╡ a1d35896-0f52-4c8b-b7dc-ec65649237c8 +karate = dataset[1] + +# ╔═╡ 48d7df25-9190-45c9-9829-140f452e5151 +karate.node_data.labels_comm + +# ╔═╡ 4598bf67-5448-4ce5-8be8-a473ab1a6a07 +md""" +Now we convert the single-graph dataset to a `GNNGraph`. Moreover, we add a an array of node features, a **34-dimensional feature vector** for each node which uniquely describes the members of the karate club. We also add a training mask selecting the nodes to be used for training in our semi-supervised node classification task. +""" + +# ╔═╡ 8d41a9fa-eefe-40c9-8cc3-cd503cf7434d +begin + # convert a MLDataset.jl's dataset to a GNNGraphs (or a collection of graphs) + g = mldataset2gnngraph(dataset) + + x = zeros(Float32, g.num_nodes, g.num_nodes) + x[diagind(x)] .= 1 + + train_mask = [true, false, false, false, true, false, false, false, true, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, true, false, false, false, false, false, + false, false, false, false] + + labels = g.ndata.labels_comm + y = onehotbatch(labels, 0:3) + + g = GNNGraph(g, ndata = (; x, y, train_mask)) +end + +# ╔═╡ c42c7f73-f84e-4e72-9af4-a6421af57f0d +md""" +Let's now look at the underlying graph in more detail: +""" + +# ╔═╡ a7ad9de3-3e18-4aff-b118-a4d798a2f4ec +with_terminal() do + # Gather some statistics about the graph. + println("Number of nodes: $(g.num_nodes)") + println("Number of edges: $(g.num_edges)") + println("Average node degree: $(g.num_edges / g.num_nodes)") + println("Number of training nodes: $(sum(g.ndata.train_mask))") + println("Training node label rate: $(mean(g.ndata.train_mask))") + # println("Has isolated nodes: $(has_isolated_nodes(g))") + println("Has self-loops: $(has_self_loops(g))") + println("Is undirected: $(is_bidirected(g))") +end + +# ╔═╡ 1e362709-a0d0-45d5-b2fd-a91c45fa317a +md""" +Each graph in GraphNeuralNetworks.jl is represented by a `GNNGraph` object, which holds all the information to describe its graph representation. +We can print the data object anytime via `print(g)` to receive a short summary about its attributes and their shapes. + +The `g` object holds 3 attributes: +- `g.ndata`: contains node-related information. +- `g.edata`: holds edge-related information. +- `g.gdata`: this stores the global data, therefore neither node nor edge-specific features. + +These attributes are `NamedTuples` that can store multiple feature arrays: we can access a specific set of features e.g. `x`, with `g.ndata.x`. + + +In our task, `g.ndata.train_mask` describes for which nodes we already know their community assignments. In total, we are only aware of the ground-truth labels of 4 nodes (one for each community), and the task is to infer the community assignment for the remaining nodes. + +The `g` object also provides some **utility functions** to infer some basic properties of the underlying graph. +For example, we can easily infer whether there exist isolated nodes in the graph (*i.e.* there exists no edge to any node), whether the graph contains self-loops (*i.e.*, ``(v, v) \in \mathcal{E}``), or whether the graph is bidirected (*i.e.*, for each edge ``(v, w) \in \mathcal{E}`` there also exists the edge ``(w, v) \in \mathcal{E}``). + +Let us now inspect the `edge_index` method: + +""" + +# ╔═╡ d627736a-fd5a-4cdc-bd4e-89ff8b8c55bd +edge_index(g) + +# ╔═╡ 98bb86d2-a7b9-4110-8851-8829a9f9b4d0 +md""" +By printing `edge_index(g)`, we can understand how GraphNeuralNetworks.jl represents graph connectivity internally. +We can see that for each edge, `edge_index` holds a tuple of two node indices, where the first value describes the node index of the source node and the second value describes the node index of the destination node of an edge. + +This representation is known as the **COO format (coordinate format)** commonly used for representing sparse matrices. +Instead of holding the adjacency information in a dense representation ``\mathbf{A} \in \{ 0, 1 \}^{|\mathcal{V}| \times |\mathcal{V}|}``, GraphNeuralNetworks.jl represents graphs sparsely, which refers to only holding the coordinates/values for which entries in ``\mathbf{A}`` are non-zero. + +Importantly, GraphNeuralNetworks.jl does not distinguish between directed and undirected graphs, and treats undirected graphs as a special case of directed graphs in which reverse edges exist for every entry in the `edge_index`. + +Since a `GNNGraph` is an `AbstractGraph` from the `Graphs.jl` library, it supports graph algorithms and visualization tools from the wider julia graph ecosystem: +""" + +# ╔═╡ 9820cc77-ae0a-454a-86b6-a23dbc56b6fd +GraphMakie.graphplot(g |> to_unidirected, node_size = 20, node_color = labels, + arrow_show = false) + +# ╔═╡ 86135c51-950c-4c08-b9e0-6c892234ff87 +md""" + +## Implementing Graph Neural Networks + +After learning about GraphNeuralNetworks.jl's data handling, it's time to implement our first Graph Neural Network! + +For this, we will use on of the most simple GNN operators, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)), which is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +``` + +where ``\mathbf{W}^{(\ell + 1)}`` denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and ``c_{w,v}`` refers to a fixed normalization coefficient for each edge. + +GraphNeuralNetworks.jl implements this layer via `GCNConv`, which can be executed by passing in the node feature representation `x` and the COO graph connectivity representation `edge_index`. + +With this, we are ready to create our first Graph Neural Network by defining our network architecture: +""" + +# ╔═╡ 88d1e59f-73d6-46ee-87e8-35beb7bc7674 +begin + struct GCN + layers::NamedTuple + end + + Flux.@layer GCN # provides parameter collection, gpu movement and more + + function GCN(num_features, num_classes) + layers = (conv1 = GCNConv(num_features => 4), + conv2 = GCNConv(4 => 4), + conv3 = GCNConv(4 => 2), + classifier = Dense(2, num_classes)) + return GCN(layers) + end + + function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix) + l = gcn.layers + x = l.conv1(g, x) + x = tanh.(x) + x = l.conv2(g, x) + x = tanh.(x) + x = l.conv3(g, x) + x = tanh.(x) # Final GNN embedding space. + out = l.classifier(x) + # Apply a final (linear) classifier. + return out, x + end +end + +# ╔═╡ 9838189c-5cf6-4f21-b58e-3bb905408ad3 +md""" + +Here, we first initialize all of our building blocks in the constructor and define the computation flow of our network in the call method. +We first define and stack **three graph convolution layers**, which corresponds to aggregating 3-hop neighborhood information around each node (all nodes up to 3 "hops" away). +In addition, the `GCNConv` layers reduce the node feature dimensionality to ``2``, *i.e.*, ``34 \rightarrow 4 \rightarrow 4 \rightarrow 2``. Each `GCNConv` layer is enhanced by a `tanh` non-linearity. + +After that, we apply a single linear transformation (`Flux.Dense` that acts as a classifier to map our nodes to 1 out of the 4 classes/communities. + +We return both the output of the final classifier as well as the final node embeddings produced by our GNN. +We proceed to initialize our final model via `GCN()`, and printing our model produces a summary of all its used sub-modules. + +### Embedding the Karate Club Network + +Let's take a look at the node embeddings produced by our GNN. +Here, we pass in the initial node features `x` and the graph information `g` to the model, and visualize its 2-dimensional embedding. +""" + +# ╔═╡ ad2c2e51-08ec-4ddc-9b5c-668a3688db12 +begin + num_features = 34 + num_classes = 4 + gcn = GCN(num_features, num_classes) +end + +# ╔═╡ ce26c963-0438-4ab2-b5c6-520272beef2b +_, h = gcn(g, g.ndata.x) + +# ╔═╡ e545e74f-0a3c-4d18-9cc7-557ca60be567 +function visualize_embeddings(h; colors = nothing) + xs = h[1, :] |> vec + ys = h[2, :] |> vec + Makie.scatter(xs, ys, color = labels, markersize = 20) +end + +# ╔═╡ 26138606-2e8d-435b-aa1a-b6159a0d2739 +visualize_embeddings(h, colors = labels) + +# ╔═╡ b9359c7d-b7fe-412d-8f5e-55ba6bccb4e9 +md""" +Remarkably, even before training the weights of our model, the model produces an embedding of nodes that closely resembles the community-structure of the graph. +Nodes of the same color (community) are already closely clustered together in the embedding space, although the weights of our model are initialized **completely at random** and we have not yet performed any training so far! +This leads to the conclusion that GNNs introduce a strong inductive bias, leading to similar embeddings for nodes that are close to each other in the input graph. + +### Training on the Karate Club Network + +But can we do better? Let's look at an example on how to train our network parameters based on the knowledge of the community assignments of 4 nodes in the graph (one for each community). + +Since everything in our model is differentiable and parameterized, we can add some labels, train the model and observe how the embeddings react. +Here, we make use of a semi-supervised or transductive learning procedure: we simply train against one node per class, but are allowed to make use of the complete input graph data. + +Training our model is very similar to any other Flux model. +In addition to defining our network architecture, we define a loss criterion (here, `logitcrossentropy`), and initialize a stochastic gradient optimizer (here, `Adam`). +After that, we perform multiple rounds of optimization, where each round consists of a forward and backward pass to compute the gradients of our model parameters w.r.t. to the loss derived from the forward pass. +If you are not new to Flux, this scheme should appear familiar to you. + +Note that our semi-supervised learning scenario is achieved by the following line: +``` +loss = logitcrossentropy(ŷ[:,train_mask], y[:,train_mask]) +``` +While we compute node embeddings for all of our nodes, we **only make use of the training nodes for computing the loss**. +Here, this is implemented by filtering the output of the classifier `out` and ground-truth labels `data.y` to only contain the nodes in the `train_mask`. + +Let us now start training and see how our node embeddings evolve over time (best experienced by explicitly running the code): +""" + +# ╔═╡ 912560a1-9c72-47bd-9fce-9702b346b603 +begin + model = GCN(num_features, num_classes) + opt = Flux.setup(Adam(1e-2), model) + epochs = 2000 + + emb = h + function report(epoch, loss, h) + # p = visualize_embeddings(h) + @info (; epoch, loss) + end + + report(0, 10.0, emb) + for epoch in 1:epochs + loss, grad = Flux.withgradient(model) do model + ŷ, emb = model(g, g.ndata.x) + logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) + end + + Flux.update!(opt, model, grad[1]) + if epoch % 200 == 0 + report(epoch, loss, emb) + end + end +end + +# ╔═╡ c8a217c9-0087-41f0-90c8-aac29bc1c996 +ŷ, emb_final = model(g, g.ndata.x) + +# ╔═╡ 727b24bc-0b1e-4ebd-b8ef-987015751e38 +# train accuracy +mean(onecold(ŷ[:, train_mask]) .== onecold(y[:, train_mask])) + +# ╔═╡ 8c60ec7e-46b0-40f7-bf6a-6228a31e1f66 +# test accuracy +mean(onecold(ŷ[:, .!train_mask]) .== onecold(y[:, .!train_mask])) + +# ╔═╡ 44d9f8cf-1023-48ad-a01f-07e59f4b4226 +visualize_embeddings(emb_final, colors = labels) + +# ╔═╡ a8841d35-97f9-431d-acab-abf478ce91a9 +md""" +As one can see, our 3-layer GCN model manages to linearly separating the communities and classifying most of the nodes correctly. + +Furthermore, we did this all with a few lines of code, thanks to the GraphNeuralNetworks.jl which helped us out with data handling and GNN implementations. +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[compat] +CairoMakie = "~0.12.5" +Flux = "~0.14.16" +GraphMakie = "~0.5.12" +GraphNeuralNetworks = "~0.6.19" +Graphs = "~1.11.2" +MLDatasets = "~0.7.16" +PlutoUI = "~0.7.59" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "0bbe321bcd3061714ce11e8a8428022b3809de5f" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.AbstractPlutoDingetjes]] +deps = ["Pkg"] +git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a" +uuid = "6e696c72-6542-2067-7265-42206c756150" +version = "1.3.2" + +[[deps.AbstractTrees]] +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.5" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + +[[deps.Animations]] +deps = ["Colors"] +git-tree-sha1 = "e81c509d2c8e49592413bfb0bb3b08150056c79d" +uuid = "27a7e980-b3e6-11e9-2bcd-0b925532e340" +version = "0.4.1" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.Automa]] +deps = ["PrecompileTools", "TranscodingStreams"] +git-tree-sha1 = "014bc22d6c400a7703c0f5dc1fdc302440cf88be" +uuid = "67c07d97-cdcb-5c2c-af73-a7f9c32a568b" +version = "1.0.4" + +[[deps.AxisAlgorithms]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] +git-tree-sha1 = "01b8ccb13d68535d73d2b0c23e39bd23155fb712" +uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" +version = "1.1.0" + +[[deps.AxisArrays]] +deps = ["Dates", "IntervalSets", "IterTools", "RangeArrays"] +git-tree-sha1 = "16351be62963a67ac4083f748fdb3cca58bfd52f" +uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +version = "0.4.7" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CRC32c]] +uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" + +[[deps.CRlibm_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc" +uuid = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" +version = "1.0.1+0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + +[[deps.Cairo]] +deps = ["Cairo_jll", "Colors", "Glib_jll", "Graphics", "Libdl", "Pango_jll"] +git-tree-sha1 = "d0b3f8b4ad16cb0a2988c6788646a5e6a17b6b1b" +uuid = "159f3aea-2a34-519c-b102-8c37f9878175" +version = "1.0.5" + +[[deps.CairoMakie]] +deps = ["CRC32c", "Cairo", "Cairo_jll", "Colors", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools"] +git-tree-sha1 = "e4da5095557f24713bae4c9f50e34ff4d3b959c0" +uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +version = "0.12.5" + +[[deps.Cairo_jll]] +deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" +version = "1.18.0+2" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.69.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.24.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + +[[deps.ColorBrewer]] +deps = ["Colors", "JSON", "Test"] +git-tree-sha1 = "61c5334f33d91e570e1d0c3eb5465835242582c4" +uuid = "a2cac450-b92f-5266-8821-25eda20663c8" +version = "0.4.0" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.26.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.6" +weakdeps = ["IntervalSets", "StaticArrays"] + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Contour]] +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.6.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelaunayTriangulation]] +deps = ["EnumX", "ExactPredicates", "Random"] +git-tree-sha1 = "078c716cbb032242df18b960e8b1fec6b1b0b9f9" +uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" +version = "1.0.5" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.109" + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" + + [deps.Distributions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + +[[deps.EarCut_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" +uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" +version = "2.2.4+0" + +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.ExactPredicates]] +deps = ["IntervalArithmetic", "Random", "StaticArrays"] +git-tree-sha1 = "b3f2ff58735b5f024c392fde763f29b057e4b025" +uuid = "429591f6-91af-11e9-00e2-59fbe8cec110" +version = "2.2.8" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.6.2+0" + +[[deps.Extents]] +git-tree-sha1 = "94997910aca72897524d2237c41eb852153b0f65" +uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" +version = "0.1.3" + +[[deps.FFMPEG_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "ab3f7e1819dba9434a3a5126510c8fda3a4e7000" +uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" +version = "6.1.1+0" + +[[deps.FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] +git-tree-sha1 = "4820348781ae578893311153d69049a93d05f39d" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.8.0" + +[[deps.FFTW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.10+0" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FilePaths]] +deps = ["FilePathsBase", "MacroTools", "Reexport", "Requires"] +git-tree-sha1 = "919d9412dbf53a2e6fe74af62a73ceed0bce0629" +uuid = "8fc22ac5-c921-52a6-82fd-178b2807b824" +version = "0.8.3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.16" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.Fontconfig_jll]] +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] +git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" +uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" +version = "2.13.96+0" + +[[deps.Format]] +git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" +uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +version = "1.3.7" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FreeType]] +deps = ["CEnum", "FreeType2_jll"] +git-tree-sha1 = "907369da0f8e80728ab49c1c7e09327bf0d6d999" +uuid = "b38be410-82b0-50bf-ab77-7b57e271db43" +version = "4.1.1" + +[[deps.FreeType2_jll]] +deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" +uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" +version = "2.13.2+0" + +[[deps.FreeTypeAbstraction]] +deps = ["ColorVectorSpace", "Colors", "FreeType", "GeometryBasics"] +git-tree-sha1 = "2493cdfd0740015955a8e46de4ef28f49460d8bc" +uuid = "663a7486-cb36-511b-a19d-713bb74d65c9" +version = "0.10.3" + +[[deps.FriBidi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" +uuid = "559328eb-81f9-559d-9380-de523a88c83c" +version = "1.0.14+0" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.11" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.GeoInterface]] +deps = ["Extents"] +git-tree-sha1 = "9fff8990361d5127b770e3454488360443019bb3" +uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" +version = "1.3.5" + +[[deps.GeometryBasics]] +deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "b62f2b2d76cee0d61a2ef2b3118cd2a3215d3134" +uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" +version = "0.4.11" + +[[deps.Gettext_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" +uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" +version = "0.21.0+0" + +[[deps.Glib_jll]] +deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" +version = "2.80.2+0" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.GraphMakie]] +deps = ["DataStructures", "GeometryBasics", "Graphs", "LinearAlgebra", "Makie", "NetworkLayout", "PolynomialRoots", "SimpleTraits", "StaticArrays"] +git-tree-sha1 = "c8c3ece1211905888da48e16f438af85e951ea55" +uuid = "1ecd5474-83a3-4783-bb4f-06765db800d2" +version = "0.5.12" + +[[deps.GraphNeuralNetworks]] +deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" +uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" +version = "0.6.19" + + [deps.GraphNeuralNetworks.extensions] + GraphNeuralNetworksCUDAExt = "CUDA" + GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + + [deps.GraphNeuralNetworks.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[[deps.Graphics]] +deps = ["Colors", "LinearAlgebra", "NaNMath"] +git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd" +uuid = "a2bd30eb-e257-5431-a919-1863eab51364" +version = "1.1.2" + +[[deps.Graphite2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" +uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" +version = "1.3.14+0" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.11.2" + +[[deps.GridLayoutBase]] +deps = ["GeometryBasics", "InteractiveUtils", "Observables"] +git-tree-sha1 = "fc713f007cff99ff9e50accba6373624ddd33588" +uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" +version = "0.11.0" + +[[deps.Grisu]] +git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" +uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" +version = "1.0.2" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.3+3" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.HarfBuzz_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" +uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" +version = "2.8.1+1" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.1+0" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.23" + +[[deps.Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "179267cfa5e712760cd43dcae385d7ea90cc25a4" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.5" + +[[deps.HypertextLiteral]] +deps = ["Tricks"] +git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.5" + +[[deps.IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.5" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.ImageAxes]] +deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] +git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" +uuid = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac" +version = "0.6.11" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageIO]] +deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] +git-tree-sha1 = "437abb322a41d527c197fa800455f79d414f0a3c" +uuid = "82e4d734-157c-48bb-816b-45c225c6df19" +version = "0.6.8" + +[[deps.ImageMetadata]] +deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] +git-tree-sha1 = "355e2b974f2e3212a75dfb60519de21361ad3cb7" +uuid = "bc367c6b-8a6b-528e-b4bd-a4b897500b49" +version = "0.9.9" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + +[[deps.Imath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "0936ba688c6d201805a83da835b55c61a180db52" +uuid = "905a6f67-0a94-5f89-b386-d35d92009cd1" +version = "3.1.11+0" + +[[deps.IndirectArrays]] +git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" +uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" +version = "1.0.0" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.IntelOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "14eb2b542e748570b56446f4c50fbfb2306ebc45" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2024.2.0+0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.Interpolations]] +deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] +git-tree-sha1 = "88a101217d7cb38a7b481ccd50d21876e1d1b0e0" +uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +version = "0.15.1" +weakdeps = ["Unitful"] + + [deps.Interpolations.extensions] + InterpolationsUnitfulExt = "Unitful" + +[[deps.IntervalArithmetic]] +deps = ["CRlibm_jll", "MacroTools", "RoundingEmulator"] +git-tree-sha1 = "433b0bb201cd76cb087b017e49244f10394ebe9c" +uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" +version = "0.22.14" + + [deps.IntervalArithmetic.extensions] + IntervalArithmeticDiffRulesExt = "DiffRules" + IntervalArithmeticForwardDiffExt = "ForwardDiff" + IntervalArithmeticRecipesBaseExt = "RecipesBase" + + [deps.IntervalArithmetic.weakdeps] + DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + +[[deps.IntervalSets]] +git-tree-sha1 = "dba9ddf07f77f60450fe5d2e2beb9854d9a49bd0" +uuid = "8197267c-284f-5f27-9208-e0e47529a953" +version = "0.7.10" + + [deps.IntervalSets.extensions] + IntervalSetsRandomExt = "Random" + IntervalSetsRecipesBaseExt = "RecipesBase" + IntervalSetsStatisticsExt = "Statistics" + + [deps.IntervalSets.weakdeps] + Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.15" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.Isoband]] +deps = ["isoband_jll"] +git-tree-sha1 = "f9b6d97355599074dc867318950adaa6f9946137" +uuid = "f1662d9f-8043-43de-a69a-05efc1cc6ff4" +version = "0.1.1" + +[[deps.IterTools]] +git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.10.0" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.50" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JpegTurbo]] +deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] +git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" +uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" +version = "0.1.5" + +[[deps.JpegTurbo_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" +version = "3.0.3+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.22" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.KernelDensity]] +deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] +git-tree-sha1 = "7d703202e65efa1369de1279c162b915e245eed1" +uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +version = "0.6.9" + +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] +git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.7.1" + +[[deps.LAME_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" +uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" +version = "3.100.2+0" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.0.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.30+0" + +[[deps.LLVMOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" +uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" +version = "15.0.7+0" + +[[deps.LZO_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" +version = "2.10.2+0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libffi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" +uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" +version = "3.2.2+1" + +[[deps.Libgcrypt_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] +git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" +uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" +version = "1.8.11+0" + +[[deps.Libgpg_error_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" +uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" +version = "1.49.0+0" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.Libmount_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" +uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" +version = "2.40.1+0" + +[[deps.Libuuid_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" +uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" +version = "2.40.1+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MIMEs]] +git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" +uuid = "6c6e2e6c-3030-632d-7369-2d6c69616d65" +version = "0.1.4" + +[[deps.MKL_jll]] +deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "oneTBB_jll"] +git-tree-sha1 = "f046ccd0c6db2832a9f639e2c669c6fe867e5f4f" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2024.2.0+0" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.16" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.4.0+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Makie]] +deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "Dates", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FilePaths", "FixedPointNumbers", "Format", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Scratch", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun", "Unitful"] +git-tree-sha1 = "863b9e666b5a099c8835e85476a5834f9d77c4c1" +uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +version = "0.21.5" + +[[deps.MakieCore]] +deps = ["ColorTypes", "GeometryBasics", "IntervalSets", "Observables"] +git-tree-sha1 = "c1c950560397ee68ad7302ee0e3efa1b07466a2f" +uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" +version = "0.8.4" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MathTeXEngine]] +deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "UnicodeFun"] +git-tree-sha1 = "e1641f32ae592e415e3dbae7f4a188b5316d4b62" +uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53" +version = "0.6.1" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + +[[deps.Netpbm]] +deps = ["FileIO", "ImageCore", "ImageMetadata"] +git-tree-sha1 = "d92b107dbb887293622df7697a2223f9f8176fcd" +uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" +version = "1.1.1" + +[[deps.NetworkLayout]] +deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "StaticArrays"] +git-tree-sha1 = "91bb2fedff8e43793650e7a677ccda6e6e6e166b" +uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" +version = "0.4.6" +weakdeps = ["Graphs"] + + [deps.NetworkLayout.extensions] + NetworkLayoutGraphsExt = "Graphs" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.Observables]] +git-tree-sha1 = "7438a59546cf62428fc9d1bc94729146d37a7225" +uuid = "510215fc-4207-5dde-b226-833fc4488ee2" +version = "0.5.5" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.Ogg_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" +uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" +version = "1.3.5+1" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenEXR]] +deps = ["Colors", "FileIO", "OpenEXR_jll"] +git-tree-sha1 = "327f53360fdb54df7ecd01e96ef1983536d1e633" +uuid = "52e1d378-f018-4a11-a4be-720524705ac7" +version = "0.3.2" + +[[deps.OpenEXR_jll]] +deps = ["Artifacts", "Imath_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "8292dd5c8a38257111ada2174000a33745b06d4e" +uuid = "18a262bb-aa17-5467-a713-aee519bc75cb" +version = "3.2.4+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.Opus_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +uuid = "91d4177d-7536-5919-b921-800302f37372" +version = "1.3.2+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PCRE2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.42.0+1" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.31" + +[[deps.PNGFiles]] +deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] +git-tree-sha1 = "67186a2bc9a90f9f85ff3cc8277868961fb57cbd" +uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" +version = "0.4.3" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.Packing]] +deps = ["GeometryBasics"] +git-tree-sha1 = "ec3edfe723df33528e085e632414499f26650501" +uuid = "19eb6ba3-879d-56ad-ad62-d5c202156566" +version = "0.5.0" + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Pango_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "cb5a2ab6763464ae0f19c86c56c63d4a2b0f5bda" +uuid = "36c8627f-9965-5494-a995-c6b170f724f3" +version = "1.52.2+0" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pixman_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" +uuid = "30392449-352a-5448-841d-b1acce4e97dc" +version = "0.43.4+0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PkgVersion]] +deps = ["Pkg"] +git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da" +uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" +version = "0.3.3" + +[[deps.PlotUtils]] +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" +version = "1.4.1" + +[[deps.PlutoUI]] +deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] +git-tree-sha1 = "ab55ee1510ad2af0ff674dbcced5e94921f867a9" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.59" + +[[deps.PolygonOps]] +git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" +uuid = "647866c9-e3ac-4575-94e7-e3d426903924" +version = "0.1.2" + +[[deps.PolynomialRoots]] +git-tree-sha1 = "5f807b5345093487f733e520a1b7395ee9324825" +uuid = "3a141323-8675-5d76-9d11-e1df1406c778" +version = "1.0.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.10.2" + +[[deps.PtrArrays]] +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.0" + +[[deps.QOI]] +deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] +git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" +uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" +version = "1.0.0" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.9.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RangeArrays]] +git-tree-sha1 = "b9039e93773ddcfc828f12aadf7115b4b4d225f5" +uuid = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d" +version = "0.3.2" + +[[deps.Ratios]] +deps = ["Requires"] +git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b" +uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" +version = "0.4.5" +weakdeps = ["FixedPointNumbers"] + + [deps.Ratios.extensions] + RatiosFixedPointNumbersExt = "FixedPointNumbers" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.1" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.1" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.4.2+0" + +[[deps.RoundingEmulator]] +git-tree-sha1 = "40b9edad2e5287e05bd413a38f61a8ff55b9557b" +uuid = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705" +version = "0.2.1" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SIMD]] +deps = ["PrecompileTools"] +git-tree-sha1 = "2803cab51702db743f3fda07dd1745aadfbf43bd" +uuid = "fdea26ae-647d-5447-a871-4b548cad5224" +version = "3.5.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShaderAbstractions]] +deps = ["ColorTypes", "FixedPointNumbers", "GeometryBasics", "LinearAlgebra", "Observables", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "79123bc60c5507f035e6d1d9e563bb2971954ec8" +uuid = "65257c39-d410-5151-9873-9b3e5be5013e" +version = "0.4.1" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.Showoff]] +deps = ["Dates", "Grisu"] +git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" +uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" +version = "1.0.3" + +[[deps.SignedDistanceFields]] +deps = ["Random", "Statistics", "Test"] +git-tree-sha1 = "d263a08ec505853a5ff1c1ebde2070419e3f28e9" +uuid = "73760f76-fbc4-59ce-8f25-708e95d2df96" +version = "0.4.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sixel]] +deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] +git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" +uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" +version = "0.1.3" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StatsFuns]] +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.3.1" +weakdeps = ["ChainRulesCore", "InverseFunctions"] + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TiffImages]] +deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "SIMD", "UUIDs"] +git-tree-sha1 = "bc7fd5c91041f44636b2c134041f7e5263ce58ae" +uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" +version = "0.10.0" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.1" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + +[[deps.TriplotBase]] +git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" +uuid = "981d1d27-644d-49a2-9326-4793e63143c3" +version = "0.1.0" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnicodeFun]] +deps = ["REPL"] +git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" +uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" +version = "0.4.1" + +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.21.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.5" + +[[deps.VectorInterface]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" +uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +version = "0.4.6" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WoodburyMatrices]] +deps = ["LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "c1a7aa6219628fcd757dede0ca95e245c5cd9511" +uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" +version = "1.0.0" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] +git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.13.1+0" + +[[deps.XSLT_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] +git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" +uuid = "aed1982a-8fda-507f-9586-7b0439959a61" +version = "1.1.41+0" + +[[deps.Xorg_libX11_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] +git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" +uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" +version = "1.8.6+0" + +[[deps.Xorg_libXau_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" +uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" +version = "1.0.11+0" + +[[deps.Xorg_libXdmcp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" +uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" +version = "1.1.4+0" + +[[deps.Xorg_libXext_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" +uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" +version = "1.3.6+0" + +[[deps.Xorg_libXrender_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" +uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" +version = "0.9.11+0" + +[[deps.Xorg_libpthread_stubs_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" +uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" +version = "0.1.1+0" + +[[deps.Xorg_libxcb_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] +git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" +uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" +version = "1.17.0+0" + +[[deps.Xorg_xtrans_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" +uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" +version = "1.5.0+0" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.isoband_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "51b5eeb3f98367157a7a12a1fb0aa5328946c03c" +uuid = "9a68df92-36a6-505f-a73e-abb412b6bfb4" +version = "0.2.3+0" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + +[[deps.libaom_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" +uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" +version = "3.9.0+0" + +[[deps.libass_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" +version = "0.15.1+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.libfdk_aac_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" +version = "2.0.2+0" + +[[deps.libpng_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" +version = "1.6.43+1" + +[[deps.libsixel_jll]] +deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "libpng_jll"] +git-tree-sha1 = "d4f63314c8aa1e48cd22aa0c17ed76cd1ae48c3c" +uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" +version = "1.10.3+0" + +[[deps.libvorbis_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] +git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" +uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" +version = "1.3.7+2" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.oneTBB_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "7d0ea0f4895ef2f5cb83645fa689e52cb55cf493" +uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e" +version = "2021.12.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" + +[[deps.x264_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" +uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" +version = "2021.5.5+0" + +[[deps.x265_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" +uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" +version = "3.5.0+0" +""" + +# ╔═╡ Cell order: +# ╟─03a9e023-e682-4ea3-a10b-14c4d101b291 +# ╠═42c84361-222a-46c4-b81f-d33eb41635c9 +# ╠═361e0948-d91a-11ec-2d95-2db77435a0c1 +# ╟─ef96f5ae-724d-4b8e-b7d7-c116ad1c3279 +# ╠═4ba372d4-7a6a-41e0-92a0-9547a78e2898 +# ╟─55aca2f0-4bbb-4d3a-9777-703896cfc548 +# ╠═a1d35896-0f52-4c8b-b7dc-ec65649237c8 +# ╠═48d7df25-9190-45c9-9829-140f452e5151 +# ╟─4598bf67-5448-4ce5-8be8-a473ab1a6a07 +# ╠═8d41a9fa-eefe-40c9-8cc3-cd503cf7434d +# ╟─c42c7f73-f84e-4e72-9af4-a6421af57f0d +# ╠═a7ad9de3-3e18-4aff-b118-a4d798a2f4ec +# ╟─1e362709-a0d0-45d5-b2fd-a91c45fa317a +# ╠═d627736a-fd5a-4cdc-bd4e-89ff8b8c55bd +# ╟─98bb86d2-a7b9-4110-8851-8829a9f9b4d0 +# ╠═9820cc77-ae0a-454a-86b6-a23dbc56b6fd +# ╟─86135c51-950c-4c08-b9e0-6c892234ff87 +# ╠═88d1e59f-73d6-46ee-87e8-35beb7bc7674 +# ╟─9838189c-5cf6-4f21-b58e-3bb905408ad3 +# ╠═ad2c2e51-08ec-4ddc-9b5c-668a3688db12 +# ╠═ce26c963-0438-4ab2-b5c6-520272beef2b +# ╠═e545e74f-0a3c-4d18-9cc7-557ca60be567 +# ╠═26138606-2e8d-435b-aa1a-b6159a0d2739 +# ╟─b9359c7d-b7fe-412d-8f5e-55ba6bccb4e9 +# ╠═912560a1-9c72-47bd-9fce-9702b346b603 +# ╠═c8a217c9-0087-41f0-90c8-aac29bc1c996 +# ╠═727b24bc-0b1e-4ebd-b8ef-987015751e38 +# ╠═8c60ec7e-46b0-40f7-bf6a-6228a31e1f66 +# ╠═44d9f8cf-1023-48ad-a01f-07e59f4b4226 +# ╟─a8841d35-97f9-431d-acab-abf478ce91a9 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 + +[.\docs\tutorials\introductory_tutorials\graph_classification_pluto.jl] +### A Pluto.jl notebook ### +# v0.19.45 + +#> [frontmatter] +#> author = "[Carlo Lucibello](https://github.com/CarloLucibello)" +#> title = "Graph Classification with Graph Neural Networks" +#> date = "2022-05-23" +#> description = "Tutorial for Graph Classification using GraphNeuralNetworks.jl" +#> cover = "assets/graph_classification.gif" + +using Markdown +using InteractiveUtils + +# ╔═╡ 361e0948-d91a-11ec-2d95-2db77435a0c1 +# ╠═╡ show_logs = false +begin + using Flux + using Flux: onecold, onehotbatch, logitcrossentropy + using Flux: DataLoader + using GraphNeuralNetworks + using MLDatasets + using MLUtils + using LinearAlgebra, Random, Statistics + + ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation + Random.seed!(17) # for reproducibility +end; + +# ╔═╡ 15136fd8-f9b2-4841-9a95-9de7b8969687 +md""" +*This Pluto notebook is a julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* + +In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. +Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties. +Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. + + +The most common task for graph classification is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. + +The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. +Let's load and inspect one of the smaller ones, the **MUTAG dataset**: +""" + +# ╔═╡ f6e86958-e96f-4c77-91fc-c72d8967575c +dataset = TUDataset("MUTAG") + +# ╔═╡ 24f76360-8599-46c8-a49f-4c31f02eb7d8 +dataset.graph_data.targets |> union + +# ╔═╡ 5d5e5152-c860-4158-8bc7-67ee1022f9f8 +g1, y1 = dataset[1] #get the first graph and target + +# ╔═╡ 33163dd2-cb35-45c7-ae5b-d4854d141773 +reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union + +# ╔═╡ a8d6a133-a828-4d51-83c4-fb44f9d5ede1 +reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union + +# ╔═╡ 3b3e0a79-264b-47d7-8bda-2a6db7290828 +md""" +This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. + +By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. +It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). +However, for the sake of simplicity, we will not make use of edge labels. +""" + +# ╔═╡ 7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb +md""" +We now convert the MLDatasets.jl graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): +""" + +# ╔═╡ 936c09f6-ee62-4bc2-a0c6-749a66080fd2 +begin + graphs = mldataset2gnngraph(dataset) + graphs = [GNNGraph(g, + ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), + edata = nothing) + for g in graphs] + y = onehotbatch(dataset.graph_data.targets, [-1, 1]) +end + +# ╔═╡ 2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4 +md"""We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: +""" + +# ╔═╡ 519477b2-8323-4ece-a7eb-141e9841117c +train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs + +# ╔═╡ 3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b +begin + train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) + test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) +end + +# ╔═╡ f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4 +md""" +Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all ``4 \cdot 32+22 = 150`` graphs. +""" + +# ╔═╡ 2a1c501e-811b-4ddd-887b-91e8c929c8b7 +md""" +## Mini-batching of graphs + +Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. +In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. +The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. + +However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. +Therefore, GraphNeuralNetworks.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). + +This procedure has some crucial advantages over other batching procedures: + +1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. + +2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. + +GraphNeuralNetworks.jl can **batch multiple graphs into a single giant graph**: +""" + +# ╔═╡ a142610a-d862-42a9-88af-c8d8b6825650 +vec_gs, _ = first(train_loader) + +# ╔═╡ 6faaf637-a0ff-468c-86b5-b0a7250258d6 +MLUtils.batch(vec_gs) + +# ╔═╡ e314b25f-e904-4c39-bf60-24cddf91fe9d +md""" +Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: + +```math +\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] +``` +""" + +# ╔═╡ ac69571a-998b-4630-afd6-f3d405618bc5 +md""" +## Training a Graph Neural Network (GNN) + +Training a GNN for graph classification usually follows a simple recipe: + +1. Embed each node by performing multiple rounds of message passing +2. Aggregate node embeddings into a unified graph embedding (**readout layer**) +3. Train a final classifier on the graph embedding + +There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: + +```math +\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v +``` + +GraphNeuralNetworks.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. + +The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: +""" + +# ╔═╡ 04402032-18a4-42b5-ad04-19b286bd29b7 +function create_model(nin, nh, nout) + GNNChain(GCNConv(nin => nh, relu), + GCNConv(nh => nh, relu), + GCNConv(nh => nh), + GlobalPool(mean), + Dropout(0.5), + Dense(nh, nout)) +end + +# ╔═╡ 2313fd8d-6e84-4bde-bacc-fb697dc33cbb +md""" +Here, we again make use of the `GCNConv` with ``\mathrm{ReLU}(x) = \max(x, 0)`` activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. + +Let's train our network for a few epochs to see how well it performs on the training as well as test set: +""" + +# ╔═╡ c956ed97-fa5c-45c6-84dd-39f3e37d8070 +function eval_loss_accuracy(model, data_loader, device) + loss = 0.0 + acc = 0.0 + ntot = 0 + for (g, y) in data_loader + g, y = MLUtils.batch(g) |> device, y |> device + n = length(y) + ŷ = model(g, g.ndata.x) + loss += logitcrossentropy(ŷ, y) * n + acc += mean((ŷ .> 0) .== y) * n + ntot += n + end + return (loss = round(loss / ntot, digits = 4), + acc = round(acc * 100 / ntot, digits = 2)) +end + +# ╔═╡ 968c7087-7637-4844-9509-dd838cf99a8c +function train!(model; epochs = 200, η = 1e-2, infotime = 10) + # device = Flux.gpu # uncomment this for GPU training + device = Flux.cpu + model = model |> device + opt = Flux.setup(Adam(1e-3), model) + + function report(epoch) + train = eval_loss_accuracy(model, train_loader, device) + test = eval_loss_accuracy(model, test_loader, device) + @info (; epoch, train, test) + end + + report(0) + for epoch in 1:epochs + for (g, y) in train_loader + g, y = MLUtils.batch(g) |> device, y |> device + grad = Flux.gradient(model) do model + ŷ = model(g, g.ndata.x) + logitcrossentropy(ŷ, y) + end + Flux.update!(opt, model, grad[1]) + end + epoch % infotime == 0 && report(epoch) + end +end + +# ╔═╡ dedf18d8-4281-49fa-adaf-bd57fc15095d +begin + nin = 7 + nh = 64 + nout = 2 + model = create_model(nin, nh, nout) + train!(model) +end + +# ╔═╡ 3454b311-9545-411d-b47a-b43724b84c36 +md""" +As one can see, our model reaches around **74% test accuracy**. +Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. + +## (Optional) Exercise + +Can we do better than this? +As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. +An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: + +```math +\mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} +``` + +This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. + +As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. +This should bring you close to **82% test accuracy**. +""" + +# ╔═╡ 93e08871-2929-4279-9f8a-587168617365 +md""" +## Conclusion + +In this chapter, you have learned how to apply GNNs to the task of graph classification. +You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[compat] +Flux = "~0.14.16" +GraphNeuralNetworks = "~0.6.19" +MLDatasets = "~0.7.16" +MLUtils = "~0.4.4" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "4d31565cd40e53ce5e158a179486a694e9c7da67" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.69.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.24.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.26.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.6" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.16" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.11" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.GraphNeuralNetworks]] +deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" +uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" +version = "0.6.19" + + [deps.GraphNeuralNetworks.extensions] + GraphNeuralNetworksCUDAExt = "CUDA" + GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + + [deps.GraphNeuralNetworks.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.11.2" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.3+3" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.1+0" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.15" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.50" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.22" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] +git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.7.1" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.0.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.30+0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.16" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.4.0+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.1" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.21.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.5" + +[[deps.VectorInterface]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" +uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +version = "0.4.6" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" +""" + +# ╔═╡ Cell order: +# ╠═361e0948-d91a-11ec-2d95-2db77435a0c1 +# ╟─15136fd8-f9b2-4841-9a95-9de7b8969687 +# ╠═f6e86958-e96f-4c77-91fc-c72d8967575c +# ╠═24f76360-8599-46c8-a49f-4c31f02eb7d8 +# ╠═5d5e5152-c860-4158-8bc7-67ee1022f9f8 +# ╠═33163dd2-cb35-45c7-ae5b-d4854d141773 +# ╠═a8d6a133-a828-4d51-83c4-fb44f9d5ede1 +# ╟─3b3e0a79-264b-47d7-8bda-2a6db7290828 +# ╟─7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb +# ╠═936c09f6-ee62-4bc2-a0c6-749a66080fd2 +# ╟─2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4 +# ╠═519477b2-8323-4ece-a7eb-141e9841117c +# ╠═3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b +# ╟─f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4 +# ╟─2a1c501e-811b-4ddd-887b-91e8c929c8b7 +# ╠═a142610a-d862-42a9-88af-c8d8b6825650 +# ╠═6faaf637-a0ff-468c-86b5-b0a7250258d6 +# ╟─e314b25f-e904-4c39-bf60-24cddf91fe9d +# ╟─ac69571a-998b-4630-afd6-f3d405618bc5 +# ╠═04402032-18a4-42b5-ad04-19b286bd29b7 +# ╟─2313fd8d-6e84-4bde-bacc-fb697dc33cbb +# ╠═c956ed97-fa5c-45c6-84dd-39f3e37d8070 +# ╠═968c7087-7637-4844-9509-dd838cf99a8c +# ╠═dedf18d8-4281-49fa-adaf-bd57fc15095d +# ╟─3454b311-9545-411d-b47a-b43724b84c36 +# ╟─93e08871-2929-4279-9f8a-587168617365 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 + +[.\docs\tutorials\introductory_tutorials\node_classification_pluto.jl] +### A Pluto.jl notebook ### +# v0.19.45 + +#> [frontmatter] +#> author = "[Deeptendu Santra](https://github.com/Dsantra92)" +#> title = "Node Classification with Graph Neural Networks" +#> date = "2022-09-25" +#> description = "Tutorial for Node classification using GraphNeuralNetworks.jl" +#> cover = "assets/node_classsification.gif" + +using Markdown +using InteractiveUtils + +# ╔═╡ 5463330a-0161-11ed-1b18-936030a32bbf +# ╠═╡ show_logs = false +begin + using MLDatasets + using GraphNeuralNetworks + using Flux + using Flux: onecold, onehotbatch, logitcrossentropy + using Plots + using PlutoUI + using TSne + using Random + using Statistics + + ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" + Random.seed!(17) # for reproducibility +end; + +# ╔═╡ ca2f0293-7eac-4d9a-9a2f-fda47fd95a99 +md""" +In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). +""" + +# ╔═╡ 4455f18c-2bd9-42ed-bce3-cfe6561eab23 +md""" +## Import +Let us start off by importing some libraries. We will be using Flux.jl and `GraphNeuralNetworks.jl` for our tutorial. +""" + +# ╔═╡ 0d556a7c-d4b6-4cef-806c-3e1712de0791 +md""" +## Visualize +We want to visualize the the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane. +""" + +# ╔═╡ 997b5387-3811-4998-a9d1-7981b58b9e09 +function visualize_tsne(out, targets) + z = tsne(out, 2) + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) +end + +# ╔═╡ 4b6fa18d-7ccd-4c07-8dc3-ded4d7da8562 +md""" +## Dataset: Cora + +For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represent articles/documents and the edges between these nodes if one of them cite each other. + +Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. + +This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. +""" + +# ╔═╡ edab1e3a-31f6-471f-9835-5b1f97e5cf3f +dataset = Cora() + +# ╔═╡ d73a2db5-9417-4b2c-a9f5-b7d499a53fcb +md""" +Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. +""" + +# ╔═╡ 32bb90c1-c802-4c0c-a620-5d3b8f3f2477 +dataset.metadata + +# ╔═╡ 3438ee7f-bfca-465d-85df-13379622d415 +md""" +The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph. +""" + +# ╔═╡ eec6fb60-0774-4f2a-bcb7-dbc28ab747a6 +dataset.graphs + +# ╔═╡ bd2fd04d-7fb0-4b31-959b-bddabe681754 +md""" +There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. +""" + +# ╔═╡ b29c3a02-c21b-4b10-aa04-b90bcc2931d8 +g = mldataset2gnngraph(dataset) + +# ╔═╡ 16d9fbad-d4dc-4b51-9576-1736d228e2b3 +with_terminal() do + # Gather some statistics about the graph. + println("Number of nodes: $(g.num_nodes)") + println("Number of edges: $(g.num_edges)") + println("Average node degree: $(g.num_edges / g.num_nodes)") + println("Number of training nodes: $(sum(g.ndata.train_mask))") + println("Training node label rate: $(mean(g.ndata.train_mask))") + # println("Has isolated nodes: $(has_isolated_nodes(g))") + println("Has self-loops: $(has_self_loops(g))") + println("Is undirected: $(is_bidirected(g))") +end + +# ╔═╡ 923d061c-25c3-4826-8147-9afa3dbd5bac +md""" +Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. +We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. +For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). +This results in a training node label rate of only 5%. + +We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). +""" + +# ╔═╡ 28e00b95-56db-4d36-a205-fd24d3c54e17 +begin + x = g.ndata.features + # we onehot encode both the node labels (what we want to predict): + y = onehotbatch(g.ndata.targets, 1:7) + train_mask = g.ndata.train_mask + num_features = size(x)[1] + hidden_channels = 16 + num_classes = dataset.metadata["num_classes"] +end; + +# ╔═╡ fa743000-604f-4d28-99f1-46ab2f884b8e +md""" +## Multi-layer Perception Network (MLP) + +In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. + +Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): +""" + +# ╔═╡ f972f61b-2001-409b-9190-ac2c0652829a +begin + struct MLP + layers::NamedTuple + end + + Flux.@layer :expand MLP + + function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5) + layers = (hidden = Dense(num_features => hidden_channels), + drop = Dropout(drop_rate), + classifier = Dense(hidden_channels => num_classes)) + return MLP(layers) + end + + function (model::MLP)(x::AbstractMatrix) + l = model.layers + x = l.hidden(x) + x = relu(x) + x = l.drop(x) + x = l.classifier(x) + return x + end +end + +# ╔═╡ 4dade64a-e28e-42c7-8ad5-93fc04724d4d +md""" +### Training a Multilayer Perceptron + +Our MLP is defined by two linear layers and enhanced by [ReLU](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib.relu) non-linearity and [Dropout](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Dropout). +Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. + +Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/tutorials/introductory_tutorials/gnn_intro_pluto/#Hands-on-introduction-to-Graph-Neural-Networks). +We again make use of the **cross entropy loss** and **Adam optimizer**. +This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). +""" + +# ╔═╡ 05979cfe-439c-4abc-90cd-6ca2a05f6e0f +function train(model::MLP, data::AbstractMatrix, epochs::Int, opt) + Flux.trainmode!(model) + + for epoch in 1:epochs + loss, grad = Flux.withgradient(model) do model + ŷ = model(data) + logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) + end + + Flux.update!(opt, model, grad[1]) + if epoch % 200 == 0 + @show epoch, loss + end + end +end + +# ╔═╡ a3f420e1-7521-4df9-b6d5-fc0a1fd05095 +function accuracy(model::MLP, x::AbstractMatrix, y::Flux.OneHotArray, mask::BitVector) + Flux.testmode!(model) + mean(onecold(model(x))[mask] .== onecold(y)[mask]) +end + +# ╔═╡ b18384fe-b8ae-4f51-bd73-d129d5e70f98 +md""" +After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. +Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: +""" + +# ╔═╡ 54a2972e-b107-47c8-bf7e-eb51b4ccbe02 +md""" +As one can see, our MLP performs rather bad with only about 47% test accuracy. +But why does the MLP do not perform better? +The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. + +It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. +That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. +""" + +# ╔═╡ 623e7b53-046c-4858-89d9-13caae45255d +md""" +## Training a Graph Convolutional Neural Network (GNN) + +Following-up on [the first part of this tutorial](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/tutorials/introductory_tutorials/node_classification_pluto/#Multi-layer-Perception-Network-(MLP)), we replace the `Dense` linear layers by the [`GCNConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GCNConv) module. +To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +``` + +where ``\mathbf{W}^{(\ell + 1)}`` denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. +In contrast, a single `Linear` layer is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} +``` + +which does not make use of neighboring node information. +""" + +# ╔═╡ eb36a46c-f139-425e-8a93-207bc4a16f89 +begin + struct GCN + layers::NamedTuple + end + + Flux.@layer GCN # provides parameter collection, gpu movement and more + + function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5) + layers = (conv1 = GCNConv(num_features => hidden_channels), + drop = Dropout(drop_rate), + conv2 = GCNConv(hidden_channels => num_classes)) + return GCN(layers) + end + + function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix) + l = gcn.layers + x = l.conv1(g, x) + x = relu.(x) + x = l.drop(x) + x = l.conv2(g, x) + return x + end +end + +# ╔═╡ 20b5f802-abce-49e1-a442-f381e80c0f85 +md""" +Now let's visualize the node embeddings of our **untrained** GCN network. +""" + +# ╔═╡ b295adce-b37e-45f3-963a-3699d714e36d +# ╠═╡ show_logs = false +begin + gcn = GCN(num_features, num_classes, hidden_channels) + h_untrained = gcn(g, x) |> transpose + visualize_tsne(h_untrained, g.ndata.targets) +end + +# ╔═╡ 5538970f-b273-4122-9d50-7deb049e6934 +md""" +We certainly can do better by training our model. +The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. +""" + +# ╔═╡ 901d9478-9a12-4122-905d-6cfc6d80e84c +function train(model::GCN, g::GNNGraph, x::AbstractMatrix, epochs::Int, opt) + Flux.trainmode!(model) + + for epoch in 1:epochs + loss, grad = Flux.withgradient(model) do model + ŷ = model(g, x) + logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) + end + + Flux.update!(opt, model, grad[1]) + if epoch % 200 == 0 + @show epoch, loss + end + end +end + +# ╔═╡ 026911dd-6a27-49ce-9d41-21e01646c10a +# ╠═╡ show_logs = false +begin + mlp = MLP(num_features, num_classes, hidden_channels) + opt_mlp = Flux.setup(Adam(1e-3), mlp) + epochs = 2000 + train(mlp, g.ndata.features, epochs, opt_mlp) +end + +# ╔═╡ 65d9fd3d-1649-4b95-a106-f26fa4ab9bce +function accuracy(model::GCN, g::GNNGraph, x::AbstractMatrix, y::Flux.OneHotArray, + mask::BitVector) + Flux.testmode!(model) + mean(onecold(model(g, x))[mask] .== onecold(y)[mask]) +end + +# ╔═╡ b2302697-1e20-4721-ae93-0b121ff9ce8f +accuracy(mlp, g.ndata.features, y, .!train_mask) + +# ╔═╡ 20be52b1-1c33-4f54-b5c0-fecc4e24fbb5 +# ╠═╡ show_logs = false +begin + opt_gcn = Flux.setup(Adam(1e-2), gcn) + train(gcn, g, x, epochs, opt_gcn) +end + +# ╔═╡ 5aa99aff-b5ed-40ec-a7ec-0ba53385e6bd +md""" +Now let's evaluate the loss of our trained GCN. +""" + +# ╔═╡ 2163d0d8-0661-4d11-a09e-708769011d35 +with_terminal() do + train_accuracy = accuracy(gcn, g, g.ndata.features, y, train_mask) + test_accuracy = accuracy(gcn, g, g.ndata.features, y, .!train_mask) + + println("Train accuracy: $(train_accuracy)") + println("Test accuracy: $(test_accuracy)") +end + +# ╔═╡ 6cd49f3f-a415-4b6a-9323-4d6aa6b87f18 +md""" +**There it is!** +By simply swapping the linear layers with GNN layers, we can reach **75.77% of test accuracy**! +This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. + +We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. +""" + +# ╔═╡ 7a93a802-6774-42f9-b6da-7ae614464e72 +# ╠═╡ show_logs = false +begin + Flux.testmode!(gcn) # inference mode + + out_trained = gcn(g, x) |> transpose + visualize_tsne(out_trained, g.ndata.targets) +end + +# ╔═╡ 50a409fd-d80b-4c48-a51b-173c39a6dcb4 +md""" +## (Optional) Exercises + +1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. + +2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? + +3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. +""" + +# ╔═╡ c343419f-a1d7-45a0-b600-2c868588b33a +md""" +## Conclusion +In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TSne = "24678dba-d5e9-5843-a4c6-250288b04835" + +[compat] +Flux = "~0.14.16" +GraphNeuralNetworks = "~0.6.19" +MLDatasets = "~0.7.16" +Plots = "~1.40.5" +PlutoUI = "~0.7.59" +TSne = "~1.3.0" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "fb2b669c9e43473fabf01e07c834a510ae36fa5e" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.AbstractPlutoDingetjes]] +deps = ["Pkg"] +git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a" +uuid = "6e696c72-6542-2067-7265-42206c756150" +version = "1.3.2" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + +[[deps.Cairo_jll]] +deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" +version = "1.18.0+2" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.69.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.24.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.26.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.6" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Contour]] +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.6.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.EpollShim_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" +uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" +version = "0.0.20230411+0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.6.2+0" + +[[deps.FFMPEG]] +deps = ["FFMPEG_jll"] +git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" +version = "0.4.1" + +[[deps.FFMPEG_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" +uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" +version = "4.4.4+1" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.16" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.Fontconfig_jll]] +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] +git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" +uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" +version = "2.13.96+0" + +[[deps.Format]] +git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" +uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +version = "1.3.7" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FreeType2_jll]] +deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" +uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" +version = "2.13.2+0" + +[[deps.FriBidi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" +uuid = "559328eb-81f9-559d-9380-de523a88c83c" +version = "1.0.14+0" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.11" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GLFW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" +uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" +version = "3.4.0+0" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GR]] +deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] +git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" +version = "0.73.7" + +[[deps.GR_jll]] +deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" +version = "0.73.7+0" + +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.Gettext_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" +uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" +version = "0.21.0+0" + +[[deps.Glib_jll]] +deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" +version = "2.80.2+0" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.GraphNeuralNetworks]] +deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" +uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" +version = "0.6.19" + + [deps.GraphNeuralNetworks.extensions] + GraphNeuralNetworksCUDAExt = "CUDA" + GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + + [deps.GraphNeuralNetworks.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[[deps.Graphite2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" +uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" +version = "1.3.14+0" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.11.2" + +[[deps.Grisu]] +git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" +uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" +version = "1.0.2" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.2+1" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.HarfBuzz_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" +uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" +version = "2.8.1+1" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.1+0" + +[[deps.Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "179267cfa5e712760cd43dcae385d7ea90cc25a4" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.5" + +[[deps.HypertextLiteral]] +deps = ["Tricks"] +git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.5" + +[[deps.IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.5" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.15" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.50" + +[[deps.JLFzf]] +deps = ["Pipe", "REPL", "Random", "fzf_jll"] +git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" +uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" +version = "0.1.7" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JpegTurbo_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" +version = "3.0.3+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.22" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] +git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.7.1" + +[[deps.LAME_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" +uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" +version = "3.100.2+0" + +[[deps.LERC_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +uuid = "88015f11-f218-50d7-93a8-a6af411a945d" +version = "3.0.0+1" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.0.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.30+0" + +[[deps.LLVMOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" +uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" +version = "15.0.7+0" + +[[deps.LZO_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" +version = "2.10.2+0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.Latexify]] +deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] +git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" +uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" +version = "0.16.4" + + [deps.Latexify.extensions] + DataFramesExt = "DataFrames" + SymEngineExt = "SymEngine" + + [deps.Latexify.weakdeps] + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libffi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" +uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" +version = "3.2.2+1" + +[[deps.Libgcrypt_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] +git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" +uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" +version = "1.8.11+0" + +[[deps.Libglvnd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] +git-tree-sha1 = "6f73d1dd803986947b2c750138528a999a6c7733" +uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" +version = "1.6.0+0" + +[[deps.Libgpg_error_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" +uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" +version = "1.49.0+0" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.Libmount_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" +uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" +version = "2.40.1+0" + +[[deps.Libtiff_jll]] +deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] +git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +uuid = "89763e89-9b03-5906-acba-b20f662cd828" +version = "4.5.1+1" + +[[deps.Libuuid_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" +uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" +version = "2.40.1+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MIMEs]] +git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" +uuid = "6c6e2e6c-3030-632d-7369-2d6c69616d65" +version = "0.1.4" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.16" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.4.0+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.Measures]] +git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" +uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +version = "0.3.2" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.Ogg_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" +uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" +version = "1.3.5+1" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] +git-tree-sha1 = "2f0a1d8c79bc385ec3fcda12830c9d0e72b30e71" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "5.0.4+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.Opus_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +uuid = "91d4177d-7536-5919-b921-800302f37372" +version = "1.3.2+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PCRE2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.42.0+1" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pipe]] +git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" +uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" +version = "1.3.0" + +[[deps.Pixman_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" +uuid = "30392449-352a-5448-841d-b1acce4e97dc" +version = "0.43.4+0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PlotThemes]] +deps = ["PlotUtils", "Statistics"] +git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" +uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" +version = "3.2.0" + +[[deps.PlotUtils]] +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" +version = "1.4.1" + +[[deps.Plots]] +deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] +git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" +uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +version = "1.40.5" + + [deps.Plots.extensions] + FileIOExt = "FileIO" + GeometryBasicsExt = "GeometryBasics" + IJuliaExt = "IJulia" + ImageInTerminalExt = "ImageInTerminal" + UnitfulExt = "Unitful" + + [deps.Plots.weakdeps] + FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" + GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" + IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" + ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.PlutoUI]] +deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] +git-tree-sha1 = "ab55ee1510ad2af0ff674dbcced5e94921f867a9" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.59" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.10.2" + +[[deps.Qt6Base_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] +git-tree-sha1 = "492601870742dcd38f233b23c3ec629628c1d724" +uuid = "c0090381-4147-56d7-9ebc-da0b1113ec56" +version = "6.7.1+1" + +[[deps.Qt6Declarative_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6ShaderTools_jll"] +git-tree-sha1 = "e5dd466bf2569fe08c91a2cc29c1003f4797ac3b" +uuid = "629bc702-f1f5-5709-abd5-49b8460ea067" +version = "6.7.1+2" + +[[deps.Qt6ShaderTools_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll"] +git-tree-sha1 = "1a180aeced866700d4bebc3120ea1451201f16bc" +uuid = "ce943373-25bb-56aa-8eca-768745ed7b5a" +version = "6.7.1+1" + +[[deps.Qt6Wayland_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6Declarative_jll"] +git-tree-sha1 = "729927532d48cf79f49070341e1d918a65aba6b0" +uuid = "e99dba38-086e-5de3-a5b1-6e4c66e897c3" +version = "6.7.1+1" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecipesPipeline]] +deps = ["Dates", "NaNMath", "PlotUtils", "PrecompileTools", "RecipesBase"] +git-tree-sha1 = "45cf9fd0ca5839d06ef333c8201714e888486342" +uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" +version = "0.6.12" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.1" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.Showoff]] +deps = ["Dates", "Grisu"] +git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" +uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" +version = "1.0.3" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TSne]] +deps = ["Distances", "LinearAlgebra", "Printf", "ProgressMeter", "Statistics"] +git-tree-sha1 = "6f1dfbf9dad6958439816fa9c5fa20898203fdf4" +uuid = "24678dba-d5e9-5843-a4c6-250288b04835" +version = "1.3.0" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.1" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnicodeFun]] +deps = ["REPL"] +git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" +uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" +version = "0.4.1" + +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.21.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + +[[deps.UnitfulLatexify]] +deps = ["LaTeXStrings", "Latexify", "Unitful"] +git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" +uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" +version = "1.6.4" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.5" + +[[deps.Unzip]] +git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" +uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" +version = "0.2.0" + +[[deps.VectorInterface]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" +uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +version = "0.4.6" + +[[deps.Vulkan_Loader_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Wayland_jll", "Xorg_libX11_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +git-tree-sha1 = "2f0486047a07670caad3a81a075d2e518acc5c59" +uuid = "a44049a8-05dd-5a78-86c9-5fde0876e88c" +version = "1.3.243+0" + +[[deps.Wayland_jll]] +deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "7558e29847e99bc3f04d6569e82d0f5c54460703" +uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" +version = "1.21.0+1" + +[[deps.Wayland_protocols_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "93f43ab61b16ddfb2fd3bb13b3ce241cafb0e6c9" +uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" +version = "1.31.0+0" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] +git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.13.1+0" + +[[deps.XSLT_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] +git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" +uuid = "aed1982a-8fda-507f-9586-7b0439959a61" +version = "1.1.41+0" + +[[deps.XZ_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" +uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" +version = "5.4.6+0" + +[[deps.Xorg_libICE_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" +uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" +version = "1.1.1+0" + +[[deps.Xorg_libSM_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] +git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" +uuid = "c834827a-8449-5923-a945-d239c165b7dd" +version = "1.2.4+0" + +[[deps.Xorg_libX11_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] +git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" +uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" +version = "1.8.6+0" + +[[deps.Xorg_libXau_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" +uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" +version = "1.0.11+0" + +[[deps.Xorg_libXcursor_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXfixes_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "12e0eb3bc634fa2080c1c37fccf56f7c22989afd" +uuid = "935fb764-8cf2-53bf-bb30-45bb1f8bf724" +version = "1.2.0+4" + +[[deps.Xorg_libXdmcp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" +uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" +version = "1.1.4+0" + +[[deps.Xorg_libXext_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" +uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" +version = "1.3.6+0" + +[[deps.Xorg_libXfixes_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "0e0dc7431e7a0587559f9294aeec269471c991a4" +uuid = "d091e8ba-531a-589c-9de9-94069b037ed8" +version = "5.0.3+4" + +[[deps.Xorg_libXi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXfixes_jll"] +git-tree-sha1 = "89b52bc2160aadc84d707093930ef0bffa641246" +uuid = "a51aa0fd-4e3c-5386-b890-e753decda492" +version = "1.7.10+4" + +[[deps.Xorg_libXinerama_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll"] +git-tree-sha1 = "26be8b1c342929259317d8b9f7b53bf2bb73b123" +uuid = "d1454406-59df-5ea1-beac-c340f2130bc3" +version = "1.1.4+4" + +[[deps.Xorg_libXrandr_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "34cea83cb726fb58f325887bf0612c6b3fb17631" +uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" +version = "1.5.2+4" + +[[deps.Xorg_libXrender_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" +uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" +version = "0.9.11+0" + +[[deps.Xorg_libpthread_stubs_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" +uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" +version = "0.1.1+0" + +[[deps.Xorg_libxcb_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] +git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" +uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" +version = "1.17.0+0" + +[[deps.Xorg_libxkbfile_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "730eeca102434283c50ccf7d1ecdadf521a765a4" +uuid = "cc61e674-0454-545c-8b26-ed2c68acab7a" +version = "1.1.2+0" + +[[deps.Xorg_xcb_util_cursor_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_jll", "Xorg_xcb_util_renderutil_jll"] +git-tree-sha1 = "04341cb870f29dcd5e39055f895c39d016e18ccd" +uuid = "e920d4aa-a673-5f3a-b3d7-f755a4d47c43" +version = "0.1.4+0" + +[[deps.Xorg_xcb_util_image_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "0fab0a40349ba1cba2c1da699243396ff8e94b97" +uuid = "12413925-8142-5f55-bb0e-6d7ca50bb09b" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll"] +git-tree-sha1 = "e7fd7b2881fa2eaa72717420894d3938177862d1" +uuid = "2def613f-5ad1-5310-b15b-b15d46f528f5" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_keysyms_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "d1151e2c45a544f32441a567d1690e701ec89b00" +uuid = "975044d2-76e6-5fbe-bf08-97ce7c6574c7" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_renderutil_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "dfd7a8f38d4613b6a575253b3174dd991ca6183e" +uuid = "0d47668e-0667-5a69-a72c-f761630bfb7e" +version = "0.3.9+1" + +[[deps.Xorg_xcb_util_wm_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "e78d10aab01a4a154142c5006ed44fd9e8e31b67" +uuid = "c22f9ab0-d5fe-5066-847c-f4bb1cd4e361" +version = "0.4.1+1" + +[[deps.Xorg_xkbcomp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxkbfile_jll"] +git-tree-sha1 = "330f955bc41bb8f5270a369c473fc4a5a4e4d3cb" +uuid = "35661453-b289-5fab-8a00-3d9160c6a3a4" +version = "1.4.6+0" + +[[deps.Xorg_xkeyboard_config_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xkbcomp_jll"] +git-tree-sha1 = "691634e5453ad362044e2ad653e79f3ee3bb98c3" +uuid = "33bec58e-1273-512f-9401-5d533626f822" +version = "2.39.0+0" + +[[deps.Xorg_xtrans_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" +uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" +version = "1.5.0+0" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zstd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" +version = "1.5.6+0" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.eudev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] +git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" +uuid = "35ca27e7-8b34-5b7f-bca9-bdc33f59eb06" +version = "3.2.9+0" + +[[deps.fzf_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" +uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" +version = "0.43.0+0" + +[[deps.gperf_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "3516a5630f741c9eecb3720b1ec9d8edc3ecc033" +uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" +version = "3.1.1+0" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + +[[deps.libaom_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" +uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" +version = "3.9.0+0" + +[[deps.libass_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" +version = "0.15.1+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.libevdev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" +uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" +version = "1.11.0+0" + +[[deps.libfdk_aac_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" +version = "2.0.2+0" + +[[deps.libinput_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] +git-tree-sha1 = "ad50e5b90f222cfe78aa3d5183a20a12de1322ce" +uuid = "36db933b-70db-51c0-b978-0f229ee0e533" +version = "1.18.0+0" + +[[deps.libpng_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" +version = "1.6.43+1" + +[[deps.libvorbis_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] +git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" +uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" +version = "1.3.7+2" + +[[deps.mtdev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "814e154bdb7be91d78b6802843f76b6ece642f11" +uuid = "009596ad-96f7-51b1-9f1b-5ce2d5e8a71e" +version = "1.1.6+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" + +[[deps.x264_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" +uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" +version = "2021.5.5+0" + +[[deps.x265_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" +uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" +version = "3.5.0+0" + +[[deps.xkbcommon_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] +git-tree-sha1 = "9c304562909ab2bab0262639bd4f444d7bc2be37" +uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" +version = "1.4.1+1" +""" + +# ╔═╡ Cell order: +# ╟─ca2f0293-7eac-4d9a-9a2f-fda47fd95a99 +# ╟─4455f18c-2bd9-42ed-bce3-cfe6561eab23 +# ╠═5463330a-0161-11ed-1b18-936030a32bbf +# ╟─0d556a7c-d4b6-4cef-806c-3e1712de0791 +# ╠═997b5387-3811-4998-a9d1-7981b58b9e09 +# ╟─4b6fa18d-7ccd-4c07-8dc3-ded4d7da8562 +# ╠═edab1e3a-31f6-471f-9835-5b1f97e5cf3f +# ╟─d73a2db5-9417-4b2c-a9f5-b7d499a53fcb +# ╠═32bb90c1-c802-4c0c-a620-5d3b8f3f2477 +# ╟─3438ee7f-bfca-465d-85df-13379622d415 +# ╠═eec6fb60-0774-4f2a-bcb7-dbc28ab747a6 +# ╟─bd2fd04d-7fb0-4b31-959b-bddabe681754 +# ╠═b29c3a02-c21b-4b10-aa04-b90bcc2931d8 +# ╠═16d9fbad-d4dc-4b51-9576-1736d228e2b3 +# ╟─923d061c-25c3-4826-8147-9afa3dbd5bac +# ╠═28e00b95-56db-4d36-a205-fd24d3c54e17 +# ╟─fa743000-604f-4d28-99f1-46ab2f884b8e +# ╠═f972f61b-2001-409b-9190-ac2c0652829a +# ╟─4dade64a-e28e-42c7-8ad5-93fc04724d4d +# ╠═05979cfe-439c-4abc-90cd-6ca2a05f6e0f +# ╠═a3f420e1-7521-4df9-b6d5-fc0a1fd05095 +# ╠═026911dd-6a27-49ce-9d41-21e01646c10a +# ╟─b18384fe-b8ae-4f51-bd73-d129d5e70f98 +# ╠═b2302697-1e20-4721-ae93-0b121ff9ce8f +# ╟─54a2972e-b107-47c8-bf7e-eb51b4ccbe02 +# ╟─623e7b53-046c-4858-89d9-13caae45255d +# ╠═eb36a46c-f139-425e-8a93-207bc4a16f89 +# ╟─20b5f802-abce-49e1-a442-f381e80c0f85 +# ╠═b295adce-b37e-45f3-963a-3699d714e36d +# ╟─5538970f-b273-4122-9d50-7deb049e6934 +# ╠═901d9478-9a12-4122-905d-6cfc6d80e84c +# ╠═65d9fd3d-1649-4b95-a106-f26fa4ab9bce +# ╠═20be52b1-1c33-4f54-b5c0-fecc4e24fbb5 +# ╟─5aa99aff-b5ed-40ec-a7ec-0ba53385e6bd +# ╠═2163d0d8-0661-4d11-a09e-708769011d35 +# ╟─6cd49f3f-a415-4b6a-9323-4d6aa6b87f18 +# ╠═7a93a802-6774-42f9-b6da-7ae614464e72 +# ╟─50a409fd-d80b-4c48-a51b-173c39a6dcb4 +# ╟─c343419f-a1d7-45a0-b600-2c868588b33a +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 + +[.\docs\tutorials_broken\temporal_graph_classification_pluto.jl] +### A Pluto.jl notebook ### +# v0.19.45 + +#> [frontmatter] +#> author = "[Aurora Rossi](https://github.com/aurorarossi)" +#> title = "Temporal Graph classification with Graph Neural Networks" +#> date = "2024-03-06" +#> description = "Temporal Graph classification with GraphNeuralNetworks.jl" +#> cover = "assets/brain_gnn.gif" + +using Markdown +using InteractiveUtils + +# ╔═╡ b8df1800-c69d-4e18-8a0a-097381b62a4c +begin + using Flux + using GraphNeuralNetworks + using Statistics, Random + using LinearAlgebra + using MLDatasets: TemporalBrains + using CUDA + using cuDNN +end + +# ╔═╡ 69d00ec8-da47-11ee-1bba-13a14e8a6db2 +md"In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying. + +We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU. +" + +# ╔═╡ ef8406e4-117a-4cc6-9fa5-5028695b1a4f +md" +## Import + +We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. +" + +# ╔═╡ 2544d468-1430-4986-88a9-be4df2a7cf27 +md" +## Dataset: TemporalBrains +The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation). +Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions. + +The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time. +For each snapshot, the feature of a node represents the average activation of the node during that snapshot. +Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+). +The network's edge weights are binarized, and the threshold is set to 0.6 by default. +" + +# ╔═╡ f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 +brain_dataset = TemporalBrains() + +# ╔═╡ d9e4722d-6f02-4d41-955c-8bb3e411e404 +md"After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format. +So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. +" + +# ╔═╡ bb36237a-5545-47d0-a873-7ddff3efe8ba +function data_loader(brain_dataset) + graphs = brain_dataset.graphs + dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) + for i in 1:length(graphs) + graph = graphs[i] + dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots)) + # Add graph and node features + for t in 1:27 + s = dataset[i].snapshots[t] + s.ndata.x = [I(102); s.ndata.x'] + end + dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"])) + end + # Split the dataset into a 80% training set and a 20% test set + train_loader = dataset[1:200] + test_loader = dataset[201:250] + return train_loader, test_loader +end; + +# ╔═╡ d4732340-9179-4ada-b82e-a04291d745c2 +md" +The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`. + +The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender. + +The last part splits the dataset. +" + + +# ╔═╡ ec088a59-2fc2-426a-a406-f8f8d6784128 +md" +## Model + +We now implement a simple model that takes a `TemporalSnapshotsGNNGraph` as input. +It consists of a `GINConv` applied independently to each snapshot, a `GlobalPool` to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a `Dense` layer. + +First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`. +" + +# ╔═╡ 5ea98df9-4920-4c94-9472-3ef475af89fd +function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) + h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] + sze = size(h[1]) + reshape(reduce(hcat, h), sze[1], length(h)) +end + +# ╔═╡ cfda2cf4-d08b-4f46-bd39-02ae3ed53369 +md" +Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass. +" + +# ╔═╡ 2eedd408-67ee-47b2-be6f-2caec94e95b5 +begin + struct GenderPredictionModel + gin::GINConv + mlp::Chain + globalpool::GlobalPool + f::Function + dense::Dense + end + + Flux.@layer GenderPredictionModel + + function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) + mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) + gin = GINConv(mlp, 0.5) + globalpool = GlobalPool(mean) + f = x -> mean(x, dims = 2) + dense = Dense(nhidden, 2) + GenderPredictionModel(gin, mlp, globalpool, f, dense) + end + + function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) + h = m.gin(g, g.ndata.x) + h = m.globalpool(g, h) + h = m.f(h) + m.dense(h) + end + +end + +# ╔═╡ 76780020-406d-4803-9af0-d928e54fc18c +md" +## Training + +We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format. +The accuracy expresses the number of correct classifications. +" + +# ╔═╡ 0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 +lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y); + +# ╔═╡ cc2ebdcf-72de-4a3b-af46-5bddab6689cc +function eval_loss_accuracy(model, data_loader) + error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader]) + acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader]) + return (loss = error, acc = acc) +end; + +# ╔═╡ d64be72e-8c1f-4551-b4f2-28c8b78466c0 +function train(dataset; usecuda::Bool, kws...) + + if usecuda && CUDA.functional() #check if GPU is available + my_device = gpu + @info "Training on GPU" + else + my_device = cpu + @info "Training on CPU" + end + + function report(epoch) + train_loss, train_acc = eval_loss_accuracy(model, train_loader) + test_loss, test_acc = eval_loss_accuracy(model, test_loader) + println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") + return (train_loss, train_acc, test_loss, test_acc) + end + + model = GenderPredictionModel() |> my_device + + opt = Flux.setup(Adam(1.0f-3), model) + + train_loader, test_loader = data_loader(dataset) + train_loader = train_loader |> my_device + test_loader = test_loader |> my_device + + report(0) + for epoch in 1:100 + for g in train_loader + grads = Flux.gradient(model) do model + ŷ = model(g) + lossfunction(vec(ŷ), g.tgdata.g) + end + Flux.update!(opt, model, grads[1]) + end + if epoch % 10 == 0 + report(epoch) + end + end + return model +end; + + +# ╔═╡ 483f17ba-871c-4769-88bd-8ec781d1909d +train(brain_dataset; usecuda = true) + +# ╔═╡ b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 +md" +We set up the training on the GPU because training takes a lot of time, especially when working on the CPU. +" + +# ╔═╡ cb4eed19-2658-411d-886c-e0c9c2b44219 +md" +## Conclusions + +In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data. +" + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[compat] +CUDA = "~5.4.3" +Flux = "~0.14.16" +GraphNeuralNetworks = "~0.6.19" +MLDatasets = "~0.7.16" +cuDNN = "~1.3.2" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "25724970092e282d6cd2d6ea9e021d61f3714205" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] +git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "5.4.3" + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + EnzymeCoreExt = "EnzymeCore" + SpecialFunctionsExt = "SpecialFunctions" + + [deps.CUDA.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "97df9d4d6be8ac6270cb8fd3b8fc413690820cbd" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.9.1+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.3.4" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.14.1+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "9.0.0+1" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.69.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.24.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.26.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.6" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.16" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.11" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.26.7" + +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.GraphNeuralNetworks]] +deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" +uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" +version = "0.6.19" + + [deps.GraphNeuralNetworks.extensions] + GraphNeuralNetworksCUDAExt = "CUDA" + GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + + [deps.GraphNeuralNetworks.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.11.2" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.3+3" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.1+0" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.15" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.50" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.22" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] +git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.7.1" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.0.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.30+0" + +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.16" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.4.0+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.4" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.7.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" +weakdeps = ["CUDA"] + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.24" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.1" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.21.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.5" + +[[deps.VectorInterface]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" +uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +version = "0.4.6" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.3.2" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" +""" + +# ╔═╡ Cell order: +# ╟─69d00ec8-da47-11ee-1bba-13a14e8a6db2 +# ╟─ef8406e4-117a-4cc6-9fa5-5028695b1a4f +# ╠═b8df1800-c69d-4e18-8a0a-097381b62a4c +# ╟─2544d468-1430-4986-88a9-be4df2a7cf27 +# ╠═f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 +# ╟─d9e4722d-6f02-4d41-955c-8bb3e411e404 +# ╠═bb36237a-5545-47d0-a873-7ddff3efe8ba +# ╟─d4732340-9179-4ada-b82e-a04291d745c2 +# ╟─ec088a59-2fc2-426a-a406-f8f8d6784128 +# ╠═5ea98df9-4920-4c94-9472-3ef475af89fd +# ╟─cfda2cf4-d08b-4f46-bd39-02ae3ed53369 +# ╠═2eedd408-67ee-47b2-be6f-2caec94e95b5 +# ╟─76780020-406d-4803-9af0-d928e54fc18c +# ╠═0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 +# ╠═cc2ebdcf-72de-4a3b-af46-5bddab6689cc +# ╠═d64be72e-8c1f-4551-b4f2-28c8b78466c0 +# ╠═483f17ba-871c-4769-88bd-8ec781d1909d +# ╟─b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 +# ╟─cb4eed19-2658-411d-886c-e0c9c2b44219 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 + +[.\docs\tutorials_broken\traffic_prediction.jl] +### A Pluto.jl notebook ### +# v0.19.45 + +#> [frontmatter] +#> author = "[Aurora Rossi](https://github.com/aurorarossi)" +#> title = "Traffic Prediction using recurrent Temporal Graph Convolutional Network" +#> date = "2023-08-21" +#> description = "Traffic Prediction using GraphNeuralNetworks.jl" +#> cover = "assets/traffic.gif" + +using Markdown +using InteractiveUtils + +# ╔═╡ 1f95ad97-a007-4724-84db-392b0026e1a4 +begin + using GraphNeuralNetworks + using Flux + using Flux.Losses: mae + using MLDatasets: METRLA + using Statistics + using Plots +end + +# ╔═╡ 5fdab668-4003-11ee-33f5-3953225b0c0f +md" +In this tutorial, we will learn how to use a recurrent Temporal Graph Convolutional Network (TGCN) to predict traffic in a spatio-temporal setting. Traffic forecasting is the problem of predicting future traffic trends on a road network given historical traffic data, such as, in our case, traffic speed and time of day. +" + +# ╔═╡ 3dd0ce32-2339-4d5a-9a6f-1f662bc5500b +md" +## Import + +We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. +" + +# ╔═╡ ec5caeb6-1f95-4cb9-8739-8cadba29a22d +md" +## Dataset: METR-LA + +We use the `METR-LA` dataset from the paper [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926.pdf), which contains traffic data from loop detectors in the highway of Los Angeles County. The dataset contains traffic speed data from March 1, 2012 to June 30, 2012. The data is collected every 5 minutes, resulting in 12 observations per hour, from 207 sensors. Each sensor is a node in the graph, and the edges represent the distances between the sensors. +" + +# ╔═╡ f531e39c-6842-494a-b4ac-8904321098c9 +dataset_metrla = METRLA(; num_timesteps = 3) + +# ╔═╡ d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad + g = dataset_metrla[1] + +# ╔═╡ dc2d5e98-2201-4754-bfc6-8ed2bbb82153 +md" +`edge_data` contains the weights of the edges of the graph and +`node_data` contains a node feature vector and a target vector. The latter vectors contain batches of dimension `num_timesteps`, which means that they contain vectors with the node features and targets of `num_timesteps` time steps. Two consecutive batches are shifted by one-time step. +The node features are the traffic speed of the sensors and the time of the day, and the targets are the traffic speed of the sensors in the next time step. +Let's see some examples: +" + +# ╔═╡ 0dde5fd3-72d0-4b15-afb3-9a5b102327c9 +size(g.node_data.features[1]) + +# ╔═╡ f7a6d572-28cf-4d69-a9be-d49f367eca37 +md" +The first dimension correspond to the two features (first line the speed value and the second line the time of the day), the second to the nodes and the third to the number of timestep `num_timesteps`. +" + +# ╔═╡ 3d5503bc-bb97-422e-9465-becc7d3dbe07 +size(g.node_data.targets[1]) + +# ╔═╡ 3569715d-08f5-4605-b946-9ef7ccd86ae5 +md" +In the case of the targets the first dimension is 1 because they store just the speed value. +" + +# ╔═╡ aa4eb172-2a42-4c01-a6ef-c6c95208d5b2 +g.node_data.features[1][:,1,:] + +# ╔═╡ 367ed417-4f53-44d4-8135-0c91c842a75f +g.node_data.features[2][:,1,:] + +# ╔═╡ 7c084eaa-655c-4251-a342-6b6f4df76ddb +g.node_data.targets[1][:,1,:] + +# ╔═╡ bf0d820d-32c0-4731-8053-53d5d499e009 +function plot_data(data,sensor) + p = plot(legend=false, xlabel="Time (h)", ylabel="Normalized speed") + plotdata = [] + for i in 1:3:length(data) + push!(plotdata,data[i][1,sensor,:]) + end + plotdata = reduce(vcat,plotdata) + plot!(p, collect(1:length(data)), plotdata, color = :green, xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24])) + return p +end + +# ╔═╡ cb89d1a3-b4ff-421a-8717-a0b7f21dea1a +plot_data(g.node_data.features[1:288],1) + +# ╔═╡ 3b49a612-3a04-4eb5-bfbc-360614f4581a +md" +Now let's construct the static graph, the temporal features and targets from the dataset. +" + +# ╔═╡ 95d8bd24-a40d-409f-a1e7-4174428ef860 +begin + graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes) + features = g.node_data.features + targets = g.node_data.targets +end; + +# ╔═╡ fde2ac9e-b121-4105-8428-1820b9c17a43 +md" +Now let's construct the `train_loader` and `data_loader`. +" + + +# ╔═╡ 111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3 +begin + train_loader = zip(features[1:200], targets[1:200]) + test_loader = zip(features[2001:2288], targets[2001:2288]) +end; + +# ╔═╡ 572a6633-875b-4d7e-9afc-543b442948fb +md" +## Model: T-GCN + +We use the T-GCN model from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction] (https://arxiv.org/pdf/1811.05320.pdf), which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). The GCN is used to capture spatial features from the graph, and the GRU is used to capture temporal features from the feature time series. +" + +# ╔═╡ 5502f4fa-3201-4980-b766-2ab88b175b11 +model = GNNChain(TGCN(2 => 100), Dense(100, 1)) + +# ╔═╡ 4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e +md" +![](https://www.researchgate.net/profile/Haifeng-Li-3/publication/335353434/figure/fig4/AS:851870352437249@1580113127759/The-architecture-of-the-Gated-Recurrent-Unit-model.jpg) +" + +# ╔═╡ 755a88c2-c2e5-46d1-9582-af4b2c5a6bbd +md" +## Training + +We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the mean absolute error (MAE) as the loss function. +" + +# ╔═╡ e83253b2-9f3a-44e2-a747-cce1661657c4 +function train(graph, train_loader, model) + + opt = Flux.setup(Adam(0.001), model) + + for epoch in 1:100 + for (x, y) in train_loader + x, y = (x, y) + grads = Flux.gradient(model) do model + ŷ = model(graph, x) + Flux.mae(ŷ, y) + end + Flux.update!(opt, model, grads[1]) + end + + if epoch % 10 == 0 + loss = mean([Flux.mae(model(graph,x), y) for (x, y) in train_loader]) + @show epoch, loss + end + end + return model +end + +# ╔═╡ 85a923da-3027-4f71-8db6-96852c115c03 +train(graph, train_loader, model) + +# ╔═╡ 39c82234-97ea-48d6-98dd-915f072b7f85 +function plot_predicted_data(graph,features,targets, sensor) + p = plot(xlabel="Time (h)", ylabel="Normalized speed") + prediction = [] + grand_truth = [] + for i in 1:3:length(features) + push!(grand_truth,targets[i][1,sensor,:]) + push!(prediction, model(graph, features[i])[1,sensor,:]) + end + prediction = reduce(vcat,prediction) + grand_truth = reduce(vcat, grand_truth) + plot!(p, collect(1:length(features)), grand_truth, color = :blue, label = "Grand Truth", xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24])) + plot!(p, collect(1:length(features)), prediction, color = :red, label= "Prediction") + return p +end + +# ╔═╡ 8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4 +plot_predicted_data(graph,features[301:588],targets[301:588], 1) + +# ╔═╡ 2c5f6250-ee7a-41b1-9551-bcfeba83ca8b +accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y) + +# ╔═╡ 1008dad4-d784-4c38-a7cf-d9b64728e28d +mean([accuracy(model(graph,x), y) for (x, y) in test_loader]) + +# ╔═╡ 8d0e8b9f-226f-4bff-9deb-046e6a897b71 +md"The accuracy is not very good but can be improved by training using more data. We used a small subset of the dataset for this tutorial because of the computational cost of training the model. From the plot of the predictions, we can see that the model is able to capture the general trend of the traffic speed, but it is not able to capture the peaks of the traffic." + +# ╔═╡ a7e4bb23-6687-476a-a0c2-1b2736873d9d +md" +## Conclusion + +In this tutorial, we learned how to use a recurrent temporal graph convolutional network to predict traffic in a spatio-temporal setting. We used the TGCN model, which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). We then trained the model for 100 epochs on a small subset of the METR-LA dataset. The accuracy of the model is not very good, but it can be improved by training on more data. +" + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[compat] +Flux = "~0.14.16" +GraphNeuralNetworks = "~0.6.19" +MLDatasets = "~0.7.16" +Plots = "~1.40.5" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "8742c1fb8ae152ad31b34471cf90f234c1b8b06c" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.AtomsBase]] +deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" +uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +version = "0.3.5" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.BufferedStreams]] +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.2.1" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + +[[deps.Cairo_jll]] +deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" +version = "1.18.0+2" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.69.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.24.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.Chemfiles]] +deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] +git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.41" + +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.4+0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.26.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" +weakdeps = ["SpecialFunctions"] + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.6" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Contour]] +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.6.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.EpollShim_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" +uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" +version = "0.0.20230411+0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.6.2+0" + +[[deps.FFMPEG]] +deps = ["FFMPEG_jll"] +git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" +version = "0.4.1" + +[[deps.FFMPEG_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" +uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" +version = "4.4.4+1" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.16" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.Fontconfig_jll]] +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] +git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" +uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" +version = "2.13.96+0" + +[[deps.Format]] +git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" +uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +version = "1.3.7" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FreeType2_jll]] +deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" +uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" +version = "2.13.2+0" + +[[deps.FriBidi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" +uuid = "559328eb-81f9-559d-9380-de523a88c83c" +version = "1.0.14+0" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.11" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GLFW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" +uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" +version = "3.4.0+0" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GR]] +deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] +git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" +version = "0.73.7" + +[[deps.GR_jll]] +deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" +version = "0.73.7+0" + +[[deps.GZip]] +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.6.2" + +[[deps.Gettext_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" +uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" +version = "0.21.0+0" + +[[deps.Glib_jll]] +deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" +version = "2.80.2+0" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.GraphNeuralNetworks]] +deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" +uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" +version = "0.6.19" + + [deps.GraphNeuralNetworks.extensions] + GraphNeuralNetworksCUDAExt = "CUDA" + GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + + [deps.GraphNeuralNetworks.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[[deps.Graphite2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" +uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" +version = "1.3.14+0" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.11.2" + +[[deps.Grisu]] +git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" +uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" +version = "1.0.2" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.2+1" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.HarfBuzz_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" +uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" +version = "2.8.1+1" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.1+0" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.8" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.15" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.50" + +[[deps.JLFzf]] +deps = ["Pipe", "REPL", "Random", "fzf_jll"] +git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" +uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" +version = "0.1.7" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JpegTurbo_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" +version = "3.0.3+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.22" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.KrylovKit]] +deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] +git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.7.1" + +[[deps.LAME_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" +uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" +version = "3.100.2+0" + +[[deps.LERC_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +uuid = "88015f11-f218-50d7-93a8-a6af411a945d" +version = "3.0.0+1" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.0.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.30+0" + +[[deps.LLVMOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" +uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" +version = "15.0.7+0" + +[[deps.LZO_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" +version = "2.10.2+0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.Latexify]] +deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] +git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" +uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" +version = "0.16.4" + + [deps.Latexify.extensions] + DataFramesExt = "DataFrames" + SymEngineExt = "SymEngine" + + [deps.Latexify.weakdeps] + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libffi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" +uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" +version = "3.2.2+1" + +[[deps.Libgcrypt_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] +git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" +uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" +version = "1.8.11+0" + +[[deps.Libglvnd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] +git-tree-sha1 = "6f73d1dd803986947b2c750138528a999a6c7733" +uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" +version = "1.6.0+0" + +[[deps.Libgpg_error_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" +uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" +version = "1.49.0+0" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.Libmount_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" +uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" +version = "2.40.1+0" + +[[deps.Libtiff_jll]] +deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] +git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +uuid = "89763e89-9b03-5906-acba-b20f662cd828" +version = "4.5.1+1" + +[[deps.Libuuid_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" +uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" +version = "2.40.1+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.7" + +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] +git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.7.16" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.4.0+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.Measures]] +git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" +uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +version = "0.3.2" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.Ogg_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" +uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" +version = "1.3.5+1" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] +git-tree-sha1 = "2f0a1d8c79bc385ec3fcda12830c9d0e72b30e71" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "5.0.4+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.Opus_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +uuid = "91d4177d-7536-5919-b921-800302f37372" +version = "1.3.2+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PCRE2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.42.0+1" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PeriodicTable]] +deps = ["Base64", "Unitful"] +git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" +uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" +version = "1.2.1" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pipe]] +git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" +uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" +version = "1.3.0" + +[[deps.Pixman_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" +uuid = "30392449-352a-5448-841d-b1acce4e97dc" +version = "0.43.4+0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PlotThemes]] +deps = ["PlotUtils", "Statistics"] +git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" +uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" +version = "3.2.0" + +[[deps.PlotUtils]] +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] +git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" +version = "1.4.1" + +[[deps.Plots]] +deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] +git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" +uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +version = "1.40.5" + + [deps.Plots.extensions] + FileIOExt = "FileIO" + GeometryBasicsExt = "GeometryBasics" + IJuliaExt = "IJulia" + ImageInTerminalExt = "ImageInTerminal" + UnitfulExt = "Unitful" + + [deps.Plots.weakdeps] + FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" + GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" + IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" + ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.Qt6Base_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] +git-tree-sha1 = "492601870742dcd38f233b23c3ec629628c1d724" +uuid = "c0090381-4147-56d7-9ebc-da0b1113ec56" +version = "6.7.1+1" + +[[deps.Qt6Declarative_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6ShaderTools_jll"] +git-tree-sha1 = "e5dd466bf2569fe08c91a2cc29c1003f4797ac3b" +uuid = "629bc702-f1f5-5709-abd5-49b8460ea067" +version = "6.7.1+2" + +[[deps.Qt6ShaderTools_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll"] +git-tree-sha1 = "1a180aeced866700d4bebc3120ea1451201f16bc" +uuid = "ce943373-25bb-56aa-8eca-768745ed7b5a" +version = "6.7.1+1" + +[[deps.Qt6Wayland_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6Declarative_jll"] +git-tree-sha1 = "729927532d48cf79f49070341e1d918a65aba6b0" +uuid = "e99dba38-086e-5de3-a5b1-6e4c66e897c3" +version = "6.7.1+1" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecipesPipeline]] +deps = ["Dates", "NaNMath", "PlotUtils", "PrecompileTools", "RecipesBase"] +git-tree-sha1 = "45cf9fd0ca5839d06ef333c8201714e888486342" +uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" +version = "0.6.12" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.1" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.Showoff]] +deps = ["Dates", "Grisu"] +git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" +uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" +version = "1.0.3" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.1" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnicodeFun]] +deps = ["REPL"] +git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" +uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" +version = "0.4.1" + +[[deps.Unitful]] +deps = ["Dates", "LinearAlgebra", "Random"] +git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" +uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" +version = "1.21.0" +weakdeps = ["ConstructionBase", "InverseFunctions"] + + [deps.Unitful.extensions] + ConstructionBaseUnitfulExt = "ConstructionBase" + InverseFunctionsUnitfulExt = "InverseFunctions" + +[[deps.UnitfulAtomic]] +deps = ["Unitful"] +git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" +uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" +version = "1.0.0" + +[[deps.UnitfulLatexify]] +deps = ["LaTeXStrings", "Latexify", "Unitful"] +git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" +uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" +version = "1.6.4" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.5" + +[[deps.Unzip]] +git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" +uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" +version = "0.2.0" + +[[deps.VectorInterface]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" +uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +version = "0.4.6" + +[[deps.Vulkan_Loader_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Wayland_jll", "Xorg_libX11_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +git-tree-sha1 = "2f0486047a07670caad3a81a075d2e518acc5c59" +uuid = "a44049a8-05dd-5a78-86c9-5fde0876e88c" +version = "1.3.243+0" + +[[deps.Wayland_jll]] +deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] +git-tree-sha1 = "7558e29847e99bc3f04d6569e82d0f5c54460703" +uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" +version = "1.21.0+1" + +[[deps.Wayland_protocols_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "93f43ab61b16ddfb2fd3bb13b3ce241cafb0e6c9" +uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" +version = "1.31.0+0" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] +git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.13.1+0" + +[[deps.XSLT_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] +git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" +uuid = "aed1982a-8fda-507f-9586-7b0439959a61" +version = "1.1.41+0" + +[[deps.XZ_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" +uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" +version = "5.4.6+0" + +[[deps.Xorg_libICE_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" +uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" +version = "1.1.1+0" + +[[deps.Xorg_libSM_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] +git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" +uuid = "c834827a-8449-5923-a945-d239c165b7dd" +version = "1.2.4+0" + +[[deps.Xorg_libX11_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] +git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" +uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" +version = "1.8.6+0" + +[[deps.Xorg_libXau_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" +uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" +version = "1.0.11+0" + +[[deps.Xorg_libXcursor_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXfixes_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "12e0eb3bc634fa2080c1c37fccf56f7c22989afd" +uuid = "935fb764-8cf2-53bf-bb30-45bb1f8bf724" +version = "1.2.0+4" + +[[deps.Xorg_libXdmcp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" +uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" +version = "1.1.4+0" + +[[deps.Xorg_libXext_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" +uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" +version = "1.3.6+0" + +[[deps.Xorg_libXfixes_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] +git-tree-sha1 = "0e0dc7431e7a0587559f9294aeec269471c991a4" +uuid = "d091e8ba-531a-589c-9de9-94069b037ed8" +version = "5.0.3+4" + +[[deps.Xorg_libXi_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXfixes_jll"] +git-tree-sha1 = "89b52bc2160aadc84d707093930ef0bffa641246" +uuid = "a51aa0fd-4e3c-5386-b890-e753decda492" +version = "1.7.10+4" + +[[deps.Xorg_libXinerama_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll"] +git-tree-sha1 = "26be8b1c342929259317d8b9f7b53bf2bb73b123" +uuid = "d1454406-59df-5ea1-beac-c340f2130bc3" +version = "1.1.4+4" + +[[deps.Xorg_libXrandr_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll"] +git-tree-sha1 = "34cea83cb726fb58f325887bf0612c6b3fb17631" +uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" +version = "1.5.2+4" + +[[deps.Xorg_libXrender_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" +uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" +version = "0.9.11+0" + +[[deps.Xorg_libpthread_stubs_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" +uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" +version = "0.1.1+0" + +[[deps.Xorg_libxcb_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] +git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" +uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" +version = "1.17.0+0" + +[[deps.Xorg_libxkbfile_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "730eeca102434283c50ccf7d1ecdadf521a765a4" +uuid = "cc61e674-0454-545c-8b26-ed2c68acab7a" +version = "1.1.2+0" + +[[deps.Xorg_xcb_util_cursor_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_jll", "Xorg_xcb_util_renderutil_jll"] +git-tree-sha1 = "04341cb870f29dcd5e39055f895c39d016e18ccd" +uuid = "e920d4aa-a673-5f3a-b3d7-f755a4d47c43" +version = "0.1.4+0" + +[[deps.Xorg_xcb_util_image_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "0fab0a40349ba1cba2c1da699243396ff8e94b97" +uuid = "12413925-8142-5f55-bb0e-6d7ca50bb09b" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll"] +git-tree-sha1 = "e7fd7b2881fa2eaa72717420894d3938177862d1" +uuid = "2def613f-5ad1-5310-b15b-b15d46f528f5" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_keysyms_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "d1151e2c45a544f32441a567d1690e701ec89b00" +uuid = "975044d2-76e6-5fbe-bf08-97ce7c6574c7" +version = "0.4.0+1" + +[[deps.Xorg_xcb_util_renderutil_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "dfd7a8f38d4613b6a575253b3174dd991ca6183e" +uuid = "0d47668e-0667-5a69-a72c-f761630bfb7e" +version = "0.3.9+1" + +[[deps.Xorg_xcb_util_wm_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] +git-tree-sha1 = "e78d10aab01a4a154142c5006ed44fd9e8e31b67" +uuid = "c22f9ab0-d5fe-5066-847c-f4bb1cd4e361" +version = "0.4.1+1" + +[[deps.Xorg_xkbcomp_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxkbfile_jll"] +git-tree-sha1 = "330f955bc41bb8f5270a369c473fc4a5a4e4d3cb" +uuid = "35661453-b289-5fab-8a00-3d9160c6a3a4" +version = "1.4.6+0" + +[[deps.Xorg_xkeyboard_config_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xkbcomp_jll"] +git-tree-sha1 = "691634e5453ad362044e2ad653e79f3ee3bb98c3" +uuid = "33bec58e-1273-512f-9401-5d533626f822" +version = "2.39.0+0" + +[[deps.Xorg_xtrans_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" +uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" +version = "1.5.0+0" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zstd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" +version = "1.5.6+0" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.eudev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] +git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" +uuid = "35ca27e7-8b34-5b7f-bca9-bdc33f59eb06" +version = "3.2.9+0" + +[[deps.fzf_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" +uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" +version = "0.43.0+0" + +[[deps.gperf_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "3516a5630f741c9eecb3720b1ec9d8edc3ecc033" +uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" +version = "3.1.1+0" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+0" + +[[deps.libaom_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" +uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" +version = "3.9.0+0" + +[[deps.libass_jll]] +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] +git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" +version = "0.15.1+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.libevdev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" +uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" +version = "1.11.0+0" + +[[deps.libfdk_aac_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" +version = "2.0.2+0" + +[[deps.libinput_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] +git-tree-sha1 = "ad50e5b90f222cfe78aa3d5183a20a12de1322ce" +uuid = "36db933b-70db-51c0-b978-0f229ee0e533" +version = "1.18.0+0" + +[[deps.libpng_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" +version = "1.6.43+1" + +[[deps.libvorbis_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] +git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" +uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" +version = "1.3.7+2" + +[[deps.mtdev_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "814e154bdb7be91d78b6802843f76b6ece642f11" +uuid = "009596ad-96f7-51b1-9f1b-5ce2d5e8a71e" +version = "1.1.6+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" + +[[deps.x264_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" +uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" +version = "2021.5.5+0" + +[[deps.x265_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" +uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" +version = "3.5.0+0" + +[[deps.xkbcommon_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] +git-tree-sha1 = "9c304562909ab2bab0262639bd4f444d7bc2be37" +uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" +version = "1.4.1+1" +""" + +# ╔═╡ Cell order: +# ╟─5fdab668-4003-11ee-33f5-3953225b0c0f +# ╟─3dd0ce32-2339-4d5a-9a6f-1f662bc5500b +# ╠═1f95ad97-a007-4724-84db-392b0026e1a4 +# ╟─ec5caeb6-1f95-4cb9-8739-8cadba29a22d +# ╠═f531e39c-6842-494a-b4ac-8904321098c9 +# ╠═d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad +# ╟─dc2d5e98-2201-4754-bfc6-8ed2bbb82153 +# ╠═0dde5fd3-72d0-4b15-afb3-9a5b102327c9 +# ╟─f7a6d572-28cf-4d69-a9be-d49f367eca37 +# ╠═3d5503bc-bb97-422e-9465-becc7d3dbe07 +# ╟─3569715d-08f5-4605-b946-9ef7ccd86ae5 +# ╠═aa4eb172-2a42-4c01-a6ef-c6c95208d5b2 +# ╠═367ed417-4f53-44d4-8135-0c91c842a75f +# ╠═7c084eaa-655c-4251-a342-6b6f4df76ddb +# ╠═bf0d820d-32c0-4731-8053-53d5d499e009 +# ╠═cb89d1a3-b4ff-421a-8717-a0b7f21dea1a +# ╟─3b49a612-3a04-4eb5-bfbc-360614f4581a +# ╠═95d8bd24-a40d-409f-a1e7-4174428ef860 +# ╟─fde2ac9e-b121-4105-8428-1820b9c17a43 +# ╠═111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3 +# ╟─572a6633-875b-4d7e-9afc-543b442948fb +# ╠═5502f4fa-3201-4980-b766-2ab88b175b11 +# ╟─4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e +# ╟─755a88c2-c2e5-46d1-9582-af4b2c5a6bbd +# ╠═e83253b2-9f3a-44e2-a747-cce1661657c4 +# ╠═85a923da-3027-4f71-8db6-96852c115c03 +# ╠═39c82234-97ea-48d6-98dd-915f072b7f85 +# ╠═8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4 +# ╠═2c5f6250-ee7a-41b1-9551-bcfeba83ca8b +# ╠═1008dad4-d784-4c38-a7cf-d9b64728e28d +# ╟─8d0e8b9f-226f-4bff-9deb-046e6a897b71 +# ╟─a7e4bb23-6687-476a-a0c2-1b2736873d9d +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 + +[.\examples\graph_classification_temporalbrains.jl] +# Example of graph classification when graphs are temporal and modeled as `TemporalSnapshotsGNNGraphs'. +# In this code, we train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. +# The dataset used is the TemporalBrains dataset from the MLDataset.jl package, and the accuracy achieved with the model reaches 65-70% (it can be improved by fine-tuning the parameters of the model). +# Author: Aurora Rossi + +# Load packages +using Flux +using Flux.Losses: mae +using GraphNeuralNetworks +using CUDA +using Statistics, Random +using LinearAlgebra +using MLDatasets +CUDA.allowscalar(false) + +# Load data +MLdataset = TemporalBrains() +graphs = MLdataset.graphs + +# Function to transform the graphs from the MLDatasets format to the TemporalSnapshotsGNNGraph format +# and split the dataset into a training and a test set +function data_loader(graphs) + dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) + for i in 1:length(graphs) + gr = graphs[i] + dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(gr.snapshots)) + for t in 1:27 + dataset[i].snapshots[t].ndata.x = reduce( + vcat, [I(102), dataset[i].snapshots[t].ndata.x']) + end + dataset[i].tgdata.g = Float32.(Array(Flux.onehot(gr.graph_data.g, ["F", "M"]))) + end + # Split the dataset into a 80% training set and a 20% test set + train_loader = dataset[1:800] + test_loader = dataset[801:1000] + return train_loader, test_loader +end + +# Arguments for the train function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + epochs = 200 # number of epochs + seed = -5 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 128 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +# Adapt GlobalPool to work with TemporalSnapshotsGNNGraph +function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) + h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] + sze = size(h[1]) + reshape(reduce(hcat, h), sze[1], length(h)) +end + +# Define the model +struct GenderPredictionModel + gin::GINConv + mlp::Chain + globalpool::GlobalPool + f::Function + dense::Dense +end + +Flux.@layer GenderPredictionModel + +function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) + mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) + gin = GINConv(mlp, 0.5) + globalpool = GlobalPool(mean) + f = x -> mean(x, dims = 2) + dense = Dense(nhidden, 2) + GenderPredictionModel(gin, mlp, globalpool, f, dense) +end + +function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) + h = m.gin(g, g.ndata.x) + h = m.globalpool(g, h) + h = m.f(h) + m.dense(h) +end + +# Train the model + +function train(graphs; kws...) + args = Args(; kws...) + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + my_device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + my_device = cpu + @info "Training on CPU" + end + + lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y) |> my_device + + function eval_loss_accuracy(model, data_loader) + error = mean([lossfunction(model(g), gpu(g.tgdata.g)) for g in data_loader]) + acc = mean([round( + 100 * + mean(Flux.onecold(model(g)) .== Flux.onecold(gpu(g.tgdata.g))); + digits = 2) for g in data_loader]) + return (loss = error, acc = acc) + end + + function report(epoch) + train_loss, train_acc = eval_loss_accuracy(model, train_loader) + test_loss, test_acc = eval_loss_accuracy(model, test_loader) + println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") + return (train_loss, train_acc, test_loss, test_acc) + end + + model = GenderPredictionModel() |> my_device + + opt = Flux.setup(Adam(args.η), model) + + train_loader, test_loader = data_loader(graphs) # it takes a while to load the data + + train_loader = train_loader |> my_device + test_loader = test_loader |> my_device + + report(0) + for epoch in 1:(args.epochs) + for g in train_loader + grads = Flux.gradient(model) do model + ŷ = model(g) + lossfunction(vec(ŷ), g.tgdata.g) + end + Flux.update!(opt, model, grads[1]) + end + if args.infotime > 0 && epoch % args.infotime == 0 + report(epoch) + end + end + return model +end + +model = train(graphs) +[.\examples\graph_classification_tudataset.jl] +# An example of graph classification + +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitbinarycrossentropy +using Flux: DataLoader +using GraphNeuralNetworks +using MLDatasets: TUDataset +using Statistics, Random +using MLUtils +using CUDA +CUDA.allowscalar(false) + +function eval_loss_accuracy(model, data_loader, device) + loss = 0.0 + acc = 0.0 + ntot = 0 + for (g, y) in data_loader + g, y = (g, y) |> device + n = length(y) + ŷ = model(g, g.ndata.x) |> vec + loss += logitbinarycrossentropy(ŷ, y) * n + acc += mean((ŷ .> 0) .== y) * n + ntot += n + end + return (loss = round(loss / ntot, digits = 4), + acc = round(acc * 100 / ntot, digits = 2)) +end + +function getdataset() + tudata = TUDataset("MUTAG") + display(tudata) + graphs = mldataset2gnngraph(tudata) + oh(x) = Float32.(onehotbatch(x, 0:6)) + graphs = [GNNGraph(g, ndata = oh(g.ndata.targets)) for g in graphs] + y = (1 .+ Float32.(tudata.graph_data.targets)) ./ 2 + @assert all(∈([0, 1]), y) # binary classification + return graphs, y +end + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + batchsize = 32 # batch size (number of graphs in each batch) + epochs = 200 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 128 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +function train(; kws...) + args = Args(; kws...) + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + # LOAD DATA + NUM_TRAIN = 150 + + dataset = getdataset() + train_data, test_data = splitobs(dataset, at = NUM_TRAIN, shuffle = true) + + train_loader = DataLoader(train_data; args.batchsize, shuffle = true, collate = true) + test_loader = DataLoader(test_data; args.batchsize, shuffle = false, collate = true) + + # DEFINE MODEL + + nin = size(dataset[1][1].ndata.x, 1) + nhidden = args.nhidden + + model = GNNChain(GraphConv(nin => nhidden, relu), + GraphConv(nhidden => nhidden, relu), + GlobalPool(mean), + Dense(nhidden, 1)) |> device + + opt = Flux.setup(Adam(args.η), model) + + # LOGGING FUNCTION + + function report(epoch) + train = eval_loss_accuracy(model, train_loader, device) + test = eval_loss_accuracy(model, test_loader, device) + println("Epoch: $epoch Train: $(train) Test: $(test)") + end + + # TRAIN + + report(0) + for epoch in 1:(args.epochs) + for (g, y) in train_loader + g, y = (g, y) |> device + grads = Flux.gradient(model) do model + ŷ = model(g, g.ndata.x) |> vec + logitbinarycrossentropy(ŷ, y) + end + Flux.update!(opt, model, grads[1]) + end + epoch % args.infotime == 0 && report(epoch) + end +end + +train() + +[.\examples\link_prediction_pubmed.jl] +# An example of link prediction using negative and positive samples. +# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py +# See the comparison paper https://arxiv.org/pdf/2102.12557.pdf for more details + +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitbinarycrossentropy +using GraphNeuralNetworks +using MLDatasets: PubMed +using Statistics, Random, LinearAlgebra +using CUDA +CUDA.allowscalar(false) + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + epochs = 200 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 64 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +# We define our own edge prediction layer but could also +# use GraphNeuralNetworks.DotDecoder instead. +struct DotPredictor end + +function (::DotPredictor)(g, x) + z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims = 1), g, xi = x, xj = x) + # z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with built-in method + return vec(z) +end + +function train(; kws...) + args = Args(; kws...) + + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + ### LOAD DATA + g = mldataset2gnngraph(PubMed()) + + # Print some info + display(g) + @show is_bidirected(g) + @show has_self_loops(g) + @show has_multi_edges(g) + @show mean(degree(g)) + isbidir = is_bidirected(g) + + # Move to device + g = g |> device + X = g.ndata.features + + #### TRAIN/TEST splits + # With bidirected graph, we make sure that an edge and its reverse + # are in the same split + train_pos_g, test_pos_g = rand_edge_split(g, 0.9, bidirected = isbidir) + test_neg_g = negative_sample(g, num_neg_edges = test_pos_g.num_edges, + bidirected = isbidir) + + ### DEFINE MODEL ######### + nin, nhidden = size(X, 1), args.nhidden + + # We embed the graph with positive training edges in the model + model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu), + GCNConv(nhidden => nhidden)), + train_pos_g) |> device + + pred = DotPredictor() + + opt = Flux.setup(Adam(args.η), model) + + ### LOSS FUNCTION ############ + + function loss(model, pos_g, neg_g = nothing; with_accuracy = false) + h = model(X) + if neg_g === nothing + # We sample a negative graph at each training step + neg_g = negative_sample(pos_g, bidirected = isbidir) + end + pos_score = pred(pos_g, h) + neg_score = pred(neg_g, h) + scores = [pos_score; neg_score] + labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)] + l = logitbinarycrossentropy(scores, labels) + if with_accuracy + acc = 0.5 * mean(pos_score .>= 0) + 0.5 * mean(neg_score .< 0) + return l, acc + else + return l + end + end + + ### LOGGING FUNCTION + function report(epoch) + train_loss, train_acc = loss(model, train_pos_g, with_accuracy = true) + test_loss, test_acc = loss(model, test_pos_g, test_neg_g, with_accuracy = true) + println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") + end + + ### TRAINING + report(0) + for epoch in 1:(args.epochs) + grads = Flux.gradient(model -> loss(model, train_pos_g), model) + Flux.update!(opt, model, grads[1]) + epoch % args.infotime == 0 && report(epoch) + end +end + +train() + +[.\examples\neural_ode_cora.jl] +# Load the packages +using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations +using Flux: onehotbatch, onecold +using Flux.Losses: logitcrossentropy +using Flux +using Statistics: mean +using MLDatasets: Cora +using CUDA +# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux + +# device = cpu # `gpu` not working yet +device = CUDA.functional() ? gpu : cpu + +# LOAD DATA +dataset = Cora() +classes = dataset.metadata["classes"] +g = mldataset2gnngraph(dataset) |> device +X = g.ndata.features +y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged +(; train_mask, val_mask, test_mask) = g.ndata +ytrain = y[:, train_mask] + +# Model and Data Configuration +nin = size(X, 1) +nhidden = 16 +nout = length(classes) +epochs = 40 + +# Define the Neural GDE +diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) + +node_chain = GNNChain(GCNConv(nhidden => nhidden, relu), + GCNConv(nhidden => nhidden, relu)) |> device + +node = NeuralODE(WithGraph(node_chain, g), + (0.0f0, 1.0f0), Tsit5(), save_everystep = false, + reltol = 1e-3, abstol = 1e-3, save_start = false) |> device + +model = GNNChain(GCNConv(nin => nhidden, relu), + node, + diffeqsol_to_array, + Dense(nhidden, nout)) |> device + +# # Training + +opt = Flux.setup(Adam(0.01), model) + +function eval_loss_accuracy(X, y, mask) + ŷ = model(g, X) + l = logitcrossentropy(ŷ[:, mask], y[:, mask]) + acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask])) + return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) +end + +# ## Training Loop +for epoch in 1:epochs + grad = gradient(model) do model + ŷ = model(g, X) + logitcrossentropy(ŷ[:, train_mask], ytrain) + end + Flux.update!(opt, model, grad[1]) + @show eval_loss_accuracy(X, y, train_mask) +end + +[.\examples\node_classification_cora.jl] +# An example of semi-supervised node classification + +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitcrossentropy +using GraphNeuralNetworks +using MLDatasets: Cora +using Statistics, Random +using CUDA +CUDA.allowscalar(false) + +function eval_loss_accuracy(X, y, mask, model, g) + ŷ = model(g, X) + l = logitcrossentropy(ŷ[:, mask], y[:, mask]) + acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask])) + return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) +end + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + epochs = 100 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 128 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +function train(; kws...) + args = Args(; kws...) + + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + # LOAD DATA + dataset = Cora() + classes = dataset.metadata["classes"] + g = mldataset2gnngraph(dataset) |> device + X = g.features + y = onehotbatch(g.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged + ytrain = y[:, g.train_mask] + + nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) + + ## DEFINE MODEL + model = GNNChain(GCNConv(nin => nhidden, relu), + GCNConv(nhidden => nhidden, relu), + Dense(nhidden, nout)) |> device + + opt = Flux.setup(Adam(args.η), model) + + display(g) + + ## LOGGING FUNCTION + function report(epoch) + train = eval_loss_accuracy(X, y, g.train_mask, model, g) + test = eval_loss_accuracy(X, y, g.test_mask, model, g) + println("Epoch: $epoch Train: $(train) Test: $(test)") + end + + ## TRAINING + report(0) + for epoch in 1:(args.epochs) + grad = Flux.gradient(model) do model + ŷ = model(g, X) + logitcrossentropy(ŷ[:, g.train_mask], ytrain) + end + + Flux.update!(opt, model, grad[1]) + + epoch % args.infotime == 0 && report(epoch) + end +end + +train() + +[.\examples\traffic_prediction.jl] +# Example of using TGCN, a recurrent temporal graph convolutional network of the paper https://arxiv.org/pdf/1811.05320.pdf, for traffic prediction by training it on the METRLA dataset + +# Load packages +using Flux +using Flux.Losses: mae +using GraphNeuralNetworks +using MLDatasets: METRLA +using CUDA +using Statistics, Random +CUDA.allowscalar(false) + +# Import dataset function +function getdataset() + metrla = METRLA(; num_timesteps = 3) + g = metrla[1] + graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes) + features = g.node_data.features + targets = g.node_data.targets + train_loader = zip(features[1:2000], targets[1:2000]) + test_loader = zip(features[2001:2288], targets[2001:2288]) + return graph, train_loader, test_loader +end + +# Loss and accuracy functions +lossfunction(ŷ, y) = Flux.mae(ŷ, y) +accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y) + +function eval_loss_accuracy(model, graph, data_loader) + error = mean([lossfunction(model(graph,x), y) for (x, y) in data_loader]) + acc = mean([accuracy(model(graph,x), y) for (x, y) in data_loader]) + return (loss = round(error, digits = 4), acc = round(acc , digits = 4)) +end + +# Arguments for the train function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + epochs = 100 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 100 # dimension of hidden features + infotime = 20 # report every `infotime` epochs +end + +# Train function +function train(; kws...) + args = Args(; kws...) + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + # Define model + model = GNNChain(TGCN(2 => args.nhidden), Dense(args.nhidden, 1)) |> device + + opt = Flux.setup(Adam(args.η), model) + + graph, train_loader, test_loader = getdataset() + graph = graph |> device + train_loader = train_loader |> device + test_loader = test_loader |> device + + function report(epoch) + train_loss, train_acc = eval_loss_accuracy(model, graph, train_loader) + test_loss, test_acc = eval_loss_accuracy(model, graph, test_loader) + println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") + end + + report(0) + for epoch in 1:(args.epochs) + for (x, y) in train_loader + x, y = (x, y) + grads = Flux.gradient(model) do model + ŷ = model(graph, x) + lossfunction(y,ŷ) + end + Flux.update!(opt, model, grads[1]) + end + + args.infotime > 0 && epoch % args.infotime == 0 && report(epoch) + + end + return model +end + +train() + + +[.\GNNGraphs\ext\GNNGraphsCUDAExt.jl] +module GNNGraphsCUDAExt + +using CUDA +using Random, Statistics, LinearAlgebra +using GNNGraphs +using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T + +const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} + +# Query + +GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1)) + +# Transform + +GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz) + + +# Utils + +GNNGraphs.iscuarray(x::AnyCuArray) = true + + +function sort_edge_index(u::AnyCuArray, v::AnyCuArray) + dev = get_device(u) + cdev = cpu_device() + u, v = u |> cdev, v |> cdev + #TODO proper cuda friendly implementation + sort_edge_index(u, v) |> dev +end + + +end #module + +[.\GNNGraphs\ext\GNNGraphsSimpleWeightedGraphsExt.jl] +module GNNGraphsSimpleWeightedGraphsExt + +using Graphs +using GNNGraphs +using SimpleWeightedGraphs + +function GNNGraphs.GNNGraph(g::T; kws...) where + {T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}} + return GNNGraph(g.weights, kws...) +end + +end #module +[.\GNNGraphs\src\abstracttypes.jl] + +const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}} +const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} +const ADJMAT_T = AbstractMatrix +const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T + +const AVecI = AbstractVector{<:Integer} + +# All concrete graph types should be subtypes of AbstractGNNGraph{T}. +# GNNGraph and GNNHeteroGraph are the two concrete types. +abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end + +[.\GNNGraphs\src\chainrules.jl] +# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648 +# Remove when merged + +function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict} + ks = map(first, ps) + project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps) + function Dict_pullback(ȳ) + dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v + dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent())) + Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) + end + return (NoTangent(), dps...) + end + return T(ps...), Dict_pullback +end + +[.\GNNGraphs\src\convert.jl] +### CONVERT_TO_COO REPRESENTATION ######## + +function to_coo(data::EDict; num_nodes = nothing, kws...) + graph = EDict{COO_T}() + _num_nodes = NDict{Int}() + num_edges = EDict{Int}() + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) + end + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + end + return graph, _num_nodes, num_edges +end + +function to_coo(coo::COO_T; dir = :out, num_nodes = nothing, weighted = true, + hetero = false) + s, t, val = coo + + if isnothing(num_nodes) + ns = maximum(s) + nt = maximum(t) + num_nodes = hetero ? (ns, nt) : max(ns, nt) + elseif num_nodes isa Integer + ns = num_nodes + nt = num_nodes + elseif num_nodes isa Tuple + ns = isnothing(num_nodes[1]) ? maximum(s) : num_nodes[1] + nt = isnothing(num_nodes[2]) ? maximum(t) : num_nodes[2] + num_nodes = (ns, nt) + else + error("Invalid num_nodes $num_nodes") + end + @assert isnothing(val) || length(val) == length(s) + @assert length(s) == length(t) + if !isempty(s) + @assert minimum(s) >= 1 + @assert minimum(t) >= 1 + @assert maximum(s) <= ns + @assert maximum(t) <= nt + end + num_edges = length(s) + if !weighted + coo = (s, t, nothing) + end + return coo, num_nodes, num_edges +end + +function to_coo(A::SPARSE_T; dir = :out, num_nodes = nothing, weighted = true) + s, t, v = findnz(A) + if dir == :in + s, t = t, s + end + num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes + num_edges = length(s) + if !weighted + v = nothing + end + return (s, t, v), num_nodes, num_edges +end + +function _findnz_idx(A) + nz = findall(!=(0), A) # vec of cartesian indexes + s, t = ntuple(i -> map(t -> t[i], nz), 2) + return s, t, nz +end + +@non_differentiable _findnz_idx(A) + +function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true) + s, t, nz = _findnz_idx(A) + v = A[nz] + if dir == :in + s, t = t, s + end + num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes + num_edges = length(s) + if !weighted + v = nothing + end + return (s, t, v), num_nodes, num_edges +end + +function to_coo(adj_list::ADJLIST_T; dir = :out, num_nodes = nothing, weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + s = similar(adj_list[1], eltype(adj_list[1]), num_edges) + t = similar(adj_list[1], eltype(adj_list[1]), num_edges) + e = 0 + for i in 1:num_nodes + for j in adj_list[i] + e += 1 + s[e] = i + t[e] = j + end + end + @assert e == num_edges + if dir == :in + s, t = t, s + end + (s, t, nothing), num_nodes, num_edges +end + +### CONVERT TO ADJACENCY MATRIX ################ + +### DENSE #################### + +to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...) + +function to_dense(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + T = T === nothing ? eltype(A) : T + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + # @assert all(x -> (x == 1) || (x == 0), A) + num_edges = numnonzeros(A) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + if !weighted + A = map(x -> ifelse(x > 0, T(1), T(0)), A) + end + return A, num_nodes, num_edges +end + +function to_dense(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + T = T === nothing ? eltype(adj_list[1]) : T + A = fill!(similar(adj_list[1], T, (num_nodes, num_nodes)), 0) + if dir == :out + for (i, neigs) in enumerate(adj_list) + A[i, neigs] .= 1 + end + else + for (i, neigs) in enumerate(adj_list) + A[neigs, i] .= 1 + end + end + A, num_nodes, num_edges +end + +function to_dense(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, weighted = true) + # `dir` will be ignored since the input `coo` is always in source -> target format. + # The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j) + s, t, val = coo + n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + if T === nothing + T = isnothing(val) ? eltype(s) : eltype(val) + end + if val === nothing || !weighted + val = ones_like(s, T) + end + if eltype(val) != T + val = T.(val) + end + + idxs = s .+ n .* (t .- 1) + + ## using scatter instead of indexing since there could be multiple edges + # A = fill!(similar(s, T, (n, n)), 0) + # v = vec(A) # vec view of A + # A[idxs] .= val # exploiting linear indexing + v = NNlib.scatter(+, val, idxs, dstsize = n^2) + A = reshape(v, (n, n)) + return A, n, length(s) +end + +### SPARSE ############# + +function to_sparse(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + T = T === nothing ? eltype(A) : T + num_edges = A isa AbstractSparseMatrix ? nnz(A) : count(!=(0), A) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + if !(A isa AbstractSparseMatrix) + A = sparse(A) + end + if !weighted + A = map(x -> ifelse(x > 0, T(1), T(0)), A) + end + return A, num_nodes, num_edges +end + +function to_sparse(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes) + return to_sparse(coo; num_nodes) +end + +function to_sparse(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + s, t, eweight = coo + T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T + + if eweight === nothing || !weighted + eweight = fill!(similar(s, T), 1) + end + + num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + A = sparse(s, t, eweight, num_nodes, num_nodes) + num_edges::Int = nnz(A) + if eltype(A) != T + A = T.(A) + end + return A, num_nodes, num_edges +end + +[.\GNNGraphs\src\datastore.jl] +""" + DataStore([n, data]) + DataStore([n,] k1 = x1, k2 = x2, ...) + +A container for feature arrays. The optional argument `n` enforces that +`numobs(x) == n` for each array contained in the datastore. + +At construction time, the `data` can be provided as any iterables of pairs +of symbols and arrays or as keyword arguments: + +```jldoctest +julia> ds = DataStore(3, x = rand(Float32, 2, 3), y = rand(Float32, 3)) +DataStore(3) with 2 elements: + y = 3-element Vector{Float32} + x = 2×3 Matrix{Float32} + +julia> ds = DataStore(3, Dict(:x => rand(Float32, 2, 3), :y => rand(Float32, 3))); # equivalent to above + +julia> ds = DataStore(3, (x = rand(Float32, 2, 3), y = rand(Float32, 30))) +ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3 +Stacktrace: + [1] DataStore(n::Int64, data::Dict{Symbol, Any}) + @ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:54 + [2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float32}, Vector{Float32}}}) + @ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:73 + [3] top-level scope + @ REPL[13]:1 + +julia> ds = DataStore(x = randFloat32, 2, 3), y = rand(Float32, 30)) # no checks +DataStore() with 2 elements: + y = 30-element Vector{Float32} + x = 2×3 Matrix{Float32} + y = 30-element Vector{Float64} + x = 2×3 Matrix{Float64} +``` + +The `DataStore` has an interface similar to both dictionaries and named tuples. +Arrays can be accessed and added using either the indexing or the property syntax: + +```jldoctest +julia> ds = DataStore(x = ones(Float32, 2, 3), y = zeros(Float32, 3)) +DataStore() with 2 elements: + y = 3-element Vector{Float32} + x = 2×3 Matrix{Float32} + +julia> ds.x # same as `ds[:x]` +2×3 Matrix{Float32}: + 1.0 1.0 1.0 + 1.0 1.0 1.0 + +julia> ds.z = zeros(Float32, 3) # Add new feature array `z`. Same as `ds[:z] = rand(Float32, 3)` +3-element Vector{Float64}: +0.0 +0.0 +0.0 +``` + +The `DataStore` can be iterated over, and the keys and values can be accessed +using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f` +to each feature array: + +```jldoctest +julia> ds = DataStore(a = zeros(2), b = zeros(2)); + +julia> ds2 = map(x -> x .+ 1, ds) + +julia> ds2.a +2-element Vector{Float64}: + 1.0 + 1.0 +``` +""" +struct DataStore + _n::Int # either -1 or numobs(data) + _data::Dict{Symbol, Any} + + function DataStore(n::Int, data::Dict{Symbol, Any}) + if n >= 0 + for (k, v) in data + @assert numobs(v)==n "DataStore: data[$k] has $(numobs(v)) observations, but n = $n" + end + end + return new(n, data) + end +end + +@functor DataStore + +DataStore(data) = DataStore(-1, data) +DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol, Any}(pairs(data))) +DataStore(n::Int, data) = DataStore(n, Dict{Symbol, Any}(data)) + +DataStore(; kws...) = DataStore(-1; kws...) +DataStore(n::Int; kws...) = DataStore(n, Dict{Symbol, Any}(kws...)) + +getdata(ds::DataStore) = getfield(ds, :_data) +getn(ds::DataStore) = getfield(ds, :_n) +# setn!(ds::DataStore, n::Int) = setfield!(ds, :n, n) + +function Base.getproperty(ds::DataStore, s::Symbol) + if s === :_n + return getn(ds) + elseif s === :_data + return getdata(ds) + else + return getdata(ds)[s] + end +end + +function Base.getproperty(vds::Vector{DataStore}, s::Symbol) + if s === :_n + return [getn(ds) for ds in vds] + elseif s === :_data + return [getdata(ds) for ds in vds] + else + return [getdata(ds)[s] for ds in vds] + end +end + +function Base.setproperty!(ds::DataStore, s::Symbol, x) + @assert s != :_n "cannot set _n directly" + @assert s != :_data "cannot set _data directly" + if getn(ds) >= 0 + numobs(x) == getn(ds) || throw(DimensionMismatch("expected $(getn(ds)) object features but got $(numobs(x)).")) + end + return getdata(ds)[s] = x +end + +Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s) +Base.setindex!(ds::DataStore, x, s::Symbol) = setproperty!(ds, s, x) + +function Base.show(io::IO, ds::DataStore) + len = length(ds) + n = getn(ds) + if n < 0 + print(io, "DataStore()") + else + print(io, "DataStore($(getn(ds)))") + end + if len > 0 + print(io, " with $(length(getdata(ds))) element") + len > 1 && print(io, "s") + print(io, ":") + for (k, v) in getdata(ds) + print(io, "\n $(k) = $(summary(v))") + end + else + print(io, " with no elements") + end +end + +Base.iterate(ds::DataStore) = iterate(getdata(ds)) +Base.iterate(ds::DataStore, state) = iterate(getdata(ds), state) +Base.keys(ds::DataStore) = keys(getdata(ds)) +Base.values(ds::DataStore) = values(getdata(ds)) +Base.length(ds::DataStore) = length(getdata(ds)) +Base.haskey(ds::DataStore, k) = haskey(getdata(ds), k) +Base.get(ds::DataStore, k, default) = get(getdata(ds), k, default) +Base.pairs(ds::DataStore) = pairs(getdata(ds)) +Base.:(==)(ds1::DataStore, ds2::DataStore) = getdata(ds1) == getdata(ds2) +Base.isempty(ds::DataStore) = isempty(getdata(ds)) +Base.delete!(ds::DataStore, k) = delete!(getdata(ds), k) + +function Base.map(f, ds::DataStore) + d = getdata(ds) + newd = Dict{Symbol, Any}(k => f(v) for (k, v) in d) + return DataStore(getn(ds), newd) +end + +MLUtils.numobs(ds::DataStore) = numobs(getdata(ds)) + +function MLUtils.getobs(ds::DataStore, i::Int) + newdata = getobs(getdata(ds), i) + return DataStore(-1, newdata) +end + +function MLUtils.getobs(ds::DataStore, + i::AbstractVector{T}) where {T <: Union{Integer, Bool}} + newdata = getobs(getdata(ds), i) + n = getn(ds) + if n >= 0 + if length(ds) > 0 + n = numobs(newdata) + else + # if newdata is empty, then we can't get the number of observations from it + n = T == Bool ? sum(i) : length(i) + end + end + if !(newdata isa Dict{Symbol, Any}) + newdata = Dict{Symbol, Any}(newdata) + end + return DataStore(n, newdata) +end + +function cat_features(ds1::DataStore, ds2::DataStore) + n1, n2 = getn(ds1), getn(ds2) + n1 = n1 >= 0 ? n1 : 1 + n2 = n2 >= 0 ? n2 : 1 + return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2))) +end + +function cat_features(dss::AbstractVector{DataStore}; kws...) + ns = getn.(dss) + ns = map(n -> n >= 0 ? n : 1, ns) + return DataStore(sum(ns), cat_features(getdata.(dss); kws...)) +end + +# DataStore is always already normalized +normalize_graphdata(ds::DataStore; kws...) = ds + +_gather(x::DataStore, i) = map(x -> _gather(x, i), x) + +function _scatter(aggr, src::DataStore, idx, n) + newdata = _scatter(aggr, getdata(src), idx, n) + if !(newdata isa Dict{Symbol, Any}) + newdata = Dict{Symbol, Any}(newdata) + end + return DataStore(n, newdata) +end + +function Base.hash(ds::D, h::UInt) where {D <: DataStore} + fs = (getfield(ds, k) for k in fieldnames(D)) + return foldl((h, f) -> hash(f, h), fs, init = hash(D, h)) +end + +[.\GNNGraphs\src\gatherscatter.jl] +_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x) +_gather(x::Dict, i) = Dict([k => _gather(v, i) for (k, v) in x]...) +_gather(x::Tuple, i) = map(x -> _gather(x, i), x) +_gather(x::AbstractArray, i) = NNlib.gather(x, i) +_gather(x::Nothing, i) = nothing + +_scatter(aggr, src::Nothing, idx, n) = nothing +_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr, src::Dict, idx, n) = Dict([k => _scatter(aggr, v, idx, n) for (k, v) in src]...) + +function _scatter(aggr, + src::AbstractArray, + idx::AbstractVector{<:Integer}, + n::Integer) + dstsize = (size(src)[1:(end - 1)]..., n) + return NNlib.scatter(aggr, src, idx; dstsize) +end + +[.\GNNGraphs\src\generate.jl] +""" + rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...) + +Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. + +If `bidirected=true` the reverse edge of each edge will be present. +If `bidirected=false` instead, `m` unrelated edges are generated. +In any case, the output graph will contain no self-loops or multi-edges. + +A vector can be passed as `edge_weight`. Its length has to be equal to `m` +in the directed case, and `m÷2` in the bidirected one. + +Pass a random number generator as the first argument to make the generation reproducible. + +Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> g = rand_graph(5, 4, bidirected=false) +GNNGraph: + num_nodes = 5 + num_edges = 4 + +julia> edge_index(g) +([1, 3, 3, 4], [5, 4, 5, 2]) + +# In the bidirected case, edge data will be duplicated on the reverse edges if needed. +julia> g = rand_graph(5, 4, edata=rand(Float32, 16, 2)) +GNNGraph: + num_nodes = 5 + num_edges = 4 + edata: + e => (16, 4) + +# Each edge has a reverse +julia> edge_index(g) +([1, 3, 3, 4], [3, 4, 1, 3]) +``` +""" +function rand_graph(n::Integer, m::Integer; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_graph(rng, n, m; kws...) +end + +function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; + bidirected::Bool = true, + edge_weight::Union{AbstractVector, Nothing} = nothing, kws...) + if bidirected + @assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m." + s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false) + s, t = vcat(s, t), vcat(t, s) + if edge_weight !== nothing + edge_weight = vcat(edge_weight, edge_weight) + end + else + s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false) + end + return GNNGraph((s, t, edge_weight); num_nodes=n, kws...) +end + +""" + rand_heterograph([rng,] n, m; bidirected=false, kws...) + +Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges +specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs +specifing node/edge types and their numbers. + +Pass a random number generator as a first argument to make the generation reproducible. + +Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. +Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` +will be generated. + +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> g = rand_heterograph((:user => 10, :movie => 20), + (:user, :rate, :movie) => 30) +GNNHeteroGraph: + num_nodes: (:user => 10, :movie => 20) + num_edges: ((:user, :rate, :movie) => 30,) +``` +""" +function rand_heterograph end + +# for generic iterators of pairs +rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) +rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...) + +function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_heterograph(rng, n, m; kws...) +end + +function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...) + if bidirected + return _rand_bidirected_heterograph(rng, n, m; kws...) + end + graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m)) + return GNNHeteroGraph(graphs; num_nodes = n, kws...) +end + +function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...) + for k in keys(m) + if reverse(k) ∈ keys(m) + @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." + else + m[reverse(k)] = m[k] + end + end + graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}() + for k in keys(m) + reverse(k) ∈ keys(graphs) && continue + s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) + graphs[k] = s, t, val + graphs[reverse(k)] = t, s, val + end + return GNNHeteroGraph(graphs; num_nodes = n, kws...) +end + + +""" + rand_bipartite_heterograph([rng,] + (n1, n2), (m12, m21); + bidirected = true, + node_t = (:A, :B), + edge_t = :to, + kws...) + +Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph. +The graph will have two types of nodes, and edges will only connect nodes of different types. + +The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type. +The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2` +and vice versa. + +The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments, +which default to `(:A, :B)` and `:to` respectively. + +If `bidirected=true` (default), the reverse edge of each edge will be present. In this case +`m12 == m21` is required. + +A random number generator can be passed as the first argument to make the generation reproducible. + +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. + +See [`rand_heterograph`](@ref) for a more general version. + +# Examples + +```julia-repl +julia> g = rand_bipartite_heterograph((10, 15), 20) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 15) + num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20) + +julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false) +GNNHeteroGraph: + num_nodes: Dict(:item => 15, :user => 10) + num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20) +``` +""" +rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...) + +function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true, + node_t = (:A, :B), edge_t::Symbol = :to, kws...) + if m isa Integer + m12 = m21 = m + else + m12, m21 = m + end + + return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2), + Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21); + bidirected, kws...) +end + +""" + knn_graph(points::AbstractMatrix, + k::Int; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + +Create a `k`-nearest neighbor graph where each node is linked +to its `k` closest `points`. + +# Arguments + +- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. +- `k`: The number of neighbors considered in the kNN algorithm. +- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, + in which case the returned graph will be a batch of graphs. +- `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the `k` + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> n, k = 10, 3; + +julia> x = rand(Float32, 3, n); + +julia> g = knn_graph(x, k) +GNNGraph: + num_nodes = 10 + num_edges = 30 + +julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; + +julia> g = knn_graph(x, k; graph_indicator) +GNNGraph: + num_nodes = 10 + num_edges = 30 + num_graphs = 2 + +``` +""" +function knn_graph(points::AbstractMatrix, k::Int; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + if graph_indicator !== nothing + d, n = size(points) + @assert graph_indicator isa AbstractVector{<:Integer} + @assert length(graph_indicator) == n + # All graphs in the batch must have at least k nodes. + cm = StatsBase.countmap(graph_indicator) + @assert all(values(cm) .>= k) + + # Make sure that the distance between points in different graphs + # is always larger than any distance within the same graph. + points = points .- minimum(points) + points = points ./ maximum(points) + dummy_feature = 2d .* reshape(graph_indicator, 1, n) + points = vcat(points, dummy_feature) + end + + kdtree = NearestNeighbors.KDTree(points) + if !self_loops + k += 1 + end + sortres = false + idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres) + + g = GNNGraph(idxs; dir, graph_indicator, kws...) + if !self_loops + g = remove_self_loops(g) + end + return g +end + +""" + radius_graph(points::AbstractMatrix, + r::AbstractFloat; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + +Create a graph where each node is linked +to its neighbors within a given distance `r`. + +# Arguments + +- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. +- `r`: The radius. +- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, + in which case the returned graph will be a batch of graphs. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> n, r = 10, 0.75; + +julia> x = rand(Float32, 3, n); + +julia> g = radius_graph(x, r) +GNNGraph: + num_nodes = 10 + num_edges = 46 + +julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; + +julia> g = radius_graph(x, r; graph_indicator) +GNNGraph: + num_nodes = 10 + num_edges = 20 + num_graphs = 2 + +``` +# References +Section B paragraphs 1 and 2 of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) +""" +function radius_graph(points::AbstractMatrix, r::AbstractFloat; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + if graph_indicator !== nothing + d, n = size(points) + @assert graph_indicator isa AbstractVector{<:Integer} + @assert length(graph_indicator) == n + + # Make sure that the distance between points in different graphs + # is always larger than r. + dummy_feature = 2r .* reshape(graph_indicator, 1, n) + points = vcat(points, dummy_feature) + end + + balltree = NearestNeighbors.BallTree(points) + + sortres = false + idxs = NearestNeighbors.inrange(balltree, points, r, sortres) + + g = GNNGraph(idxs; dir, graph_indicator, kws...) + if !self_loops + g = remove_self_loops(g) + end + return g +end + +""" + rand_temporal_radius_graph(number_nodes::Int, + number_snapshots::Int, + speed::AbstractFloat, + r::AbstractFloat; + self_loops = false, + dir = :in, + kws...) + +Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. +First, the positions of the nodes are randomly generated in the unit square. Two nodes are connected if their distance is less than a given radius `r`. +Each following snapshot is obtained by applying the same construction to new positions obtained as follows. +For each snapshot, the new positions of the points are determined by applying random independent displacement vectors to the previous positions. The direction of the displacement is chosen uniformly at random and its length is chosen uniformly in `[0, speed]`. Then the connections are recomputed. +If a point happens to move outside the boundary, its position is updated as if it had bounced off the boundary. + +# Arguments + +- `number_nodes`: The number of nodes of each snapshot. +- `number_snapshots`: The number of snapshots. +- `speed`: The speed to update the nodes. +- `r`: The radius of connection. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. + +# Example + +```jldoctest +julia> n, snaps, s, r = 10, 5, 0.1, 1.5; + +julia> tg = rand_temporal_radius_graph(n,snaps,s,r) # complete graph at each snapshot +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [90, 90, 90, 90, 90] + num_snapshots: 5 +``` + +""" +function rand_temporal_radius_graph(number_nodes::Int, + number_snapshots::Int, + speed::AbstractFloat, + r::AbstractFloat; + self_loops = false, + dir = :in, + kws...) + points=rand(2, number_nodes) + tg = Vector{GNNGraph}(undef, number_snapshots) + for t in 1:number_snapshots + tg[t] = radius_graph(points, r; graph_indicator = nothing, self_loops, dir, kws...) + for i in 1:number_nodes + ρ = 2 * speed * rand() - speed + theta=2*pi*rand() + points[1,i]=1-abs(1-(abs(points[1,i]+ρ*cos(theta)))) + points[2,i]=1-abs(1-(abs(points[2,i]+ρ*sin(theta)))) + end + end + return TemporalSnapshotsGNNGraph(tg) +end + + +function _hyperbolic_distance(nodeA::Array{Float64, 1},nodeB::Array{Float64, 1}; ζ::Real) + if nodeA != nodeB + a = cosh(ζ * nodeA[1]) * cosh(ζ * nodeB[1]) + b = sinh(ζ * nodeA[1]) * sinh(ζ * nodeB[1]) + c = cos(pi - abs(pi - abs(nodeA[2] - nodeB[2]))) + d = acosh(a - (b * c)) / ζ + else + d = 0.0 + end + return d +end + +""" + rand_temporal_hyperbolic_graph(number_nodes::Int, + number_snapshots::Int; + α::Real, + R::Real, + speed::Real, + ζ::Real=1, + self_loop = false, + kws...) + +Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. +First, the positions of the nodes are generated with a quasi-uniform distribution (depending on the parameter `α`) in hyperbolic space within a disk of radius `R`. Two nodes are connected if their hyperbolic distance is less than `R`. Each following snapshot is created in order to keep the same initial distribution. + +# Arguments + +- `number_nodes`: The number of nodes of each snapshot. +- `number_snapshots`: The number of snapshots. +- `α`: The parameter that controls the position of the points. If `α=ζ`, the points are uniformly distributed on the disk of radius `R`. If `α>ζ`, the points are more concentrated in the center of the disk. If `α<ζ`, the points are more concentrated at the boundary of the disk. +- `R`: The radius of the disk and of connection. +- `speed`: The speed to update the nodes. +- `ζ`: The parameter that controls the curvature of the disk. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. + +# Example + +```jldoctest +julia> n, snaps, α, R, speed, ζ = 10, 5, 1.0, 4.0, 0.1, 1.0; + +julia> thg = rand_temporal_hyperbolic_graph(n, snaps; α, R, speed, ζ) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [44, 46, 48, 42, 38] + num_snapshots: 5 +``` + +# References +Section D of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) and the paper +[Hyperbolic Geometry of Complex Networks](https://arxiv.org/pdf/1006.5169.pdf) +""" +function rand_temporal_hyperbolic_graph(number_nodes::Int, + number_snapshots::Int; + α::Real, + R::Real, + speed::Real, + ζ::Real=1, + self_loop = false, + kws...) + @assert number_snapshots > 1 "The number of snapshots must be greater than 1" + @assert α > 0 "α must be greater than 0" + + probabilities = rand(number_nodes) + + points = Array{Float64}(undef,2,number_nodes) + points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) + points[2,:].= 2 * pi * rand(number_nodes) + + tg = Vector{GNNGraph}(undef, number_snapshots) + + for time in 1:number_snapshots + adj = zeros(number_nodes,number_nodes) + for i in 1:number_nodes + for j in 1:number_nodes + if !self_loop && i==j + continue + elseif _hyperbolic_distance(points[:,i],points[:,j]; ζ) <= R + adj[i,j] = adj[j,i] = 1 + end + end + end + tg[time] = GNNGraph(adj) + + probabilities .= probabilities .+ (2 * speed * rand(number_nodes) .- speed) + probabilities[probabilities.>1] .= 1 .- (probabilities[probabilities .> 1] .% 1) + probabilities[probabilities.<0] .= abs.(probabilities[probabilities .< 0]) + + points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) + points[2,:].= points[2,:] .+ (2 * speed * rand(number_nodes) .- speed) + end + return TemporalSnapshotsGNNGraph(tg) +end + +[.\GNNGraphs\src\gnngraph.jl] +#=================================== +Define GNNGraph type as a subtype of Graphs.AbstractGraph. +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types +=============================================# + +""" + GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) + GNNGraph(g::GNNGraph; [ndata, edata, gdata]) + +A type representing a graph structure that also stores +feature arrays associated to nodes, edges, and the graph itself. + +The feature arrays are stored in the fields `ndata`, `edata`, and `gdata` +as [`DataStore`](@ref) objects offering a convenient dictionary-like +and namedtuple-like interface. The features can be passed at construction +time or added later. + +A `GNNGraph` can be constructed out of different `data` objects +expressing the connections inside the graph. The internal representation type +is determined by `graph_type`. + +When constructed from another `GNNGraph`, the internal graph representation +is preserved and shared. The node/edge/graph features are retained +as well, unless explicitely set by the keyword arguments +`ndata`, `edata`, and `gdata`. + +A `GNNGraph` can also represent multiple graphs batched togheter +(see [`MLUtils.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). +The field `g.graph_indicator` contains the graph membership +of each node. + +`GNNGraph`s are always directed graphs, therefore each edge is defined +by a source node and a target node (see [`edge_index`](@ref)). +Self loops (edges connecting a node to itself) and multiple edges +(more than one edge between the same pair of nodes) are supported. + +A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most +functionality from that library. + +# Arguments + +- `data`: Some data representing the graph topology. Possible type are + - An adjacency matrix + - An adjacency list. + - A tuple containing the source and target vectors (COO representation) + - A Graphs.jl' graph. +- `graph_type`: A keyword argument that specifies + the underlying representation used by the GNNGraph. + Currently supported values are + - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge + connects the node `source[k]` to node `target[k]`. + Optionally, also edge weights can be given: `(source, target, weights)`. + - `:sparse`. A sparse adjacency matrix representation. + - `:dense`. A dense adjacency matrix representation. + Defaults to `:coo`, currently the most supported type. +- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. + Possible values are `:out` and `:in`. Default `:out`. +- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. +- `graph_indicator`: For batched graphs, a vector containing the graph assignment of each node. Default `nothing`. +- `ndata`: Node features. An array or named tuple of arrays whose last dimension has size `num_nodes`. +- `edata`: Edge features. An array or named tuple of arrays whose last dimension has size `num_edges`. +- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. + +# Examples + +```julia +using GraphNeuralNetworks + +# Construct from adjacency list representation +data = [[2,3], [1,4,5], [1], [2,5], [2,4]] +g = GNNGraph(data) + +# Number of nodes, edges, and batched graphs +g.num_nodes # 5 +g.num_edges # 10 +g.num_graphs # 1 + +# Same graph in COO representation +s = [1,1,2,2,2,3,4,4,5,5] +t = [2,3,1,4,5,3,2,5,2,4] +g = GNNGraph(s, t) + +# From a Graphs' graph +g = GNNGraph(erdos_renyi(100, 20)) + +# Add 2 node feature arrays at creation time +g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes))) + +# Add 1 edge feature array, after the graph creation +g.edata.z = rand(16, g.num_edges) + +# Add node features and edge features with default names `x` and `e` +g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges)) + +g.ndata.x # or just g.x +g.edata.e # or just g.e + +# Collect edges' source and target nodes. +# Both source and target are vectors of length num_edges +source, target = edge_index(g) +``` +A `GNNGraph` can be sent to the GPU using e.g. Flux's `gpu` function: +``` +# Send to gpu +using Flux, CUDA +g = g |> Flux.gpu +``` +""" +struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} + graph::T + num_nodes::Int + num_edges::Int + num_graphs::Int + graph_indicator::Union{Nothing, AVecI} # vector of ints or nothing + ndata::DataStore + edata::DataStore + gdata::DataStore +end + +@functor GNNGraph + +function GNNGraph(data::D; + num_nodes = nothing, + graph_indicator = nothing, + graph_type = :coo, + dir = :out, + ndata = nothing, + edata = nothing, + gdata = nothing) where {D <: Union{COO_T, ADJMAT_T, ADJLIST_T}} + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) + end + + num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 + + ndata = normalize_graphdata(ndata, default_name = :x, n = num_nodes) + edata = normalize_graphdata(edata, default_name = :e, n = num_edges, + duplicate_if_needed = true) + + # don't force the shape of the data when there is only one graph + gdata = normalize_graphdata(gdata, default_name = :u, + n = num_graphs > 1 ? num_graphs : -1) + + GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) +end + +GNNGraph(; kws...) = GNNGraph(0; kws...) + +function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer} + s, t = T[], T[] + return GNNGraph(s, t; num_nodes, kws...) +end + +Base.zero(::Type{G}) where {G <: GNNGraph} = G(0) + +# COO convenience constructors +function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) + GNNGraph((s, t, v); kws...) +end +GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) + +# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...) + +function GNNGraph(g::AbstractGraph; edge_weight = nothing, kws...) + s = Graphs.src.(Graphs.edges(g)) + t = Graphs.dst.(Graphs.edges(g)) + w = edge_weight + if !Graphs.is_directed(g) + # add reverse edges since GNNGraph is directed + s, t = [s; t], [t; s] + if !isnothing(w) + @assert length(w) == Graphs.ne(g) "edge_weight must have length equal to the number of undirected edges" + w = [w; w] + end + end + num_nodes::Int = Graphs.nv(g) + GNNGraph((s, t, w); num_nodes = num_nodes, kws...) +end + +function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata, + graph_type = nothing) + ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes) + edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges, + duplicate_if_needed = true) + gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs) + + if !isnothing(graph_type) + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(g.graph; g.num_nodes) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(g.graph; g.num_nodes) + end + @assert num_nodes == g.num_nodes + @assert num_edges == g.num_edges + else + graph = g.graph + end + return GNNGraph(graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, gdata) +end + +""" + copy(g::GNNGraph; deep=false) + +Create a copy of `g`. If `deep` is `true`, then copy will be a deep copy (equivalent to `deepcopy(g)`), +otherwise it will be a shallow copy with the same underlying graph data. +""" +function Base.copy(g::GNNGraph; deep = false) + if deep + GNNGraph(deepcopy(g.graph), + g.num_nodes, g.num_edges, g.num_graphs, + deepcopy(g.graph_indicator), + deepcopy(g.ndata), deepcopy(g.edata), deepcopy(g.gdata)) + else + GNNGraph(g.graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) + end +end + +function print_feature(io::IO, feature) + if !isempty(feature) + if length(keys(feature)) == 1 + k = first(keys(feature)) + v = first(values(feature)) + print(io, "$(k): $(dims2string(size(v)))") + else + print(io, "(") + for (i, (k, v)) in enumerate(pairs(feature)) + print(io, "$k: $(dims2string(size(v)))") + if i == length(feature) + print(io, ")") + else + print(io, ", ") + end + end + end + end +end + +function print_all_features(io::IO, feat1, feat2, feat3) + n1 = length(feat1) + n2 = length(feat2) + n3 = length(feat3) + if n1 == 0 && n2 == 0 && n3 == 0 + print(io, "no") + elseif n1 != 0 && (n2 != 0 || n3 != 0) + print_feature(io, feat1) + print(io, ", ") + elseif n2 == 0 && n3 == 0 + print_feature(io, feat1) + end + if n2 != 0 && n3 != 0 + print_feature(io, feat2) + print(io, ", ") + elseif n2 != 0 && n3 == 0 + print_feature(io, feat2) + end + print_feature(io, feat3) +end + +function Base.show(io::IO, g::GNNGraph) + print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") + print_all_features(io, g.ndata, g.edata, g.gdata) + print(io, " data") +end + +function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph) + if get(io, :compact, false) + print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") + print_all_features(io, g.ndata, g.edata, g.gdata) + print(io, " data") + else + print(io, + "GNNGraph:\n num_nodes: $(g.num_nodes)\n num_edges: $(g.num_edges)") + g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") + if !isempty(g.ndata) + print(io, "\n ndata:") + for k in keys(g.ndata) + print(io, "\n\t$k = $(shortsummary(g.ndata[k]))") + end + end + if !isempty(g.edata) + print(io, "\n edata:") + for k in keys(g.edata) + print(io, "\n\t$k = $(shortsummary(g.edata[k]))") + end + end + if !isempty(g.gdata) + print(io, "\n gdata:") + for k in keys(g.gdata) + print(io, "\n\t$k = $(shortsummary(g.gdata[k]))") + end + end + end +end + +MLUtils.numobs(g::GNNGraph) = g.num_graphs +MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i) + +######################### + +function Base.:(==)(g1::GNNGraph, g2::GNNGraph) + g1 === g2 && return true + for k in fieldnames(typeof(g1)) + k === :graph_indicator && continue + getfield(g1, k) != getfield(g2, k) && return false + end + return true +end + +function Base.hash(g::T, h::UInt) where {T <: GNNGraph} + fs = (getfield(g, k) for k in fieldnames(T) if k !== :graph_indicator) + return foldl((h, f) -> hash(f, h), fs, init = hash(T, h)) +end + +function Base.getproperty(g::GNNGraph, s::Symbol) + if s in fieldnames(GNNGraph) + return getfield(g, s) + end + if (s in keys(g.ndata)) + (s in keys(g.edata)) + (s in keys(g.gdata)) > 1 + throw(ArgumentError("Ambiguous property name $s")) + end + if s in keys(g.ndata) + return g.ndata[s] + elseif s in keys(g.edata) + return g.edata[s] + elseif s in keys(g.gdata) + return g.gdata[s] + else + throw(ArgumentError("$(s) is not a field of GNNGraph")) + end +end + +[.\GNNGraphs\src\GNNGraphs.jl] +module GNNGraphs + +using SparseArrays +using Functors: @functor +import Graphs +using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, + has_self_loops, is_directed +import NearestNeighbors +import NNlib +import StatsBase +import KrylovKit +using ChainRulesCore +using LinearAlgebra, Random, Statistics +import MLUtils +using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like +import Functors +using MLDataDevices: get_device, cpu_device, CPUDevice + +include("chainrules.jl") # hacks for differentiability + +include("datastore.jl") +export DataStore + +include("abstracttypes.jl") +export AbstractGNNGraph + +include("gnngraph.jl") +export GNNGraph, + node_features, + edge_features, + graph_features + +include("gnnheterograph.jl") +export GNNHeteroGraph, + num_edge_types, + num_node_types, + edge_type_subgraph + +include("temporalsnapshotsgnngraph.jl") +export TemporalSnapshotsGNNGraph, + add_snapshot, + # add_snapshot!, + remove_snapshot + # remove_snapshot! + +include("query.jl") +export adjacency_list, + edge_index, + get_edge_weight, + graph_indicator, + has_multi_edges, + is_directed, + is_bidirected, + normalized_laplacian, + scaled_laplacian, + laplacian_lambda_max, +# from Graphs + adjacency_matrix, + degree, + has_self_loops, + has_isolated_nodes, + inneighbors, + outneighbors, + khop_adj + +include("transform.jl") +export add_nodes, + add_edges, + add_self_loops, + getgraph, + negative_sample, + rand_edge_split, + remove_self_loops, + remove_edges, + remove_multi_edges, + set_edge_weight, + to_bidirected, + to_unidirected, + random_walk_pe, + perturb_edges, + remove_nodes, + ppr_diffusion, +# from MLUtils + batch, + unbatch, +# from SparseArrays + blockdiag + +include("generate.jl") +export rand_graph, + rand_heterograph, + rand_bipartite_heterograph, + knn_graph, + radius_graph, + rand_temporal_radius_graph, + rand_temporal_hyperbolic_graph + +include("sampling.jl") +export sample_neighbors + +include("operators.jl") +# Base.intersect + +include("convert.jl") +include("utils.jl") +export sort_edge_index, color_refinement + +include("gatherscatter.jl") +# _gather, _scatter + +include("mldatasets.jl") +export mldataset2gnngraph + +end #module + +[.\GNNGraphs\src\gnnheterograph.jl] + +const EType = Tuple{Symbol, Symbol, Symbol} +const NType = Symbol +const EDict{T} = Dict{EType, T} +const NDict{T} = Dict{NType, T} + +""" + GNNHeteroGraph(data; [ndata, edata, gdata, num_nodes]) + GNNHeteroGraph(pairs...; [ndata, edata, gdata, num_nodes]) + +A type representing a heterogeneous graph structure. +It is similar to [`GNNGraph`](@ref) but nodes and edges are of different types. + +# Constructor Arguments + +- `data`: A dictionary or an iterable object that maps `(source_type, edge_type, target_type)` + triples to `(source, target)` index vectors (or to `(source, target, weight)` if also edge weights are present). +- `pairs`: Passing multiple relations as pairs is equivalent to passing `data=Dict(pairs...)`. +- `ndata`: Node features. A dictionary of arrays or named tuple of arrays. + The size of the last dimension of each array must be given by `g.num_nodes`. +- `edata`: Edge features. A dictionary of arrays or named tuple of arrays. Default `nothing`. + The size of the last dimension of each array must be given by `g.num_edges`. Default `nothing`. +- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. Default `nothing`. +- `num_nodes`: The number of nodes for each type. If not specified, inferred from `data`. Default `nothing`. + +# Fields + +- `graph`: A dictionary that maps (source_type, edge_type, target_type) triples to (source, target) index vectors. +- `num_nodes`: The number of nodes for each type. +- `num_edges`: The number of edges for each type. +- `ndata`: Node features. +- `edata`: Edge features. +- `gdata`: Graph features. +- `ntypes`: The node types. +- `etypes`: The edge types. + +# Examples + +```julia +julia> using GraphNeuralNetworks + +julia> nA, nB = 10, 20; + +julia> num_nodes = Dict(:A => nA, :B => nB); + +julia> edges1 = (rand(1:nA, 20), rand(1:nB, 20)) +([4, 8, 6, 3, 4, 7, 2, 7, 3, 2, 3, 4, 9, 4, 2, 9, 10, 1, 3, 9], [6, 4, 20, 8, 16, 7, 12, 16, 5, 4, 6, 20, 11, 19, 17, 9, 12, 2, 18, 12]) + +julia> edges2 = (rand(1:nB, 30), rand(1:nA, 30)) +([17, 5, 2, 4, 5, 3, 8, 7, 9, 7 … 19, 8, 20, 7, 16, 2, 9, 15, 8, 13], [1, 1, 3, 1, 1, 3, 2, 7, 4, 4 … 7, 10, 6, 3, 4, 9, 1, 5, 8, 5]) + +julia> data = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2); + +julia> hg = GNNHeteroGraph(data; num_nodes) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 20) + num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) + +julia> hg.num_edges +Dict{Tuple{Symbol, Symbol, Symbol}, Int64} with 2 entries: +(:A, :rel1, :B) => 20 +(:B, :rel2, :A) => 30 + +# Let's add some node features +julia> ndata = Dict(:A => (x = rand(2, nA), y = rand(3, num_nodes[:A])), + :B => rand(10, nB)); + +julia> hg = GNNHeteroGraph(data; num_nodes, ndata) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 20) + num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) + ndata: + :A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64}) + :B => x = 10×20 Matrix{Float64} + +# Access features of nodes of type :A +julia> hg.ndata[:A].x +2×10 Matrix{Float64}: + 0.825882 0.0797502 0.245813 0.142281 0.231253 0.685025 0.821457 0.888838 0.571347 0.53165 + 0.631286 0.316292 0.705325 0.239211 0.533007 0.249233 0.473736 0.595475 0.0623298 0.159307 +``` + +See also [`GNNGraph`](@ref) for a homogeneous graph type and [`rand_heterograph`](@ref) for a function to generate random heterographs. +""" +struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} + graph::EDict{T} + num_nodes::NDict{Int} + num_edges::EDict{Int} + num_graphs::Int + graph_indicator::Union{Nothing, NDict} + ndata::NDict{DataStore} + edata::EDict{DataStore} + gdata::DataStore + ntypes::Vector{NType} + etypes::Vector{EType} +end + +@functor GNNHeteroGraph + +GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...) +GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...) + +GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}()) + +function GNNHeteroGraph(data::Dict; kws...) + all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`")) + return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...) +end + +function GNNHeteroGraph(data::EDict; + num_nodes = nothing, + graph_indicator = nothing, + graph_type = :coo, + dir = :out, + ndata = nothing, + edata = nothing, + gdata = (;)) + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + @assert graph_type==:coo "only :coo graph_type is supported for now" + + if num_nodes !== nothing + num_nodes = Dict(num_nodes) + end + + ntypes = union([[k[1] for k in keys(data)]; [k[3] for k in keys(data)]]) + etypes = collect(keys(data)) + + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) + end + + num_graphs = !isnothing(graph_indicator) ? + maximum([maximum(gi) for gi in values(graph_indicator)]) : 1 + + + if length(keys(graph)) == 0 + ndata = Dict{Symbol, DataStore}() + edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}() + gdata = DataStore() + else + ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) + edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, + duplicate_if_needed = true) + gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) + end + + return GNNHeteroGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata, + ntypes, etypes) +end + +function show_sorted_dict(io::IO, d::Dict, compact::Bool) + # if compact + print(io, "Dict") + # end + print(io, "(") + if !isempty(d) + _keys = sort!(collect(keys(d))) + for key in _keys[1:end-1] + print(io, "$(_str(key)) => $(d[key]), ") + end + print(io, "$(_str(_keys[end])) => $(d[_keys[end]])") + end + # if length(d) == 1 + # print(io, ",") + # end + print(io, ")") +end + +function Base.show(io::IO, g::GNNHeteroGraph) + print(io, "GNNHeteroGraph(") + show_sorted_dict(io, g.num_nodes, true) + print(io, ", ") + show_sorted_dict(io, g.num_edges, true) + print(io, ")") +end + +function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph) + if get(io, :compact, false) + print(io, "GNNHeteroGraph(") + show_sorted_dict(io, g.num_nodes, true) + print(io, ", ") + show_sorted_dict(io, g.num_edges, true) + print(io, ")") + else + print(io, "GNNHeteroGraph:\n num_nodes: ") + show_sorted_dict(io, g.num_nodes, false) + print(io, "\n num_edges: ") + show_sorted_dict(io, g.num_edges, false) + g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") + if !isempty(g.ndata) && !all(isempty, values(g.ndata)) + print(io, "\n ndata:") + for k in sort(collect(keys(g.ndata))) + isempty(g.ndata[k]) && continue + print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))") + end + end + if !isempty(g.edata) && !all(isempty, values(g.edata)) + print(io, "\n edata:") + for k in sort(collect(keys(g.edata))) + isempty(g.edata[k]) && continue + print(io, "\n\t$k => $(shortsummary(g.edata[k]))") + end + end + if !isempty(g.gdata) + print(io, "\n gdata:\n\t") + shortsummary(io, g.gdata) + end + end +end + +_str(s::Symbol) = ":$s" +_str(s) = "$s" + +MLUtils.numobs(g::GNNHeteroGraph) = g.num_graphs +# MLUtils.getobs(g::GNNHeteroGraph, i) = getgraph(g, i) + + +""" + num_edge_types(g) + +Return the number of edge types in the graph. For [`GNNGraph`](@ref)s, this is always 1. +For [`GNNHeteroGraph`](@ref)s, this is the number of unique edge types. +""" +num_edge_types(g::GNNGraph) = 1 + +num_edge_types(g::GNNHeteroGraph) = length(g.etypes) + +""" + num_node_types(g) + +Return the number of node types in the graph. For [`GNNGraph`](@ref)s, this is always 1. +For [`GNNHeteroGraph`](@ref)s, this is the number of unique node types. +""" +num_node_types(g::GNNGraph) = 1 + +num_node_types(g::GNNHeteroGraph) = length(g.ntypes) + +""" + edge_type_subgraph(g::GNNHeteroGraph, edge_ts) + +Return a subgraph of `g` that contains only the edges of type `edge_ts`. +Edge types can be specified as a single edge type (i.e. a tuple containing 3 symbols) or a vector of edge types. +""" +edge_type_subgraph(g::GNNHeteroGraph, edge_t::EType) = edge_type_subgraph(g, [edge_t]) + +function edge_type_subgraph(g::GNNHeteroGraph, edge_ts::AbstractVector{<:EType}) + for edge_t in edge_ts + @assert edge_t in g.etypes "Edge type $(edge_t) not found in graph" + end + node_ts = _ntypes_from_edges(edge_ts) + graph = Dict([edge_t => g.graph[edge_t] for edge_t in edge_ts]...) + num_nodes = Dict([node_t => g.num_nodes[node_t] for node_t in node_ts]...) + num_edges = Dict([edge_t => g.num_edges[edge_t] for edge_t in edge_ts]...) + if g.graph_indicator === nothing + graph_indicator = nothing + else + graph_indicator = Dict([node_t => g.graph_indicator[node_t] for node_t in node_ts]...) + end + ndata = Dict([node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata)]...) + edata = Dict([edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata)]...) + + return GNNHeteroGraph(graph, num_nodes, num_edges, g.num_graphs, + graph_indicator, ndata, edata, g.gdata, + node_ts, edge_ts) +end + +# TODO this is not correct but Zygote cannot differentiate +# through dictionary generation +# @non_differentiable edge_type_subgraph(::Any...) + +function _ntypes_from_edges(edge_ts::AbstractVector{<:EType}) + ntypes = Symbol[] + for edge_t in edge_ts + node1_t, _, node2_t = edge_t + !in(node1_t, ntypes) && push!(ntypes, node1_t) + !in(node2_t, ntypes) && push!(ntypes, node2_t) + end + return ntypes +end + +@non_differentiable _ntypes_from_edges(::Any...) + +function Base.getindex(g::GNNHeteroGraph, node_t::NType) + return g.ndata[node_t] +end + +Base.getindex(g::GNNHeteroGraph, n1_t::Symbol, rel::Symbol, n2_t::Symbol) = g[(n1_t, rel, n2_t)] + +function Base.getindex(g::GNNHeteroGraph, edge_t::EType) + return g.edata[edge_t] +end + +[.\GNNGraphs\src\mldatasets.jl] +# We load a Graph Dataset from MLDatasets without explicitly depending on it + +""" + mldataset2gnngraph(dataset) + +Convert a graph dataset from the package MLDatasets.jl into one or many [`GNNGraph`](@ref)s. + +# Examples + +```jldoctest +julia> using MLDatasets, GraphNeuralNetworks + +julia> mldataset2gnngraph(Cora()) +GNNGraph: + num_nodes = 2708 + num_edges = 10556 + ndata: + features => 1433×2708 Matrix{Float32} + targets => 2708-element Vector{Int64} + train_mask => 2708-element BitVector + val_mask => 2708-element BitVector + test_mask => 2708-element BitVector +``` +""" +function mldataset2gnngraph(dataset::D) where {D} + @assert hasproperty(dataset, :graphs) + graphs = mlgraph2gnngraph.(dataset.graphs) + if length(graphs) == 1 + return graphs[1] + else + return graphs + end +end + +function mlgraph2gnngraph(g::G) where {G} + @assert hasproperty(g, :num_nodes) + @assert hasproperty(g, :edge_index) + @assert hasproperty(g, :node_data) + @assert hasproperty(g, :edge_data) + return GNNGraph(g.edge_index; ndata = g.node_data, edata = g.edge_data, g.num_nodes) +end + +[.\GNNGraphs\src\operators.jl] +# 2 or more args graph operators +"""" + intersect(g1::GNNGraph, g2::GNNGraph) + +Intersect two graphs by keeping only the common edges. +""" +function Base.intersect(g1::GNNGraph, g2::GNNGraph) + @assert g1.num_nodes == g2.num_nodes + @assert graph_type_symbol(g1) == graph_type_symbol(g2) + graph_type = graph_type_symbol(g1) + num_nodes = g1.num_nodes + + idx1, _ = edge_encoding(edge_index(g1)..., num_nodes) + idx2, _ = edge_encoding(edge_index(g2)..., num_nodes) + idx = intersect(idx1, idx2) + s, t = edge_decoding(idx, num_nodes) + return GNNGraph(s, t; num_nodes, graph_type) +end + +[.\GNNGraphs\src\query.jl] + +""" + edge_index(g::GNNGraph) + +Return a tuple containing two vectors, respectively storing +the source and target nodes for each edges in `g`. + +```julia +s, t = edge_index(g) +``` +""" +edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2] + +edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2] + +""" + edge_index(g::GNNHeteroGraph, [edge_t]) + +Return a tuple containing two vectors, respectively storing the source and target nodes +for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`. + +If `edge_t` is not provided, it will error if `g` has more than one edge type. +""" +edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2] +edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2] + +get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] + +get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][3] + +get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3] + +Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...) + +Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)} + +# """ +# eltype(g::GNNGraph) +# +# Type of nodes in `g`, +# an integer type like `Int`, `Int32`, `Uint16`, .... +# """ +function Base.eltype(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + return w !== nothing ? eltype(w) : eltype(s) +end + +Base.eltype(g::GNNGraph{<:ADJMAT_T}) = eltype(g.graph) + +function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer) + s, t = edge_index(g) + return any((s .== i) .& (t .== j)) +end + +Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i, j] != 0 + +""" + has_edge(g::GNNHeteroGraph, edge_t, i, j) + +Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`. + +# Examples + +```jldoctest +julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) +GNNHeteroGraph: + num_nodes: (:A => 2, :B => 2) + num_edges: ((:A, :to, :B) => 4, (:B, :to, :A) => 0) + +julia> has_edge(g, (:A,:to,:B), 1, 1) +true + +julia> has_edge(g, (:B,:to,:A), 1, 1) +false +``` +""" +function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer) + s, t = edge_index(g, edge_t) + return any((s .== i) .& (t .== j)) +end + +graph_type_symbol(::GNNGraph{<:COO_T}) = :coo +graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse +graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense + +Graphs.nv(g::GNNGraph) = g.num_nodes +Graphs.ne(g::GNNGraph) = g.num_edges +Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes +Graphs.vertices(g::GNNGraph) = 1:(g.num_nodes) + + +""" + neighbors(g::GNNGraph, i::Integer; dir=:out) + +Return the neighbors of node `i` in the graph `g`. +If `dir=:out`, return the neighbors through outgoing edges. +If `dir=:in`, return the neighbors through incoming edges. + +See also [`outneighbors`](@ref Graphs.outneighbors), [`inneighbors`](@ref Graphs.inneighbors). +""" +function Graphs.neighbors(g::GNNGraph, i::Integer; dir::Symbol = :out) + @assert dir ∈ (:in, :out) + if dir == :out + outneighbors(g, i) + else + inneighbors(g, i) + end +end + +""" + outneighbors(g::GNNGraph, i::Integer) + +Return the neighbors of node `i` in the graph `g` through outgoing edges. + +See also [`neighbors`](@ref Graphs.neighbors) and [`inneighbors`](@ref Graphs.inneighbors). +""" +function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return t[s .== i] +end + +function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[i, :]) +end + +""" + inneighbors(g::GNNGraph, i::Integer) + +Return the neighbors of node `i` in the graph `g` through incoming edges. + +See also [`neighbors`](@ref Graphs.neighbors) and [`outneighbors`](@ref Graphs.outneighbors). +""" +function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return s[t .== i] +end + +function Graphs.inneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[:, i]) +end + +Graphs.is_directed(::GNNGraph) = true +Graphs.is_directed(::Type{<:GNNGraph}) = true + +""" + adjacency_list(g; dir=:out) + adjacency_list(g, nodes; dir=:out) + +Return the adjacency list representation (a vector of vectors) +of the graph `g`. + +Calling `a` the adjacency list, if `dir=:out` than +`a[i]` will contain the neighbors of node `i` through +outgoing edges. If `dir=:in`, it will contain neighbors from +incoming edges instead. + +If `nodes` is given, return the neighborhood of the nodes in `nodes` only. +""" +function adjacency_list(g::GNNGraph, nodes; dir = :out, with_eid = false) + @assert dir ∈ [:out, :in] + s, t = edge_index(g) + if dir == :in + s, t = t, s + end + T = eltype(s) + idict = 0 + dmap = Dict(n => (idict += 1) for n in nodes) + adjlist = [T[] for _ in 1:length(dmap)] + eidlist = [T[] for _ in 1:length(dmap)] + for (eid, (i, j)) in enumerate(zip(s, t)) + inew = get(dmap, i, 0) + inew == 0 && continue + push!(adjlist[inew], j) + push!(eidlist[inew], eid) + end + if with_eid + return adjlist, eidlist + else + return adjlist + end +end + +# function adjacency_list(g::GNNGraph, nodes; dir=:out) +# @assert dir ∈ [:out, :in] +# fneighs = dir == :out ? outneighbors : inneighbors +# return [fneighs(g, i) for i in nodes] +# end + +adjacency_list(g::GNNGraph; dir = :out) = adjacency_list(g, 1:(g.num_nodes); dir) + +""" + adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true) + +Return the adjacency matrix `A` for the graph `g`. + +If `dir=:out`, `A[i,j] > 0` denotes the presence of an edge from node `i` to node `j`. +If `dir=:in` instead, `A[i,j] > 0` denotes the presence of an edge from node `j` to node `i`. + +User may specify the eltype `T` of the returned matrix. + +If `weighted=true`, the `A` will contain the edge weights if any, otherwise the elements of `A` will be either 0 or 1. +""" +function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out, + weighted = true) + if iscuarray(g.graph[1]) + # Revisit after + # https://github.com/JuliaGPU/CUDA.jl/issues/1113 + A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted) + else + A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted) + end + @assert size(A) == (n, n) + return dir == :out ? A : A' +end + +function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g); + dir = :out, weighted = true) + @assert dir ∈ [:in, :out] + A = g.graph + if !weighted + A = binarize(A) + end + A = T != eltype(A) ? T.(A) : A + return dir == :out ? A : A' +end + +function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; + dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}} + A = adjacency_matrix(g, T; dir, weighted) + if !weighted + function adjacency_matrix_pullback_noweight(Δ) + return (NoTangent(), ZeroTangent(), NoTangent()) + end + return A, adjacency_matrix_pullback_noweight + else + function adjacency_matrix_pullback_weighted(Δ) + dg = Tangent{G}(; graph = Δ .* binarize(A)) + return (NoTangent(), dg, NoTangent()) + end + return A, adjacency_matrix_pullback_weighted + end +end + +function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; + dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}} + A = adjacency_matrix(g, T; dir, weighted) + w = get_edge_weight(g) + if !weighted || w === nothing + function adjacency_matrix_pullback_noweight(Δ) + return (NoTangent(), ZeroTangent(), NoTangent()) + end + return A, adjacency_matrix_pullback_noweight + else + function adjacency_matrix_pullback_weighted(Δ) + s, t = edge_index(g) + dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t))) + return (NoTangent(), dg, NoTangent()) + end + return A, adjacency_matrix_pullback_weighted + end +end + +function _get_edge_weight(g, edge_weight::Bool) + if edge_weight === true + return get_edge_weight(g) + elseif edge_weight === false + return nothing + end +end + +_get_edge_weight(g, edge_weight::AbstractVector) = edge_weight + +""" + degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true) + +Return a vector containing the degrees of the nodes in `g`. + +The gradient is propagated through this function only if `edge_weight` is `true` +or a vector. + +# Arguments + +- `g`: A graph. +- `T`: Element type of the returned vector. If `nothing`, is + chosen based on the graph type and will be an integer + if `edge_weight = false`. Default `nothing`. +- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. + For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. +- `edge_weight`: If `true` and the graph contains weighted edges, the degree will + be weighted. Set to `false` instead to just count the number of + outgoing/ingoing edges. + Finally, you can also pass a vector of weights to be used + instead of the graph's own weights. + Default `true`. + +""" +function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT = nothing; dir = :out, + edge_weight = true) where { + TT <: Union{Nothing, Type{<:Number}}} + s, t = edge_index(g) + + ew = _get_edge_weight(g, edge_weight) + + T = if isnothing(T) + if !isnothing(ew) + eltype(ew) + else + eltype(s) + end + else + T + end + return _degree((s, t), T, dir, ew, g.num_nodes) +end + +# TODO:: Make efficient +Graphs.degree(g::GNNGraph, i::Union{Int, AbstractVector}; dir = :out) = degree(g; dir)[i] + +function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out, + edge_weight = true) where {TT<:Union{Nothing, Type{<:Number}}} + + # edge_weight=true or edge_weight=nothing act the same here + @assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations" + @assert dir ∈ (:in, :out, :both) + if T === nothing + Nt = eltype(g) + if edge_weight === false && !(Nt <: Integer) + T = Nt == Float32 ? Int32 : + Nt == Float16 ? Int16 : Int + else + T = Nt + end + end + A = adjacency_matrix(g) + return _degree(A, T, dir, edge_weight, g.num_nodes) +end + +""" + degree(g::GNNHeteroGraph, edge_type::EType; dir = :in) + +Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph +given `edge_type`. + +# Arguments + +- `g`: A graph. +- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type. +- `T`: Element type of the returned vector. If `nothing`, is + chosen based on the graph type. Default `nothing`. +- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. + For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. + Default `dir = :out`. + +""" +function Graphs.degree(g::GNNHeteroGraph, edge::EType, + T::TT = nothing; dir = :out) where { + TT <: Union{Nothing, Type{<:Number}}} + + s, t = edge_index(g, edge) + + T = isnothing(T) ? eltype(s) : T + + n_type = dir == :in ? g.ntypes[2] : g.ntypes[1] + + return _degree((s, t), T, dir, nothing, g.num_nodes[n_type]) +end + +function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int) + _degree((s, t), T, dir, ones_like(s, T), num_nodes) +end + +function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::AbstractVector, num_nodes::Int) + degs = zeros_like(s, T, num_nodes) + + if dir ∈ [:out, :both] + degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (num_nodes,)) + end + if dir ∈ [:in, :both] + degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize = (num_nodes,)) + end + return degs +end + +function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int) + if edge_weight === false + A = binarize(A) + end + A = eltype(A) != T ? T.(A) : A + return dir == :out ? vec(sum(A, dims = 2)) : + dir == :in ? vec(sum(A, dims = 1)) : + vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2)) +end + +function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes) + degs = _degree(graph, T, dir, edge_weight, num_nodes) + function _degree_pullback(Δ) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback +end + +function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes) + degs = _degree(A, T, dir, edge_weight, num_nodes) + if edge_weight === false + function _degree_pullback_noweights(Δ) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback_noweights + else + function _degree_pullback_weights(Δ) + # We propagate the gradient only to the non-zero elements + # of the adjacency matrix. + bA = binarize(A) + if dir == :in + dA = bA .* Δ' + elseif dir == :out + dA = Δ .* bA + else # dir == :both + dA = Δ .* bA + Δ' .* bA + end + return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback_weights + end +end + +""" + has_isolated_nodes(g::GNNGraph; dir=:out) + +Return true if the graph `g` contains nodes with out-degree (if `dir=:out`) +or in-degree (if `dir = :in`) equal to zero. +""" +function has_isolated_nodes(g::GNNGraph; dir = :out) + return any(iszero, degree(g; dir)) +end + +function Graphs.laplacian_matrix(g::GNNGraph, T::DataType = eltype(g); dir::Symbol = :out) + A = adjacency_matrix(g, T; dir = dir) + D = Diagonal(vec(sum(A; dims = 2))) + return D - A +end + +""" + normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `add_self_loops`: add self-loops while calculating the matrix. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + Ã = normalized_adjacency(g, T; dir, add_self_loops) + return I - Ã +end + +function normalized_adjacency(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + A = adjacency_matrix(g, T; dir = dir) + if add_self_loops + A = A + I + end + degs = vec(sum(A; dims = 2)) + ChainRulesCore.ignore_derivatives() do + @assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`." + end + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g, T=Float32; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(g::GNNGraph, T::DataType = Float32; dir = :out) + L = normalized_laplacian(g, T) + # @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" + λmax = _eigmax(L) + return 2 / λmax * L - I +end + +# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays +function _eigmax(A) + x0 = _rand_dense_vector(A) + KrylovKit.eigsolve(Symmetric(A), x0, 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available +end + +_rand_dense_vector(A::AbstractMatrix{T}) where {T} = randn(float(T), size(A, 1)) + +# Eigenvalues for cuarray don't seem to be well supported. +# https://github.com/JuliaGPU/CUDA.jl/issues/154 +# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 + +""" + graph_indicator(g::GNNGraph; edges=false) + +Return a vector containing the graph membership +(an integer from `1` to `g.num_graphs`) of each node in the graph. +If `edges=true`, return the graph membership of each edge instead. +""" +function graph_indicator(g::GNNGraph; edges = false) + if isnothing(g.graph_indicator) + gi = ones_like(edge_index(g)[1], Int, g.num_nodes) + else + gi = g.graph_indicator + end + if edges + s, t = edge_index(g) + return gi[s] + else + return gi + end +end + +""" + graph_indicator(g::GNNHeteroGraph, [node_t]) + +Return a Dict of vectors containing the graph membership +(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type. +If `node_t` is provided, return the graph membership of each node of type `node_t` instead. + +See also [`batch`](@ref). +""" +function graph_indicator(g::GNNHeteroGraph) + return g.graph_indicator +end + +function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) + @assert node_t ∈ g.ntypes + if isnothing(g.graph_indicator) + gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t]) + else + gi = g.graph_indicator[node_t] + end + return gi +end + +function node_features(g::GNNGraph) + if isempty(g.ndata) + return nothing + elseif length(g.ndata) > 1 + @error "Multiple feature arrays, access directly through `g.ndata`" + else + return first(values(g.ndata)) + end +end + +function edge_features(g::GNNGraph) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple feature arrays, access directly through `g.edata`" + else + return first(values(g.edata)) + end +end + +function graph_features(g::GNNGraph) + if isempty(g.gdata) + return nothing + elseif length(g.gdata) > 1 + @error "Multiple feature arrays, access directly through `g.gdata`" + else + return first(values(g.gdata)) + end +end + +""" + is_bidirected(g::GNNGraph) + +Check if the directed graph `g` essentially corresponds +to an undirected graph, i.e. if for each edge it also contains the +reverse edge. +""" +function is_bidirected(g::GNNGraph) + s, t = edge_index(g) + s1, t1 = sort_edge_index(s, t) + s2, t2 = sort_edge_index(t, s) + all((s1 .== s2) .& (t1 .== t2)) +end + +""" + has_self_loops(g::GNNGraph) + +Return `true` if `g` has any self loops. +""" +function Graphs.has_self_loops(g::GNNGraph) + s, t = edge_index(g) + any(s .== t) +end + +""" + has_multi_edges(g::GNNGraph) + +Return `true` if `g` has any multiple edges. +""" +function has_multi_edges(g::GNNGraph) + s, t = edge_index(g) + idxs, _ = edge_encoding(s, t, g.num_nodes) + length(union(idxs)) < length(idxs) +end + +""" + khop_adj(g::GNNGraph,k::Int,T::DataType=eltype(g); dir=:out, weighted=true) + +Return ``A^k`` where ``A`` is the adjacency matrix of the graph 'g'. + +""" +function khop_adj(g::GNNGraph, k::Int, T::DataType = eltype(g); dir = :out, weighted = true) + return (adjacency_matrix(g, T; dir, weighted))^k +end + +""" + laplacian_lambda_max(g::GNNGraph, T=Float32; add_self_loops=false, dir=:out) + +Return the largest eigenvalue of the normalized symmetric Laplacian of the graph `g`. + +If the graph is batched from multiple graphs, return the list of the largest eigenvalue for each graph. +""" +function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + if g.num_graphs == 1 + return _eigmax(normalized_laplacian(g, T; add_self_loops, dir)) + else + eigenvalues = zeros(g.num_graphs) + for i in 1:(g.num_graphs) + eigenvalues[i] = _eigmax(normalized_laplacian(getgraph(g, i), T; add_self_loops, + dir)) + end + return eigenvalues + end +end + +@non_differentiable edge_index(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable graph_indicator(x...) +@non_differentiable has_multi_edges(x...) +@non_differentiable Graphs.has_self_loops(x...) +@non_differentiable is_bidirected(x...) +@non_differentiable normalized_adjacency(x...) # TODO remove this in the future +@non_differentiable normalized_laplacian(x...) # TODO remove this in the future +@non_differentiable scaled_laplacian(x...) # TODO remove this in the future + +[.\GNNGraphs\src\sampling.jl] +""" + sample_neighbors(g, nodes, K=-1; dir=:in, replace=false, dropnodes=false) + +Sample neighboring edges of the given nodes and return the induced subgraph. +For each node, a number of inbound (or outbound when `dir = :out``) edges will be randomly chosen. +If `dropnodes=false`, the graph returned will then contain all the nodes in the original graph, +but only the sampled edges. + +The returned graph will contain an edge feature `EID` corresponding to the id of the edge +in the original graph. If `dropnodes=true`, it will also contain a node feature `NID` with +the node ids in the original graph. + +# Arguments + +- `g`. The graph. +- `nodes`. A list of node IDs to sample neighbors from. +- `K`. The maximum number of edges to be sampled for each node. + If -1, all the neighboring edges will be selected. +- `dir`. Determines whether to sample inbound (`:in`) or outbound (``:out`) edges (Default `:in`). +- `replace`. If `true`, sample with replacement. +- `dropnodes`. If `true`, the resulting subgraph will contain only the nodes involved in the sampled edges. + +# Examples + +```julia +julia> g = rand_graph(20, 100) +GNNGraph: + num_nodes = 20 + num_edges = 100 + +julia> sample_neighbors(g, 2:3) +GNNGraph: + num_nodes = 20 + num_edges = 9 + edata: + EID => (9,) + +julia> sg = sample_neighbors(g, 2:3, dropnodes=true) +GNNGraph: + num_nodes = 10 + num_edges = 9 + ndata: + NID => (10,) + edata: + EID => (9,) + +julia> sg.ndata.NID +10-element Vector{Int64}: + 2 + 3 + 17 + 14 + 18 + 15 + 16 + 20 + 7 + 10 + +julia> sample_neighbors(g, 2:3, 5, replace=true) +GNNGraph: + num_nodes = 20 + num_edges = 10 + edata: + EID => (10,) +``` +""" +function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; + dir = :in, replace = false, dropnodes = false) + @assert dir ∈ (:in, :out) + _, eidlist = adjacency_list(g, nodes; dir, with_eid = true) + for i in 1:length(eidlist) + if replace + k = K > 0 ? K : length(eidlist[i]) + else + k = K > 0 ? min(length(eidlist[i]), K) : length(eidlist[i]) + end + eidlist[i] = StatsBase.sample(eidlist[i], k; replace) + end + eids = reduce(vcat, eidlist) + s, t = edge_index(g) + w = get_edge_weight(g) + s = s[eids] + t = t[eids] + w = isnothing(w) ? nothing : w[eids] + + edata = getobs(g.edata, eids) + edata.EID = eids + + num_edges = length(eids) + + if !dropnodes + graph = (s, t, w) + + gnew = GNNGraph(graph, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) + else + nodes_other = dir == :in ? setdiff(s, nodes) : setdiff(t, nodes) + nodes_all = [nodes; nodes_other] + nodemap = Dict(n => i for (i, n) in enumerate(nodes_all)) + s = [nodemap[s] for s in s] + t = [nodemap[t] for t in t] + graph = (s, t, w) + graph_indicator = g.graph_indicator !== nothing ? g.graph_indicator[nodes_all] : + nothing + num_nodes = length(nodes_all) + ndata = getobs(g.ndata, nodes_all) + ndata.NID = nodes_all + + gnew = GNNGraph(graph, + num_nodes, num_edges, g.num_graphs, + graph_indicator, + ndata, edata, g.gdata) + end + return gnew +end + +[.\GNNGraphs\src\temporalsnapshotsgnngraph.jl] +""" + TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) + +A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref). + +`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object, +and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features. +The features can be passed at construction time or added later. + +# Constructor Arguments + +- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes. + +# Examples + +```julia +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10,20) for i in 1:5]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + +julia> tg.tgdata.x = rand(4); # add temporal graph feature + +julia> tg # show temporal graph with new feature +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + tgdata: + x = 4-element Vector{Float64} +``` +""" +struct TemporalSnapshotsGNNGraph + num_nodes::AbstractVector{Int} + num_edges::AbstractVector{Int} + num_snapshots::Int + snapshots::AbstractVector{<:GNNGraph} + tgdata::DataStore +end + +function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) + @assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes" + return TemporalSnapshotsGNNGraph( + [s.num_nodes for s in snapshots], + [s.num_edges for s in snapshots], + length(snapshots), + snapshots, + DataStore() + ) +end + +function Base.:(==)(tsg1::TemporalSnapshotsGNNGraph, tsg2::TemporalSnapshotsGNNGraph) + tsg1 === tsg2 && return true + for k in fieldnames(typeof(tsg1)) + getfield(tsg1, k) != getfield(tsg2, k) && return false + end + return true +end + +function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int) + return tg.snapshots[t] +end + +function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector) + return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata) +end + +""" + add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + +Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the snapshot `g` at time index `t`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10, 20) for i in 1:5]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + +julia> new_tg = add_snapshot(tg, 3, rand_graph(10, 16)) # add a new snapshot at time 3 +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10, 10] + num_edges: [20, 20, 16, 20, 20, 20] + num_snapshots: 6 +``` +""" +function add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + if tg.num_snapshots > 0 + @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" + end + @assert t <= tg.num_snapshots + 1 "cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots" + num_nodes = tg.num_nodes |> copy + num_edges = tg.num_edges |> copy + snapshots = tg.snapshots |> copy + num_snapshots = tg.num_snapshots + 1 + insert!(num_nodes, t, g.num_nodes) + insert!(num_edges, t, g.num_edges) + insert!(snapshots, t, g) + return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) +end + +# """ +# add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + +# Add to `tg` the snapshot `g` at time index `t`. + +# See also [`add_snapshot`](@ref) for a non-mutating version. +# """ +# function add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) +# if t > tg.num_snapshots + 1 +# error("cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots") +# end +# if tg.num_snapshots > 0 +# @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" +# end +# insert!(tg.num_nodes, t, g.num_nodes) +# insert!(tg.num_edges, t, g.num_edges) +# insert!(tg.snapshots, t, g) +# return tg +# end + +""" + remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) + +Return a [`TemporalSnapshotsGNNGraph`](@ref) created starting from `tg` by removing the snapshot at time index `t`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10] + num_edges: [20, 14, 22] + num_snapshots: 3 + +julia> new_tg = remove_snapshot(tg, 2) # remove snapshot at time 2 +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10] + num_edges: [20, 22] + num_snapshots: 2 +``` +""" +function remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) + num_nodes = tg.num_nodes |> copy + num_edges = tg.num_edges |> copy + snapshots = tg.snapshots |> copy + num_snapshots = tg.num_snapshots - 1 + deleteat!(num_nodes, t) + deleteat!(num_edges, t) + deleteat!(snapshots, t) + return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) +end + +# """ +# remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) + +# Remove the snapshot at time index `t` from `tg` and return `tg`. + +# See [`remove_snapshot`](@ref) for a non-mutating version. +# """ +# function remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) +# @assert t <= tg.num_snapshots "snapshot index $t out of bounds" +# tg.num_snapshots -= 1 +# deleteat!(tg.num_nodes, t) +# deleteat!(tg.num_edges, t) +# deleteat!(tg.snapshots, t) +# return tg +# end + +function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol) + if prop ∈ fieldnames(TemporalSnapshotsGNNGraph) + return getfield(tg, prop) + elseif prop == :ndata + return [s.ndata for s in tg.snapshots] + elseif prop == :edata + return [s.edata for s in tg.snapshots] + elseif prop == :gdata + return [s.gdata for s in tg.snapshots] + else + return [getproperty(s,prop) for s in tg.snapshots] + end +end + +function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph) + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") + print_feature_t(io, tsg.tgdata) + print(io, " data") +end + +function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph) + if get(io, :compact, false) + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") + print_feature_t(io, tsg.tgdata) + print(io, " data") + else + print(io, + "TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)") + if !isempty(tsg.tgdata) + print(io, "\n tgdata:") + for k in keys(tsg.tgdata) + print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))") + end + end + end +end + +function print_feature_t(io::IO, feature) + if !isempty(feature) + if length(keys(feature)) == 1 + k = first(keys(feature)) + v = first(values(feature)) + print(io, "$(k): $(dims2string(size(v)))") + else + print(io, "(") + for (i, (k, v)) in enumerate(pairs(feature)) + print(io, "$k: $(dims2string(size(v)))") + if i == length(feature) + print(io, ")") + else + print(io, ", ") + end + end + end + else + print(io, "no") + end +end + +@functor TemporalSnapshotsGNNGraph + +[.\GNNGraphs\src\transform.jl] + +""" + add_self_loops(g::GNNGraph) + +Return a graph with the same features as `g` +but also adding edges connecting the nodes to themselves. + +Nodes with already existing self-loops will obtain a second self-loop. + +If the graphs has edge weights, the new edges will have weight 1. +""" +function add_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + @assert isempty(g.edata) + ew = get_edge_weight(g) + n = g.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + if ew !== nothing + ew = [ew; fill!(similar(ew, n), 1)] + end + + return GNNGraph((s, t, ew), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +function add_self_loops(g::GNNGraph{<:ADJMAT_T}) + A = g.graph + @assert isempty(g.edata) + num_edges = g.num_edges + g.num_nodes + A = A + I + return GNNGraph(A, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +""" + add_self_loops(g::GNNHeteroGraph, edge_t::EType) + add_self_loops(g::GNNHeteroGraph) + +If the source node type is the same as the destination node type in `edge_t`, +return a graph with the same features as `g` but also add self-loops +of the specified type, `edge_t`. Otherwise, it returns `g` unchanged. + +Nodes with already existing self-loops of type `edge_t` will obtain +a second set of self-loops of the same type. + +If the graph has edge weights for edges of type `edge_t`, the new edges will have weight 1. + +If no edges of type `edge_t` exist, or all existing edges have no weight, +then all new self loops will have no weight. + +If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. +This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. +""" +function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + get(g.graph, edge_t, (nothing, nothing, nothing))[3] + end + + src_t, _, tgt_t = edge_t + (src_t === tgt_t) || + return g + + n = get(g.num_nodes, src_t, 0) + + if haskey(g.graph, edge_t) + s, t = g.graph[edge_t][1:2] + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + else + if !isempty(g.graph) + T = typeof(first(values(g.graph))[1]) + nodes = convert(T, [1:n;]) + else + nodes = [1:n;] + end + s = nodes + t = nodes + end + + graph = g.graph |> copy + ew = get(g.graph, edge_t, (nothing, nothing, nothing))[3] + + if ew !== nothing + ew = [ew; fill!(similar(ew, n), 1)] + end + + graph[edge_t] = (s, t, ew) + edata = g.edata |> copy + ndata = g.ndata |> copy + ntypes = g.ntypes |> copy + etypes = g.etypes |> copy + num_nodes = g.num_nodes |> copy + num_edges = g.num_edges |> copy + num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1]) + + return GNNHeteroGraph(graph, + num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata, + ntypes, etypes) +end + +function add_self_loops(g::GNNHeteroGraph) + for edge_t in keys(g.graph) + g = add_self_loops(g, edge_t) + end + return g +end + +""" + remove_self_loops(g::GNNGraph) + +Return a graph constructed from `g` where self-loops (edges from a node to itself) +are removed. + +See also [`add_self_loops`](@ref) and [`remove_multi_edges`](@ref). +""" +function remove_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + edata = getobs(edata, mask_old_loops) + w = isnothing(w) ? nothing : getobs(w, mask_old_loops) + + GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) + @assert isempty(g.edata) + A = g.graph + A[diagind(A)] .= 0 + if A isa AbstractSparseMatrix + dropzeros!(A) + end + num_edges = numnonzeros(A) + return GNNGraph(A, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +""" + remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) + remove_edges(g::GNNGraph, p=0.5) + +Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability. + +# Arguments +- `g`: The input graph from which edges will be removed. +- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method. +- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5. + +# Returns +A new GNNGraph with the specified edges removed. + +# Example +```julia +julia> using GraphNeuralNetworks + +# Construct a GNNGraph +julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 3 + num_edges: 5 + +# Remove the second edge +julia> g_new = remove_edges(g, [2]); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 4 + +# Remove edges with a probability of 0.5 +julia> g_new = remove_edges(g, 0.5); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 2 +``` +""" +function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + + mask_to_keep = trues(length(s)) + + mask_to_keep[edges_to_remove] .= false + + s = s[mask_to_keep] + t = t[mask_to_keep] + edata = getobs(edata, mask_to_keep) + w = isnothing(w) ? nothing : getobs(w, mask_to_keep) + + return GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + + +function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) + num_edges = g.num_edges + edges_to_remove = filter(_ -> rand() < p, 1:num_edges) + return remove_edges(g, edges_to_remove) +end + +""" + remove_multi_edges(g::GNNGraph; aggr=+) + +Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. +Possible edge features are aggregated according to `aggr`, that can take value +`+`,`min`, `max` or `mean`. + +See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). +""" +function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + num_edges = g.num_edges + idxs, idxmax = edge_encoding(s, t, g.num_nodes) + + perm = sortperm(idxs) + idxs = idxs[perm] + s, t = s[perm], t[perm] + edata = getobs(edata, perm) + w = isnothing(w) ? nothing : getobs(w, perm) + idxs = [-1; idxs] + mask = idxs[2:end] .> idxs[1:(end - 1)] + if !all(mask) + s, t = s[mask], t[mask] + idxs = similar(s, num_edges) + idxs .= 1:num_edges + idxs .= idxs .- cumsum(.!mask) + num_edges = length(s) + w = _scatter(aggr, w, idxs, num_edges) + edata = _scatter(aggr, edata, idxs, num_edges) + end + + return GNNGraph((s, t, w), + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +""" + remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector) + +Remove specified nodes, and their associated edges, from a GNNGraph. This operation reindexes the remaining nodes to maintain a continuous sequence of node indices, starting from 1. Similarly, edges are reindexed to account for the removal of edges connected to the removed nodes. + +# Arguments +- `g`: The input graph from which nodes (and their edges) will be removed. +- `nodes_to_remove`: Vector of node indices to be removed. + +# Returns +A new GNNGraph with the specified nodes and all edges associated with these nodes removed. + +# Example +```julia +using GraphNeuralNetworks + +g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) + +# Remove nodes with indices 2 and 3, for example +g_new = remove_nodes(g, [2, 3]) + +# g_new now does not contain nodes 2 and 3, and any edges that were connected to these nodes. +println(g_new) +``` +""" +function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector) + nodes_to_remove = sort(union(nodes_to_remove)) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + ndata = g.ndata + + function find_edges_to_remove(nodes, nodes_to_remove) + return findall(node_id -> begin + idx = searchsortedlast(nodes_to_remove, node_id) + idx >= 1 && idx <= length(nodes_to_remove) && nodes_to_remove[idx] == node_id + end, nodes) + end + + edges_to_remove_s = find_edges_to_remove(s, nodes_to_remove) + edges_to_remove_t = find_edges_to_remove(t, nodes_to_remove) + edges_to_remove = union(edges_to_remove_s, edges_to_remove_t) + + mask_edges_to_keep = trues(length(s)) + mask_edges_to_keep[edges_to_remove] .= false + s = s[mask_edges_to_keep] + t = t[mask_edges_to_keep] + + w = isnothing(w) ? nothing : getobs(w, mask_edges_to_keep) + + for node in sort(nodes_to_remove, rev=true) + s[s .> node] .-= 1 + t[t .> node] .-= 1 + end + + nodes_to_keep = setdiff(1:g.num_nodes, nodes_to_remove) + ndata = getobs(ndata, nodes_to_keep) + edata = getobs(edata, mask_edges_to_keep) + + num_nodes = g.num_nodes - length(nodes_to_remove) + + return GNNGraph((s, t, w), + num_nodes, length(s), g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata) +end + +""" + remove_nodes(g::GNNGraph, p) + +Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`. + +# Examples + +```julia +julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 4 + num_edges: 6 + +julia> g_new = remove_nodes(g, 0.5) +GNNGraph: + num_nodes: 2 + num_edges: 2 +``` +""" +function remove_nodes(g::GNNGraph, p::AbstractFloat) + nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes) + return remove_nodes(g, nodes_to_remove) +end + +""" + add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata]) + add_edges(g::GNNGraph, (s, t); [edata]) + add_edges(g::GNNGraph, (s, t, w); [edata]) + +Add to graph `g` the edges with source nodes `s` and target nodes `t`. +Optionally, pass the edge weight `w` and the features `edata` for the new edges. +Returns a new graph sharing part of the underlying data with `g`. + +If the `s` or `t` contain nodes that are not already present in the graph, +they are added to the graph as well. + +# Examples + +```jldoctest +julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; + +julia> w = Float32[1.0, 2.0, 3.0, 4.0, 5.0]; + +julia> g = GNNGraph((s, t, w)) +GNNGraph: + num_nodes: 4 + num_edges: 5 + +julia> add_edges(g, ([2, 3], [4, 1], [10.0, 20.0])) +GNNGraph: + num_nodes: 4 + num_edges: 7 +``` +```jldoctest +julia> g = GNNGraph() +GNNGraph: + num_nodes: 0 + num_edges: 0 + +julia> add_edges(g, [1,2], [2,3]) +GNNGraph: + num_nodes: 3 + num_edges: 2 +``` +""" +add_edges(g::GNNGraph{<:COO_T}, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, (snew, tnew, nothing); kws...) +add_edges(g, data::Tuple{<:AbstractVector, <:AbstractVector}; kws...) = add_edges(g, (data..., nothing); kws...) + +function add_edges(g::GNNGraph{<:COO_T}, data::COO_T; edata = nothing) + snew, tnew, wnew = data + @assert length(snew) == length(tnew) + @assert isnothing(wnew) || length(wnew) == length(snew) + if length(snew) == 0 + return g + end + @assert minimum(snew) >= 1 + @assert minimum(tnew) >= 1 + num_new = length(snew) + edata = normalize_graphdata(edata, default_name = :e, n = num_new) + edata = cat_features(g.edata, edata) + + s, t = edge_index(g) + s = [s; snew] + t = [t; tnew] + w = get_edge_weight(g) + w = cat_features(w, wnew, g.num_edges, num_new) + + num_nodes = max(maximum(snew), maximum(tnew), g.num_nodes) + if num_nodes > g.num_nodes + ndata_new = normalize_graphdata((;), default_name = :x, n = num_nodes - g.num_nodes) + ndata = cat_features(g.ndata, ndata_new) + else + ndata = g.ndata + end + + return GNNGraph((s, t, w), + num_nodes, length(s), g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata) +end + +""" + add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes]) + add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes]) + add_edges(g::GNNHeteroGraph, edge_t => (s, t, w); [edata, num_nodes]) + +Add to heterograph `g` edges of type `edge_t` with source node vector `s` and target node vector `t`. +Optionally, pass the edge weights `w` or the features `edata` for the new edges. +`edge_t` is a triplet of symbols `(src_t, rel_t, dst_t)`. + +If the edge type is not already present in the graph, it is added. +If it involves new node types, they are added to the graph as well. +In this case, a dictionary or named tuple of `num_nodes` can be passed to specify the number of nodes of the new types, +otherwise the number of nodes is inferred from the maximum node id in `s` and `t`. +""" +add_edges(g::GNNHeteroGraph{<:COO_T}, edge_t::EType, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, edge_t => (snew, tnew, nothing); kws...) +add_edges(g::GNNHeteroGraph{<:COO_T}, data::Pair{EType, <:Tuple{<:AbstractVector, <:AbstractVector}}; kws...) = add_edges(g, data.first => (data.second..., nothing); kws...) + +function add_edges(g::GNNHeteroGraph{<:COO_T}, + data::Pair{EType, <:COO_T}; + edata = nothing, + num_nodes = Dict{Symbol,Int}()) + edge_t, (snew, tnew, wnew) = data + @assert length(snew) == length(tnew) + if length(snew) == 0 + return g + end + @assert minimum(snew) >= 1 + @assert minimum(tnew) >= 1 + + is_existing_rel = haskey(g.graph, edge_t) + + edata = normalize_graphdata(edata, default_name = :e, n = length(snew)) + _edata = g.edata |> copy + if haskey(_edata, edge_t) + _edata[edge_t] = cat_features(g.edata[edge_t], edata) + else + _edata[edge_t] = edata + end + + graph = g.graph |> copy + etypes = g.etypes |> copy + ntypes = g.ntypes |> copy + _num_nodes = g.num_nodes |> copy + ndata = g.ndata |> copy + if !is_existing_rel + for (node_t, st) in [(edge_t[1], snew), (edge_t[3], tnew)] + if node_t ∉ ntypes + push!(ntypes, node_t) + if haskey(num_nodes, node_t) + _num_nodes[node_t] = num_nodes[node_t] + else + _num_nodes[node_t] = maximum(st) + end + ndata[node_t] = DataStore(_num_nodes[node_t]) + end + end + push!(etypes, edge_t) + else + s, t = edge_index(g, edge_t) + snew = [s; snew] + tnew = [t; tnew] + w = get_edge_weight(g, edge_t) + wnew = cat_features(w, wnew, length(s), length(snew)) + end + + if maximum(snew) > _num_nodes[edge_t[1]] + ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(snew) - _num_nodes[edge_t[1]]) + ndata[edge_t[1]] = cat_features(ndata[edge_t[1]], ndata_new) + _num_nodes[edge_t[1]] = maximum(snew) + end + if maximum(tnew) > _num_nodes[edge_t[3]] + ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(tnew) - _num_nodes[edge_t[3]]) + ndata[edge_t[3]] = cat_features(ndata[edge_t[3]], ndata_new) + _num_nodes[edge_t[3]] = maximum(tnew) + end + + graph[edge_t] = (snew, tnew, wnew) + num_edges = g.num_edges |> copy + num_edges[edge_t] = length(graph[edge_t][1]) + + return GNNHeteroGraph(graph, + _num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + ndata, _edata, g.gdata, + ntypes, etypes) +end + + +""" + perturb_edges([rng], g::GNNGraph, perturb_ratio) + +Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`. +The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. +These new edges are added without creating self-loops. + +The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. +The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. + +# Arguments + +- `g::GNNGraph`: The graph to be perturbed. +- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. +- `rng`: An optionalrandom number generator to ensure reproducible results. + +# Examples + +```julia +julia> g = GNNGraph((s, t, w)) +GNNGraph: + num_nodes: 4 + num_edges: 5 + +julia> perturbed_g = perturb_edges(g, 0.2) +GNNGraph: + num_nodes: 4 + num_edges: 6 +``` +""" +perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat) = + perturb_edges(Random.default_rng(), g, perturb_ratio) + +function perturb_edges(rng::AbstractRNG, g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat) + @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" + + num_current_edges = g.num_edges + num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) + + if num_edges_to_add == 0 + return g + end + + num_nodes = g.num_nodes + @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" + + snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + + mask_loops = snew .!= tnew + snew = snew[mask_loops] + tnew = tnew[mask_loops] + + while length(snew) < num_edges_to_add + n = num_edges_to_add - length(snew) + snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + mask_new_loops = snewnew .!= tnewnew + snewnew = snewnew[mask_new_loops] + tnewnew = tnewnew[mask_new_loops] + snew = [snew; snewnew] + tnew = [tnew; tnewnew] + end + + return add_edges(g, (snew, tnew, nothing)) +end + + +### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable +# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector} +# s, t = edge_index(g) +# @assert length(snew) == length(tnew) +# # TODO remove this constraint +# @assert get_edge_weight(g) === nothing + +# edata = normalize_graphdata(edata, default_name=:e, n=length(snew)) +# edata = cat_features(g.edata, edata) + +# s, t = edge_index(g) +# append!(s, snew) +# append!(t, tnew) +# g.num_edges += length(snew) +# return true +# end + +""" + to_bidirected(g) + +Adds a reverse edge for each edge in the graph, then calls +[`remove_multi_edges`](@ref) with `mean` aggregation to simplify the graph. + +See also [`is_bidirected`](@ref). + +# Examples + +```jldoctest +julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; + +julia> w = [1.0, 2.0, 3.0, 4.0, 5.0]; + +julia> e = [10.0, 20.0, 30.0, 40.0, 50.0]; + +julia> g = GNNGraph(s, t, w, edata = e) +GNNGraph: + num_nodes = 4 + num_edges = 5 + edata: + e => (5,) + +julia> g2 = to_bidirected(g) +GNNGraph: + num_nodes = 4 + num_edges = 7 + edata: + e => (7,) + +julia> edge_index(g2) +([1, 2, 2, 3, 3, 4, 4], [2, 1, 3, 2, 4, 3, 4]) + +julia> get_edge_weight(g2) +7-element Vector{Float64}: + 1.0 + 1.0 + 2.0 + 2.0 + 3.5 + 3.5 + 5.0 + +julia> g2.edata.e +7-element Vector{Float64}: + 10.0 + 10.0 + 20.0 + 20.0 + 35.0 + 35.0 + 50.0 +``` +""" +function to_bidirected(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + snew = [s; t] + tnew = [t; s] + w = cat_features(w, w) + edata = cat_features(g.edata, g.edata) + + g = GNNGraph((snew, tnew, w), + g.num_nodes, length(snew), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) + + return remove_multi_edges(g; aggr = mean) +end + +""" + to_unidirected(g::GNNGraph) + +Return a graph that for each multiple edge between two nodes in `g` +keeps only an edge in one direction. +""" +function to_unidirected(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + idxs, _ = edge_encoding(s, t, g.num_nodes, directed = false) + snew, tnew = edge_decoding(idxs, g.num_nodes, directed = false) + + g = GNNGraph((snew, tnew, w), + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) + + return remove_multi_edges(g; aggr = mean) +end + +function Graphs.SimpleGraph(g::GNNGraph) + G = Graphs.SimpleGraph(g.num_nodes) + for e in Graphs.edges(g) + Graphs.add_edge!(G, e) + end + return G +end +function Graphs.SimpleDiGraph(g::GNNGraph) + G = Graphs.SimpleDiGraph(g.num_nodes) + for e in Graphs.edges(g) + Graphs.add_edge!(G, e) + end + return G +end + +""" + add_nodes(g::GNNGraph, n; [ndata]) + +Add `n` new nodes to graph `g`. In the +new graph, these nodes will have indexes from `g.num_nodes + 1` +to `g.num_nodes + n`. +""" +function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata = (;)) + ndata = normalize_graphdata(ndata, default_name = :x, n = n) + ndata = cat_features(g.ndata, ndata) + + GNNGraph(g.graph, + g.num_nodes + n, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, g.edata, g.gdata) +end + +""" + set_edge_weight(g::GNNGraph, w::AbstractVector) + +Set `w` as edge weights in the returned graph. +""" +function set_edge_weight(g::GNNGraph, w::AbstractVector) + s, t = edge_index(g) + @assert length(w) == length(s) + + return GNNGraph((s, t, w), + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) + nv1, nv2 = g1.num_nodes, g2.num_nodes + if g1.graph isa COO_T + s1, t1 = edge_index(g1) + s2, t2 = edge_index(g2) + s = vcat(s1, nv1 .+ s2) + t = vcat(t1, nv1 .+ t2) + w = cat_features(get_edge_weight(g1), get_edge_weight(g2)) + graph = (s, t, w) + ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, nv2) : g2.graph_indicator + elseif g1.graph isa ADJMAT_T + graph = blockdiag(g1.graph, g2.graph) + ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, nv2) : g2.graph_indicator + end + graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) + + GNNGraph(graph, + nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs, + graph_indicator, + cat_features(g1.ndata, g2.ndata), + cat_features(g1.edata, g2.edata), + cat_features(g1.gdata, g2.gdata)) +end + +# PIRACY +function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) + m1, n1 = size(A1) + @assert m1 == n1 + m2, n2 = size(A2) + @assert m2 == n2 + O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0) + O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0) + return [A1 O1 + O2 A2] +end + +""" + blockdiag(xs::GNNGraph...) + +Equivalent to [`MLUtils.batch`](@ref). +""" +function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) + g = g1 + for go in gothers + g = blockdiag(g, go) + end + return g +end + +""" + batch(gs::Vector{<:GNNGraph}) + +Batch together multiple `GNNGraph`s into a single one +containing the total number of original nodes and edges. + +Equivalent to [`SparseArrays.blockdiag`](@ref). +See also [`MLUtils.unbatch`](@ref). + +# Examples + +```jldoctest +julia> g1 = rand_graph(4, 6, ndata=ones(8, 4)) +GNNGraph: + num_nodes = 4 + num_edges = 6 + ndata: + x => (8, 4) + +julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7)) +GNNGraph: + num_nodes = 7 + num_edges = 4 + ndata: + x => (8, 7) + +julia> g12 = MLUtils.batch([g1, g2]) +GNNGraph: + num_nodes = 11 + num_edges = 10 + num_graphs = 2 + ndata: + x => (8, 11) + +julia> g12.ndata.x +8×11 Matrix{Float64}: + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +``` +""" +function MLUtils.batch(gs::AbstractVector{<:GNNGraph}) + Told = eltype(gs) + # try to restrict the eltype + gs = [g for g in gs] + if eltype(gs) != Told + return MLUtils.batch(gs) + else + return blockdiag(gs...) + end +end + +function MLUtils.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T} + v_num_nodes = [g.num_nodes for g in gs] + edge_indices = [edge_index(g) for g in gs] + nodesum = cumsum([0; v_num_nodes])[1:(end - 1)] + s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + w = cat_features([get_edge_weight(g) for g in gs]) + graph = (s, t, w) + + function materialize_graph_indicator(g) + g.graph_indicator === nothing ? ones_like(s, g.num_nodes) : g.graph_indicator + end + + v_gi = materialize_graph_indicator.(gs) + v_num_graphs = [g.num_graphs for g in gs] + graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] + v_gi = [ng .+ gi for (ng, gi) in zip(graphsum, v_gi)] + graph_indicator = cat_features(v_gi) + + GNNGraph(graph, + sum(v_num_nodes), + sum([g.num_edges for g in gs]), + sum(v_num_graphs), + graph_indicator, + cat_features([g.ndata for g in gs]), + cat_features([g.edata for g in gs]), + cat_features([g.gdata for g in gs])) +end + +function MLUtils.batch(g::GNNGraph) + throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead.")) +end + + +function MLUtils.batch(gs::AbstractVector{<:GNNHeteroGraph}) + function edge_index_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + if haskey(g.graph, edge_t) + g.graph[edge_t][1:2] + else + nothing + end + end + + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + get(g.graph, edge_t, (nothing, nothing, nothing))[3] + end + + @assert length(gs) > 0 + ntypes = union([g.ntypes for g in gs]...) + etypes = union([g.etypes for g in gs]...) + + v_num_nodes = Dict(node_t => [get(g.num_nodes, node_t, 0) for g in gs] for node_t in ntypes) + num_nodes = Dict(node_t => sum(v_num_nodes[node_t]) for node_t in ntypes) + num_edges = Dict(edge_t => sum(get(g.num_edges, edge_t, 0) for g in gs) for edge_t in etypes) + edge_indices = edge_indices = Dict(edge_t => [edge_index_nullable(g, edge_t) for g in gs] for edge_t in etypes) + nodesum = Dict(node_t => cumsum([0; v_num_nodes[node_t]])[1:(end - 1)] for node_t in ntypes) + graphs = [] + for edge_t in etypes + src_t, _, dst_t = edge_t + # @show edge_t edge_indices[edge_t] first(edge_indices[edge_t]) + # for ei in edge_indices[edge_t] + # @show ei[1] + # end + # # [ei[1] for (ii, ei) in enumerate(edge_indices[edge_t])] + s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) + t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) + w = cat_features(filter(x -> x !== nothing, [get_edge_weight_nullable(g, edge_t) for g in gs])) + push!(graphs, edge_t => (s, t, w)) + end + graph = Dict(graphs...) + + #TODO relax this restriction + @assert all(g -> g.num_graphs == 1, gs) + + s = edge_index(gs[1], gs[1].etypes[1])[1] # grab any source vector + + function materialize_graph_indicator(g, node_t) + n = get(g.num_nodes, node_t, 0) + return ones_like(s, n) + end + v_gi = Dict(node_t => [materialize_graph_indicator(g, node_t) for g in gs] for node_t in ntypes) + v_num_graphs = [g.num_graphs for g in gs] + graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] + v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes) + graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes) + + function data_or_else(data, types) + Dict(type => get(data, type, DataStore(0)) for type in types) + end + + return GNNHeteroGraph(graph, + num_nodes, + num_edges, + sum(v_num_graphs), + graph_indicator, + cat_features([data_or_else(g.ndata, ntypes) for g in gs]), + cat_features([data_or_else(g.edata, etypes) for g in gs]), + cat_features([g.gdata for g in gs]), + ntypes, etypes) +end + +""" + unbatch(g::GNNGraph) + +Opposite of the [`MLUtils.batch`](@ref) operation, returns +an array of the individual graphs batched together in `g`. + +See also [`MLUtils.batch`](@ref) and [`getgraph`](@ref). + +# Examples + +```jldoctest +julia> gbatched = MLUtils.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)]) +GNNGraph: + num_nodes = 19 + num_edges = 16 + num_graphs = 3 + +julia> MLUtils.unbatch(gbatched) +3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: + GNNGraph: + num_nodes = 5 + num_edges = 6 + + GNNGraph: + num_nodes = 10 + num_edges = 8 + + GNNGraph: + num_nodes = 4 + num_edges = 2 +``` +""" +function MLUtils.unbatch(g::GNNGraph{T}) where {T <: COO_T} + g.num_graphs == 1 && return [g] + + nodemasks = _unbatch_nodemasks(g.graph_indicator, g.num_graphs) + num_nodes = length.(nodemasks) + cumnum_nodes = [0; cumsum(num_nodes)] + + s, t = edge_index(g) + w = get_edge_weight(g) + + edgemasks = _unbatch_edgemasks(s, t, g.num_graphs, cumnum_nodes) + num_edges = length.(edgemasks) + @assert sum(num_edges)==g.num_edges "Error in unbatching, likely the edges are not sorted (first edges belong to the first graphs, then edges in the second graph and so on)" + + function build_graph(i) + node_mask = nodemasks[i] + edge_mask = edgemasks[i] + snew = s[edge_mask] .- cumnum_nodes[i] + tnew = t[edge_mask] .- cumnum_nodes[i] + wnew = w === nothing ? nothing : w[edge_mask] + graph = (snew, tnew, wnew) + graph_indicator = nothing + ndata = getobs(g.ndata, node_mask) + edata = getobs(g.edata, edge_mask) + gdata = getobs(g.gdata, i) + + nedges = num_edges[i] + nnodes = num_nodes[i] + ngraphs = 1 + + return GNNGraph(graph, + nnodes, nedges, ngraphs, + graph_indicator, + ndata, edata, gdata) + end + + return [build_graph(i) for i in 1:(g.num_graphs)] +end + +function MLUtils.unbatch(g::GNNGraph) + return [getgraph(g, i) for i in 1:(g.num_graphs)] +end + +function _unbatch_nodemasks(graph_indicator, num_graphs) + @assert issorted(graph_indicator) "The graph_indicator vector must be sorted." + idxslast = [searchsortedlast(graph_indicator, i) for i in 1:num_graphs] + + nodemasks = [1:idxslast[1]] + for i in 2:num_graphs + push!(nodemasks, (idxslast[i - 1] + 1):idxslast[i]) + end + return nodemasks +end + +function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes) + edgemasks = [] + for i in 1:(num_graphs - 1) + lastedgeid = findfirst(s) do x + x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2] + end + firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1 + # if nothing make empty range + lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1 + + push!(edgemasks, firstedgeid:lastedgeid) + end + push!(edgemasks, (last(edgemasks[end]) + 1):length(s)) + return edgemasks +end + +@non_differentiable _unbatch_nodemasks(::Any...) +@non_differentiable _unbatch_edgemasks(::Any...) + +""" + getgraph(g::GNNGraph, i; nmap=false) + +Return the subgraph of `g` induced by those nodes `j` +for which `g.graph_indicator[j] == i` or, +if `i` is a collection, `g.graph_indicator[j] ∈ i`. +In other words, it extract the component graphs from a batched graph. + +If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones. +The node `i` in the subgraph will correspond to the node `v[i]` in `g`. +""" +getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...) + +function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap = false) + if g.graph_indicator === nothing + @assert i == [1] + if nmap + return g, 1:(g.num_nodes) + else + return g + end + end + + node_mask = g.graph_indicator .∈ Ref(i) + + nodes = (1:(g.num_nodes))[node_mask] + nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes)) + + graphmap = Dict(i => inew for (inew, i) in enumerate(i)) + graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]] + + s, t = edge_index(g) + w = get_edge_weight(g) + edge_mask = s .∈ Ref(nodes) + + if g.graph isa COO_T + s = [nodemap[i] for i in s[edge_mask]] + t = [nodemap[i] for i in t[edge_mask]] + w = isnothing(w) ? nothing : w[edge_mask] + graph = (s, t, w) + elseif g.graph isa ADJMAT_T + graph = g.graph[nodes, nodes] + end + + ndata = getobs(g.ndata, node_mask) + edata = getobs(g.edata, edge_mask) + gdata = getobs(g.gdata, i) + + num_edges = sum(edge_mask) + num_nodes = length(graph_indicator) + num_graphs = length(i) + + gnew = GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) + + if nmap + return gnew, nodes + else + return gnew + end +end + +""" + negative_sample(g::GNNGraph; + num_neg_edges = g.num_edges, + bidirected = is_bidirected(g)) + +Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges. + +If `bidirected=true`, the output graph will be bidirected and there will be no +leakage from the origin graph. + +See also [`is_bidirected`](@ref). +""" +function negative_sample(g::GNNGraph; + max_trials = 3, + num_neg_edges = g.num_edges, + bidirected = is_bidirected(g)) + @assert g.num_graphs == 1 + # Consider self-loops as positive edges + # Construct new graph dropping features + g = add_self_loops(GNNGraph(edge_index(g), num_nodes = g.num_nodes)) + + s, t = edge_index(g) + n = g.num_nodes + dev = get_device(s) + cdev = cpu_device() + s, t = s |> cdev, t |> cdev + idx_pos, maxid = edge_encoding(s, t, n) + if bidirected + num_neg_edges = num_neg_edges ÷ 2 + pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge + else + pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge + end + # pneg * sample_prob * maxid == num_neg_edges + sample_prob = min(1, num_neg_edges / (pneg * maxid) * 1.1) + idx_neg = Int[] + for _ in 1:max_trials + rnd = randsubseq(1:maxid, sample_prob) + setdiff!(rnd, idx_pos) + union!(idx_neg, rnd) + if length(idx_neg) >= num_neg_edges + idx_neg = idx_neg[1:num_neg_edges] + break + end + end + s_neg, t_neg = edge_decoding(idx_neg, n) + if bidirected + s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg] + end + s_neg, t_neg = s_neg |> dev, t_neg |> dev + return GNNGraph(s_neg, t_neg, num_nodes = n) +end + +""" + rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2 + +Randomly partition the edges in `g` to form two graphs, `g1` +and `g2`. Both will have the same number of nodes as `g`. +`g1` will contain a fraction `frac` of the original edges, +while `g2` wil contain the rest. + +If `bidirected = true` makes sure that an edge and its reverse go into the same split. +This option is supported only for bidirected graphs with no self-loops +and multi-edges. + +`rand_edge_split` is tipically used to create train/test splits in link prediction tasks. +""" +function rand_edge_split(g::GNNGraph, frac; bidirected = is_bidirected(g)) + s, t = edge_index(g) + ne = bidirected ? g.num_edges ÷ 2 : g.num_edges + eids = randperm(ne) + size1 = round(Int, ne * frac) + + if !bidirected + s1, t1 = s[eids[1:size1]], t[eids[1:size1]] + s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] + else + # @assert is_bidirected(g) + # @assert !has_self_loops(g) + # @assert !has_multi_edges(g) + mask = s .< t + s, t = s[mask], t[mask] + s1, t1 = s[eids[1:size1]], t[eids[1:size1]] + s1, t1 = [s1; t1], [t1; s1] + s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] + s2, t2 = [s2; t2], [t2; s2] + end + g1 = GNNGraph(s1, t1, num_nodes = g.num_nodes) + g2 = GNNGraph(s2, t2, num_nodes = g.num_nodes) + return g1, g2 +end + +""" + random_walk_pe(g, walk_length) + +Return the random walk positional encoding from the paper [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) of the given graph `g` and the length of the walk `walk_length` as a matrix of size `(walk_length, g.num_nodes)`. +""" +function random_walk_pe(g::GNNGraph, walk_length::Int) + matrix = zeros(walk_length, g.num_nodes) + adj = adjacency_matrix(g, Float32; dir = :out) + matrix = dense_zeros_like(adj, Float32, (walk_length, g.num_nodes)) + deg = sum(adj, dims = 2) |> vec + deg_inv = inv.(deg) + deg_inv[isinf.(deg_inv)] .= 0 + RW = adj * Diagonal(deg_inv) + out = RW + matrix[1, :] .= diag(RW) + for i in 2:walk_length + out = out * RW + matrix[i, :] .= diag(out) + end + return matrix +end + +dense_zeros_like(a::SparseMatrixCSC, T::Type, sz = size(a)) = zeros(T, sz) +dense_zeros_like(a::AbstractArray, T::Type, sz = size(a)) = fill!(similar(a, T, sz), 0) +dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz) + +# """ +# Transform vector of cartesian indexes into a tuple of vectors containing integers. +# """ +ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) + +@non_differentiable negative_sample(x...) +@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +@non_differentiable dense_zeros_like(x...) + +""" + ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph + +Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix. +References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422) + + +The function performs the following steps: +1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix. +2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities. +3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix. +4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix. + +# Arguments +- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available. +- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`. + +# Returns +- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation. +""" +function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0) + s, t = edge_index(g) + w = get_edge_weight(g) + if isnothing(w) + w = ones(Float32, g.num_edges) + end + + N = g.num_nodes + + initial_A = sparse(t, s, w, N, N) + scaled_A = (Float32(alpha) - 1) * initial_A + + I_sparse = sparse(Diagonal(ones(Float32, N))) + A_sparse = I_sparse + scaled_A + + A_dense = Matrix(A_sparse) + + PPR = alpha * inv(A_dense) + + new_w = [PPR[dst, src] for (src, dst) in zip(s, t)] + + return GNNGraph((s, t, new_w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +[.\GNNGraphs\src\utils.jl] +function check_num_nodes(g::GNNGraph, x::AbstractArray) + @assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)" + return true +end +function check_num_nodes(g::GNNGraph, x::Union{Tuple, NamedTuple}) + map(x -> check_num_nodes(g, x), x) + return true +end + +check_num_nodes(::GNNGraph, ::Nothing) = true + +function check_num_nodes(g::GNNGraph, x::Tuple) + @assert length(x) == 2 + check_num_nodes(g, x[1]) + check_num_nodes(g, x[2]) + return true +end + +# x = (Xsrc, Xdst) = (Xj, Xi) +function check_num_nodes(g::GNNHeteroGraph, x::Tuple) + @assert length(x) == 2 + @assert length(g.etypes) == 1 + nt1, _, nt2 = only(g.etypes) + if x[1] isa AbstractArray + @assert size(x[1], ndims(x[1])) == g.num_nodes[nt1] + end + if x[2] isa AbstractArray + @assert size(x[2], ndims(x[2])) == g.num_nodes[nt2] + end + return true +end + +function check_num_edges(g::GNNGraph, e::AbstractArray) + @assert g.num_edges==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(g.num_edges)" + return true +end +function check_num_edges(g::AbstractGNNGraph, x::Union{Tuple, NamedTuple}) + map(x -> check_num_edges(g, x), x) + return true +end + +check_num_edges(::AbstractGNNGraph, ::Nothing) = true + +function check_num_edges(g::GNNHeteroGraph, e::AbstractArray) + num_edgs = only(g.num_edges)[2] + @assert only(num_edgs)==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(num_edgs)" + return true +end + +sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) + +""" + sort_edge_index(ei::Tuple) -> u', v' + sort_edge_index(u, v) -> u', v' + +Return a sorted version of the tuple of vectors `ei = (u, v)`, +applying a common permutation to `u` and `v`. +The sorting is lexycographic, that is the pairs `(ui, vi)` +are sorted first according to the `ui` and then according to `vi`. +""" +function sort_edge_index(u, v) + uv = collect(zip(u, v)) + p = sortperm(uv) # isless lexicographically defined for tuples + return u[p], v[p] +end + + +cat_features(x1::Nothing, x2::Nothing) = nothing +cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1)) +function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) + cat(x1, x2, dims = 1) +end + +# workaround for issue #98 #104 +# See https://github.com/JuliaStrings/InlineStrings.jl/issues/21 +# Remove when minimum supported version is julia v1.8 +cat_features(x1::NamedTuple{(), Tuple{}}, x2::NamedTuple{(), Tuple{}}) = (;) +cat_features(xs::AbstractVector{NamedTuple{(), Tuple{}}}) = (;) + +function cat_features(x1::NamedTuple, x2::NamedTuple) + sort(collect(keys(x1))) == sort(collect(keys(x2))) || + @error "cannot concatenate feature data with different keys" + + return NamedTuple(k => cat_features(x1[k], x2[k]) for k in keys(x1)) +end + +function cat_features(x1::Dict{Symbol, T}, x2::Dict{Symbol, T}) where {T} + sort(collect(keys(x1))) == sort(collect(keys(x2))) || + @error "cannot concatenate feature data with different keys" + + return Dict{Symbol, T}([k => cat_features(x1[k], x2[k]) for k in keys(x1)]...) +end + +function cat_features(x::Dict) + return Dict([k => cat_features(v) for (k, v) in pairs(x)]...) +end + + +function cat_features(xs::AbstractVector{<:AbstractArray{T, N}}) where {T <: Number, N} + cat(xs...; dims = N) +end + +cat_features(xs::AbstractVector{Nothing}) = nothing +cat_features(xs::AbstractVector{<:Number}) = xs + +function cat_features(xs::AbstractVector{<:NamedTuple}) + symbols = [sort(collect(keys(x))) for x in xs] + all(y -> y == symbols[1], symbols) || + @error "cannot concatenate feature data with different keys" + length(xs) == 1 && return xs[1] + + # concatenate + syms = symbols[1] + NamedTuple(k => cat_features([x[k] for x in xs]) for k in syms) +end + +# function cat_features(xs::AbstractVector{Dict{Symbol, T}}) where {T} +# symbols = [sort(collect(keys(x))) for x in xs] +# all(y -> y == symbols[1], symbols) || +# @error "cannot concatenate feature data with different keys" +# length(xs) == 1 && return xs[1] + +# # concatenate +# syms = symbols[1] +# return Dict{Symbol, T}([k => cat_features([x[k] for x in xs]) for k in syms]...) +# end + +function cat_features(xs::AbstractVector{<:Dict}) + _allkeys = [sort(collect(keys(x))) for x in xs] + _keys = union(_allkeys...) + length(xs) == 1 && return xs[1] + + # concatenate + return Dict([k => cat_features([x[k] for x in xs if haskey(x, k)]) for k in _keys]...) +end + + +# Used to concatenate edge weights +cat_features(w1::Nothing, w2::Nothing, n1::Int, n2::Int) = nothing +cat_features(w1::AbstractVector, w2::Nothing, n1::Int, n2::Int) = cat_features(w1, ones_like(w1, n2)) +cat_features(w1::Nothing, w2::AbstractVector, n1::Int, n2::Int) = cat_features(ones_like(w2, n1), w2) +cat_features(w1::AbstractVector, w2::AbstractVector, n1::Int, n2::Int) = cat_features(w1, w2) + + +# Turns generic type into named tuple +normalize_graphdata(data::Nothing; n, kws...) = DataStore(n) + +function normalize_graphdata(data; default_name::Symbol, kws...) + normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) +end + +function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false) + # This had to workaround two Zygote bugs with NamedTuples + # https://github.com/FluxML/Zygote.jl/issues/1071 + # https://github.com/FluxML/Zygote.jl/issues/1072 + + if n > 1 + @assert all(x -> x isa AbstractArray, data) "Non-array features provided." + end + + if n <= 1 + # If last array dimension is not 1, add a new dimension. + # This is mostly useful to reshape global feature vectors + # of size D to Dx1 matrices. + unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v + unsqz_last(v) = v + + data = map(unsqz_last, data) + end + + if n > 0 + if duplicate_if_needed + function duplicate(v) + if v isa AbstractArray && size(v)[end] == n ÷ 2 + v = cat(v, v, dims = ndims(v)) + end + return v + end + data = map(duplicate, data) + end + + for x in data + if x isa AbstractArray + @assert size(x)[end]==n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])." + end + end + end + + return DataStore(n, data) +end + +# For heterogeneous graphs +function normalize_heterographdata(data::Nothing; default_name::Symbol, ns::Dict, kws...) + Dict([k => normalize_graphdata(nothing; default_name = default_name, n, kws...) + for (k, n) in ns]...) +end + +normalize_heterographdata(data; kws...) = normalize_heterographdata(Dict(data); kws...) + +function normalize_heterographdata(data::Dict; default_name::Symbol, ns::Dict, kws...) + Dict([k => normalize_graphdata(get(data, k, nothing); default_name = default_name, n, kws...) + for (k, n) in ns]...) +end + +numnonzeros(a::AbstractSparseMatrix) = nnz(a) +numnonzeros(a::AbstractMatrix) = count(!=(0), a) + +## Map edges into a contiguous range of integers +function edge_encoding(s, t, n; directed = true, self_loops = true) + if directed && self_loops + maxid = n^2 + idx = (s .- 1) .* n .+ t + elseif !directed && self_loops + maxid = n * (n + 1) ÷ 2 + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + s, t = snew, tnew + + # idx = ∑_{i',i'=i'}^n 1 + ∑_{j',i<=j'<=j} 1 + # = ∑_{i',i'=i'}^n 1 + (j - i + 1) + # = ∑_{i',i' s) + elseif !directed && !self_loops + @assert all(s .!= t) + maxid = n * (n - 1) ÷ 2 + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + s, t = snew, tnew + + # idx(s,t) = ∑_{s',1<= s'= s) + elseif !directed && !self_loops + # Considering t = s + 1 in + # idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s) + # and inverting for s we have + s = @. floor(Int, 1/2 + n - 1/2 * sqrt(9 - 4n + 4n^2 - 8*idx)) + # now we can find t + t = @. idx - (s - 1) * n + s * (s - 1) ÷ 2 + s + end + return s, t +end + +# for bipartite graphs +function edge_decoding(idx, n1, n2) + @assert all(1 .<= idx .<= n1 * n2) + s = (idx .- 1) .÷ n2 .+ 1 + t = (idx .- 1) .% n2 .+ 1 + return s, t +end + +function _rand_edges(rng, n::Int, m::Int; directed = true, self_loops = true) + idmax = if directed && self_loops + n^2 + elseif !directed && self_loops + n * (n + 1) ÷ 2 + elseif directed && !self_loops + n * (n - 1) + elseif !directed && !self_loops + n * (n - 1) ÷ 2 + end + idx = StatsBase.sample(rng, 1:idmax, m, replace = false) + s, t = edge_decoding(idx, n; directed, self_loops) + val = nothing + return s, t, val +end + +function _rand_edges(rng, (n1, n2), m) + idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) + s, t = edge_decoding(idx, n1, n2) + val = nothing + return s, t, val +end + +binarize(x) = map(>(0), x) + +@non_differentiable binarize(x...) +@non_differentiable edge_encoding(x...) +@non_differentiable edge_decoding(x...) + +### PRINTING ##### + +function shortsummary(io::IO, x) + s = shortsummary(x) + s === nothing && return + print(io, s) +end + +shortsummary(x) = summary(x) +shortsummary(x::Number) = "$x" + +function shortsummary(x::NamedTuple) + if length(x) == 0 + return nothing + elseif length(x) === 1 + return "$(keys(x)[1]) = $(shortsummary(x[1]))" + else + "(" * join(("$k = $(shortsummary(x[k]))" for k in keys(x)), ", ") * ")" + end +end + +function shortsummary(x::DataStore) + length(x) == 0 && return nothing + return "DataStore(" * join(("$k = [$(shortsummary(x[k]))]" for k in keys(x)), ", ") * + ")" +end + +# from (2,2,3) output of size function to a string "2×2×3" +function dims2string(d) + isempty(d) ? "0-dimensional" : + length(d) == 1 ? "$(d[1])-element" : + join(map(string, d), '×') +end + +@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}}) +@non_differentiable normalize_graphdata(::Nothing) + +iscuarray(x::AbstractArray) = false +@non_differentiable iscuarray(::Any) + + +@doc raw""" + color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters + +The color refinement algorithm for graph coloring. +Given a graph `g` and an initial coloring `x0`, the algorithm +iteratively refines the coloring until a fixed point is reached. + +At each iteration the algorithm computes a hash of the coloring and the sorted list of colors +of the neighbors of each node. This hash is used to determine if the coloring has changed. + +```math +x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))). +```` + +This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing. + +# Arguments +- `g::GNNGraph`: The graph to color. +- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1. + +# Returns +- `x::AbstractVector{<:Integer}`: The final coloring. +- `num_colors::Int`: The number of colors used. +- `niters::Int`: The number of iterations until convergence. +""" +color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes)) + +function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer}) + @assert length(x0) == g.num_nodes + s, t = edge_index(g) + t, s = sort_edge_index(t, s) # sort by target + degs = degree(g, dir=:in) + x = x0 + + hashmap = Dict{UInt64, Int}() + x′ = zeros(Int, length(x0)) + niters = 0 + while true + xneigs = chunk(x[s], size=degs) + for (i, (xi, xineigs)) in enumerate(zip(x, xneigs)) + idx = hash((xi, sort(xineigs))) + x′[i] = get!(hashmap, idx, length(hashmap) + 1) + end + niters += 1 + x == x′ && break + x = x′ + end + num_colors = length(union(x)) + return x, num_colors, niters +end +[.\GNNGraphs\test\chainrules.jl] +@testset "dict constructor" begin + grad = gradient(1.) do x + d = Dict([:x => x, :y => 5]...) + return sum(d[:x].^2) + end[1] + + @test grad == 2 + + ## BROKEN Constructors + # grad = gradient(1.) do x + # d = Dict([(:x => x), (:y => 5)]) + # return sum(d[:x].^2) + # end[1] + + # @test grad == 2 + + + # grad = gradient(1.) do x + # d = Dict([(:x => x), (:y => 5)]) + # return sum(d[:x].^2) + # end[1] + + # @test grad == 2 +end + +[.\GNNGraphs\test\convert.jl] +if TEST_GPU + @testset "to_coo(dense) on gpu" begin + get_st(A) = GNNGraphs.to_coo(A)[1][1:2] + get_val(A) = GNNGraphs.to_coo(A)[1][3] + + A = cu([0 2 2; 2.0 0 2; 2 2 0]) + + y = get_val(A) + @test y isa CuVector{Float32} + @test Array(y) ≈ [2, 2, 2, 2, 2, 2] + + s, t = get_st(A) + @test s isa CuVector{<:Integer} + @test t isa CuVector{<:Integer} + @test Array(s) == [2, 3, 1, 3, 1, 2] + @test Array(t) == [1, 1, 2, 2, 3, 3] + + @test gradient(A -> sum(get_val(A)), A)[1] isa CuMatrix{Float32} + end +end + +[.\GNNGraphs\test\datastore.jl] + +@testset "constructor" begin + @test_throws AssertionError DataStore(10, (:x => rand(10), :y => rand(2, 4))) + + @testset "keyword args" begin + ds = DataStore(10, x = rand(10), y = rand(2, 10)) + @test size(ds.x) == (10,) + @test size(ds.y) == (2, 10) + + ds = DataStore(x = rand(10), y = rand(2, 10)) + @test size(ds.x) == (10,) + @test size(ds.y) == (2, 10) + end +end + +@testset "getproperty / setproperty!" begin + x = rand(10) + ds = DataStore(10, (:x => x, :y => rand(2, 10))) + @test ds.x == ds[:x] == x + @test_throws DimensionMismatch ds.z=rand(12) + ds.z = [1:10;] + @test ds.z == [1:10;] + vec = [DataStore(10, (:x => x,)), DataStore(10, (:x => x, :y => rand(2, 10)))] + @test vec.x == [x, x] + @test_throws KeyError vec.z + @test vec._n == [10, 10] + @test vec._data == [Dict(:x => x), Dict(:x => x, :y => vec[2].y)] +end + +@testset "setindex!" begin + ds = DataStore(10) + x = rand(10) + @test (ds[:x] = x) == x # Tests setindex! + @test ds.x == ds[:x] == x +end + +@testset "map" begin + ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) + ds2 = map(x -> x .+ 1, ds) + @test ds2.x == ds.x .+ 1 + @test ds2.y == ds.y .+ 1 + + @test_throws AssertionError ds2=map(x -> [x; x], ds) +end + +@testset "getdata / getn" begin + ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) + @test getdata(ds) == getfield(ds, :_data) + @test_throws KeyError ds.data + @test getn(ds) == getfield(ds, :_n) + @test_throws KeyError ds.n +end + +@testset "cat empty" begin + ds1 = DataStore(2, (:x => rand(2))) + ds2 = DataStore(1, (:x => rand(1))) + dsempty = DataStore(0, (:x => rand(0))) + + ds = GNNGraphs.cat_features(ds1, ds2) + @test getn(ds) == 3 + ds = GNNGraphs.cat_features(ds1, dsempty) + @test getn(ds) == 2 + + # issue #280 + g = GNNGraph([1], [2]) + h = add_edges(g, Int[], Int[]) # adds no edges + @test getn(g.edata) == 1 + @test getn(h.edata) == 1 +end + + +@testset "gradient" begin + ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) + + f1(ds) = sum(ds.x) + grad = gradient(f1, ds)[1] + @test grad._data[:x] ≈ ngradient(f1, ds)[1][:x] + + g = rand_graph(5, 2) + x = rand(2, 5) + grad = gradient(x -> sum(exp, GNNGraph(g, ndata = x).ndata.x), x)[1] + @test grad == exp.(x) +end + +@testset "functor" begin + ds = DataStore(10, (:x => zeros(10), :y => ones(2, 10))) + p, re = Functors.functor(ds) + @test p[1] === getn(ds) + @test p[2] === getdata(ds) + @test ds == re(p) + + ds2 = Functors.fmap(ds) do x + if x isa AbstractArray + x .+ 1 + else + x + end + end + @test ds isa DataStore + @test ds2.x == ds.x .+ 1 +end + +[.\GNNGraphs\test\generate.jl] +@testset "rand_graph" begin + n, m = 10, 20 + m2 = m ÷ 2 + x = rand(3, n) + e = rand(4, m2) + + g = rand_graph(n, m, ndata = x, edata = e, graph_type = GRAPH_T) + @test g.num_nodes == n + @test g.num_edges == m + @test g.ndata.x === x + if GRAPH_T == :coo + s, t = edge_index(g) + @test s[1:m2] == t[(m2 + 1):end] + @test t[1:m2] == s[(m2 + 1):end] + @test g.edata.e[:, 1:m2] == e + @test g.edata.e[:, (m2 + 1):end] == e + end + + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) + @test g.num_nodes == n + @test g.num_edges == m + + rng = MersenneTwister(17) + g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) + @test edge_index(g2) == edge_index(g) + + ew = rand(m2) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew) + @test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo) + + ew = rand(m) + rng = MersenneTwister(17) + g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew) + @test get_edge_weight(g) == ew broken=(GRAPH_T != :coo) +end + +@testset "knn_graph" begin + n, k = 10, 3 + x = rand(3, n) + g = knn_graph(x, k; graph_type = GRAPH_T) + @test g.num_nodes == 10 + @test g.num_edges == n * k + @test degree(g, dir = :in) == fill(k, n) + @test has_self_loops(g) == false + + g = knn_graph(x, k; dir = :out, self_loops = true, graph_type = GRAPH_T) + @test g.num_nodes == 10 + @test g.num_edges == n * k + @test degree(g, dir = :out) == fill(k, n) + @test has_self_loops(g) == true + + graph_indicator = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2] + g = knn_graph(x, k; graph_indicator, graph_type = GRAPH_T) + @test g.num_graphs == 2 + s, t = edge_index(g) + ne = n * k ÷ 2 + @test all(1 .<= s[1:ne] .<= 5) + @test all(1 .<= t[1:ne] .<= 5) + @test all(6 .<= s[(ne + 1):end] .<= 10) + @test all(6 .<= t[(ne + 1):end] .<= 10) +end + +@testset "radius_graph" begin + n, r = 10, 0.5 + x = rand(3, n) + g = radius_graph(x, r; graph_type = GRAPH_T) + @test g.num_nodes == 10 + @test has_self_loops(g) == false + + g = radius_graph(x, r; dir = :out, self_loops = true, graph_type = GRAPH_T) + @test g.num_nodes == 10 + @test has_self_loops(g) == true + + graph_indicator = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2] + g = radius_graph(x, r; graph_indicator, graph_type = GRAPH_T) + @test g.num_graphs == 2 + s, t = edge_index(g) + @test (s .> 5) == (t .> 5) +end + +@testset "rand_bipartite_heterograph" begin + g = rand_bipartite_heterograph((10, 15), (20, 20)) + @test g.num_nodes == Dict(:A => 10, :B => 15) + @test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) + sA, tB = edge_index(g, (:A, :to, :B)) + for (s, t) in zip(sA, tB) + @test 1 <= s <= 10 + @test 1 <= t <= 15 + @test has_edge(g, (:A,:to,:B), s, t) + @test has_edge(g, (:B,:to,:A), t, s) + end + + g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) + @test has_edge(g, (:A,:to,:B), 1, 1) + @test !has_edge(g, (:B,:to,:A), 1, 1) +end + +@testset "rand_temporal_radius_graph" begin + number_nodes = 30 + number_snapshots = 5 + r = 0.1 + speed = 0.1 + tg = rand_temporal_radius_graph(number_nodes, number_snapshots, speed, r) + @test tg.num_nodes == [number_nodes for i in 1:number_snapshots] + @test tg.num_snapshots == number_snapshots + r2 = 0.95 + tg2 = rand_temporal_radius_graph(number_nodes, number_snapshots, speed, r2) + @test mean(mean(degree.(tg.snapshots)))<=mean(mean(degree.(tg2.snapshots))) +end + +@testset "rand_temporal_hyperbolic_graph" begin + @test GNNGraphs._hyperbolic_distance([1.0,1.0],[1.0,1.0];ζ=1)==0 + @test GNNGraphs._hyperbolic_distance([0.23,0.11],[0.98,0.55];ζ=1) == GNNGraphs._hyperbolic_distance([0.98,0.55],[0.23,0.11];ζ=1) + number_nodes = 30 + number_snapshots = 5 + α, R, speed, ζ = 1, 1, 0.1, 1 + + tg = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ) + @test tg.num_nodes == [number_nodes for i in 1:number_snapshots] + @test tg.num_snapshots == number_snapshots + R = 10 + tg1 = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ) + @test mean(mean(degree.(tg1.snapshots)))<=mean(mean(degree.(tg.snapshots))) +end + +[.\GNNGraphs\test\gnngraph.jl] +@testset "Constructor: adjacency matrix" begin + A = sprand(10, 10, 0.5) + sA, tA, vA = findnz(A) + + g = GNNGraph(A, graph_type = GRAPH_T) + s, t = edge_index(g) + v = get_edge_weight(g) + @test s == sA + @test t == tA + @test v == vA + + g = GNNGraph(Matrix(A), graph_type = GRAPH_T) + s, t = edge_index(g) + v = get_edge_weight(g) + @test s == sA + @test t == tA + @test v == vA + + g = GNNGraph([0 0 0 + 0 0 1 + 0 1 0], graph_type = GRAPH_T) + @test g.num_nodes == 3 + @test g.num_edges == 2 + + g = GNNGraph([0 1 0 + 1 0 0 + 0 0 0], graph_type = GRAPH_T) + @test g.num_nodes == 3 + @test g.num_edges == 2 +end + +@testset "Constructor: integer" begin + g = GNNGraph(10, graph_type = GRAPH_T) + @test g.num_nodes == 10 + @test g.num_edges == 0 + + g2 = rand_graph(10, 30, graph_type = GRAPH_T) + G = typeof(g2) + g = G(10) + @test g.num_nodes == 10 + @test g.num_edges == 0 + + g = GNNGraph(graph_type = GRAPH_T) + @test g.num_nodes == 0 +end + +@testset "symmetric graph" begin + s = [1, 1, 2, 2, 3, 3, 4, 4] + t = [2, 4, 1, 3, 2, 4, 1, 3] + adj_mat = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + adj_list_out = [[2, 4], [1, 3], [2, 4], [1, 3]] + adj_list_in = [[2, 4], [1, 3], [2, 4], [1, 3]] + + # core functionality + g = GNNGraph(s, t; graph_type = GRAPH_T) + if TEST_GPU + dev = CUDADevice() + g_gpu = g |> dev + end + + @test g.num_edges == 8 + @test g.num_nodes == 4 + @test nv(g) == g.num_nodes + @test ne(g) == g.num_edges + @test Tuple.(collect(edges(g))) |> sort == collect(zip(s, t)) |> sort + @test sort(outneighbors(g, 1)) == [2, 4] + @test sort(inneighbors(g, 1)) == [2, 4] + @test is_directed(g) == true + s1, t1 = sort_edge_index(edge_index(g)) + @test s1 == s + @test t1 == t + @test vertices(g) == 1:(g.num_nodes) + + @test sort.(adjacency_list(g; dir = :in)) == adj_list_in + @test sort.(adjacency_list(g; dir = :out)) == adj_list_out + + @testset "adjacency_matrix" begin + @test adjacency_matrix(g) == adj_mat + @test adjacency_matrix(g; dir = :in) == adj_mat + @test adjacency_matrix(g; dir = :out) == adj_mat + + if TEST_GPU + # See https://github.com/JuliaGPU/CUDA.jl/pull/1093 + mat_gpu = adjacency_matrix(g_gpu) + @test mat_gpu isa ACUMatrix{Int} + @test Array(mat_gpu) == adj_mat + end + end + + @testset "normalized_laplacian" begin + mat = normalized_laplacian(g) + if TEST_GPU + mat_gpu = normalized_laplacian(g_gpu) + @test mat_gpu isa ACUMatrix{Float32} + @test Array(mat_gpu) == mat + end + end + + @testset "scaled_laplacian" begin if TEST_GPU + mat = scaled_laplacian(g) + mat_gpu = scaled_laplacian(g_gpu) + @test mat_gpu isa ACUMatrix{Float32} + @test Array(mat_gpu) ≈ mat + end end + + @testset "constructors" begin + adjacency_matrix(g; dir = :out) == adj_mat + adjacency_matrix(g; dir = :in) == adj_mat + end + + if TEST_GPU + @testset "functor" begin + s_cpu, t_cpu = edge_index(g) + s_gpu, t_gpu = edge_index(g_gpu) + @test s_gpu isa CuVector{Int} + @test Array(s_gpu) == s_cpu + @test t_gpu isa CuVector{Int} + @test Array(t_gpu) == t_cpu + end + end +end + +@testset "asymmetric graph" begin + s = [1, 2, 3, 4] + t = [2, 3, 4, 1] + adj_mat_out = [0 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + adj_list_out = [[2], [3], [4], [1]] + + adj_mat_in = [0 0 0 1 + 1 0 0 0 + 0 1 0 0 + 0 0 1 0] + adj_list_in = [[4], [1], [2], [3]] + + # core functionality + g = GNNGraph(s, t; graph_type = GRAPH_T) + if TEST_GPU + dev = CUDADevice() #TODO replace with `gpu_device()` + g_gpu = g |> dev + end + + @test g.num_edges == 4 + @test g.num_nodes == 4 + @test length(edges(g)) == 4 + @test sort(outneighbors(g, 1)) == [2] + @test sort(inneighbors(g, 1)) == [4] + @test is_directed(g) == true + @test is_directed(typeof(g)) == true + s1, t1 = sort_edge_index(edge_index(g)) + @test s1 == s + @test t1 == t + + # adjacency + @test adjacency_matrix(g) == adj_mat_out + @test adjacency_list(g) == adj_list_out + @test adjacency_matrix(g, dir = :out) == adj_mat_out + @test adjacency_list(g, dir = :out) == adj_list_out + @test adjacency_matrix(g, dir = :in) == adj_mat_in + @test adjacency_list(g, dir = :in) == adj_list_in +end + +@testset "zero" begin + g = rand_graph(4, 6, graph_type = GRAPH_T) + G = typeof(g) + @test zero(G) == G(0) +end + +@testset "Graphs.jl constructor" begin + lg = random_regular_graph(10, 4) + @test !Graphs.is_directed(lg) + g = GNNGraph(lg) + @test g.num_edges == 2 * ne(lg) # g in undirected + @test Graphs.is_directed(g) + for e in Graphs.edges(lg) + i, j = src(e), dst(e) + @test has_edge(g, i, j) + @test has_edge(g, j, i) + end + + @testset "SimpleGraph{Int32}" begin + g = GNNGraph(SimpleGraph{Int32}(6), graph_type = GRAPH_T) + @test g.num_nodes == 6 + end +end + +@testset "Features" begin + g = GNNGraph(sprand(10, 10, 0.3), graph_type = GRAPH_T) + + # default names + X = rand(10, g.num_nodes) + E = rand(10, g.num_edges) + U = rand(10, g.num_graphs) + + g = GNNGraph(g, ndata = X, edata = E, gdata = U) + @test g.ndata.x === X + @test g.edata.e === E + @test g.gdata.u === U + @test g.x === g.ndata.x + @test g.e === g.edata.e + @test g.u === g.gdata.u + + # Check no args + g = GNNGraph(g) + @test g.ndata.x === X + @test g.edata.e === E + @test g.gdata.u === U + + # multiple features names + g = GNNGraph(g, ndata = (x2 = 2X, g.ndata...), edata = (e2 = 2E, g.edata...), + gdata = (u2 = 2U, g.gdata...)) + @test g.ndata.x === X + @test g.edata.e === E + @test g.gdata.u === U + @test g.ndata.x2 ≈ 2X + @test g.edata.e2 ≈ 2E + @test g.gdata.u2 ≈ 2U + @test g.x === g.ndata.x + @test g.e === g.edata.e + @test g.u === g.gdata.u + @test g.x2 === g.ndata.x2 + @test g.e2 === g.edata.e2 + @test g.u2 === g.gdata.u2 + + # Dimension checks + @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata = rand(29), + graph_type = GRAPH_T) + @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata = rand(2, 29), + graph_type = GRAPH_T) + @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), + edata = (; x = rand(30), y = rand(29)), + graph_type = GRAPH_T) + + # Copy features on reverse edge + e = rand(30) + g = GNNGraph(erdos_renyi(10, 30), edata = e, graph_type = GRAPH_T) + @test g.edata.e == [e; e] + + # non-array global + g = rand_graph(10, 30, gdata = "ciao", graph_type = GRAPH_T) + @test g.gdata.u == "ciao" + + # vectors stays vectors + g = rand_graph(10, 30, ndata = rand(10), + edata = rand(30), + gdata = (u = rand(2), z = rand(1), q = 1), + graph_type = GRAPH_T) + @test size(g.ndata.x) == (10,) + @test size(g.edata.e) == (30,) + @test size(g.gdata.u) == (2, 1) + @test size(g.gdata.z) == (1,) + @test g.gdata.q === 1 + + # Error for non-array ndata + @test_throws AssertionError rand_graph(10, 30, ndata = "ciao", graph_type = GRAPH_T) + @test_throws AssertionError rand_graph(10, 30, ndata = 1, graph_type = GRAPH_T) + + # Error for Ambiguous getproperty + g = rand_graph(10, 20, ndata = rand(2, 10), edata = (; x = rand(3, 20)), + graph_type = GRAPH_T) + @test size(g.ndata.x) == (2, 10) + @test size(g.edata.x) == (3, 20) + @test_throws ArgumentError g.x +end + +@testset "MLUtils and DataLoader compat" begin + n, m, num_graphs = 10, 30, 50 + X = rand(10, n) + E = rand(10, m) + U = rand(10, 1) + data = [rand_graph(n, m, ndata = X, edata = E, gdata = U, graph_type = GRAPH_T) + for _ in 1:num_graphs] + g = MLUtils.batch(data) + + @testset "batch then pass to dataloader" begin + @test MLUtils.getobs(g, 3) == getgraph(g, 3) + @test MLUtils.getobs(g, 3:5) == getgraph(g, 3:5) + @test MLUtils.numobs(g) == g.num_graphs + + d = MLUtils.DataLoader(g, batchsize = 2, shuffle = false) + @test first(d) == getgraph(g, 1:2) + end + + @testset "pass to dataloader and no automatic collation" begin + @test MLUtils.getobs(data, 3) == data[3] + @test MLUtils.getobs(data, 3:5) isa Vector{<:GNNGraph} + @test MLUtils.getobs(data, 3:5) == [data[3], data[4], data[5]] + @test MLUtils.numobs(data) == g.num_graphs + + d = MLUtils.DataLoader(data, batchsize = 2, shuffle = false) + @test first(d) == [data[1], data[2]] + end +end + +@testset "Graphs.jl integration" begin + g = GNNGraph(erdos_renyi(10, 20), graph_type = GRAPH_T) + @test g isa Graphs.AbstractGraph +end + +@testset "==" begin + g1 = rand_graph(5, 6, ndata = rand(5), edata = rand(6), graph_type = GRAPH_T) + @test g1 == g1 + @test g1 == deepcopy(g1) + @test g1 !== deepcopy(g1) + + g2 = GNNGraph(g1, graph_type = GRAPH_T) + @test g1 == g2 + @test g1 === g2 # this is true since GNNGraph is immutable + + g2 = GNNGraph(g1, ndata = rand(5), graph_type = GRAPH_T) + @test g1 != g2 + @test g1 !== g2 + + g2 = GNNGraph(g1, edata = rand(6), graph_type = GRAPH_T) + @test g1 != g2 + @test g1 !== g2 +end + +@testset "hash" begin + g1 = rand_graph(5, 6, ndata = rand(5), edata = rand(6), graph_type = GRAPH_T) + @test hash(g1) == hash(g1) + @test hash(g1) == hash(deepcopy(g1)) + @test hash(g1) == hash(GNNGraph(g1, ndata = g1.ndata, graph_type = GRAPH_T)) + @test hash(g1) == hash(GNNGraph(g1, ndata = g1.ndata, graph_type = GRAPH_T)) + @test hash(g1) != hash(GNNGraph(g1, ndata = rand(5), graph_type = GRAPH_T)) + @test hash(g1) != hash(GNNGraph(g1, edata = rand(6), graph_type = GRAPH_T)) +end + +@testset "copy" begin + g1 = rand_graph(10, 4, ndata = rand(2, 10), graph_type = GRAPH_T) + g2 = copy(g1) + @test g1 === g2 # shallow copies are identical for immutable objects + + g2 = copy(g1, deep = true) + @test g1 == g2 + @test g1 !== g2 +end + +## Cannot test this because DataStore is not an ordered collection +## Uncomment when/if it will be based on OrderedDict +# @testset "show" begin +# @test sprint(show, rand_graph(10, 20)) == "GNNGraph(10, 20) with no data" +# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10))) == "GNNGraph(10, 20) with x: 5×10 data" +# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1)))) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20, (q: 1×1, p: 3×1) data" +# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10),))) == "GNNGraph(10, 20) with a: 5×10 data" +# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20))) == "GNNGraph(10, 20) with x: 5×10, e: 2×20 data" +# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1))) == "GNNGraph(10, 20) with x: 5×10, u: 1×1 data" +# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1))) == "GNNGraph(10, 20) with x: 5×10, (e: 2×20, f: 2×20, h: 3×20), u: 1×1 data" +# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20))) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20 data" +# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5,5, 10), b=rand(3,2, 10)), edata=rand(2, 20))) == "GNNGraph(10, 20) with (a: 5×5×10, b: 3×2×10), e: 2×20 data" +# end + +# @testset "show plain/text compact true" begin +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => true) == "GNNGraph(10, 20) with no data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1))); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20, (q: 1×1, p: 3×1) data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10),)); context=:compact => true) == "GNNGraph(10, 20) with a: 5×10 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, e: 2×20 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, u: 1×1 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, (e: 2×20, f: 2×20, h: 3×20), u: 1×1 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20 data" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5,5, 10), b=rand(3,2, 10)), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×5×10, b: 3×2×10), e: 2×20 data" +# end + +# @testset "show plain/text compact false" begin +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1))); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}\n\tb = 3×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}\n gdata:\n\tq = 1×1 Matrix{Float64}\n\tp = 3×1 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10),)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n gdata:\n\tu = 1×1 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}\n\tf = 2×20 Matrix{Float64}\n\th = 3×20 Matrix{Float64}\n gdata:\n\tu = 1×1 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}\n\tb = 3×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 5, 10), b=rand(3, 2, 10)), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×5×10 Array{Float64, 3}\n\tb = 3×2×10 Array{Float64, 3}\n edata:\n\te = 2×20 Matrix{Float64}" +# end + +[.\GNNGraphs\test\gnnheterograph.jl] + + +@testset "Empty constructor" begin + g = GNNHeteroGraph() + @test isempty(g.num_nodes) + g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4])) + @test g.num_nodes[:user] == 3 + @test g.num_nodes[:actor] == 9 + @test g.num_edges[(:user, :like, :actor)] == 5 +end + +@testset "Constructor from pairs" begin + hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3,4], [3,2,1,5])) + @test hg.num_nodes == Dict(:A => 4, :B => 5) + @test hg.num_edges == Dict((:A, :e1, :B) => 4) + + hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3], [3,2,1]), + (:A, :e2, :C) => ([1,2,3], [4,5,6])) + @test hg.num_nodes == Dict(:A => 3, :B => 3, :C => 6) + @test hg.num_edges == Dict((:A, :e1, :B) => 3, (:A, :e2, :C) => 3) +end + +@testset "Generation" begin + hg = rand_heterograph(Dict(:A => 10, :B => 20), + Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10)) + + @test hg.num_nodes == Dict(:A => 10, :B => 20) + @test hg.num_edges == Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10) + @test hg.graph_indicator === nothing + @test hg.num_graphs == 1 + @test hg.ndata isa Dict{Symbol, DataStore} + @test hg.edata isa Dict{Tuple{Symbol, Symbol, Symbol}, DataStore} + @test isempty(hg.gdata) + @test sort(hg.ntypes) == [:A, :B] + @test sort(hg.etypes) == [(:A, :rel1, :B), (:B, :rel2, :A)] + +end + +@testset "features" begin + hg = rand_heterograph(Dict(:A => 10, :B => 20), + Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), + ndata = Dict(:A => rand(2, 10), + :B => (x = rand(3, 20), y = rand(4, 20))), + edata = Dict((:A, :rel1, :B) => rand(5, 30)), + gdata = 1) + + @test size(hg.ndata[:A].x) == (2, 10) + @test size(hg.ndata[:B].x) == (3, 20) + @test size(hg.ndata[:B].y) == (4, 20) + @test size(hg.edata[(:A, :rel1, :B)].e) == (5, 30) + @test hg.gdata == DataStore(u = 1) + +end + +@testset "indexing syntax" begin + g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7])) + g[:movie].z = rand(Float32, 64, 13); + g[:user, :rate, :movie].e = rand(Float32, 64, 4); + g[:user].x = rand(Float32, 64, 3); + @test size(g.ndata[:user].x) == (64, 3) + @test size(g.ndata[:movie].z) == (64, 13) + @test size(g.edata[(:user, :rate, :movie)].e) == (64, 4) +end + + +@testset "simplified constructor" begin + hg = rand_heterograph((:A => 10, :B => 20), + ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), + ndata = (:A => rand(2, 10), + :B => (x = rand(3, 20), y = rand(4, 20))), + edata = (:A, :rel1, :B) => rand(5, 30), + gdata = 1) + + @test hg.num_nodes == Dict(:A => 10, :B => 20) + @test hg.num_edges == Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10) + @test hg.graph_indicator === nothing + @test hg.num_graphs == 1 + @test size(hg.ndata[:A].x) == (2, 10) + @test size(hg.ndata[:B].x) == (3, 20) + @test size(hg.ndata[:B].y) == (4, 20) + @test size(hg.edata[(:A, :rel1, :B)].e) == (5, 30) + @test hg.gdata == DataStore(u = 1) + + nA, nB = 10, 20 + edges1 = rand(1:nA, 20), rand(1:nB, 20) + edges2 = rand(1:nB, 30), rand(1:nA, 30) + hg = GNNHeteroGraph(((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2)) + @test hg.num_edges == Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) + + nA, nB = 10, 20 + edges1 = rand(1:nA, 20), rand(1:nB, 20) + edges2 = rand(1:nB, 30), rand(1:nA, 30) + hg = GNNHeteroGraph(((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2); + num_nodes = (:A => nA, :B => nB)) + @test hg.num_nodes == Dict(:A => 10, :B => 20) + @test hg.num_edges == Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) +end + +@testset "num_edge_types / num_node_types" begin + hg = rand_heterograph((:A => 10, :B => 20), + ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), + ndata = (:A => rand(2, 10), + :B => (x = rand(3, 20), y = rand(4, 20))), + edata = (:A, :rel1, :B) => rand(5, 30), + gdata = 1) + @test num_edge_types(hg) == 2 + @test num_node_types(hg) == 2 + + g = rand_graph(10, 20) + @test num_edge_types(g) == 1 + @test num_node_types(g) == 1 +end + +@testset "numobs" begin + hg = rand_heterograph((:A => 10, :B => 20), + ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), + ndata = (:A => rand(2, 10), + :B => (x = rand(3, 20), y = rand(4, 20))), + edata = (:A, :rel1, :B) => rand(5, 30), + gdata = 1) + @test MLUtils.numobs(hg) == 1 +end + +@testset "get/set node features" begin + d, n = 3, 5 + g = rand_bipartite_heterograph((n, 2*n), 15) + g[:A].x = rand(Float32, d, n) + g[:B].y = rand(Float32, d, 2*n) + + @test size(g[:A].x) == (d, n) + @test size(g[:B].y) == (d, 2*n) +end + +@testset "add_edges" begin + d, n = 3, 5 + g = rand_bipartite_heterograph((n, 2 * n), 15) + s, t = [1, 2, 3], [3, 2, 1] + ## Keep the same ntypes - construct with args + g1 = add_edges(g, (:A, :rel1, :B), s, t) + @test num_node_types(g1) == 2 + @test num_edge_types(g1) == 3 + for i in eachindex(s, t) + @test has_edge(g1, (:A, :rel1, :B), s[i], t[i]) + end + # no change to num_nodes + @test g1.num_nodes[:A] == n + @test g1.num_nodes[:B] == 2n + + ## Keep the same ntypes - construct with a pair + g2 = add_edges(g, (:A, :rel1, :B) => (s, t)) + @test num_node_types(g2) == 2 + @test num_edge_types(g2) == 3 + for i in eachindex(s, t) + @test has_edge(g2, (:A, :rel1, :B), s[i], t[i]) + end + # no change to num_nodes + @test g2.num_nodes[:A] == n + @test g2.num_nodes[:B] == 2n + + ## New ntype with num_nodes (applies only to the new ntype) and edata + edata = rand(Float32, d, length(s)) + g3 = add_edges(g, + (:A, :rel1, :C) => (s, t); + num_nodes = Dict(:A => 1, :B => 1, :C => 10), + edata) + @test num_node_types(g3) == 3 + @test num_edge_types(g3) == 3 + for i in eachindex(s, t) + @test has_edge(g3, (:A, :rel1, :C), s[i], t[i]) + end + # added edata + @test g3.edata[(:A, :rel1, :C)].e == edata + # no change to existing num_nodes + @test g3.num_nodes[:A] == n + @test g3.num_nodes[:B] == 2n + # new num_nodes added as per kwarg + @test g3.num_nodes[:C] == 10 +end + +@testset "add self loops" begin + g1 = GNNHeteroGraph((:A, :to, :B) => ([1,2,3,4], [3,2,1,5])) + g2 = add_self_loops(g1, (:A, :to, :B)) + @test g2.num_edges[(:A, :to, :B)] === g1.num_edges[(:A, :to, :B)] + g1 = GNNHeteroGraph((:A, :to, :A) => ([1,2,3,4], [3,2,1,5])) + g2 = add_self_loops(g1, (:A, :to, :A)) + @test g2.num_edges[(:A, :to, :A)] === g1.num_edges[(:A, :to, :A)] + g1.num_nodes[(:A)] +end + +## Cannot test this because DataStore is not an ordered collection +## Uncomment when/if it will be based on OrderedDict +# @testset "show" begin +# num_nodes = Dict(:A => 10, :B => 20); +# edges1 = rand(1:num_nodes[:A], 20), rand(1:num_nodes[:B], 20) +# edges2 = rand(1:num_nodes[:B], 30), rand(1:num_nodes[:A], 30) +# eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2) +# ndata = Dict(:A => (x = rand(2, num_nodes[:A]), y = rand(3, num_nodes[:A])),:B => rand(10, num_nodes[:B])) +# edata= Dict((:A, :rel1, :B) => (x = rand(2, 20), y = rand(3, 20)),(:B, :rel2, :A) => rand(10, 30)) +# hg1 = GNNHeteroGraph(eindex; num_nodes) +# hg2 = GNNHeteroGraph(eindex; num_nodes, ndata,edata) +# hg3 = GNNHeteroGraph(eindex; num_nodes, ndata) +# @test sprint(show, hg1) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))" +# @test sprint(show, hg2) == sprint(show, hg1) +# @test sprint(show, MIME("text/plain"), hg1; context=:compact => true) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))" +# @test sprint(show, MIME("text/plain"), hg2; context=:compact => true) == sprint(show, MIME("text/plain"), hg1;context=:compact => true) +# @test sprint(show, MIME("text/plain"), hg1; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)" +# @test sprint(show, MIME("text/plain"), hg2; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}\n edata:\n\t(:A, :rel1, :B) => (x = 2×20 Matrix{Float64}, y = 3×20 Matrix{Float64})\n\t(:B, :rel2, :A) => e = 10×30 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), hg3; context=:compact => false) =="GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}" +# @test sprint(show, MIME("text/plain"), hg2; context=:compact => false) != sprint(show, MIME("text/plain"), hg3; context=:compact => false) +# end + +[.\GNNGraphs\test\mldatasets.jl] +dataset = Cora() +classes = dataset.metadata["classes"] +gml = dataset[1] +g = mldataset2gnngraph(dataset) +@test g isa GNNGraph +@test g.num_nodes == gml.num_nodes +@test g.num_edges == gml.num_edges +@test edge_index(g) === gml.edge_index + +[.\GNNGraphs\test\operators.jl] +@testset "intersect" begin + g = rand_graph(10, 20, graph_type = GRAPH_T) + @test intersect(g, g).num_edges == 20 +end + +[.\GNNGraphs\test\query.jl] +@testset "is_bidirected" begin + g = rand_graph(10, 20, bidirected = true, graph_type = GRAPH_T) + @test is_bidirected(g) + + g = rand_graph(10, 20, bidirected = false, graph_type = GRAPH_T) + @test !is_bidirected(g) +end + +@testset "has_multi_edges" begin if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 2, 2, 4] + g = GNNGraph(s, t, graph_type = GRAPH_T) + @test has_multi_edges(g) + + s = [1, 2, 2, 3] + t = [2, 1, 2, 4] + g = GNNGraph(s, t, graph_type = GRAPH_T) + @test !has_multi_edges(g) +end end + +@testset "edges" begin + g = rand_graph(4, 10, graph_type = GRAPH_T) + @test edgetype(g) <: Graphs.Edge + for e in edges(g) + @test e isa Graphs.Edge + end +end + +@testset "has_isolated_nodes" begin + s = [1, 2, 3] + t = [2, 3, 2] + g = GNNGraph(s, t, graph_type = GRAPH_T) + @test has_isolated_nodes(g) == false + @test has_isolated_nodes(g, dir = :in) == true +end + +@testset "has_self_loops" begin + s = [1, 1, 2, 3] + t = [2, 2, 2, 4] + g = GNNGraph(s, t, graph_type = GRAPH_T) + @test has_self_loops(g) + + s = [1, 1, 2, 3] + t = [2, 2, 3, 4] + g = GNNGraph(s, t, graph_type = GRAPH_T) + @test !has_self_loops(g) +end + +@testset "degree" begin + @testset "unweighted" begin + s = [1, 1, 2, 3] + t = [2, 2, 2, 4] + g = GNNGraph(s, t, graph_type = GRAPH_T) + + @test degree(g) isa Vector{Int} + @test degree(g) == degree(g; dir = :out) == [2, 1, 1, 0] # default is outdegree + @test degree(g; dir = :in) == [0, 3, 0, 1] + @test degree(g; dir = :both) == [2, 4, 1, 1] + @test eltype(degree(g, Float32)) == Float32 + + if TEST_GPU + dev = CUDADevice() #TODO replace with `gpu_device()` + g_gpu = g |> dev + d = degree(g) + d_gpu = degree(g_gpu) + @test d_gpu isa CuVector{Int} + @test Array(d_gpu) == d + end + end + + @testset "weighted" begin + # weighted degree + s = [1, 1, 2, 3] + t = [2, 2, 2, 4] + eweight = Float32[0.1, 2.1, 1.2, 1] + g = GNNGraph((s, t, eweight), graph_type = GRAPH_T) + @test degree(g) ≈ [2.2, 1.2, 1.0, 0.0] + d = degree(g, edge_weight = false) + if GRAPH_T == :coo + @test d == [2, 1, 1, 0] + else + # Adjacency matrix representation cannot disambiguate multiple edges + # and edge weights + @test d == [1, 1, 1, 0] + end + @test eltype(d) <: Integer + @test degree(g, edge_weight = 2 * eweight) ≈ [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo) + + if TEST_GPU + dev = CUDADevice() #TODO replace with `gpu_device()` + g_gpu = g |> dev + d = degree(g) + d_gpu = degree(g_gpu) + @test d_gpu isa CuVector{Float32} + @test Array(d_gpu) ≈ d + end + @testset "gradient" begin + gw = gradient(eweight) do w + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + sum(degree(g, edge_weight = false)) + end[1] + + @test gw === nothing + + gw = gradient(eweight) do w + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + sum(degree(g, edge_weight = true)) + end[1] + + @test gw isa AbstractVector{Float32} + @test gw isa Vector{Float32} broken = (GRAPH_T == :sparse) + @test gw ≈ ones(Float32, length(gw)) + + gw = gradient(eweight) do w + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + sum(degree(g, dir=:both, edge_weight=true)) + end[1] + + @test gw isa AbstractVector{Float32} + @test gw isa Vector{Float32} broken = (GRAPH_T == :sparse) + @test gw ≈ 2 * ones(Float32, length(gw)) + + grad = gradient(g) do g + sum(degree(g, edge_weight=false)) + end[1] + @test grad === nothing + + grad = gradient(g) do g + sum(degree(g, edge_weight=true)) + end[1] + + if GRAPH_T == :coo + @test grad.graph[3] isa Vector{Float32} + @test grad.graph[3] ≈ ones(Float32, length(gw)) + else + if GRAPH_T == :sparse + @test grad.graph isa AbstractSparseMatrix{Float32} + end + @test grad.graph isa AbstractMatrix{Float32} + + @test grad.graph ≈ [0.0 1.0 0.0 0.0 + 0.0 1.0 0.0 0.0 + 0.0 0.0 0.0 1.0 + 0.0 0.0 0.0 0.0] + end + + @testset "directed, degree dir=$dir" for dir in [:in, :out, :both] + g = rand_graph(10, 30, bidirected=false) + w = rand(Float32, 30) + s, t = edge_index(g) + + grad = gradient(w) do w + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + sum(tanh.(degree(g; dir, edge_weight=true))) + end[1] + + ngrad = ngradient(w) do w + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + sum(tanh.(degree(g; dir, edge_weight=true))) + end[1] + + @test grad ≈ ngrad + end + + @testset "heterognn, degree" begin + g = GNNHeteroGraph((:A, :to, :B) => ([1,1,2,3], [7,13,5,7])) + @test degree(g, (:A, :to, :B), dir = :out) == [2, 1, 1] + @test degree(g, (:A, :to, :B), dir = :in) == [0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1] + @test degree(g, (:A, :to, :B)) == [2, 1, 1] + end + end + end +end + +@testset "laplacian_matrix" begin + g = rand_graph(10, 30, graph_type = GRAPH_T) + A = adjacency_matrix(g) + D = Diagonal(vec(sum(A, dims = 2))) + L = laplacian_matrix(g) + @test eltype(L) == eltype(g) + @test L ≈ D - A +end + +@testset "laplacian_lambda_max" begin + s = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + t = [2, 3, 4, 5, 1, 5, 1, 2, 3, 4] + g = GNNGraph(s, t) + @test laplacian_lambda_max(g) ≈ Float32(1.809017) + data1 = [g for i in 1:5] + gall1 = MLUtils.batch(data1) + @test laplacian_lambda_max(gall1) ≈ [Float32(1.809017) for i in 1:5] + data2 = [rand_graph(10, 20) for i in 1:3] + gall2 = MLUtils.batch(data2) + @test length(laplacian_lambda_max(gall2, add_self_loops=true)) == 3 +end + +@testset "adjacency_matrix" begin + a = sprand(5, 5, 0.5) + abin = map(x -> x > 0 ? 1 : 0, a) + + g = GNNGraph(a, graph_type = GRAPH_T) + A = adjacency_matrix(g, Float32) + @test A ≈ a + @test eltype(A) == Float32 + + Abin = adjacency_matrix(g, Float32, weighted = false) + @test Abin ≈ abin + @test eltype(Abin) == Float32 + + @testset "gradient" begin + s = [1, 2, 3] + t = [2, 3, 1] + w = [0.1, 0.1, 0.2] + gw = gradient(w) do w + g = GNNGraph(s, t, w, graph_type = GRAPH_T) + A = adjacency_matrix(g, weighted = false) + sum(A) + end[1] + @test gw === nothing + + gw = gradient(w) do w + g = GNNGraph(s, t, w, graph_type = GRAPH_T) + A = adjacency_matrix(g, weighted = true) + sum(A) + end[1] + + @test gw == [1, 1, 1] + end + + @testset "khop_adj" begin + s = [1, 2, 3] + t = [2, 3, 1] + w = [0.1, 0.1, 0.2] + g = GNNGraph(s, t, w) + @test khop_adj(g, 2) == adjacency_matrix(g) * adjacency_matrix(g) + @test khop_adj(g, 2, Int8; weighted = false) == sparse([0 0 1; 1 0 0; 0 1 0]) + @test khop_adj(g, 2, Int8; dir = in, weighted = false) == + sparse([0 0 1; 1 0 0; 0 1 0]') + @test khop_adj(g, 1) == adjacency_matrix(g) + @test eltype(khop_adj(g, 4)) == Float64 + @test eltype(khop_adj(g, 10, Float32)) == Float32 + end +end + +if GRAPH_T == :coo + @testset "HeteroGraph" begin + @testset "graph_indicator" begin + gs = [rand_heterograph(Dict(:user => 10, :movie => 20, :actor => 30), + Dict((:user,:like,:movie) => 10, + (:actor,:rate,:movie)=>20)) for _ in 1:3] + g = MLUtils.batch(gs) + @test graph_indicator(g) == Dict(:user => [repeat([1], 10); repeat([2], 10); repeat([3], 10)], + :movie => [repeat([1], 20); repeat([2], 20); repeat([3], 20)], + :actor => [repeat([1], 30); repeat([2], 30); repeat([3], 30)]) + @test graph_indicator(g, :movie) == [repeat([1], 20); repeat([2], 20); repeat([3], 20)] + end + end +end + + +[.\GNNGraphs\test\runtests.jl] +using CUDA, cuDNN +using GNNGraphs +using GNNGraphs: getn, getdata +using Functors +using LinearAlgebra, Statistics, Random +using NNlib +import MLUtils +import StatsBase +using SparseArrays +using Graphs +using Zygote +using Test +using MLDatasets +using InlineStrings # not used but with the import we test #98 and #104 +using SimpleWeightedGraphs +using MLDataDevices: gpu_device, cpu_device, get_device +using MLDataDevices: CUDADevice + +CUDA.allowscalar(false) + +const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} + +ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets + +include("test_utils.jl") + +tests = [ + "chainrules", + "datastore", + "gnngraph", + "convert", + "transform", + "operators", + "generate", + "query", + "sampling", + "gnnheterograph", + "temporalsnapshotsgnngraph", + "mldatasets", + "ext/SimpleWeightedGraphs" +] + +!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") + +for graph_type in (:coo, :dense, :sparse) + @info "Testing graph format :$graph_type" + global GRAPH_T = graph_type + global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) + # global GRAPH_T = :sparse + # global TEST_GPU = false + + @testset "$t" for t in tests + include("$t.jl") + end +end + +[.\GNNGraphs\test\sampling.jl] +if GRAPH_T == :coo + @testset "sample_neighbors" begin + # replace = false + dir = :in + nodes = 2:3 + g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) + sg = sample_neighbors(g, nodes; dir) + @test sg.num_nodes == 10 + @test sg.num_edges == sum(degree(g, i; dir) for i in nodes) + @test size(sg.edata.EID) == (sg.num_edges,) + @test length(union(sg.edata.EID)) == length(sg.edata.EID) + adjlist = adjacency_list(g; dir) + s, t = edge_index(sg) + @test all(t .∈ Ref(nodes)) + for i in nodes + @test sort(neighbors(sg, i; dir)) == sort(neighbors(g, i; dir)) + end + + # replace = true + dir = :out + nodes = 2:3 + K = 2 + g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) + sg = sample_neighbors(g, nodes, K; dir, replace = true) + @test sg.num_nodes == 10 + @test sg.num_edges == sum(K for i in nodes) + @test size(sg.edata.EID) == (sg.num_edges,) + adjlist = adjacency_list(g; dir) + s, t = edge_index(sg) + @test all(s .∈ Ref(nodes)) + for i in nodes + @test issubset(neighbors(sg, i; dir), adjlist[i]) + end + + # dropnodes = true + dir = :in + nodes = 2:3 + g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) + g = GNNGraph(g, ndata = (x1 = rand(10),), edata = (e1 = rand(40),)) + sg = sample_neighbors(g, nodes; dir, dropnodes = true) + @test sg.num_edges == sum(degree(g, i; dir) for i in nodes) + @test size(sg.edata.EID) == (sg.num_edges,) + @test size(sg.ndata.NID) == (sg.num_nodes,) + @test sg.edata.e1 == g.edata.e1[sg.edata.EID] + @test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID] + @test length(union(sg.ndata.NID)) == length(sg.ndata.NID) + end +end +[.\GNNGraphs\test\temporalsnapshotsgnngraph.jl] +@testset "Constructor array TemporalSnapshotsGNNGraph" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + @test tsg.num_nodes == [10 for i in 1:5] + @test tsg.num_edges == [20 for i in 1:5] + wrsnapshots = [rand_graph(10,20), rand_graph(12,22)] + @test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots) +end + +@testset "==" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg1 = TemporalSnapshotsGNNGraph(snapshots) + tsg2 = TemporalSnapshotsGNNGraph(snapshots) + @test tsg1 == tsg2 + tsg3 = TemporalSnapshotsGNNGraph(snapshots[1:3]) + @test tsg1 != tsg3 + @test tsg1 !== tsg3 +end + +@testset "getindex" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + @test tsg[3] == snapshots[3] + @test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata) +end + +@testset "getproperty" begin + x = rand(10) + snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + @test tsg.tgdata == DataStore() + @test tsg.x == tsg.ndata.x == [x for i in 1:5] + @test_throws KeyError tsg.ndata.w + @test_throws ArgumentError tsg.w +end + +@testset "add/remove_snapshot" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + g = rand_graph(10, 20) + tsg = add_snapshot(tsg, 3, g) + @test tsg.num_nodes == [10 for i in 1:6] + @test tsg.num_edges == [20 for i in 1:6] + @test tsg.snapshots[3] == g + tsg = remove_snapshot(tsg, 3) + @test tsg.num_nodes == [10 for i in 1:5] + @test tsg.num_edges == [20 for i in 1:5] + @test tsg.snapshots == snapshots +end + +@testset "add/remove_snapshot" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + g = rand_graph(10, 20) + tsg2 = add_snapshot(tsg, 3, g) + @test tsg2.num_nodes == [10 for i in 1:6] + @test tsg2.num_edges == [20 for i in 1:6] + @test tsg2.snapshots[3] == g + @test tsg2.num_snapshots == 6 + @test tsg.num_nodes == [10 for i in 1:5] + @test tsg.num_edges == [20 for i in 1:5] + @test tsg.snapshots[2] === tsg2.snapshots[2] + @test tsg.snapshots[3] === tsg2.snapshots[4] + @test length(tsg.snapshots) == 5 + @test tsg.num_snapshots == 5 + + tsg21 = add_snapshot(tsg2, 7, g) + @test tsg21.num_snapshots == 7 + + tsg3 = remove_snapshot(tsg, 3) + @test tsg3.num_nodes == [10 for i in 1:4] + @test tsg3.num_edges == [20 for i in 1:4] + @test tsg3.snapshots == snapshots[[1,2,4,5]] +end + + +# @testset "add/remove_snapshot!" begin +# snapshots = [rand_graph(10, 20) for i in 1:5] +# tsg = TemporalSnapshotsGNNGraph(snapshots) +# g = rand_graph(10, 20) +# tsg2 = add_snapshot!(tsg, 3, g) +# @test tsg2.num_nodes == [10 for i in 1:6] +# @test tsg2.num_edges == [20 for i in 1:6] +# @test tsg2.snapshots[3] == g +# @test tsg2.num_snapshots == 6 +# @test tsg2 === tsg + +# tsg3 = remove_snapshot!(tsg, 3) +# @test tsg3.num_nodes == [10 for i in 1:4] +# @test tsg3.num_edges == [20 for i in 1:4] +# @test length(tsg3.snapshots) === 4 +# @test tsg3 === tsg +# end + +@testset "show" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data" + @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data" + @test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5" + tsg.tgdata.x=rand(4) + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data" +end + +if TEST_GPU + @testset "gpu" begin + snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + tsg.tgdata.x = rand(5) + dev = CUDADevice() #TODO replace with `gpu_device()` + tsg = tsg |> dev + @test tsg.snapshots[1].ndata.x isa CuArray + @test tsg.snapshots[end].ndata.x isa CuArray + @test tsg.tgdata.x isa CuArray + @test tsg.num_nodes isa CuArray + @test tsg.num_edges isa CuArray + end +end + +[.\GNNGraphs\test\test_utils.jl] +using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA +CUDA.allowscalar(false) + +function ngradient(f, x...) + fdm = central_fdm(5, 1) + return FiniteDifferences.grad(fdm, f, x...) +end + +const rule_config = Zygote.ZygoteRuleConfig() + +# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed +function FiniteDifferences.to_vec(x::Integer) + Integer_from_vec(v) = x + return Int[x], Integer_from_vec +end + +# Test that forward pass on cpu and gpu are the same. +# Tests also gradient on cpu and gpu comparing with +# finite difference methods. +# Test gradients with respects to layer weights and to input. +# If `g` has edge features, it is assumed that the layer can +# use them in the forward pass as `l(g, x, e)`. +# Test also gradient with respect to `e`. +function test_layer(l, g::GNNGraph; atol = 1e-5, rtol = 1e-5, + exclude_grad_fields = [], + verbose = false, + test_gpu = TEST_GPU, + outsize = nothing, + outtype = :node) + + # TODO these give errors, probably some bugs in ChainRulesTestUtils + # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false) + # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false) + + isnothing(node_features(g)) && error("Plese add node data to the input graph") + fdm = central_fdm(5, 1) + + x = node_features(g) + e = edge_features(g) + use_edge_feat = !isnothing(e) + + x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad + xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g]) + + f(l, g::GNNGraph) = l(g) + f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x) + + loss(l, g::GNNGraph) = + if outtype == :node + sum(node_features(f(l, g))) + elseif outtype == :edge + sum(edge_features(f(l, g))) + elseif outtype == :graph + sum(graph_features(f(l, g))) + elseif outtype == :node_edge + gnew = f(l, g) + sum(node_features(gnew)) + sum(edge_features(gnew)) + end + + function loss(l, g::GNNGraph, x, e) + y = f(l, g, x, e) + if outtype == :node_edge + return sum(y[1]) + sum(y[2]) + else + return sum(y) + end + end + + # TEST OUTPUT + y = f(l, g, x, e) + if outtype == :node_edge + @assert y isa Tuple + @test eltype(y[1]) == eltype(x) + @test eltype(y[2]) == eltype(e) + @test all(isfinite, y[1]) + @test all(isfinite, y[2]) + if !isnothing(outsize) + @test size(y[1]) == outsize[1] + @test size(y[2]) == outsize[2] + end + else + @test eltype(y) == eltype(x) + @test all(isfinite, y) + if !isnothing(outsize) + @test size(y) == outsize + end + end + + # test same output on different graph formats + gcoo = GNNGraph(g, graph_type = :coo) + ycoo = f(l, gcoo, x, e) + if outtype == :node_edge + @test ycoo[1] ≈ y[1] + @test ycoo[2] ≈ y[2] + else + @test ycoo ≈ y + end + + g′ = f(l, g) + if outtype == :node + @test g′.ndata.x ≈ y + elseif outtype == :edge + @test g′.edata.e ≈ y + elseif outtype == :graph + @test g′.gdata.u ≈ y + elseif outtype == :node_edge + @test g′.ndata.x ≈ y[1] + @test g′.edata.e ≈ y[2] + else + @error "wrong outtype $outtype" + end + if test_gpu + ygpu = f(lgpu, ggpu, xgpu, egpu) + if outtype == :node_edge + @test ygpu[1] isa CuArray + @test eltype(ygpu[1]) == eltype(xgpu) + @test Array(ygpu[1]) ≈ y[1] + @test ygpu[2] isa CuArray + @test eltype(ygpu[2]) == eltype(xgpu) + @test Array(ygpu[2]) ≈ y[2] + else + @test ygpu isa CuArray + @test eltype(ygpu) == eltype(xgpu) + @test Array(ygpu) ≈ y + end + end + + # TEST x INPUT GRADIENT + x̄ = gradient(x -> loss(l, g, x, e), x)[1] + x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1] + @test eltype(x̄) == eltype(x) + @test x̄≈x̄_fd atol=atol rtol=rtol + + if test_gpu + x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1] + @test x̄gpu isa CuArray + @test eltype(x̄gpu) == eltype(x) + @test Array(x̄gpu)≈x̄ atol=atol rtol=rtol + end + + # TEST e INPUT GRADIENT + if e !== nothing + verbose && println("Test e gradient cpu") + ē = gradient(e -> loss(l, g, x, e), e)[1] + ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1] + @test eltype(ē) == eltype(e) + @test ē≈ē_fd atol=atol rtol=rtol + + if test_gpu + verbose && println("Test e gradient gpu") + ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1] + @test ēgpu isa CuArray + @test eltype(ēgpu) == eltype(ē) + @test Array(ēgpu)≈ē atol=atol rtol=rtol + end + end + + # TEST LAYER GRADIENT - l(g, x, e) + l̄ = gradient(l -> loss(l, g, x, e), l)[1] + l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1] + test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) + + if test_gpu + l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1] + test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose) + end + + # TEST LAYER GRADIENT - l(g) + l̄ = gradient(l -> loss(l, g), l)[1] + test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) + + return true +end + +function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5, + exclude_grad_fields = [], + verbose = false) + l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue + l̄fd = l̄fd isa Base.RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue + + for f in fieldnames(typeof(l)) + f ∈ exclude_grad_fields && continue + verbose && println("Test gradient of field $f...") + x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f) + test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose) + verbose && println("... field $f done!") + end + return true +end + +function test_approx_structs(x, g::Nothing, gfd; atol, rtol, kws...) + # finite diff gradients has to be zero if present + @test !(gfd isa AbstractArray) || isapprox(gfd, fill!(similar(gfd), 0); atol, rtol) +end + +function test_approx_structs(x::Union{AbstractArray, Number}, + g::Union{AbstractArray, Number}, gfd; atol, rtol, kws...) + @test eltype(g) == eltype(x) + if x isa CuArray + @test g isa CuArray + g = Array(g) + end + @test g≈gfd atol=atol rtol=rtol +end + +""" + to32(m) + +Convert the `eltype` of model's float parameters to `Float32`. +Preserves integer arrays. +""" +to32(m) = _paramtype(Float32, m) + +""" + to64(m) + +Convert the `eltype` of model's float parameters to `Float64`. +Preserves integer arrays. +""" +to64(m) = _paramtype(Float64, m) + +struct GNNEltypeAdaptor{T} end + +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where T = convert(AbstractArray{T}, x) +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Integer}) where T = x +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x) + +_paramtype(::Type{T}, m) where T = fmap(adapt(GNNEltypeAdaptor{T}()), m) + +[.\GNNGraphs\test\transform.jl] +@testset "add self-loops" begin + A = [1 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + A2 = [2 1 0 0 + 0 1 1 0 + 0 0 1 1 + 1 0 0 1] + + g = GNNGraph(A; graph_type = GRAPH_T) + fg2 = add_self_loops(g) + @test adjacency_matrix(g) == A + @test g.num_edges == sum(A) + @test adjacency_matrix(fg2) == A2 + @test fg2.num_edges == sum(A2) +end + +@testset "batch" begin + g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), + graph_type = GRAPH_T) + g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) + g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) + + g12 = MLUtils.batch([g1, g2]) + g12b = blockdiag(g1, g2) + @test g12 == g12b + + g123 = MLUtils.batch([g1, g2, g3]) + @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] + + # Allow wider eltype + g123 = MLUtils.batch(GNNGraph[g1, g2, g3]) + @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] + + + s, t = edge_index(g123) + @test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]] + @test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]] + @test node_features(g123)[:, 11:14] ≈ node_features(g2) + + # scalar graph features + g1 = GNNGraph(g1, gdata = rand()) + g2 = GNNGraph(g2, gdata = rand()) + g3 = GNNGraph(g3, gdata = rand()) + g123 = MLUtils.batch([g1, g2, g3]) + @test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u] + + # Batch of batches + g123123 = MLUtils.batch([g123, g123]) + @test g123123.graph_indicator == + [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] + @test g123123.num_graphs == 6 +end + +@testset "unbatch" begin + g1 = rand_graph(10, 20, graph_type = GRAPH_T) + g2 = rand_graph(5, 10, graph_type = GRAPH_T) + g12 = MLUtils.batch([g1, g2]) + gs = MLUtils.unbatch([g1, g2]) + @test length(gs) == 2 + @test gs[1].num_nodes == 10 + @test gs[1].num_edges == 20 + @test gs[1].num_graphs == 1 + @test gs[2].num_nodes == 5 + @test gs[2].num_edges == 10 + @test gs[2].num_graphs == 1 +end + +@testset "batch/unbatch roundtrip" begin + n = 20 + c = 3 + ngraphs = 10 + gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n), + graph_type = GRAPH_T) + for _ in 1:ngraphs] + gall = MLUtils.batch(gs) + gs2 = MLUtils.unbatch(gall) + @test gs2[1] == gs[1] + @test gs2[end] == gs[end] +end + +@testset "getgraph" begin + g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), + graph_type = GRAPH_T) + g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) + g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) + g = MLUtils.batch([g1, g2, g3]) + + g2b, nodemap = getgraph(g, 2, nmap = true) + s, t = edge_index(g2b) + @test s == edge_index(g2)[1] + @test t == edge_index(g2)[2] + @test node_features(g2b) ≈ node_features(g2) + + g2c = getgraph(g, 2) + @test g2c isa GNNGraph{typeof(g.graph)} + + g1b, nodemap = getgraph(g1, 1, nmap = true) + @test g1b === g1 + @test nodemap == 1:(g1.num_nodes) +end + +@testset "remove_edges" begin + if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + w = [0.1, 0.2, 0.3, 0.4] + edata = ['a', 'b', 'c', 'd'] + g = GNNGraph(s, t, w, edata = edata, graph_type = GRAPH_T) + + # single edge removal + gnew = remove_edges(g, [1]) + new_s, new_t = edge_index(gnew) + @test gnew.num_edges == 3 + @test new_s == s[2:end] + @test new_t == t[2:end] + + # multiple edge removal + gnew = remove_edges(g, [1,2,4]) + new_s, new_t = edge_index(gnew) + new_w = get_edge_weight(gnew) + new_edata = gnew.edata.e + @test gnew.num_edges == 1 + @test new_s == [2] + @test new_t == [4] + @test new_w == [0.3] + @test new_edata == ['c'] + + # drop with probability + gnew = remove_edges(g, Float32(1.0)) + @test gnew.num_edges == 0 + + gnew = remove_edges(g, Float32(0.0)) + @test gnew.num_edges == g.num_edges + end +end + +@testset "add_edges" begin + if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + g = GNNGraph(s, t, graph_type = GRAPH_T) + snew = [1] + tnew = [4] + gnew = add_edges(g, snew, tnew) + @test gnew.num_edges == 5 + @test sort(inneighbors(gnew, 4)) == [1, 2] + + gnew2 = add_edges(g, (snew, tnew)) + @test gnew2 == gnew + @test get_edge_weight(gnew2) === nothing + + g = GNNGraph(s, t, edata = (e1 = rand(2, 4), e2 = rand(3, 4)), graph_type = GRAPH_T) + # @test_throws ErrorException add_edges(g, snew, tnew) + gnew = add_edges(g, snew, tnew, edata = (e1 = ones(2, 1), e2 = zeros(3, 1))) + @test all(gnew.edata.e1[:, 5] .== 1) + @test all(gnew.edata.e2[:, 5] .== 0) + + @testset "adding new nodes" begin + g = GNNGraph() + g = add_edges(g, ([1,3], [2, 1])) + @test g.num_nodes == 3 + @test g.num_edges == 2 + @test sort(inneighbors(g, 1)) == [3] + @test sort(outneighbors(g, 1)) == [2] + end + @testset "also add weights" begin + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + w = [1.0, 2.0, 3.0, 4.0] + snew = [1] + tnew = [4] + wnew = [5.] + + g = GNNGraph((s, t), graph_type = GRAPH_T) + gnew = add_edges(g, (snew, tnew, wnew)) + @test get_edge_weight(gnew) == [ones(length(s)); wnew] + + g = GNNGraph((s, t, w), graph_type = GRAPH_T) + gnew = add_edges(g, (snew, tnew, wnew)) + @test get_edge_weight(gnew) == [w; wnew] + end + end +end + +@testset "perturb_edges" begin if GRAPH_T == :coo + s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1] + g = GNNGraph((s, t)) + rng = MersenneTwister(42) + g_per = perturb_edges(rng, g, 0.5) + @test g_per.num_edges == 8 +end end + +@testset "remove_nodes" begin if GRAPH_T == :coo + #single node + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + eweights = [0.1, 0.2, 0.3, 0.4] + ndata = [1.0, 2.0, 3.0, 4.0, 5.0] + edata = ['a', 'b', 'c', 'd'] + + g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T) + + gnew = remove_nodes(g, [1]) + + snew = [1, 2] + tnew = [3, 4] + eweights_new = [0.3, 0.4] + ndata_new = [2.0, 3.0, 4.0, 5.0] + edata_new = ['c', 'd'] + + stest, ttest = edge_index(gnew) + eweightstest = get_edge_weight(gnew) + ndatatest = gnew.ndata.x + edatatest = gnew.edata.e + + + @test gnew.num_edges == 2 + @test gnew.num_nodes == 4 + @test snew == stest + @test tnew == ttest + @test eweights_new == eweightstest + @test ndata_new == ndatatest + @test edata_new == edatatest + + # multiple nodes + s = [1, 5, 2, 3] + t = [2, 3, 4, 5] + eweights = [0.1, 0.2, 0.3, 0.4] + ndata = [1.0, 2.0, 3.0, 4.0, 5.0] + edata = ['a', 'b', 'c', 'd'] + + g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T) + + gnew = remove_nodes(g, [1,4]) + snew = [3,2] + tnew = [2,3] + eweights_new = [0.2,0.4] + ndata_new = [2.0,3.0,5.0] + edata_new = ['b','d'] + + stest, ttest = edge_index(gnew) + eweightstest = get_edge_weight(gnew) + ndatatest = gnew.ndata.x + edatatest = gnew.edata.e + + @test gnew.num_edges == 2 + @test gnew.num_nodes == 3 + @test snew == stest + @test tnew == ttest + @test eweights_new == eweightstest + @test ndata_new == ndatatest + @test edata_new == edatatest +end end + +@testset "remove_nodes(g, p)" begin + if GRAPH_T == :coo + Random.seed!(42) + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + g = GNNGraph(s, t, graph_type = GRAPH_T) + + gnew = remove_nodes(g, 0.5) + @test gnew.num_nodes == 3 + + gnew = remove_nodes(g, 1.0) + @test gnew.num_nodes == 0 + + gnew = remove_nodes(g, 0.0) + @test gnew.num_nodes == 5 + end +end + +@testset "add_nodes" begin if GRAPH_T == :coo + g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T) + gnew = add_nodes(g, 5, ndata = ones(2, 5)) + @test gnew.num_nodes == g.num_nodes + 5 + @test gnew.num_edges == g.num_edges + @test gnew.num_graphs == g.num_graphs + @test all(gnew.ndata.x[:, 7:11] .== 1) +end end + +@testset "remove_self_loops" begin if GRAPH_T == :coo # add_edges and set_edge_weight only implemented for coo + g = rand_graph(10, 20, graph_type = GRAPH_T) + g1 = add_edges(g, [1:5;], [1:5;]) + @test g1.num_edges == g.num_edges + 5 + g2 = remove_self_loops(g1) + @test g2.num_edges == g.num_edges + @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) + + # with edge features and weights + g1 = GNNGraph(g1, edata = (e1 = ones(3, g1.num_edges), e2 = 2 * ones(g1.num_edges))) + g1 = set_edge_weight(g1, 3 * ones(g1.num_edges)) + g2 = remove_self_loops(g1) + @test g2.num_edges == g.num_edges + @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) + @test size(get_edge_weight(g2)) == (g2.num_edges,) + @test size(g2.edata.e1) == (3, g2.num_edges) + @test size(g2.edata.e2) == (g2.num_edges,) +end end + +@testset "remove_multi_edges" begin if GRAPH_T == :coo + g = rand_graph(10, 20, graph_type = GRAPH_T) + s, t = edge_index(g) + g1 = add_edges(g, s[1:5], t[1:5]) + @test g1.num_edges == g.num_edges + 5 + g2 = remove_multi_edges(g1, aggr = +) + @test g2.num_edges == g.num_edges + @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) + + # Default aggregation is + + g1 = GNNGraph(g1, edata = (e1 = ones(3, g1.num_edges), e2 = 2 * ones(g1.num_edges))) + g1 = set_edge_weight(g1, 3 * ones(g1.num_edges)) + g2 = remove_multi_edges(g1) + @test g2.num_edges == g.num_edges + @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) + @test count(g2.edata.e1[:, i] == 2 * ones(3) for i in 1:(g2.num_edges)) == 5 + @test count(g2.edata.e2[i] == 4 for i in 1:(g2.num_edges)) == 5 + w2 = get_edge_weight(g2) + @test count(w2[i] == 6 for i in 1:(g2.num_edges)) == 5 +end end + +@testset "negative_sample" begin if GRAPH_T == :coo + n, m = 10, 30 + g = rand_graph(n, m, bidirected = true, graph_type = GRAPH_T) + + # check bidirected=is_bidirected(g) default + gneg = negative_sample(g, num_neg_edges = 20) + @test gneg.num_nodes == g.num_nodes + @test gneg.num_edges == 20 + @test is_bidirected(gneg) + @test intersect(g, gneg).num_edges == 0 +end end + +@testset "rand_edge_split" begin if GRAPH_T == :coo + n, m = 100, 300 + + g = rand_graph(n, m, bidirected = true, graph_type = GRAPH_T) + # check bidirected=is_bidirected(g) default + g1, g2 = rand_edge_split(g, 0.9) + @test is_bidirected(g1) + @test is_bidirected(g2) + @test intersect(g1, g2).num_edges == 0 + @test g1.num_edges + g2.num_edges == g.num_edges + @test g2.num_edges < 50 + + g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T) + # check bidirected=is_bidirected(g) default + g1, g2 = rand_edge_split(g, 0.9) + @test !is_bidirected(g1) + @test !is_bidirected(g2) + @test intersect(g1, g2).num_edges == 0 + @test g1.num_edges + g2.num_edges == g.num_edges + @test g2.num_edges < 50 + + g1, g2 = rand_edge_split(g, 0.9, bidirected = false) + @test !is_bidirected(g1) + @test !is_bidirected(g2) + @test intersect(g1, g2).num_edges == 0 + @test g1.num_edges + g2.num_edges == g.num_edges + @test g2.num_edges < 50 +end end + +@testset "set_edge_weight" begin + g = rand_graph(10, 20, graph_type = GRAPH_T) + w = rand(20) + + gw = set_edge_weight(g, w) + @test get_edge_weight(gw) == w + + # now from weighted graph + s, t = edge_index(g) + g2 = GNNGraph(s, t, rand(20), graph_type = GRAPH_T) + gw2 = set_edge_weight(g2, w) + @test get_edge_weight(gw2) == w +end + +@testset "to_bidirected" begin if GRAPH_T == :coo + s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4] + w = [1.0, 2.0, 3.0, 4.0, 5.0] + e = [10.0, 20.0, 30.0, 40.0, 50.0] + g = GNNGraph(s, t, w, edata = e) + + g2 = to_bidirected(g) + @test g2.num_nodes == g.num_nodes + @test g2.num_edges == 7 + @test is_bidirected(g2) + @test !has_multi_edges(g2) + + s2, t2 = edge_index(g2) + w2 = get_edge_weight(g2) + @test s2 == [1, 2, 2, 3, 3, 4, 4] + @test t2 == [2, 1, 3, 2, 4, 3, 4] + @test w2 == [1, 1, 2, 2, 3.5, 3.5, 5] + @test g2.edata.e == [10.0, 10.0, 20.0, 20.0, 35.0, 35.0, 50.0] +end end + +@testset "to_unidirected" begin if GRAPH_T == :coo + s = [1, 2, 3, 4, 4] + t = [2, 3, 4, 3, 4] + w = [1.0, 2.0, 3.0, 4.0, 5.0] + e = [10.0, 20.0, 30.0, 40.0, 50.0] + g = GNNGraph(s, t, w, edata = e) + + g2 = to_unidirected(g) + @test g2.num_nodes == g.num_nodes + @test g2.num_edges == 4 + @test !has_multi_edges(g2) + + s2, t2 = edge_index(g2) + w2 = get_edge_weight(g2) + @test s2 == [1, 2, 3, 4] + @test t2 == [2, 3, 4, 4] + @test w2 == [1, 2, 3.5, 5] + @test g2.edata.e == [10.0, 20.0, 35.0, 50.0] +end end + +@testset "Graphs.Graph from GNNGraph" begin + g = rand_graph(10, 20, graph_type = GRAPH_T) + + G = Graphs.Graph(g) + @test nv(G) == g.num_nodes + @test ne(G) == g.num_edges ÷ 2 + + DG = Graphs.DiGraph(g) + @test nv(DG) == g.num_nodes + @test ne(DG) == g.num_edges +end + +@testset "random_walk_pe" begin + s = [1, 2, 2, 3] + t = [2, 1, 3, 2] + ndata = [-1, 0, 1] + g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata) + output = random_walk_pe(g, 3) + @test output == [0.0 0.0 0.0 + 0.5 1.0 0.5 + 0.0 0.0 0.0] +end + +@testset "HeteroGraphs" begin + @testset "batch" begin + gs = [rand_bipartite_heterograph((10, 15), 20) for _ in 1:5] + g = MLUtils.batch(gs) + @test g.num_nodes[:A] == 50 + @test g.num_nodes[:B] == 75 + @test g.num_edges[(:A,:to,:B)] == 100 + @test g.num_edges[(:B,:to,:A)] == 100 + @test g.num_graphs == 5 + @test g.graph_indicator == Dict(:A => vcat([fill(i, 10) for i in 1:5]...), + :B => vcat([fill(i, 15) for i in 1:5]...)) + + for gi in gs + gi.ndata[:A].x = ones(2, 10) + gi.ndata[:A].y = zeros(10) + gi.edata[(:A,:to,:B)].e = fill(2, 20) + gi.gdata.u = 7 + end + g = MLUtils.batch(gs) + @test g.ndata[:A].x == ones(2, 50) + @test g.ndata[:A].y == zeros(50) + @test g.edata[(:A,:to,:B)].e == fill(2, 100) + @test g.gdata.u == fill(7, 5) + + # Allow for wider eltype + g = MLUtils.batch(GNNHeteroGraph[g for g in gs]) + @test g.ndata[:A].x == ones(2, 50) + @test g.ndata[:A].y == zeros(50) + @test g.edata[(:A,:to,:B)].e == fill(2, 100) + @test g.gdata.u == fill(7, 5) + end + + @testset "batch non-similar edge types" begin + gs = [rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)), + rand_heterograph((:A => 10, :B => 15), ((:A, :to1, :B) => 5, (:B, :to2, :B) => 16)), + rand_heterograph((:B => 15, :C => 5), ((:C, :to1, :B) => 5, (:B, :to2, :C) => 21)), + rand_heterograph((:A => 10, :B => 10, :C => 10), ((:A, :to1, :C) => 5, (:A, :to1, :B) => 5)), + rand_heterograph((:C => 20), ((:C, :to3, :C) => 10)) + ] + g = MLUtils.batch(gs) + + @test g.num_nodes[:A] == 10 + 10 + 10 + @test g.num_nodes[:B] == 14 + 15 + 15 + 10 + @test g.num_nodes[:C] == 5 + 10 + 20 + @test g.num_edges[(:A,:to1,:A)] == 5 + @test g.num_edges[(:A,:to1,:B)] == 20 + 5 + 5 + @test g.num_edges[(:A,:to1,:C)] == 5 + + @test g.num_edges[(:B,:to2,:B)] == 16 + @test g.num_edges[(:B,:to2,:C)] == 21 + + @test g.num_edges[(:C,:to1,:B)] == 5 + @test g.num_edges[(:C,:to3,:C)] == 10 + @test length(keys(g.num_edges)) == 7 + @test g.num_graphs == 5 + + function ndata_if_key(g, key, subkey, value) + if haskey(g.ndata, key) + g.ndata[key][subkey] = reduce(hcat, fill(value, g.num_nodes[key])) + end + end + + function edata_if_key(g, key, subkey, value) + if haskey(g.edata, key) + g.edata[key][subkey] = reduce(hcat, fill(value, g.num_edges[key])) + end + end + + for gi in gs + ndata_if_key(gi, :A, :x, [0]) + ndata_if_key(gi, :A, :y, ones(2)) + ndata_if_key(gi, :B, :x, ones(3)) + ndata_if_key(gi, :C, :y, zeros(4)) + edata_if_key(gi, (:A,:to1,:B), :x, [0]) + gi.gdata.u = 7 + end + + g = MLUtils.batch(gs) + + @test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10)) + @test g.ndata[:A].y == ones(2, 10 + 10 + 10) + @test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10) + @test g.ndata[:C].y == zeros(4, 5 + 10 + 20) + + @test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5)) + + @test g.gdata.u == fill(7, 5) + + # Allow for wider eltype + g = MLUtils.batch(GNNHeteroGraph[g for g in gs]) + @test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10)) + @test g.ndata[:A].y == ones(2, 10 + 10 + 10) + @test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10) + @test g.ndata[:C].y == zeros(4, 5 + 10 + 20) + + @test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5)) + + @test g.gdata.u == fill(7, 5) + end + + @testset "add_edges" begin + hg = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) + hg = add_edges(hg, (:B,:to,:A), [1, 1], [1,2]) + @test hg.num_edges == Dict((:A,:to,:B) => 4, (:B,:to,:A) => 2) + @test has_edge(hg, (:B,:to,:A), 1, 1) + @test has_edge(hg, (:B,:to,:A), 1, 2) + @test !has_edge(hg, (:B,:to,:A), 2, 1) + @test !has_edge(hg, (:B,:to,:A), 2, 2) + + @testset "new nodes" begin + hg = rand_bipartite_heterograph((2, 2), 3) + hg = add_edges(hg, (:C,:rel,:B) => ([1, 3], [1,2])) + @test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3) + @test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2) + s, t = edge_index(hg, (:C,:rel,:B)) + @test s == [1, 3] + @test t == [1, 2] + + hg = add_edges(hg, (:D,:rel,:F) => ([1, 3], [1,2])) + @test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3, :D => 3, :F => 2) + @test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2, (:D,:rel,:F) => 2) + s, t = edge_index(hg, (:D,:rel,:F)) + @test s == [1, 3] + @test t == [1, 2] + end + + @testset "also add weights" begin + hg = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7], [0.1, 0.2, 0.3, 0.4])) + hgnew = add_edges(hg, (:user, :like, :actor) => ([1, 2], [3, 4], [0.5, 0.6])) + @test hgnew.num_nodes[:user] == 3 + @test hgnew.num_nodes[:movie] == 13 + @test hgnew.num_nodes[:actor] == 4 + @test hgnew.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 2) + @test get_edge_weight(hgnew, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4] + @test get_edge_weight(hgnew, (:user, :like, :actor)) == [0.5, 0.6] + + hgnew2 = add_edges(hgnew, (:user, :like, :actor) => ([6, 7], [8, 10], [0.7, 0.8])) + @test hgnew2.num_nodes[:user] == 7 + @test hgnew2.num_nodes[:movie] == 13 + @test hgnew2.num_nodes[:actor] == 10 + @test hgnew2.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 4) + @test get_edge_weight(hgnew2, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4] + @test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8] + end + end + + @testset "add self-loops heterographs" begin + g = rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)) + # Case in which haskey(g.graph, edge_t) passes + g = add_self_loops(g, (:A, :to1, :A)) + + @test g.num_edges[(:A, :to1, :A)] == 5 + 10 + @test g.num_edges[(:A, :to1, :B)] == 20 + # This test should not use length(keys(g.num_edges)) since that may be undefined behavior + @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 2 + + # Case in which haskey(g.graph, edge_t) fails + g = add_self_loops(g, (:A, :to3, :A)) + + @test g.num_edges[(:A, :to1, :A)] == 5 + 10 + @test g.num_edges[(:A, :to1, :B)] == 20 + @test g.num_edges[(:A, :to3, :A)] == 10 + @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 3 + + # Case with edge weights + g = GNNHeteroGraph(Dict((:A, :to1, :A) => ([1, 2, 3], [3, 2, 1], [2, 2, 2]), (:A, :to2, :B) => ([1, 4, 5], [1, 2, 3]))) + n = g.num_nodes[:A] + g = add_self_loops(g, (:A, :to1, :A)) + + @test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n)) + end +end + +@testset "ppr_diffusion" begin + if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + eweights = [0.1, 0.2, 0.3, 0.4] + + g = GNNGraph(s, t, eweights) + + g_new = ppr_diffusion(g) + w_new = get_edge_weight(g_new) + + check_ew = Float32[0.012749999 + 0.025499998 + 0.038249996 + 0.050999995] + + @test w_new ≈ check_ew + end +end +[.\GNNGraphs\test\utils.jl] +@testset "edge encoding/decoding" begin + # not is_bidirected + n = 5 + s = [1, 1, 2, 3, 3, 4, 5] + t = [1, 3, 1, 1, 2, 5, 5] + + # directed=true + idx, maxid = GNNGraphs.edge_encoding(s, t, n) + @test maxid == n^2 + @test idx == [1, 3, 6, 11, 12, 20, 25] + + sdec, tdec = GNNGraphs.edge_decoding(idx, n) + @test sdec == s + @test tdec == t + + n1, m1 = 10, 30 + g = rand_graph(n1, m1) + s1, t1 = edge_index(g) + idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1) + sdec, tdec = GNNGraphs.edge_decoding(idx, n1) + @test sdec == s1 + @test tdec == t1 + + # directed=false + idx, maxid = GNNGraphs.edge_encoding(s, t, n, directed = false) + @test maxid == n * (n + 1) ÷ 2 + @test idx == [1, 3, 2, 3, 7, 14, 15] + + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + sdec, tdec = GNNGraphs.edge_decoding(idx, n, directed = false) + @test sdec == snew + @test tdec == tnew + + n1, m1 = 6, 8 + g = rand_graph(n1, m1) + s1, t1 = edge_index(g) + idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1, directed = false) + sdec, tdec = GNNGraphs.edge_decoding(idx, n1, directed = false) + mask = s1 .> t1 + snew = copy(s1) + tnew = copy(t1) + snew[mask] .= t1[mask] + tnew[mask] .= s1[mask] + @test sdec == snew + @test tdec == tnew + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + g = GNNGraph(s, t) + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=true, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=true, self_loops=false) + @test idxmax == n^2 - n + @test idx == [1, 9, 3, 4, 6, 7, 8, 11, 12, 16] + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=true, self_loops=false) + @test snew == s + @test tnew == t + end +end + +@testset "color_refinement" begin + rng = MersenneTwister(17) + g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) + x0 = ones(Int, 10) + x, ncolors, niters = color_refinement(g, x0) + @test ncolors == 8 + @test niters == 2 + @test x == [4, 5, 6, 7, 8, 5, 8, 9, 10, 11] + + x2, _, _ = color_refinement(g) + @test x2 == x +end +[.\GNNGraphs\test\ext\SimpleWeightedGraphs.jl] +@testset "simple_weighted_graph" begin + srcs = [1, 2, 1] + dsts = [2, 3, 3] + wts = [0.5, 0.8, 2.0] + g = SimpleWeightedGraph(srcs, dsts, wts) + gd = SimpleWeightedDiGraph(srcs, dsts, wts) + gnn_g = GNNGraph(g) + gnn_gd = GNNGraph(gd) + @test get_edge_weight(gnn_g) == [0.5, 2, 0.5, 0.8, 2.0, 0.8] + @test get_edge_weight(gnn_gd) == [0.5, 2, 0.8] +end + +[.\GNNlib\ext\GNNlibCUDAExt.jl] +module GNNlibCUDAExt + +using CUDA +using Random, Statistics, LinearAlgebra +using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj +using GNNGraphs: GNNGraph, COO_T, SPARSE_T + +###### PROPAGATE SPECIALIZATIONS #################### + +## COPY_XJ + +## avoid the fast path on gpu until we have better cuda support +function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e) + propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) +end + +## E_MUL_XJ + +## avoid the fast path on gpu until we have better cuda support +function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e::AbstractVector) + propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) +end + +## W_MUL_XJ + +## avoid the fast path on gpu until we have better cuda support +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e::Nothing) + propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) +end + +# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) +# A = adjacency_matrix(g, weighted=false) +# D = compute_degree(A) +# return xj * A * D +# end + +# # Zygote bug. Error with sparse matrix without nograd +# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) + +# Flux.Zygote.@nograd compute_degree + +end #module + +[.\GNNlib\src\GNNlib.jl] +module GNNlib + +using Statistics: mean +using LinearAlgebra, Random +using MLUtils: zeros_like +using NNlib +using NNlib: scatter, gather +using DataStructures: nlargest +using ChainRulesCore: @non_differentiable +using GNNGraphs +using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, + check_num_nodes, check_num_edges, + EType, NType # for heteroconvs + +include("utils.jl") +export reduce_nodes, + reduce_edges, + softmax_nodes, + softmax_edges, + broadcast_nodes, + broadcast_edges, + softmax_edge_neighbors + +include("msgpass.jl") +export apply_edges, + aggregate_neighbors, + propagate, + copy_xj, + copy_xi, + xi_dot_xj, + xi_sub_xj, + xj_sub_xi, + e_mul_xj, + w_mul_xj + +## The following methods are defined but not exported + +include("layers/basic.jl") +export dot_decoder + +include("layers/conv.jl") +export agnn_conv, + cg_conv, + cheb_conv, + d_conv, + edge_conv, + egnn_conv, + gat_conv, + gatv2_conv, + gated_graph_conv, + gcn_conv, + gin_conv, + gmm_conv, + graph_conv, + megnet_conv, + nn_conv, + res_gated_graph_conv, + sage_conv, + sg_conv, + tag_conv, + transformer_conv + +include("layers/temporalconv.jl") +export a3tgcn_conv + +include("layers/pool.jl") +export global_pool, + global_attention_pool, + set2set_pool, + topk_pool, + topk_index + +# include("layers/heteroconv.jl") # no functional part at the moment + +end #module + +[.\GNNlib\src\msgpass.jl] +""" + propagate(fmsg, g, aggr; [xi, xj, e]) + propagate(fmsg, g, aggr xi, xj, e=nothing) + +Performs message passing on graph `g`. Takes care of materializing the node features on each edge, +applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` +(depending on the return value of `fmsg`, an array or a named tuple of +arrays with last dimension's size `g.num_nodes`). + +It can be decomposed in two steps: + +```julia +m = apply_edges(fmsg, g, xi, xj, e) +m̄ = aggregate_neighbors(g, aggr, m) +``` + +GNN layers typically call `propagate` in their forward pass, +providing as input `f` a closure. + +# Arguments + +- `g`: A `GNNGraph`. +- `xi`: An array or a named tuple containing arrays whose last dimension's size + is `g.num_nodes`. It will be appropriately materialized on the + target node of each edge (see also [`edge_index`](@ref)). +- `xj`: As `xj`, but to be materialized on edges' sources. +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. +- `fmsg`: A generic function that will be passed over to [`apply_edges`](@ref). + Has to take as inputs the edge-materialized `xi`, `xj`, and `e` + (arrays or named tuples of arrays whose last dimension' size is the size of + a batch of edges). Its output has to be an array or a named tuple of arrays + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. +- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. + +# Examples + +```julia +using GraphNeuralNetworks, Flux + +struct GNNConv <: GNNLayer + W + b + σ +end + +Flux.@layer GNNConv + +function GNNConv(ch::Pair{Int,Int}, σ=identity) + in, out = ch + W = Flux.glorot_uniform(out, in) + b = zeros(Float32, out) + GNNConv(W, b, σ) +end + +function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix) + message(xi, xj, e) = l.W * xj + m̄ = propagate(message, g, +, xj=x) + return l.σ.(m̄ .+ l.bias) +end + +l = GNNConv(10 => 20) +l(g, x) +``` + +See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). +""" +function propagate end + +function propagate(f, g::AbstractGNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) + propagate(f, g, aggr, xi, xj, e) +end + +function propagate(f, g::AbstractGNNGraph, aggr, xi, xj, e = nothing) + m = apply_edges(f, g, xi, xj, e) + m̄ = aggregate_neighbors(g, aggr, m) + return m̄ +end + +## APPLY EDGES + +""" + apply_edges(fmsg, g; [xi, xj, e]) + apply_edges(fmsg, g, xi, xj, e=nothing) + +Returns the message from node `j` to node `i` applying +the message function `fmsg` on the edges in graph `g`. +In the message-passing scheme, the incoming messages +from the neighborhood of `i` will later be aggregated +in order to update the features of node `i` (see [`aggregate_neighbors`](@ref)). + +The function `fmsg` operates on batches of edges, therefore +`xi`, `xj`, and `e` are tensors whose last dimension +is the batch size, or can be named tuples of +such tensors. + +# Arguments + +- `g`: An `AbstractGNNGraph`. +- `xi`: An array or a named tuple containing arrays whose last dimension's size + is `g.num_nodes`. It will be appropriately materialized on the + target node of each edge (see also [`edge_index`](@ref)). +- `xj`: As `xi`, but now to be materialized on each edge's source node. +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. +- `fmsg`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. + These are arrays (or named tuples of arrays) whose last dimension' size is the size of + a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. + +See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). +""" +function apply_edges end + +function apply_edges(f, g::AbstractGNNGraph; xi = nothing, xj = nothing, e = nothing) + apply_edges(f, g, xi, xj, e) +end + +function apply_edges(f, g::AbstractGNNGraph, xi, xj, e = nothing) + check_num_nodes(g, (xj, xi)) + check_num_edges(g, e) + s, t = edge_index(g) # for heterographs, errors if more than one edge type + xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges) + xj = GNNGraphs._gather(xj, s) + m = f(xi, xj, e) + return m +end + +## AGGREGATE NEIGHBORS +@doc raw""" + aggregate_neighbors(g, aggr, m) + +Given a graph `g`, edge features `m`, and an aggregation +operator `aggr` (e.g `+, min, max, mean`), returns the new node +features +```math +\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i} +``` + +Neighborhood aggregation is the second step of [`propagate`](@ref), +where it comes after [`apply_edges`](@ref). +""" +function aggregate_neighbors(g::GNNGraph, aggr, m) + check_num_edges(g, m) + s, t = edge_index(g) + return GNNGraphs._scatter(aggr, m, t, g.num_nodes) +end + +function aggregate_neighbors(g::GNNHeteroGraph, aggr, m) + check_num_edges(g, m) + s, t = edge_index(g) + dest_node_t = only(g.etypes)[3] + return GNNGraphs._scatter(aggr, m, t, g.num_nodes[dest_node_t]) +end + +### MESSAGE FUNCTIONS ### +""" + copy_xj(xi, xj, e) = xj +""" +copy_xj(xi, xj, e) = xj + +""" + copy_xi(xi, xj, e) = xi +""" +copy_xi(xi, xj, e) = xi + +""" + xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1) +""" +xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims = 1) + +""" + xi_sub_xj(xi, xj, e) = xi .- xj +""" +xi_sub_xj(xi, xj, e) = xi .- xj + +""" + xj_sub_xi(xi, xj, e) = xj .- xi +""" +xj_sub_xi(xi, xj, e) = xj .- xi + +""" + e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj + +Reshape `e` into broadcast compatible shape with `xj` +(by prepending singleton dimensions) then perform +broadcasted multiplication. +""" +function e_mul_xj(xi, xj::AbstractArray{Tj, Nj}, + e::AbstractArray{Te, Ne}) where {Tj, Te, Nj, Ne} + @assert Ne <= Nj + e = reshape(e, ntuple(_ -> 1, Nj - Ne)..., size(e)...) + return e .* xj +end + +""" + w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj + +Similar to [`e_mul_xj`](@ref) but specialized on scalar edge features (weights). +""" +w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights + +function w_mul_xj(xi, xj::AbstractArray{Tj, Nj}, w::AbstractVector) where {Tj, Nj} + w = reshape(w, ntuple(_ -> 1, Nj - 1)..., length(w)) + return w .* xj +end + +###### PROPAGATE SPECIALIZATIONS #################### +## See also the methods defined in the package extensions. + +## COPY_XJ + +function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e) + A = adjacency_matrix(g, weighted = false) + return xj * A +end + +## E_MUL_XJ + +# for weighted convolution +function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, + e::AbstractVector) + g = set_edge_weight(g, e) + A = adjacency_matrix(g, weighted = true) + return xj * A +end + + +## W_MUL_XJ + +# for weighted convolution +function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, + e::Nothing) + A = adjacency_matrix(g, weighted = true) + return xj * A +end + + +# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) +# A = adjacency_matrix(g, weighted=false) +# D = compute_degree(A) +# return xj * A * D +# end + +# # Zygote bug. Error with sparse matrix without nograd +# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) + +# Flux.Zygote.@nograd compute_degree + +[.\GNNlib\src\utils.jl] +ofeltype(x, y) = convert(float(eltype(x)), y) + +""" + reduce_nodes(aggr, g, x) + +For a batched graph `g`, return the graph-wise aggregation of the node +features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. +The returned array will have last dimension `g.num_graphs`. + +See also: [`reduce_edges`](@ref). +""" +function reduce_nodes(aggr, g::GNNGraph, x) + @assert size(x)[end] == g.num_nodes + indexes = graph_indicator(g) + return NNlib.scatter(aggr, x, indexes) +end + +""" + reduce_nodes(aggr, indicator::AbstractVector, x) + +Return the graph-wise aggregation of the node features `x` given the +graph indicator `indicator`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. + +See also [`graph_indicator`](@ref). +""" +function reduce_nodes(aggr, indicator::AbstractVector, x) + return NNlib.scatter(aggr, x, indicator) +end + +""" + reduce_edges(aggr, g, e) + +For a batched graph `g`, return the graph-wise aggregation of the edge +features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. +The returned array will have last dimension `g.num_graphs`. +""" +function reduce_edges(aggr, g::GNNGraph, e) + @assert size(e)[end] == g.num_edges + s, t = edge_index(g) + indexes = graph_indicator(g)[s] + return NNlib.scatter(aggr, e, indexes) +end + +""" + softmax_nodes(g, x) + +Graph-wise softmax of the node features `x`. +""" +function softmax_nodes(g::GNNGraph, x) + @assert size(x)[end] == g.num_nodes + gi = graph_indicator(g) + max_ = gather(scatter(max, x, gi), gi) + num = exp.(x .- max_) + den = reduce_nodes(+, g, num) + den = gather(den, gi) + return num ./ den +end + +""" + softmax_edges(g, e) + +Graph-wise softmax of the edge features `e`. +""" +function softmax_edges(g::GNNGraph, e) + @assert size(e)[end] == g.num_edges + gi = graph_indicator(g, edges = true) + max_ = gather(scatter(max, e, gi), gi) + num = exp.(e .- max_) + den = reduce_edges(+, g, num) + den = gather(den, gi) + return num ./ (den .+ eps(eltype(e))) +end + +@doc raw""" + softmax_edge_neighbors(g, e) + +Softmax over each node's neighborhood of the edge features `e`. + +```math +\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} + {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. +``` +""" +function softmax_edge_neighbors(g::AbstractGNNGraph, e) + if g isa GNNHeteroGraph + for (key, value) in g.num_edges + @assert size(e)[end] == value + end + else + @assert size(e)[end] == g.num_edges + end + s, t = edge_index(g) + max_ = gather(scatter(max, e, t), t) + num = exp.(e .- max_) + den = gather(scatter(+, num, t), t) + return num ./ den +end + +""" + broadcast_nodes(g, x) + +Graph-wise broadcast array `x` of size `(*, g.num_graphs)` +to size `(*, g.num_nodes)`. +""" +function broadcast_nodes(g::GNNGraph, x) + @assert size(x)[end] == g.num_graphs + gi = graph_indicator(g) + return gather(x, gi) +end + +""" + broadcast_edges(g, x) + +Graph-wise broadcast array `x` of size `(*, g.num_graphs)` +to size `(*, g.num_edges)`. +""" +function broadcast_edges(g::GNNGraph, x) + @assert size(x)[end] == g.num_graphs + gi = graph_indicator(g, edges = true) + return gather(x, gi) +end + +expand_srcdst(g::AbstractGNNGraph, x) = throw(ArgumentError("Invalid input type, expected matrix or tuple of matrices.")) +expand_srcdst(g::AbstractGNNGraph, x::AbstractMatrix) = (x, x) +expand_srcdst(g::AbstractGNNGraph, x::Tuple{<:AbstractMatrix, <:AbstractMatrix}) = x + +# Replacement for Base.Fix1 to allow for multiple arguments +struct Fix1{F,X} + f::F + x::X +end + +(f::Fix1)(y...) = f.f(f.x, y...) + +[.\GNNlib\src\layers\basic.jl] +function dot_decoder(g, x) + return apply_edges(xi_dot_xj, g, xi = x, xj = x) +end + +[.\GNNlib\src\layers\conv.jl] +####################### GCNConv ###################################### + +check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = + throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) + +function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector) + if length(edge_weight) !== g.num_edges + throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) + end +end + +check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing + +function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + check_gcnconv_input(g, edge_weight) + if conv_weight === nothing + weight = l.weight + else + weight = conv_weight + if size(weight) != size(l.weight) + throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) + end + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + # Pad weights with ones + # TODO for ADJMAT_T the new edges are not generally at the end + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(weight) + if Dout < Din && !(g isa GNNHeteroGraph) + # multiply before convolution if it is more convenient, otherwise multiply after + # (this works only for homogenous graph) + x = weight * x + end + + xj, xi = expand_srcdst(g, x) # expand only after potential multiplication + T = eltype(xi) + + if g isa GNNHeteroGraph + din = degree(g, g.etypes[1], T; dir = :in) + dout = degree(g, g.etypes[1], T; dir = :out) + + cout = norm_fn(dout) + cin = norm_fn(din) + else + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) + end + cin = cout = norm_fn(d) + end + xj = xj .* cout' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = xj) + else + x = propagate(copy_xj, g, +, xj = xj) + end + x = x .* cin' + if Dout >= Din || g isa GNNHeteroGraph + x = weight * x + end + return l.σ.(x .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + g = GNNGraph(g, graph_type = :coo) + return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) +end + +####################### ChebConv ###################################### + +function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T} + check_num_nodes(g, X) + @assert size(X, 1) == size(l.weight, 2) "Input feature size must match input channel size." + + L̃ = scaled_laplacian(g, eltype(X)) + + Z_prev = X + Z = X * L̃ + Y = view(l.weight, :, :, 1) * Z_prev + Y = Y .+ view(l.weight, :, :, 2) * Z + for k in 3:(l.k) + Z, Z_prev = 2 * Z * L̃ - Z_prev, Z + Y = Y .+ view(l.weight, :, :, k) * Z + end + return Y .+ l.bias +end + +####################### GraphConv ###################################### + +function graph_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) + x = l.weight1 * xi .+ l.weight2 * m + return l.σ.(x .+ l.bias) +end + +####################### GATConv ###################################### + +function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + xj, xi = expand_srcdst(g, x) + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + + _, chout = l.channel + heads = l.heads + + Wxi = Wxj = l.dense_x(xj) + Wxi = Wxj = reshape(Wxj, chout, heads, :) + + if xi !== xj + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) + end + + # a hand-written message passing + message = Fix1(gat_message, l) + m = apply_edges(message, g, Wxi, Wxj, e) + α = softmax_edge_neighbors(g, m.logα) + α = dropout(α, l.dropout) + β = α .* m.Wxj + x = aggregate_neighbors(g, +, β) + + if !l.concat + x = mean(x, dims = 2) + end + x = reshape(x, :, size(x, 3)) # return a matrix + x = l.σ.(x .+ l.bias) + + return x +end + +function gat_message(l, Wxi, Wxj, e) + _, chout = l.channel + heads = l.heads + + if e === nothing + Wxx = vcat(Wxi, Wxj) + else + We = l.dense_e(e) + We = reshape(We, chout, heads, :) # chout × nheads × nnodes + Wxx = vcat(Wxi, Wxj, We) + end + aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges + slope = convert(eltype(aWW), l.negative_slope) + logα = leakyrelu.(aWW, slope) + return (; logα, Wxj) +end + +####################### GATv2Conv ###################################### + +function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + xj, xi = expand_srcdst(g, x) + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + _, out = l.channel + heads = l.heads + + Wxi = reshape(l.dense_i(xi), out, heads, :) # out × heads × nnodes + Wxj = reshape(l.dense_j(xj), out, heads, :) # out × heads × nnodes + + message = Fix1(gatv2_message, l) + m = apply_edges(message, g, Wxi, Wxj, e) + α = softmax_edge_neighbors(g, m.logα) + α = dropout(α, l.dropout) + β = α .* m.Wxj + x = aggregate_neighbors(g, +, β) + + if !l.concat + x = mean(x, dims = 2) + end + x = reshape(x, :, size(x, 3)) + x = l.σ.(x .+ l.bias) + return x +end + +function gatv2_message(l, Wxi, Wxj, e) + _, out = l.channel + heads = l.heads + + Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" + if e !== nothing + Wx += reshape(l.dense_e(e), out, heads, :) + end + slope = convert(eltype(Wx), l.negative_slope) + logα = sum(l.a .* leakyrelu.(Wx, slope), dims = 1) # 1 × heads × nedges + return (; logα, Wxj) +end + +####################### GatedGraphConv ###################################### + +function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + m, n = size(x) + @assert m <= l.dims "number of input features must be less or equal to output features." + if m < l.dims + xpad = zeros_like(x, (l.dims - m, n)) + x = vcat(x, xpad) + end + h = x + for i in 1:(l.num_layers) + m = view(l.weight, :, :, i) * h + m = propagate(copy_xj, g, l.aggr; xj = m) + # in gru forward, hidden state is first argument, input is second + h, _ = l.gru(h, m) + end + return h +end + +####################### EdgeConv ###################################### + +function edge_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + message = Fix1(edge_conv_message, l) + x = propagate(message, g, l.aggr; xi, xj, e = nothing) + return x +end + +edge_conv_message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) + +####################### GINConv ###################################### + +function gin_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + m = propagate(copy_xj, g, l.aggr, xj = xj) + + return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) +end + +####################### NNConv ###################################### + +function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) + check_num_nodes(g, x) + message = Fix1(nn_conv_message, l) + m = propagate(message, g, l.aggr, xj = x, e = e) + return l.σ.(l.weight * x .+ m .+ l.bias) +end + +function nn_conv_message(l, xi, xj, e) + nin, nedges = size(xj) + W = reshape(l.nn(e), (:, nin, nedges)) + xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul + m = NNlib.batched_mul(W, xj) + return reshape(m, :, nedges) +end + +####################### SAGEConv ###################################### + +function sage_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) + x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) + return x +end + +####################### ResGatedConv ###################################### + +function res_gated_graph_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx + + Ax = l.A * xi + Bx = l.B * xj + Vx = l.V * xj + + m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx)) + + return l.σ.(l.U * xi .+ m .+ l.bias) +end + +####################### CGConv ###################################### + +function cg_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + if e !== nothing + check_num_edges(g, e) + end + + message = Fix1(cg_message, l) + m = propagate(message, g, +, xi = xi, xj = xj, e = e) + + if l.residual + if size(x, 1) == size(m, 1) + m += x + else + @warn "number of output features different from number of input features, residual not applied." + end + end + + return m +end + +function cg_message(l, xi, xj, e) + if e !== nothing + z = vcat(xi, xj, e) + else + z = vcat(xi, xj) + end + return l.dense_f(z) .* l.dense_s(z) +end + +####################### AGNNConv ###################################### + +function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + if l.add_self_loops + g = add_self_loops(g) + end + + xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) + cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) + α = softmax_edge_neighbors(g, l.β .* cos_dist) + + x = propagate(g, +; xj = x, e = α) do xi, xj, α + α .* xj + end + + return x +end + +####################### MegNetConv ###################################### + +function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) + check_num_nodes(g, x) + + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e + l.ϕe(vcat(xi, xj, e)) + end + + xᵉ = aggregate_neighbors(g, l.aggr, ē) + + x̄ = l.ϕv(vcat(x, xᵉ)) + + return x̄, ē +end + +####################### GMMConv ###################################### + +function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) + (nin, ein), out = l.ch #Notational Simplicity + + @assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)" + + num_edges = g.num_edges + w = reshape(e, (ein, 1, num_edges)) + mu = reshape(l.mu, (ein, l.K, 1)) + + w = @. ((w - mu)^2) / 2 + w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1)) + w = exp.(sum(w, dims = 1)) # (1, K, num_edge) + + xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes) + + m = propagate(e_mul_xj, g, mean, xj = xj, e = w) + m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) + + m = l.σ(m .+ l.bias) + + if l.residual + if size(x, 1) == size(m, 1) + m += x + else + @warn "Residual not applied : output feature is not equal to input_feature" + end + end + + return m +end + +####################### SGCConv ###################################### + +# this layer is not stable enough to be supported by GNNHeteroGraph type +# due to it's looping mechanism +function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(g; graph_type=:coo) + return sgc_conv(l, g, x, edge_weight) +end + +####################### EGNNGConv ###################################### + +function egnn_conv(l, g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) + if l.num_features.edge > 0 + @assert e!==nothing "Edge features must be provided." + end + @assert size(h, 1)==l.num_features.in "Input features must match layer input size." + + x_diff = apply_edges(xi_sub_xj, g, x, x) + sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) + x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) + + message = Fix1(egnn_message, l) + msg = apply_edges(message, g, + xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) + h_aggr = aggregate_neighbors(g, +, msg.h) + x_aggr = aggregate_neighbors(g, mean, msg.x) + + hnew = l.ϕh(vcat(h, h_aggr)) + if l.residual + h = h .+ hnew + else + h = hnew + end + x = x .+ x_aggr + return h, x +end + +function egnn_message(l, xi, xj, e) + if l.num_features.edge > 0 + f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) + else + f = vcat(xi.h, xj.h, e.sqnorm_xdiff) + end + + msg_h = l.ϕe(f) + msg_x = l.ϕx(msg_h) .* e.x_diff + return (; x = msg_x, h = msg_h) +end + +######################## SGConv ###################################### + +# this layer is not stable enough to be supported by GNNHeteroGraph type +# due to it's looping mechanism +function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(g; graph_type=:coo) + return sg_conv(l, g, x, edge_weight) +end + +######################## TransformerConv ###################################### + +function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing} = nothing) + check_num_nodes(g, x) + + if l.add_self_loops + g = add_self_loops(g) + end + + out = l.channels[2] + heads = l.heads + W1x = !isnothing(l.W1) ? l.W1(x) : nothing + W2x = reshape(l.W2(x), out, heads, :) + W3x = reshape(l.W3(x), out, heads, :) + W4x = reshape(l.W4(x), out, heads, :) + W6e = !isnothing(l.W6) ? reshape(l.W6(e), out, heads, :) : nothing + + message_uij = Fix1(transformer_message_uij, l) + m = apply_edges(message_uij, g; xi = (; W3x), xj = (; W4x), e = (; W6e)) + α = softmax_edge_neighbors(g, m) + α_val = propagate(transformer_message_main, g, +; + xi = (; W3x), xj = (; W2x), e = (; W6e, α)) + + h = α_val + if l.concat + h = reshape(h, out * heads, :) # concatenate heads + else + h = mean(h, dims = 2) # average heads + h = reshape(h, out, :) + end + + if !isnothing(W1x) # root_weight + if !isnothing(l.W5) # gating + β = l.W5(vcat(h, W1x, h .- W1x)) + h = β .* W1x + (1.0f0 .- β) .* h + else + h += W1x + end + end + + if l.skip_connection + @assert size(h, 1)==size(x, 1) "In-channels must correspond to out-channels * heads if skip_connection is used" + h += x + end + if !isnothing(l.BN1) + h = l.BN1(h) + end + + if !isnothing(l.FF) + h1 = h + h = l.FF(h) + if l.skip_connection + h += h1 + end + if !isnothing(l.BN2) + h = l.BN2(h) + end + end + + return h +end + +# TODO remove l dependence +function transformer_message_uij(l, xi, xj, e) + key = xj.W4x + if !isnothing(e.W6e) + key += e.W6e + end + uij = sum(xi.W3x .* key, dims = 1) ./ l.sqrt_out + return uij +end + +function transformer_message_main(xi, xj, e) + val = xj.W2x + if !isnothing(e.W6e) + val += e.W6e + end + return e.α .* val +end + + +######################## TAGConv ###################################### + +function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + + sum_pow = 0 + sum_total = 0 + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + + # On the first iteration, initialize sum_pow with the first propagated features + # On subsequent iterations, accumulate propagated features + if iter == 1 + sum_pow = x + sum_total = l.weight * sum_pow + else + sum_pow += x + # Weighted sum of features for each power of adjacency matrix + # This applies the weight matrix to the accumulated sum of propagated features + sum_total += l.weight * sum_pow + end + end + + return (sum_total .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(g; graph_type = :coo) + return l(g, x, edge_weight) +end + +######################## DConv ###################################### + +function d_conv(l, g::GNNGraph, x::AbstractMatrix) + #A = adjacency_matrix(g, weighted = true) + s, t = edge_index(g) + gt = GNNGraph(t, s, get_edge_weight(g)) + deg_out = degree(g; dir = :out) + deg_in = degree(g; dir = :in) + deg_out = Diagonal(deg_out) + deg_in = Diagonal(deg_in) + + h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x + + T0 = x + if l.k > 1 + # T1_in = T0 * deg_in * A' + #T1_out = T0 * deg_out' * A + T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') + T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) + h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out + end + for i in 2:l.k + T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) + T2_in = 2 * T2_in - T0 + T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') + T2_out = 2 * T2_out - T0 + h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out + T1_in = T2_in + T1_out = T2_out + end + return h .+ l.bias +end +[.\GNNlib\src\layers\pool.jl] + + +function global_pool(l, g::GNNGraph, x::AbstractArray) + return reduce_nodes(l.aggr, g, x) +end + +function global_attention_pool(l, g::GNNGraph, x::AbstractArray) + α = softmax_nodes(g, l.fgate(x)) + feats = α .* l.ffeat(x) + u = reduce_nodes(+, g, feats) + return u +end + +function topk_pool(t, X::AbstractArray) + y = t.p' * X / norm(t.p) + idx = topk_index(y, t.k) + t.Ã .= view(t.A, idx, idx) + X_ = view(X, :, idx) .* σ.(view(y, idx)') + return X_ +end + +function topk_index(y::AbstractVector, k::Int) + v = nlargest(k, y) + return collect(1:length(y))[y .>= v[end]] +end + +topk_index(y::Adjoint, k::Int) = topk_index(y', k) + +function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) + n_in = size(x, 1) + qstar = zeros_like(x, (2*n_in, g.num_graphs)) + for t in 1:l.num_iters + q = l.lstm(qstar) # [n_in, n_graphs] + qn = broadcast_nodes(g, q) # [n_in, n_nodes] + α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] + r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] + qstar = vcat(q, r) # [2*n_in, n_graphs] + end + return qstar +end + +[.\GNNlib\src\layers\temporalconv.jl] +function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) + h = a3tgcn.tgcn(g, x) + e = a3tgcn.dense1(h) + e = a3tgcn.dense2(e) + a = softmax(e, dims = 3) + c = sum(a .* h , dims = 3) + if length(size(c)) == 3 + c = dropdims(c, dims = 3) + end + return c +end + + +[.\GNNlib\test\msgpass_tests.jl] +@testitem "msgpass" setup=[SharedTestSetup] begin + #TODO test all graph types + GRAPH_T = :coo + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 14 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + X = rand(T, in_channel, num_V) + E = rand(T, in_channel, num_E) + + g = GNNGraph(adj, graph_type = GRAPH_T) + + @testset "propagate" begin + function message(xi, xj, e) + @test xi === nothing + @test e === nothing + ones(T, out_channel, size(xj, 2)) + end + + m = propagate(message, g, +, xj = X) + + @test size(m) == (out_channel, num_V) + + @testset "isolated nodes" begin + x1 = rand(1, 6) + g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) + y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) + @test size(y1) == (1, 6) + end + end + + @testset "apply_edges" begin + m = apply_edges(g, e = E) do xi, xj, e + @test xi === nothing + @test xj === nothing + ones(out_channel, size(e, 2)) + end + + @test m == ones(out_channel, num_E) + + # With NamedTuple input + m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e + @test xi === nothing + @test xj.b == 2 * xj.a + @test size(xj.a, 2) == size(xj.b, 2) == size(e, 2) + ones(out_channel, size(e, 2)) + end + + # NamedTuple output + m = apply_edges(g, e = E) do xi, xj, e + @test xi === nothing + @test xj === nothing + (; a = ones(out_channel, size(e, 2))) + end + + @test m.a == ones(out_channel, num_E) + + @testset "sizecheck" begin + x = rand(3, g.num_nodes - 1) + @test_throws AssertionError apply_edges(copy_xj, g, xj = x) + @test_throws AssertionError apply_edges(copy_xj, g, xi = x) + + x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1)) + @test_throws AssertionError apply_edges(copy_xj, g, xj = x) + @test_throws AssertionError apply_edges(copy_xj, g, xi = x) + + e = rand(3, g.num_edges - 1) + @test_throws AssertionError apply_edges(copy_xj, g, e = e) + end + end + + @testset "copy_xj" begin + n = 128 + A = sprand(n, n, 0.1) + Adj = map(x -> x > 0 ? 1 : 0, A) + X = rand(10, n) + + g = GNNGraph(A, ndata = X, graph_type = GRAPH_T) + + function spmm_copyxj_fused(g) + propagate(copy_xj, + g, +; xj = g.ndata.x) + end + + function spmm_copyxj_unfused(g) + propagate((xi, xj, e) -> xj, + g, +; xj = g.ndata.x) + end + + @test spmm_copyxj_unfused(g) ≈ X * Adj + @test spmm_copyxj_fused(g) ≈ X * Adj + end + + @testset "e_mul_xj and w_mul_xj for weighted conv" begin + n = 128 + A = sprand(n, n, 0.1) + Adj = map(x -> x > 0 ? 1 : 0, A) + X = rand(10, n) + + g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T) + + function spmm_unfused(g) + propagate((xi, xj, e) -> reshape(e, 1, :) .* xj, + g, +; xj = g.ndata.x, e = g.edata.e) + end + function spmm_fused(g) + propagate(e_mul_xj, + g, +; xj = g.ndata.x, e = g.edata.e) + end + + function spmm_fused2(g) + propagate(w_mul_xj, + g, +; xj = g.ndata.x) + end + + @test spmm_unfused(g) ≈ X * A + @test spmm_fused(g) ≈ X * A + @test spmm_fused2(g) ≈ X * A + end + + @testset "aggregate_neighbors" begin + @testset "sizecheck" begin + m = rand(2, g.num_edges - 1) + @test_throws AssertionError aggregate_neighbors(g, +, m) + + m = (a = rand(2, g.num_edges + 1), b = nothing) + @test_throws AssertionError aggregate_neighbors(g, +, m) + end + end + +end +[.\GNNlib\test\runtests.jl] +using GNNlib +using Test +using ReTestItems +using Random, Statistics + +runtests(GNNlib) + +[.\GNNlib\test\shared_testsetup.jl] +@testsetup module SharedTestSetup + +import Reexport: @reexport + +@reexport using GNNlib +@reexport using GNNGraphs +@reexport using NNlib +@reexport using MLUtils +@reexport using SparseArrays +@reexport using Test, Random, Statistics + +end +[.\GNNlib\test\utils_tests.jl] +@testitem "utils" setup=[SharedTestSetup] begin + # TODO test all graph types + GRAPH_T = :coo + De, Dx = 3, 2 + g = MLUtils.batch([rand_graph(10, 60, bidirected=true, + ndata = rand(Dx, 10), + edata = rand(De, 30), + graph_type = GRAPH_T) for i in 1:5]) + x = g.ndata.x + e = g.edata.e + + @testset "reduce_nodes" begin + r = reduce_nodes(mean, g, x) + @test size(r) == (Dx, g.num_graphs) + @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) + + r2 = reduce_nodes(mean, graph_indicator(g), x) + @test r2 == r + end + + @testset "reduce_edges" begin + r = reduce_edges(mean, g, e) + @test size(r) == (De, g.num_graphs) + @test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) + end + + @testset "softmax_nodes" begin + r = softmax_nodes(g, x) + @test size(r) == size(x) + @test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) + end + + @testset "softmax_edges" begin + r = softmax_edges(g, e) + @test size(r) == size(e) + @test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) + end + + @testset "broadcast_nodes" begin + z = rand(4, g.num_graphs) + r = broadcast_nodes(g, z) + @test size(r) == (4, g.num_nodes) + @test r[:, 1] ≈ z[:, 1] + @test r[:, 10] ≈ z[:, 1] + @test r[:, 11] ≈ z[:, 2] + end + + @testset "broadcast_edges" begin + z = rand(4, g.num_graphs) + r = broadcast_edges(g, z) + @test size(r) == (4, g.num_edges) + @test r[:, 1] ≈ z[:, 1] + @test r[:, 60] ≈ z[:, 1] + @test r[:, 61] ≈ z[:, 2] + end + + @testset "softmax_edge_neighbors" begin + s = [1, 2, 3, 4] + t = [5, 5, 6, 6] + g2 = GNNGraph(s, t) + e2 = randn(Float32, 3, g2.num_edges) + z = softmax_edge_neighbors(g2, e2) + @test size(z) == size(e2) + @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) + @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) + end +end + + +[.\GNNLux\src\GNNLux.jl] +module GNNLux +using ConcreteStructs: @concrete +using NNlib: NNlib, sigmoid, relu, swish +using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, + initialparameters, initialstates, parameterlength, statelength +using Lux: Lux, Chain, Dense, GRUCell, + glorot_uniform, zeros32, + StatefulLuxLayer +using Reexport: @reexport +using Random: AbstractRNG +using GNNlib: GNNlib +@reexport using GNNGraphs + +include("layers/basic.jl") +export GNNLayer, + GNNContainerLayer, + GNNChain + +include("layers/conv.jl") +export AGNNConv, + CGConv, + ChebConv, + EdgeConv, + EGNNConv, + DConv, + GATConv, + GATv2Conv, + GatedGraphConv, + GCNConv, + GINConv, + # GMMConv, + GraphConv, + # MEGNetConv, + NNConv, + # ResGatedGraphConv, + # SAGEConv, + SGConv + # TAGConv, + # TransformerConv + + +end #module + +[.\GNNLux\src\layers\basic.jl] +""" + abstract type GNNLayer <: AbstractExplicitLayer end + +An abstract type from which graph neural network layers are derived. +It is Derived from Lux's `AbstractExplicitLayer` type. + +See also [`GNNChain`](@ref GNNLux.GNNChain). +""" +abstract type GNNLayer <: AbstractExplicitLayer end + +abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end + +@concrete struct GNNChain <: GNNContainerLayer{(:layers,)} + layers <: NamedTuple +end + +GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...) + +function GNNChain(; kw...) + :layers in Base.keys(kw) && + throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) + nt = NamedTuple{keys(kw)}(values(kw)) + nt = map(_wrapforchain, nt) + return GNNChain(nt) +end + +_wrapforchain(l::AbstractExplicitLayer) = l +_wrapforchain(l) = Lux.WrappedFunction(l) + +Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers)) +Base.getindex(c::GNNChain, i::Int) = c.layers[i] +Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) + +function Base.getproperty(c::GNNChain, name::Symbol) + hasfield(typeof(c), name) && return getfield(c, name) + layers = getfield(c, :layers) + hasfield(typeof(layers), name) && return getfield(layers, name) + throw(ArgumentError("$(typeof(c)) has no field or layer $name")) +end + +Base.length(c::GNNChain) = length(c.layers) +Base.lastindex(c::GNNChain) = lastindex(c.layers) +Base.firstindex(c::GNNChain) = firstindex(c.layers) + +LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end]) + +(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st) + +function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times + newst = (;) + for (name, l) in pairs(layers) + x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name)) + newst = merge(newst, (; name => s′)) + end + return x, newst +end + +_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) +_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) +_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) +_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) + +[.\GNNLux\src\layers\conv.jl] +_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false +_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple() +_getstate(s::StatefulLuxLayer{true}) = s.st +_getstate(s::StatefulLuxLayer{false}) = s.st_any + + +@concrete struct GCNConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias + σ +end + +function GCNConv(ch::Pair{Int, Int}, σ = identity; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.outputsize(d::GCNConv) = (d.out_dims,) + +function Base.show(io::IO, l::GCNConv) + print(io, "GCNConv(", l.in_dims, " => ", l.out_dims) + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) = + l(g, x, edge_weight, ps, st; conv_weight, norm_fn) + +function (l::GCNConv)(g, x, edge_weight, ps, st; + norm_fn = d -> 1 ./ sqrt.(d), + conv_weight=nothing, ) + + m = (; ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) + y = GNNlib.gcn_conv(m, g, x, edge_weight, norm_fn, conv_weight) + return y, st +end + +@concrete struct ChebConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + k::Int + init_weight + init_bias +end + +function ChebConv(ch::Pair{Int, Int}, k::Int; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true) + in_dims, out_dims = ch + return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims : + l.in_dims * l.out_dims * l.k +LuxCore.statelength(d::ChebConv) = 0 +LuxCore.outputsize(d::ChebConv) = (d.out_dims,) + +function Base.show(io::IO, l::ChebConv) + print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", k=", l.k) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end + +function (l::ChebConv)(g, x, ps, st) + m = (; ps.weight, bias = _getbias(ps), l.k) + y = GNNlib.cheb_conv(m, g, x) + return y, st + +end + +@concrete struct GraphConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + init_weight + init_bias + σ + aggr +end + +function GraphConv(ch::Pair{Int, Int}, σ = identity; + aggr = +, + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv) + weight1 = l.init_weight(rng, l.out_dims, l.in_dims) + weight2 = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight1, weight2, bias) + else + return (; weight1, weight2) + end +end + +function LuxCore.parameterlength(l::GraphConv) + if l.use_bias + return 2 * l.in_dims * l.out_dims + l.out_dims + else + return 2 * l.in_dims * l.out_dims + end +end + +LuxCore.statelength(d::GraphConv) = 0 +LuxCore.outputsize(d::GraphConv) = (d.out_dims,) + +function Base.show(io::IO, l::GraphConv) + print(io, "GraphConv(", l.in_dims, " => ", l.out_dims) + (l.σ == identity) || print(io, ", ", l.σ) + (l.aggr == +) || print(io, ", aggr=", l.aggr) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end + +function (l::GraphConv)(g, x, ps, st) + m = (; ps.weight1, ps.weight2, bias = _getbias(ps), + l.σ, l.aggr) + return GNNlib.graph_conv(m, g, x), st +end + + +@concrete struct AGNNConv <: GNNLayer + init_beta <: AbstractVector + add_self_loops::Bool + trainable::Bool +end + +function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) + return AGNNConv([init_beta], add_self_loops, trainable) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::AGNNConv) + if l.trainable + return (; β = l.init_beta) + else + return (;) + end +end + +LuxCore.parameterlength(l::AGNNConv) = l.trainable ? 1 : 0 +LuxCore.statelength(d::AGNNConv) = 0 + +function Base.show(io::IO, l::AGNNConv) + print(io, "AGNNConv(", l.init_beta) + l.add_self_loops || print(io, ", add_self_loops=false") + l.trainable || print(io, ", trainable=false") + print(io, ")") +end + +function (l::AGNNConv)(g, x::AbstractMatrix, ps, st) + β = l.trainable ? ps.β : l.init_beta + m = (; β, l.add_self_loops) + return GNNlib.agnn_conv(m, g, x), st +end + +@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)} + in_dims::NTuple{2, Int} + out_dims::Int + dense_f + dense_s + residual::Bool + init_weight + init_bias +end + +CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) + +function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, + use_bias = true, init_weight = glorot_uniform, init_bias = zeros32, + allow_fast_activation = true) + (nin, ein), out = ch + dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation) + dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation) + return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias) +end + +LuxCore.outputsize(l::CGConv) = (l.out_dims,) + +(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::CGConv)(g, x, e, ps, st) + dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f)) + dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s)) + m = (; dense_f, dense_s, l.residual) + return GNNlib.cg_conv(m, g, x, e), st +end + +@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + aggr +end + +EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) + +function Base.show(io::IO, l::EdgeConv) + print(io, "EdgeConv(", l.nn) + print(io, ", aggr=", l.aggr) + print(io, ")") +end + + +function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + m = (; nn, l.aggr) + y = GNNlib.edge_conv(m, g, x) + stnew = _getstate(nn) + return y, stnew +end + + +@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)} + ϕe + ϕx + ϕh + num_features + residual::Bool +end + +function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) +end + +#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], + residual = false) + (in_size, edge_feat_size), out_size = ch + act_fn = swish + + # +1 for the radial feature: ||x_i - x_j||^2 + ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), + Dense(hidden_size => hidden_size, act_fn)) + + ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish), + Dense(hidden_size => out_size)) + + ϕx = Chain(Dense(hidden_size => hidden_size, swish), + Dense(hidden_size => 1, use_bias = false)) + + num_features = (in = in_size, edge = edge_feat_size, out = out_size, + hidden = hidden_size) + if residual + @assert in_size==out_size "Residual connection only possible if in_size == out_size" + end + return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) +end + +LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,) + +(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st) + +function (l::EGNNConv)(g, h, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx)) + ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh)) + m = (; ϕe, ϕx, ϕh, l.residual, l.num_features) + return GNNlib.egnn_conv(m, g, h, x, e), st +end + +function Base.show(io::IO, l::EGNNConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + nh = l.num_features.hidden + print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") + if l.residual + print(io, ", residual=true") + end + print(io, ")") +end + +@concrete struct DConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + init_weight + init_bias + use_bias::Bool +end + +function DConv(ch::Pair{Int, Int}, k::Int; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias = true) + in, out = ch + return DConv(in, out, k, init_weight, init_bias, use_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::DConv) + weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weights, bias) + else + return (; weights) + end +end + +LuxCore.outputsize(l::DConv) = (l.out_dims,) +LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims : + 2 * l.in_dims * l.out_dims * l.k + +function (l::DConv)(g, x, ps, st) + m = (; ps.weights, bias = _getbias(ps), l.k) + return GNNlib.d_conv(m, g, x), st +end + +function Base.show(io::IO, l::DConv) + print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))") +end + +@concrete struct GATConv <: GNNLayer + dense_x + dense_e + init_weight + init_bias + use_bias::Bool + σ + negative_slope + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout +end + + +GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) + +function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init_weight = glorot_uniform, init_bias = zeros32, + use_bias::Bool = true, + add_self_loops = true, dropout=0.0) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in => out * heads, use_bias = false) + dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing + negative_slope = convert(Float32, negative_slope) + return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias, + σ, negative_slope, ch, heads, concat, add_self_loops, dropout) +end + +LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) +##TODO: parameterlength + +function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv) + (in, ein), out = l.channel + dense_x = LuxCore.initialparameters(rng, l.dense_x) + a = l.init_weight(ein > 0 ? 3out : 2out, l.heads) + ps = (; dense_x, a) + if ein > 0 + ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) + end + if l.use_bias + ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) + end + return ps +end + +(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::GATConv)(g, x, e, ps, st) + dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x)) + dense_e = l.dense_e === nothing ? nothing : + StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) + + m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, + ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope) + return GNNlib.gat_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::GATConv) + (in, ein), out = l.channel + print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end + +@concrete struct GATv2Conv <: GNNLayer + dense_i + dense_j + dense_e + init_weight + init_bias + use_bias::Bool + σ + negative_slope + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout +end + +function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) + GATv2Conv((ch[1], 0) => ch[2], args...; kws...) +end + +function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, + σ = identity; + heads::Int = 1, + concat::Bool = true, + negative_slope = 0.2, + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops = true, + dropout=0.0) + + (in, ein), out = ch + + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias) + dense_j = Dense(in => out * heads; use_bias = false, init_weight) + if ein > 0 + dense_e = Dense(ein => out * heads; use_bias = false, init_weight) + else + dense_e = nothing + end + return GATv2Conv(dense_i, dense_j, dense_e, + init_weight, init_bias, use_bias, + σ, negative_slope, + ch, heads, concat, add_self_loops, dropout) +end + + +LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) +##TODO: parameterlength + +function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv) + (in, ein), out = l.channel + dense_i = LuxCore.initialparameters(rng, l.dense_i) + dense_j = LuxCore.initialparameters(rng, l.dense_j) + a = l.init_weight(out, l.heads) + ps = (; dense_i, dense_j, a) + if ein > 0 + ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) + end + if l.use_bias + ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) + end + return ps +end + +(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::GATv2Conv)(g, x, e, ps, st) + dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i)) + dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j)) + dense_e = l.dense_e === nothing ? nothing : + StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) + + m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, + ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope) + return GNNlib.gatv2_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::GATv2Conv) + (in, ein), out = l.channel + print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end + +@concrete struct SGConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias +end + +function SGConv(ch::Pair{Int, Int}, k = 1; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false) + in_dims, out_dims = ch + return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.outputsize(d::SGConv) = (d.out_dims,) + +function Base.show(io::IO, l::SGConv) + print(io, "SGConv(", l.in_dims, " => ", l.out_dims) + l.k || print(io, ", ", l.k) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::SGConv)(g, x, edge_weight, ps, st) + m = (; ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.k) + y = GNNlib.sg_conv(m, g, x, edge_weight) + return y, st +end + +@concrete struct GatedGraphConv <: GNNLayer + gru + init_weight + dims::Int + num_layers::Int + aggr +end + + +function GatedGraphConv(dims::Int, num_layers::Int; + aggr = +, init_weight = glorot_uniform) + gru = GRUCell(dims => dims) + return GatedGraphConv(gru, init_weight, dims, num_layers, aggr) +end + +LuxCore.outputsize(l::GatedGraphConv) = (l.dims,) + +function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv) + gru = LuxCore.initialparameters(rng, l.gru) + weight = l.init_weight(rng, l.dims, l.dims, l.num_layers) + return (; gru, weight) +end + +LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers + + +function (l::GatedGraphConv)(g, x, ps, st) + gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) + fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style + m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) + return GNNlib.gated_graph_conv(m, g, x), st +end + +function Base.show(io::IO, l::GatedGraphConv) + print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@concrete struct GINConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + ϵ <: Real + aggr +end + +GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) + +function (l::GINConv)(g, x, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + m = (; nn, l.ϵ, l.aggr) + y = GNNlib.gin_conv(m, g, x) + stnew = _getstate(nn) + return y, stnew +end + +function Base.show(io::IO, l::GINConv) + print(io, "GINConv($(l.nn)") + print(io, ", $(l.ϵ)") + print(io, ")") +end + +@concrete struct NNConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + aggr + in_dims::Int + out_dims::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias + σ +end + + +function NNConv(ch::Pair{Int, Int}, nn, σ = identity; + aggr = +, + init_bias = zeros32, + use_bias::Bool = true, + init_weight = glorot_uniform, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) +end + +function (l::NNConv)(g, x, edge_weight, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) + y = GNNlib.nn_conv(m, g, x, edge_weight) + stnew = _getstate(nn) + return y, stnew +end + +LuxCore.outputsize(d::NNConv) = (d.out_dims,) + +function Base.show(io::IO, l::NNConv) + print(io, "NNConv($(l.nn)") + print(io, ", $(l.ϵ)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +[.\GNNLux\test\runtests.jl] +using Test +using Lux +using GNNLux +using Random, Statistics + +using ReTestItems +# using Pkg, Preferences, Test +# using InteractiveUtils, Hwloc + +runtests(GNNLux) + +[.\GNNLux\test\shared_testsetup.jl] +@testsetup module SharedTestSetup + +import Reexport: @reexport + +@reexport using Test +@reexport using GNNLux +@reexport using Lux +@reexport using StableRNGs +@reexport using Random, Statistics + +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + +export test_lux_layer + +function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; + outputsize=nothing, sizey=nothing, container=false, + atol=1.0f-2, rtol=1.0f-2) + + if container + @test l isa GNNContainerLayer + else + @test l isa GNNLayer + end + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(l) == LuxCore.statelength(st) + + y, st′ = l(g, x, ps, st) + @test eltype(y) == eltype(x) + if outputsize !== nothing + @test LuxCore.outputsize(l) == outputsize + end + if sizey !== nothing + @test size(y) == sizey + elseif outputsize !== nothing + @test size(y) == (outputsize..., g.num_nodes) + end + + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) +end + +end +[.\GNNLux\test\layers\basic_tests.jl] +@testitem "layers/basic" setup=[SharedTestSetup] begin + rng = StableRNG(17) + g = rand_graph(10, 40) + x = randn(rng, Float32, 3, 10) + + @testset "GNNLayer" begin + @test GNNLayer <: LuxCore.AbstractExplicitLayer + end + + @testset "GNNContainerLayer" begin + @test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer + end + + @testset "GNNChain" begin + @test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} + c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) + test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) + end +end + +[.\GNNLux\test\layers\conv_tests.jl] +@testitem "layers/conv" setup=[SharedTestSetup] begin + rng = StableRNG(1234) + g = rand_graph(10, 40) + in_dims = 3 + out_dims = 5 + x = randn(rng, Float32, in_dims, 10) + + @testset "GCNConv" begin + l = GCNConv(in_dims => out_dims, tanh) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "ChebConv" begin + l = ChebConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "GraphConv" begin + l = GraphConv(in_dims => out_dims, tanh) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "AGNNConv" begin + l = AGNNConv(init_beta=1.0f0) + test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) + end + + @testset "EdgeConv" begin + nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims)) + l = EdgeConv(nn, aggr = +) + test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) + end + + @testset "CGConv" begin + l = CGConv(in_dims => in_dims, residual = true) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) + end + + @testset "DConv" begin + l = DConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(5,)) + end + + @testset "EGNNConv" begin + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + end + + @testset "GATConv" begin + x = randn(rng, Float32, 6, 10) + + l = GATConv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge + end + + @testset "GATv2Conv" begin + x = randn(rng, Float32, 6, 10) + + l = GATv2Conv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge + end + + @testset "SGConv" begin + l = SGConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end + + @testset "GatedGraphConv" begin + l = GatedGraphConv(in_dims, 3) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) + end + + @testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) + end + + @testset "NNConv" begin + edim = 10 + nn = Dense(edim, out_dims * in_dims) + l = NNConv(in_dims => out_dims, nn, tanh, bias = true, aggr = +) + test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) + end +end + +[.\perf\bench_gnn.jl] +using SparseArrays +using GraphNeuralNetworks +using BenchmarkTools +import Random: seed! +using LinearAlgebra + +n = 1024 +seed!(0) +A = sprand(n, n, 0.01) +b = rand(1, n) +B = rand(100, n) + +g = GNNGraph(A, + ndata = (; b = b, B = B), + edata = (; A = reshape(A.nzval, 1, :)), + graph_type = :coo) + +function spmv(g) + propagate((xi, xj, e) -> e .* xj, # same as e_mul_xj + g, +; xj = g.ndata.b, e = g.edata.A) +end + +function spmm1(g) + propagate((xi, xj, e) -> e .* xj, # same as e_mul_xj + g, +; xj = g.ndata.B, e = g.edata.A) +end +function spmm2(g) + propagate(e_mul_xj, + g, +; xj = g.ndata.B, e = vec(g.edata.A)) +end + +# @assert isequal(spmv(g), b * A) # true +# @btime spmv(g) # ~5 ms +# @btime b * A # ~32 us + +@assert isequal(spmm1(g), B * A) # true +@assert isequal(spmm2(g), B * A) # true +@btime spmm1(g) # ~9 ms +@btime spmm2(g) # ~9 ms +@btime B * A # ~400 us + +function spmm_copyxj_fused(g) + propagate(copy_xj, + g, +; xj = g.ndata.B) +end + +function spmm_copyxj_unfused(g) + propagate((xi, xj, e) -> xj, + g, +; xj = g.ndata.B) +end + +Adj = map(x -> x > 0 ? 1 : 0, A) +@assert spmm_copyxj_unfused(g) ≈ B * Adj +@assert spmm_copyxj_fused(g) ≈ B * Adj # bug fixed in #107 + +@btime spmm_copyxj_fused(g) # 268.614 μs (22 allocations: 1.13 MiB) +@btime spmm_copyxj_unfused(g) # 4.263 ms (52855 allocations: 12.23 MiB) +@btime B * Adj # 196.135 μs (2 allocations: 800.05 KiB) + +println() + +[.\perf\neural_ode_mnist.jl] +# Load the packages +using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations +using Flux: onehotbatch, onecold +using Flux.Losses: logitcrossentropy +using Flux +using Statistics: mean +using MLDatasets +using CUDA +# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux + +# device = cpu # `gpu` not working yet +device = CUDA.functional() ? gpu : cpu + +# LOAD DATA +X, y = MNIST(:train)[:] +y = onehotbatch(y, 0:9) + +# Define the Neural GDE +diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) + +nin, nhidden, nout = 28 * 28, 100, 10 +epochs = 10 + +node_chain = Chain(Dense(nhidden => nhidden, tanh), + Dense(nhidden => nhidden)) |> device + +node = NeuralODE(node_chain, + (0.0f0, 1.0f0), Tsit5(), save_everystep = false, + reltol = 1e-3, abstol = 1e-3, save_start = false) |> device + +model = Chain(Flux.flatten, + Dense(nin => nhidden, relu), + node, + diffeqsol_to_array, + Dense(nhidden, nout)) |> device + +# # Training + +# ## Optimizer +opt = Flux.setup(Adam(0.01), model) + +function eval_loss_accuracy(X, y) + ŷ = model(X) + l = logitcrossentropy(ŷ, y) + acc = mean(onecold(ŷ) .== onecold(y)) + return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) +end + +# ## Training Loop +for epoch in 1:epochs + grad = gradient(model) do model + ŷ = model(X) + logitcrossentropy(ŷ, y) + end + Flux.update!(opt, model, grad[1]) + @show eval_loss_accuracy(X, y) +end + +[.\perf\node_classification_cora_geometricflux.jl] +# An example of semi-supervised node classification + +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitcrossentropy +using GeometricFlux, GraphSignals +using MLDatasets: Cora +using Statistics, Random +using CUDA +CUDA.allowscalar(false) + +function eval_loss_accuracy(X, y, ids, model) + ŷ = model(X) + l = logitcrossentropy(ŷ[:, ids], y[:, ids]) + acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) + return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) +end + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 1.0f-3 # learning rate + epochs = 100 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 128 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +function train(; kws...) + args = Args(; kws...) + + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + # LOAD DATA + data = Cora.dataset() + g = FeaturedGraph(data.adjacency_list) |> device + X = data.node_features |> device + y = onehotbatch(data.node_labels, 1:(data.num_classes)) |> device + train_ids = data.train_indices |> device + val_ids = data.val_indices |> device + test_ids = data.test_indices |> device + ytrain = y[:, train_ids] + + nin, nhidden, nout = size(X, 1), args.nhidden, data.num_classes + + ## DEFINE MODEL + model = Chain(GCNConv(g, nin => nhidden, relu), + Dropout(0.5), + GCNConv(g, nhidden => nhidden, relu), + Dense(nhidden, nout)) |> device + + opt = Flux.setup(Adam(args.η), model) + + @info g + + ## LOGGING FUNCTION + function report(epoch) + train = eval_loss_accuracy(X, y, train_ids, model) + test = eval_loss_accuracy(X, y, test_ids, model) + println("Epoch: $epoch Train: $(train) Test: $(test)") + end + + ## TRAINING + report(0) + for epoch in 1:(args.epochs) + grad = Flux.gradient(model) do model + ŷ = model(X) + logitcrossentropy(ŷ[:, train_ids], ytrain) + end + + Flux.update!(opt, model, grad[1]) + + epoch % args.infotime == 0 && report(epoch) + end +end + +train(usecuda = false) + +[.\perf\perf.jl] +using Flux, GraphNeuralNetworks, Graphs, BenchmarkTools, CUDA +using DataFrames, Statistics, JLD2, SparseArrays +CUDA.device!(2) +CUDA.allowscalar(false) + +BenchmarkTools.ratio(::Missing, x) = Inf +BenchmarkTools.ratio(x, ::Missing) = 0.0 +BenchmarkTools.ratio(::Missing, ::Missing) = missing + +function run_single_benchmark(N, c, D, CONV; gtype = :lg) + data = erdos_renyi(N, c / (N - 1), seed = 17) + X = randn(Float32, D, N) + + g = GNNGraph(data; ndata = X, graph_type = gtype) + g_gpu = g |> gpu + + m = CONV(D => D) + ps = Flux.params(m) + + m_gpu = m |> gpu + ps_gpu = Flux.params(m_gpu) + + res = Dict() + + res["CPU_FWD"] = @benchmark $m($g) + res["CPU_GRAD"] = @benchmark gradient(() -> sum($m($g).ndata.x), $ps) + + try + res["GPU_FWD"] = @benchmark CUDA.@sync($m_gpu($g_gpu)) teardown=(GC.gc(); CUDA.reclaim()) + catch + res["GPU_FWD"] = missing + end + + try + res["GPU_GRAD"] = @benchmark CUDA.@sync(gradient(() -> sum($m_gpu($g_gpu).ndata.x), + $ps_gpu)) teardown=(GC.gc(); CUDA.reclaim()) + catch + res["GPU_GRAD"] = missing + end + + return res +end + +""" + run_benchmarks(; + Ns = [10, 100, 1000, 10000], + c = 6, + D = 100, + layers = [GCNConv, GraphConv, GATConv] + ) + +Benchmark GNN layers on Erdos-Renyi random graphs +with average degree `c`. Benchmarks are performed for each graph size in the list `Ns`. +`D` is the number of node features. +""" +function run_benchmarks(; + Ns = [10, 100, 1000, 10000], + c = 6, + D = 100, + layers = [GCNConv, GATConv], + gtypes = [:coo, :sparse, :dense]) + df = DataFrame(N = Int[], c = Float64[], layer = String[], gtype = Symbol[], + time_cpu = Any[], time_gpu = Any[]) |> allowmissing + + for gtype in gtypes + for N in Ns + println("## GRAPH_TYPE = $gtype N = $N") + for CONV in layers + res = run_single_benchmark(N, c, D, CONV; gtype) + row = (; layer = "$CONV", + N = N, + c = c, + gtype = gtype, + time_cpu = ismissing(res["CPU"]) ? missing : median(res["CPU"]), + time_gpu = ismissing(res["GPU"]) ? missing : median(res["GPU"])) + push!(df, row) + end + end + end + + df.gpu_to_cpu = ratio.(df.time_gpu, df.time_cpu) + sort!(df, [:layer, :N, :c, :gtype]) + return df +end + +# df = run_benchmarks() +# for g in groupby(df, :layer); println(g, "\n"); end + +# @save "perf/perf_master_20210803_carlo.jld2" dfmaster=df +## or +# @save "perf/perf_pr.jld2" dfpr=df + +function compare(dfpr, dfmaster; on = [:N, :c, :gtype, :layer]) + df = outerjoin(dfpr, dfmaster; on = on, makeunique = true, + renamecols = :_pr => :_master) + df.pr_to_master_cpu = ratio.(df.time_cpu_pr, df.time_cpu_master) + df.pr_to_master_gpu = ratio.(df.time_gpu_pr, df.time_gpu_master) + return df[:, [:N, :c, :gtype, :layer, :pr_to_master_cpu, :pr_to_master_gpu]] +end + +# @load "perf/perf_pr.jld2" dfpr +# @load "perf/perf_master.jld2" dfmaster +# compare(dfpr, dfmaster) + +[.\src\deprecations.jl] + +# V1.0 deprecations +# TODO doe some reason this is not working +# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight) +# @deprecate (l::GNNLayer)(gs::AbstractVector{<:GNNGraph}, args...; kws...) l(MLUtils.batch(gs), args...; kws...) +[.\src\GraphNeuralNetworks.jl] +module GraphNeuralNetworks + +using Statistics: mean +using LinearAlgebra, Random +using Flux +using Flux: glorot_uniform, leakyrelu, GRUCell, batch +using MacroTools: @forward +using NNlib +using NNlib: scatter, gather +using ChainRulesCore +using Reexport +using MLUtils: zeros_like + +using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, + check_num_nodes, check_num_edges, + EType, NType # for heteroconvs + +@reexport using GNNGraphs +@reexport using GNNlib + +include("layers/basic.jl") +export GNNLayer, + GNNChain, + WithGraph, + DotDecoder + +include("layers/conv.jl") +export AGNNConv, + CGConv, + ChebConv, + DConv, + EdgeConv, + EGNNConv, + GATConv, + GATv2Conv, + GatedGraphConv, + GCNConv, + GINConv, + GMMConv, + GraphConv, + MEGNetConv, + NNConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv + +include("layers/heteroconv.jl") +export HeteroGraphConv + +include("layers/temporalconv.jl") +export TGCN, + A3TGCN, + GConvLSTM, + GConvGRU, + DCGRU + +include("layers/pool.jl") +export GlobalPool, + GlobalAttentionPool, + Set2Set, + TopKPool, + topk_index + +include("deprecations.jl") + +end + +[.\src\layers\basic.jl] +""" + abstract type GNNLayer end + +An abstract type from which graph neural network layers are derived. + +See also [`GNNChain`](@ref). +""" +abstract type GNNLayer end + +# Forward pass with graph-only input. +# To be specialized by layers also needing edge features as input (e.g. NNConv). +(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) + +""" + WithGraph(model, g::GNNGraph; traingraph=false) + +A type wrapping the `model` and tying it to the graph `g`. +In the forward pass, can only take feature arrays as inputs, +returning `model(g, x...; kws...)`. + +If `traingraph=false`, the graph's parameters won't be part of +the `trainable` parameters in the gradient updates. + +# Examples + +```julia +g = GNNGraph([1,2,3], [2,3,1]) +x = rand(Float32, 2, 3) +model = SAGEConv(2 => 3) +wg = WithGraph(model, g) +# No need to feed the graph to `wg` +@assert wg(x) == model(g, x) + +g2 = GNNGraph([1,1,2,3], [2,4,1,1]) +x2 = rand(Float32, 2, 4) +# WithGraph will ignore the internal graph if fed with a new one. +@assert wg(g2, x2) == model(g2, x2) +``` +""" +struct WithGraph{M, G <: GNNGraph} + model::M + g::G + traingraph::Bool +end + +WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph) + +Flux.@layer :expand WithGraph +Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model) + +(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...) +(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...) + +""" + GNNChain(layers...) + GNNChain(name = layer, ...) + +Collects multiple layers / functions to be called in sequence +on given input graph and input node features. + +It allows to compose layers in a sequential fashion as `Flux.Chain` +does, propagating the output of each layer to the next one. +In addition, `GNNChain` handles the input graph as well, providing it +as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type. + +`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`, +and if names are given, `m[:name] == m[1]` etc. + +# Examples + +```jldoctest +julia> using Flux, GraphNeuralNetworks + +julia> m = GNNChain(GCNConv(2=>5), + BatchNorm(5), + x -> relu.(x), + Dense(5, 4)) +GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)) + +julia> x = randn(Float32, 2, 3); + +julia> g = rand_graph(3, 6) +GNNGraph: + num_nodes = 3 + num_edges = 6 + +julia> m(g, x) +4×3 Matrix{Float32}: + -0.795592 -0.795592 -0.795592 + -0.736409 -0.736409 -0.736409 + 0.994925 0.994925 0.994925 + 0.857549 0.857549 0.857549 + +julia> m2 = GNNChain(enc = m, + dec = DotDecoder()) +GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder()) + +julia> m2(g, x) +1×6 Matrix{Float32}: + 2.90053 2.90053 2.90053 2.90053 2.90053 2.90053 + +julia> m2[:enc](g, x) == m(g, x) +true +``` +""" +struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer + layers::T +end + +Flux.@layer :expand GNNChain + +GNNChain(xs...) = GNNChain(xs) + +function GNNChain(; kw...) + :layers in Base.keys(kw) && + throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) + isempty(kw) && return GNNChain(()) + GNNChain(values(kw)) +end + +@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last, + Base.iterate, Base.lastindex, Base.keys, Base.firstindex + +(c::GNNChain)(g::GNNGraph, x) = _applychain(c.layers, g, x) +(c::GNNChain)(g::GNNGraph) = _applychain(c.layers, g) + +## TODO see if this is faster for small chains +## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180 +# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} +# symbols = vcat(:x, [gensym() for _ in 1:N]) +# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N] +# Expr(:block, calls...) +# end +# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x) + +function _applychain(layers, g::GNNGraph, x) # type-unstable path, helps compile times + for l in layers + x = _applylayer(l, g, x) + end + return x +end + +function _applychain(layers, g::GNNGraph) # type-unstable path, helps compile times + for l in layers + g = _applylayer(l, g) + end + return g +end + +# # explicit input +_applylayer(l, g::GNNGraph, x) = l(x) +_applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) + +# input from graph +_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(node_features(g))) +_applylayer(l::GNNLayer, g::GNNGraph) = l(g) + +# # Handle Flux.Parallel +function _applylayer(l::Parallel, g::GNNGraph) + GNNGraph(g, ndata = _applylayer(l, g, node_features(g))) +end + +function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) + closures = map(f -> (x -> _applylayer(f, g, x)), l.layers) + return Parallel(l.connection, closures)(x) +end + +Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]) +function Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) + GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) +end + +function Base.show(io::IO, c::GNNChain) + print(io, "GNNChain(") + _show_layers(io, c.layers) + print(io, ")") +end + +_show_layers(io, layers::Tuple) = join(io, layers, ", ") +function _show_layers(io, layers::NamedTuple) + join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +end +function _show_layers(io, layers::AbstractVector) + (print(io, "["); join(io, layers, ", "); print(io, "]")) +end + +""" + DotDecoder() + +A graph neural network layer that +for given input graph `g` and node features `x`, +returns the dot product `x_i ⋅ xj` on each edge. + +# Examples + +```jldoctest +julia> g = rand_graph(5, 6) +GNNGraph: + num_nodes = 5 + num_edges = 6 + +julia> dotdec = DotDecoder() +DotDecoder() + +julia> dotdec(g, rand(2, 5)) +1×6 Matrix{Float64}: + 0.345098 0.458305 0.106353 0.345098 0.458305 0.106353 +``` +""" +struct DotDecoder <: GNNLayer end + +(::DotDecoder)(g, x) = GNNlib.dot_decoder(g, x) + +[.\src\layers\conv.jl] +@doc raw""" + GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight]) + +Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907). + +Performs the operation +```math +\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j +``` +where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees. + +If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as +```math +a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}} +``` + +The input to the layer is a node feature array `X` of size `(num_features, num_nodes)` +and optionally an edge weight vector. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `σ`: Activation function. Default `identity`. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). + If `add_self_loops=true` the new weights will be set to 1. + This option is ignored if the `edge_weight` is explicitly provided in the forward pass. + Default `false`. + +# Forward + + (::GCNConv)(g::GNNGraph, x, edge_weight = nothing; norm_fn = d -> 1 ./ sqrt.(d), conv_weight = nothing) -> AbstractMatrix + +Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, +and optionally an edge weight vector. Returns a node feature matrix of size +`[out, num_nodes]`. + +The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument. +By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph. +If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix instead of the weights stored in the model. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = GCNConv(3 => 5) + +# forward pass +y = l(g, x) # size: 5 × num_nodes + +# convolution with edge weights and custom normalization function +w = [1.1, 0.1, 2.3, 0.5] +custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function +y = l(g, x, w; norm_fn = custom_norm_fn) + +# Edge weights can also be embedded in the graph. +g = GNNGraph(s, t, w) +l = GCNConv(3 => 5, use_edge_weight=true) +y = l(g, x) # same as l(g, x, w) +``` +""" +struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer + weight::W + bias::B + σ::F + add_self_loops::Bool + use_edge_weight::Bool +end + +Flux.@layer GCNConv + +function GCNConv(ch::Pair{Int, Int}, σ = identity; + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) + in, out = ch + W = init(out, in) + b = bias ? Flux.create_bias(W, true, out) : false + GCNConv(W, b, σ, add_self_loops, use_edge_weight) +end + + +function (l::GCNConv)(g, x, edge_weight = nothing; + norm_fn = d -> 1 ./ sqrt.(d), + conv_weight = nothing) + + return GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) +end + + +function Base.show(io::IO, l::GCNConv) + out, in = size(l.weight) + print(io, "GCNConv($in => $out") + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +@doc raw""" + ChebConv(in => out, k; bias=true, init=glorot_uniform) + +Chebyshev spectral graph convolutional layer from +paper [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375). + +Implements + +```math +X' = \sum^{K-1}_{k=0} W^{(k)} Z^{(k)} +``` + +where ``Z^{(k)}`` is the ``k``-th term of Chebyshev polynomials, and can be calculated by the following recursive form: + +```math +\begin{aligned} +Z^{(0)} &= X \\ +Z^{(1)} &= \hat{L} X \\ +Z^{(k)} &= 2 \hat{L} Z^{(k-1)} - Z^{(k-2)} +\end{aligned} +``` + +with ``\hat{L}`` the [`scaled_laplacian`](@ref). + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `k`: The order of Chebyshev polynomial. +- `bias`: Add learnable bias. +- `init`: Weights' initializer. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = ChebConv(3 => 5, 5) + +# forward pass +y = l(g, x) # size: 5 × num_nodes +``` +""" +struct ChebConv{W <: AbstractArray{<:Number, 3}, B} <: GNNLayer + weight::W + bias::B + k::Int +end + +function ChebConv(ch::Pair{Int, Int}, k::Int; + init = glorot_uniform, bias::Bool = true) + in, out = ch + W = init(out, in, k) + b = bias ? Flux.create_bias(W, true, out) : false + ChebConv(W, b, k) +end + +Flux.@layer ChebConv + +(l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x) + +function Base.show(io::IO, l::ChebConv) + out, in, k = size(l.weight) + print(io, "ChebConv(", in, " => ", out) + print(io, ", k=", k) + print(io, ")") +end + +@doc raw""" + GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform) + +Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244). + +Performs: +```math +\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j +``` + +where the aggregation type is selected by `aggr`. + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `σ`: Activation function. +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). +- `bias`: Add learnable bias. +- `init`: Weights' initializer. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) + +# forward pass +y = l(g, x) +``` +""" +struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer + weight1::W + weight2::W + bias::B + σ::F + aggr::A +end + +Flux.@layer GraphConv + +function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, + init = glorot_uniform, bias::Bool = true) + in, out = ch + W1 = init(out, in) + W2 = init(out, in) + b = bias ? Flux.create_bias(W1, true, out) : false + GraphConv(W1, W2, b, σ, aggr) +end + +(l::GraphConv)(g, x) = GNNlib.graph_conv(l, g, x) + +function Base.show(io::IO, l::GraphConv) + in_channel = size(l.weight1, ndims(l.weight1)) + out_channel = size(l.weight1, ndims(l.weight1) - 1) + print(io, "GraphConv(", in_channel, " => ", out_channel) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@doc raw""" + GATConv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops]) + GATConv((in, ein) => out, ...) + +Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.org/abs/1710.10903). + +Implements the operation +```math +\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j +``` +where the attention coefficients ``\alpha_{ij}`` are given by +```math +\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i; W \mathbf{x}_j])) +``` +with ``z_i`` a normalization factor. + +In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass +and the attention coefficients will be calculated as +```math +\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W_e \mathbf{e}_{j\to i}; W \mathbf{x}_i; W \mathbf{x}_j])) +``` + +# Arguments + +- `in`: The dimension of input node features. +- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward). +- `out`: The dimension of output node features. +- `σ`: Activation function. Default `identity`. +- `bias`: Learn the additive bias if true. Default `true`. +- `heads`: Number attention heads. Default `1`. +- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`. +- `negative_slope`: The parameter of LeakyReLU.Default `0.2`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. +- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; heads=2, concat=true) + +# forward pass +y = l(g, x) +``` +""" +struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer + dense_x::DX + dense_e::DE + bias::B + a::A + σ::F + negative_slope::T + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout::DV +end + +Flux.@layer GATConv +Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a) + +GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) + +function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init = glorot_uniform, bias::Bool = true, add_self_loops = true, dropout=0.0) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in, out * heads, bias = false) + dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing + b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false + a = init(ein > 0 ? 3out : 2out, heads) + negative_slope = convert(Float32, negative_slope) + GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout) +end + +(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +(l::GATConv)(g, x, e = nothing) = GNNlib.gat_conv(l, g, x, e) + +function Base.show(io::IO, l::GATConv) + (in, ein), out = l.channel + print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end + +@doc raw""" + GATv2Conv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops]) + GATv2Conv((in, ein) => out, ...) + + +GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491). + +Implements the operation +```math +\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j +``` +where the attention coefficients ``\alpha_{ij}`` are given by +```math +\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_2 \mathbf{x}_i + W_1 \mathbf{x}_j)) +``` +with ``z_i`` a normalization factor. + +In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass +and the attention coefficients will be calculated as +```math +\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_3 \mathbf{e}_{j\to i} + W_2 \mathbf{x}_i + W_1 \mathbf{x}_j)). +``` + +# Arguments + +- `in`: The dimension of input node features. +- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward). +- `out`: The dimension of output node features. +- `σ`: Activation function. Default `identity`. +- `bias`: Learn the additive bias if true. Default `true`. +- `heads`: Number attention heads. Default `1`. +- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`. +- `negative_slope`: The parameter of LeakyReLU.Default `0.2`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. +- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`. + +# Examples +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +ein = 3 +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false) + +# edge features +e = randn(Float32, ein, length(s)) + +# forward pass +y = l(g, x, e) +``` +""" +struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer + dense_i::A1 + dense_j::A2 + dense_e::A3 + bias::B + a::C + σ::F + negative_slope::T + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout::DV +end + +Flux.@layer GATv2Conv +Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) + +function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) + GATv2Conv((ch[1], 0) => ch[2], args...; kws...) +end + +function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, + σ = identity; + heads::Int = 1, + concat::Bool = true, + negative_slope = 0.2, + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + dropout=0.0) + (in, ein), out = ch + + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_i = Dense(in, out * heads; bias = bias, init = init) + dense_j = Dense(in, out * heads; bias = false, init = init) + if ein > 0 + dense_e = Dense(ein, out * heads; bias = false, init = init) + else + dense_e = nothing + end + b = bias ? Flux.create_bias(dense_i.weight, true, concat ? out * heads : out) : false + a = init(out, heads) + return GATv2Conv(dense_i, dense_j, dense_e, + b, a, σ, negative_slope, ch, heads, concat, + add_self_loops, dropout) +end + +(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +(l::GATv2Conv)(g, x, e=nothing) = GNNlib.gatv2_conv(l, g, x, e) + +function Base.show(io::IO, l::GATv2Conv) + (in, ein), out = l.channel + print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end + +@doc raw""" + GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform) + +Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493). + +Implements the recursion +```math +\begin{aligned} +\mathbf{h}^{(0)}_i &= [\mathbf{x}_i; \mathbf{0}] \\ +\mathbf{h}^{(l)}_i &= GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j) +\end{aligned} +``` + +where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing through GRU. The dimension of input ``\mathbf{x}_i`` needs to be less or equal to `out`. + +# Arguments + +- `out`: The dimension of output features. +- `num_layers`: The number of recursion steps. +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). +- `init`: Weight initialization function. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +out_channel = 5 +num_layers = 3 +g = GNNGraph(s, t) + +# create layer +l = GatedGraphConv(out_channel, num_layers) + +# forward pass +y = l(g, x) +``` +""" +struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer + weight::W + gru::R + dims::Int + num_layers::Int + aggr::A +end + +Flux.@layer GatedGraphConv + +function GatedGraphConv(dims::Int, num_layers::Int; + aggr = +, init = glorot_uniform) + w = init(dims, dims, num_layers) + gru = GRUCell(dims => dims) + GatedGraphConv(w, gru, dims, num_layers, aggr) +end + + +(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H) + +function Base.show(io::IO, l::GatedGraphConv) + print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@doc raw""" + EdgeConv(nn; aggr=max) + +Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829). + +Performs the operation +```math +\mathbf{x}_i' = \square_{j \in N(i)}\, nn([\mathbf{x}_i; \mathbf{x}_j - \mathbf{x}_i]) +``` + +where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron. + +# Arguments + +- `nn`: A (possibly learnable) function. +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) + +# create layer +l = EdgeConv(Dense(2 * in_channel, out_channel), aggr = +) + +# forward pass +y = l(g, x) +``` +""" +struct EdgeConv{NN, A} <: GNNLayer + nn::NN + aggr::A +end + +Flux.@layer :expand EdgeConv + +EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) + +(l::EdgeConv)(g, x) = GNNlib.edge_conv(l, g, x) + +function Base.show(io::IO, l::EdgeConv) + print(io, "EdgeConv(", l.nn) + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@doc raw""" + GINConv(f, ϵ; aggr=+) + +Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf). + +Implements the graph convolution +```math +\mathbf{x}_i' = f_\Theta\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right) +``` +where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron. + +# Arguments + +- `f`: A (possibly learnable) function acting on node features. +- `ϵ`: Weighting factor. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) + +# create dense layer +nn = Dense(in_channel, out_channel) + +# create layer +l = GINConv(nn, 0.01f0, aggr = mean) + +# forward pass +y = l(g, x) +``` +""" +struct GINConv{R <: Real, NN, A} <: GNNLayer + nn::NN + ϵ::R + aggr::A +end + +Flux.@layer :expand GINConv +Flux.trainable(l::GINConv) = (nn = l.nn,) + +GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) + +(l::GINConv)(g, x) = GNNlib.gin_conv(l, g, x) + +function Base.show(io::IO, l::GINConv) + print(io, "GINConv($(l.nn)") + print(io, ", $(l.ϵ)") + print(io, ")") +end + +@doc raw""" + NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform) + +The continuous kernel-based convolutional operator from the +[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper. +This convolution is also known as the edge-conditioned convolution from the +[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper. + +Performs the operation + +```math +\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j +``` + +where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron). +Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`, +the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`. +For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed. + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `f`: A (possibly learnable) function acting on edge features. +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). +- `σ`: Activation function. +- `bias`: Add learnable bias. +- `init`: Weights' initializer. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +edim = 10 +g = GNNGraph(s, t) + +# create dense layer +nn = Dense(edim => out_channel * in_channel) + +# create layer +l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) + +# forward pass +y = l(g, x) +``` +""" +struct NNConv{W, B, NN, F, A} <: GNNLayer + weight::W + bias::B + nn::NN + σ::F + aggr::A +end + +Flux.@layer :expand NNConv + +function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, + init = glorot_uniform) + in, out = ch + W = init(out, in) + b = bias ? Flux.create_bias(W, true, out) : false + return NNConv(W, b, nn, σ, aggr) +end + +(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) + +(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +function Base.show(io::IO, l::NNConv) + out, in = size(l.weight) + print(io, "NNConv($in => $out") + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@doc raw""" + SAGEConv(in => out, σ=identity; aggr=mean, bias=true, init=glorot_uniform) + +GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf). + +Performs: +```math +\mathbf{x}_i' = W \cdot [\mathbf{x}_i; \square_{j \in \mathcal{N}(i)} \mathbf{x}_j] +``` + +where the aggregation type is selected by `aggr`. + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `σ`: Activation function. +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). +- `bias`: Add learnable bias. +- `init`: Weights' initializer. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) + +# create layer +l = SAGEConv(in_channel => out_channel, tanh, bias = false, aggr = +) + +# forward pass +y = l(g, x) +``` +""" +struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer + weight::W + bias::B + σ::F + aggr::A +end + +Flux.@layer SAGEConv + +function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, + init = glorot_uniform, bias::Bool = true) + in, out = ch + W = init(out, 2 * in) + b = bias ? Flux.create_bias(W, true, out) : false + SAGEConv(W, b, σ, aggr) +end + +(l::SAGEConv)(g, x) = GNNlib.sage_conv(l, g, x) + +function Base.show(io::IO, l::SAGEConv) + out_channel, in_channel = size(l.weight) + print(io, "SAGEConv(", in_channel ÷ 2, " => ", out_channel) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@doc raw""" + ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true) + +The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) paper. + +The layer's forward pass is given by + +```math +\mathbf{x}_i' = act\big(U\mathbf{x}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big), +``` +where the edge gates ``\eta_{ij}`` are given by + +```math +\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j). +``` + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `act`: Activation function. +- `init`: Weight matrices' initializing function. +- `bias`: Learn an additive bias if true. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +in_channel = 3 +out_channel = 5 +g = GNNGraph(s, t) + +# create layer +l = ResGatedGraphConv(in_channel => out_channel, tanh, bias = true) + +# forward pass +y = l(g, x) +``` +""" +struct ResGatedGraphConv{W, B, F} <: GNNLayer + A::W + B::W + U::W + V::W + bias::B + σ::F +end + +Flux.@layer ResGatedGraphConv + +function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; + init = glorot_uniform, bias::Bool = true) + in, out = ch + A = init(out, in) + B = init(out, in) + U = init(out, in) + V = init(out, in) + b = bias ? Flux.create_bias(A, true, out) : false + return ResGatedGraphConv(A, B, U, V, b, σ) +end + +(l::ResGatedGraphConv)(g, x) = GNNlib.res_gated_graph_conv(l, g, x) + +function Base.show(io::IO, l::ResGatedGraphConv) + out_channel, in_channel = size(l.A) + print(io, "ResGatedGraphConv(", in_channel, " => ", out_channel) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +@doc raw""" + CGConv((in, ein) => out, act=identity; bias=true, init=glorot_uniform, residual=false) + CGConv(in => out, ...) + +The crystal graph convolutional layer from the paper +[Crystal Graph Convolutional Neural Networks for an Accurate and +Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf). +Performs the operation + +```math +\mathbf{x}_i' = \mathbf{x}_i + \sum_{j\in N(i)}\sigma(W_f \mathbf{z}_{ij} + \mathbf{b}_f)\, act(W_s \mathbf{z}_{ij} + \mathbf{b}_s) +``` + +where ``\mathbf{z}_{ij}`` is the node and edge features concatenation +``[\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{j\to i}]`` +and ``\sigma`` is the sigmoid function. +The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same +as the input size. + +# Arguments + +- `in`: The dimension of input node features. +- `ein`: The dimension of input edge features. +If `ein` is not given, assumes that no edge features are passed as input in the forward pass. +- `out`: The dimension of output node features. +- `act`: Activation function. +- `bias`: Add learnable bias. +- `init`: Weights' initializer. +- `residual`: Add a residual connection. + +# Examples + +```julia +g = rand_graph(5, 6) +x = rand(Float32, 2, g.num_nodes) +e = rand(Float32, 3, g.num_edges) + +l = CGConv((2, 3) => 4, tanh) +y = l(g, x, e) # size: (4, num_nodes) + +# No edge features +l = CGConv(2 => 4, tanh) +y = l(g, x) # size: (4, num_nodes) +``` +""" +struct CGConv{D1, D2} <: GNNLayer + ch::Pair{NTuple{2, Int}, Int} + dense_f::D1 + dense_s::D2 + residual::Bool +end + +Flux.@layer CGConv + +CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) + +function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, + bias = true, init = glorot_uniform) + (nin, ein), out = ch + dense_f = Dense(2nin + ein, out, sigmoid; bias, init) + dense_s = Dense(2nin + ein, out, act; bias, init) + return CGConv(ch, dense_f, dense_s, residual) +end + +(l::CGConv)(g, x, e = nothing) = GNNlib.cg_conv(l, g, x, e) + + +(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +function Base.show(io::IO, l::CGConv) + print(io, "CGConv($(l.ch)") + l.dense_s.σ == identity || print(io, ", ", l.dense_s.σ) + print(io, ", residual=$(l.residual)") + print(io, ")") +end + +@doc raw""" + AGNNConv(; init_beta=1.0f0, trainable=true, add_self_loops=true) + +Attention-based Graph Neural Network layer from paper [Attention-based +Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735). + +The forward pass is given by +```math +\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} \mathbf{x}_j +``` +where the attention coefficients ``\alpha_{ij}`` are given by +```math +\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}} + {\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_{j'})}} +``` +with the cosine distance defined by +```math +\cos(\mathbf{x}_i, \mathbf{x}_j) = + \frac{\mathbf{x}_i \cdot \mathbf{x}_j}{\lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert} +``` +and ``\beta`` a trainable parameter if `trainable=true`. + +# Arguments + +- `init_beta`: The initial value of ``\beta``. Default 1.0f0. +- `trainable`: If true, ``\beta`` is trainable. Default `true`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. + +# Examples: + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s, t) + +# create layer +l = AGNNConv(init_beta=2.0f0) + +# forward pass +y = l(g, x) +``` +""" +struct AGNNConv{A <: AbstractVector} <: GNNLayer + β::A + add_self_loops::Bool + trainable::Bool +end + +Flux.@layer AGNNConv + +Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;) + +function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) + AGNNConv([init_beta], add_self_loops, trainable) +end + +(l::AGNNConv)(g, x) = GNNlib.agnn_conv(l, g, x) + +@doc raw""" + MEGNetConv(ϕe, ϕv; aggr=mean) + MEGNetConv(in => out; aggr=mean) + +Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf) +paper. In the forward pass, takes as inputs node features `x` and edge features `e` and returns +updated features `x'` and `e'` according to + +```math +\begin{aligned} +\mathbf{e}_{i\to j}' = \phi_e([\mathbf{x}_i;\, \mathbf{x}_j;\, \mathbf{e}_{i\to j}]),\\ +\mathbf{x}_{i}' = \phi_v([\mathbf{x}_i;\, \square_{j\in \mathcal{N}(i)}\,\mathbf{e}_{j\to i}']). +\end{aligned} +``` + +`aggr` defines the aggregation to be performed. + +If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from +the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu` +activations. + +# Examples + +```julia +g = rand_graph(10, 30) +x = randn(Float32, 3, 10) +e = randn(Float32, 3, 30) +m = MEGNetConv(3 => 3) +x′, e′ = m(g, x, e) +``` +""" +struct MEGNetConv{TE, TV, A} <: GNNLayer + ϕe::TE + ϕv::TV + aggr::A +end + +Flux.@layer :expand MEGNetConv + +MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) + +function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) + nin, nout = ch + ϕe = Chain(Dense(3nin, nout, relu), + Dense(nout, nout)) + + ϕv = Chain(Dense(nin + nout, nout, relu), + Dense(nout, nout)) + + return MEGNetConv(ϕe, ϕv; aggr) +end + +function (l::MEGNetConv)(g::GNNGraph) + x, e = l(g, node_features(g), edge_features(g)) + return GNNGraph(g, ndata = x, edata = e) +end + +(l::MEGNetConv)(g, x, e) = GNNlib.megnet_conv(l, g, x, e) + +@doc raw""" + GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false) + +Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402) +Performs the operation +```math +\mathbf{x}_i' = \mathbf{x}_i + \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j +``` +where ``w^a_{k}(e^a)`` for feature `a` and kernel `k` is given by +```math +w^a_{k}(e^a) = \exp(-\frac{1}{2}(e^a - \mu^a_k)^T (\Sigma^{-1})^a_k(e^a - \mu^a_k)) +``` +``\Theta_k, \mu^a_k, (\Sigma^{-1})^a_k`` are learnable parameters. + +The input to the layer is a node feature array `x` of size `(num_features, num_nodes)` and +edge pseudo-coordinate array `e` of size `(num_features, num_edges)` +The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same +as the input size. + +# Arguments + +- `in`: Number of input node features. +- `ein`: Number of input edge features. +- `out`: Number of output features. +- `σ`: Activation function. Default `identity`. +- `K`: Number of kernels. Default `1`. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `residual`: Residual conncetion. Default `false`. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s,t) +nin, ein, out, K = 4, 10, 7, 8 +x = randn(Float32, nin, g.num_nodes) +e = randn(Float32, ein, g.num_edges) + +# create layer +l = GMMConv((nin, ein) => out, K=K) + +# forward pass +l(g, x, e) +``` +""" +struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer + mu::A + sigma_inv::A + bias::B + σ::F + ch::Pair{NTuple{2, Int}, Int} + K::Int + dense_x::Dense + residual::Bool +end + +Flux.@layer GMMConv + +function GMMConv(ch::Pair{NTuple{2, Int}, Int}, + σ = identity; + K::Int = 1, + bias::Bool = true, + init = Flux.glorot_uniform, + residual = false) + (nin, ein), out = ch + mu = init(ein, K) + sigma_inv = init(ein, K) + b = bias ? Flux.create_bias(mu, true, out) : false + dense_x = Dense(nin, out * K, bias = false) + GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual) +end + +(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e) + +(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +function Base.show(io::IO, l::GMMConv) + (nin, ein), out = l.ch + print(io, "GMMConv((", nin, ",", ein, ")=>", out) + l.σ == identity || print(io, ", σ=", l.dense_s.σ) + print(io, ", K=", l.K) + l.residual == true || print(io, ", residual=", l.residual) + print(io, ")") +end + +@doc raw""" + SGConv(int => out, k=1; [bias, init, add_self_loops, use_edge_weight]) + +SGC layer from [Simplifying Graph Convolutional Networks](https://arxiv.org/pdf/1902.07153.pdf) +Performs operation +```math +H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta +``` +where ``\tilde{A}`` is ``A + I``. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k` : Number of hops k. Default `1`. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). + If `add_self_loops=true` the new weights will be set to 1. Default `false`. + +# Examples + +```julia +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = SGConv(3 => 5; add_self_loops = true) + +# forward pass +y = l(g, x) # size: 5 × num_nodes + +# convolution with edge weights +w = [1.1, 0.1, 2.3, 0.5] +y = l(g, x, w) + +# Edge weights can also be embedded in the graph. +g = GNNGraph(s, t, w) +l = SGConv(3 => 5, add_self_loops = true, use_edge_weight=true) +y = l(g, x) # same as l(g, x, w) +``` +""" +struct SGConv{A <: AbstractMatrix, B} <: GNNLayer + weight::A + bias::B + k::Int + add_self_loops::Bool + use_edge_weight::Bool +end + +Flux.@layer SGConv + +function SGConv(ch::Pair{Int, Int}, k = 1; + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) + in, out = ch + W = init(out, in) + b = bias ? Flux.create_bias(W, true, out) : false + return SGConv(W, b, k, add_self_loops, use_edge_weight) +end + +(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight) + +function Base.show(io::IO, l::SGConv) + out, in = size(l.weight) + print(io, "SGConv($in => $out") + l.k == 1 || print(io, ", ", l.k) + print(io, ")") +end + +@doc raw""" + TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false) + +TAGConv layer from [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/pdf/1710.10370.pdf). +This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data. +It performs the operation: + +```math +H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k} +``` + +where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`. + +# Arguments +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Maximum number of hops to consider. Default is `3`. +- `bias`: Whether to include a learnable bias term. Default is `true`. +- `init`: Initialization function for the weights. Default is `glorot_uniform`. +- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`. +- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`. + +# Examples + +```julia +# Example graph data +s = [1, 1, 2, 3] +t = [2, 3, 1, 1] +g = GNNGraph(s, t) # Create a graph +x = randn(Float32, 3, g.num_nodes) # Random features for each node + +# Create a TAGConv layer +l = TAGConv(3 => 5, k=3; add_self_loops=true) + +# Apply the TAGConv layer +y = l(g, x) # Output size: 5 × num_nodes +``` +""" +struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer + weight::A + bias::B + k::Int + add_self_loops::Bool + use_edge_weight::Bool +end + +Flux.@layer TAGConv + +function TAGConv(ch::Pair{Int, Int}, k = 3; + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) + in, out = ch + W = init(out, in) + b = bias ? Flux.create_bias(W, true, out) : false + return TAGConv(W, b, k, add_self_loops, use_edge_weight) +end + +(l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight) + +function Base.show(io::IO, l::TAGConv) + out, in = size(l.weight) + print(io, "TAGConv($in => $out") + l.k == 1 || print(io, ", ", l.k) + print(io, ")") +end + +@doc raw""" + EGNNConv((in, ein) => out; hidden_size=2in, residual=false) + EGNNConv(in => out; hidden_size=2in, residual=false) + +Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph +Neural Networks](https://arxiv.org/abs/2102.09844). + +The layer performs the following operation: + +```math +\begin{aligned} +\mathbf{m}_{j\to i} &=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\ +\mathbf{x}_i' &= \mathbf{x}_i + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\ +\mathbf{m}_i &= C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\ +\mathbf{h}_i' &= \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i) +\end{aligned} +``` +where ``\mathbf{h}_i``, ``\mathbf{x}_i``, ``\mathbf{e}_{j\to i}`` are invariant node features, equivariant node +features, and edge features respectively. ``\phi_e``, ``\phi_h``, and +``\phi_x`` are two-layer MLPs. `C` is a constant for normalization, +computed as ``1/|\mathcal{N}(i)|``. + + +# Constructor Arguments + +- `in`: Number of input features for `h`. +- `out`: Number of output features for `h`. +- `ein`: Number of input edge features. +- `hidden_size`: Hidden representation size. +- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`. + +# Forward Pass + + l(g, x, h, e=nothing) + +## Forward Pass Arguments: + +- `g` : The graph. +- `x` : Matrix of equivariant node coordinates. +- `h` : Matrix of invariant node features. +- `e` : Matrix of invariant edge features. Default `nothing`. + +Returns updated `h` and `x`. + +# Examples + +```julia +g = rand_graph(10, 10) +h = randn(Float32, 5, g.num_nodes) +x = randn(Float32, 3, g.num_nodes) +egnn = EGNNConv(5 => 6, 10) +hnew, xnew = egnn(g, h, x) +``` +""" +struct EGNNConv{TE, TX, TH, NF} <: GNNLayer + ϕe::TE + ϕx::TX + ϕh::TH + num_features::NF + residual::Bool +end + +Flux.@layer EGNNConv + +function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) +end + +#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], + residual = false) + (in_size, edge_feat_size), out_size = ch + act_fn = swish + + # +1 for the radial feature: ||x_i - x_j||^2 + ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), + Dense(hidden_size => hidden_size, act_fn)) + + ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish), + Dense(hidden_size, out_size)) + + ϕx = Chain(Dense(hidden_size, hidden_size, swish), + Dense(hidden_size, 1, bias = false)) + + num_features = (in = in_size, edge = edge_feat_size, out = out_size, + hidden = hidden_size) + if residual + @assert in_size==out_size "Residual connection only possible if in_size == out_size" + end + return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) +end + +(l::EGNNConv)(g, h, x, e = nothing) = GNNlib.egnn_conv(l, g, h, x, e) + +function Base.show(io::IO, l::EGNNConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + nh = l.num_features.hidden + print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") + if l.residual + print(io, ", residual=true") + end + print(io, ")") +end + +@doc raw""" + TransformerConv((in, ein) => out; [heads, concat, init, add_self_loops, bias_qkv, + bias_root, root_weight, gating, skip_connection, batch_norm, ff_channels])) + +The transformer-like multi head attention convolutional operator from the +[Masked Label Prediction: Unified Message Passing Model for Semi-Supervised +Classification](https://arxiv.org/abs/2009.03509) paper, which also considers +edge features. +It further contains options to also be configured as the transformer-like convolutional operator from the +[Attention, Learn to Solve Routing Problems!](https://arxiv.org/abs/1706.03762) paper, +including a successive feed-forward network as well as skip layers and batch normalization. + +The layer's basic forward pass is given by +```math +x_i' = W_1x_i + \sum_{j\in N(i)} \alpha_{ij} (W_2 x_j + W_6e_{ij}) +``` +where the attention scores are +```math +\alpha_{ij} = \mathrm{softmax}\left(\frac{(W_3x_i)^T(W_4x_j+ +W_6e_{ij})}{\sqrt{d}}\right). +``` + +Optionally, a combination of the aggregated value with transformed root node features +by a gating mechanism via +```math +x'_i = \beta_i W_1 x_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} +\alpha_{i,j} W_2 x_j \right)}_{=m_i} +``` +with +```math +\beta_i = \textrm{sigmoid}(W_5^{\top} [ W_1 x_i, m_i, W_1 x_i - m_i ]). +``` +can be performed. + +# Arguments + +- `in`: Dimension of input features, which also corresponds to the dimension of + the output features. +- `ein`: Dimension of the edge features; if 0, no edge features will be used. +- `out`: Dimension of the output. +- `heads`: Number of heads in output. Default `1`. +- `concat`: Concatenate layer output or not. If not, layer output is averaged + over the heads. Default `true`. +- `init`: Weight matrices' initializing function. Default `glorot_uniform`. +- `add_self_loops`: Add self loops to the input graph. Default `false`. +- `bias_qkv`: If set, bias is used in the key, query and value transformations for nodes. + Default `true`. +- `bias_root`: If set, the layer will also learn an additive bias for the root when root + weight is used. Default `true`. +- `root_weight`: If set, the layer will add the transformed root node features + to the output. Default `true`. +- `gating`: If set, will combine aggregation and transformed root node features by a + gating mechanism. Default `false`. +- `skip_connection`: If set, a skip connection will be made from the input and + added to the output. Default `false`. +- `batch_norm`: If set, a batch normalization will be applied to the output. Default `false`. +- `ff_channels`: If positive, a feed-forward NN is appended, with the first having the given + number of hidden nodes; this NN also gets a skip connection and batch normalization + if the respective parameters are set. Default: `0`. + +# Examples + +```julia +N, in_channel, out_channel = 4, 3, 5 +ein, heads = 2, 3 +g = GNNGraph([1,1,2,4], [2,3,1,1]) +l = TransformerConv((in_channel, ein) => in_channel; heads, gating = true, bias_qkv = true) +x = rand(Float32, in_channel, N) +e = rand(Float32, ein, g.num_edges) +l(g, x, e) +``` +""" +struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLayer + W1::TW1 + W2::TW2 + W3::TW3 + W4::TW4 + W5::TW5 + W6::TW6 + FF::TFF + BN1::TBN1 + BN2::TBN2 + channels::Pair{NTuple{2, Int}, Int} + heads::Int + add_self_loops::Bool + concat::Bool + skip_connection::Bool + sqrt_out::Float32 +end + +Flux.@layer TransformerConv + +function Flux.trainable(l::TransformerConv) + (; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) +end + +function TransformerConv(ch::Pair{Int, Int}, args...; kws...) + TransformerConv((ch[1], 0) => ch[2], args...; kws...) +end + +function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; + heads::Int = 1, + concat::Bool = true, + init = glorot_uniform, + add_self_loops::Bool = false, + bias_qkv = true, + bias_root::Bool = true, + root_weight::Bool = true, + gating::Bool = false, + skip_connection::Bool = false, + batch_norm::Bool = false, + ff_channels::Int = 0) + (in, ein), out = ch + + if add_self_loops + @assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + W1 = root_weight ? + Dense(in, out * (concat ? heads : 1); bias = bias_root, init = init) : nothing + W2 = Dense(in => out * heads; bias = bias_qkv, init = init) + W3 = Dense(in => out * heads; bias = bias_qkv, init = init) + W4 = Dense(in => out * heads; bias = bias_qkv, init = init) + out_mha = out * (concat ? heads : 1) + W5 = gating ? Dense(3 * out_mha => 1, sigmoid; bias = false, init = init) : nothing + W6 = ein > 0 ? Dense(ein => out * heads; bias = bias_qkv, init = init) : nothing + FF = ff_channels > 0 ? + Chain(Dense(out_mha => ff_channels, relu), + Dense(ff_channels => out_mha)) : nothing + BN1 = batch_norm ? BatchNorm(out_mha) : nothing + BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing + + return TransformerConv(W1, W2, W3, W4, W5, W6, FF, BN1, BN2, + ch, heads, add_self_loops, concat, skip_connection, + Float32(√out)) +end + +(l::TransformerConv)(g, x, e = nothing) = GNNlib.transformer_conv(l, g, x, e) + +function (l::TransformerConv)(g::GNNGraph) + GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +end + +function Base.show(io::IO, l::TransformerConv) + (in, ein), out = l.channels + print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") +end + +""" + DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) + +Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). + +# Arguments + +- `ch`: Pair of input and output dimensions. +- `k`: Number of diffusion steps. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `bias`: Add learnable bias. Default `true`. + +# Examples +``` +julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10)); + +julia> dconv = DConv(2 => 4, 4) +DConv(2 => 4, 4) + +julia> y = dconv(g, g.ndata.x); + +julia> size(y) +(4, 10) +``` +""" +struct DConv <: GNNLayer + in::Int + out::Int + weights::AbstractArray + bias::AbstractArray + k::Int +end + +Flux.@layer DConv + +function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) + in, out = ch + weights = init(2, k, out, in) + b = bias ? Flux.create_bias(weights, true, out) : false + return DConv(in, out, weights, b, k) +end + +(l::DConv)(g, x) = GNNlib.d_conv(l, g, x) + +function Base.show(io::IO, l::DConv) + print(io, "DConv($(l.in) => $(l.out), $(l.k))") +end + +[.\src\layers\heteroconv.jl] +@doc raw""" + HeteroGraphConv(itr; aggr = +) + HeteroGraphConv(pairs...; aggr = +) + +A convolutional layer for heterogeneous graphs. + +The `itr` argument is an iterator of `pairs` of the form `edge_t => layer`, where `edge_t` is a +3-tuple of the form `(src_node_type, edge_type, dst_node_type)`, and `layer` is a +convolutional layers for homogeneous graphs. + +Each convolution is applied to the corresponding relation. +Since a node type can be involved in multiple relations, the single convolution outputs +have to be aggregated using the `aggr` function. The default is to sum the outputs. + +# Forward Arguments + +* `g::GNNHeteroGraph`: The input graph. +* `x::Union{NamedTuple,Dict}`: The input node features. The keys are node types and the + values are node feature tensors. + +# Examples + +```jldoctest +julia> g = rand_bipartite_heterograph((10, 15), 20) +GNNHeteroGraph: + num_nodes: Dict(:A => 10, :B => 15) + num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) + +julia> x = (A = rand(Float32, 64, 10), B = rand(Float32, 64, 15)); + +julia> layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu), + (:B, :to, :A) => GraphConv(64 => 32, relu)); + +julia> y = layer(g, x); # output is a named tuple + +julia> size(y.A) == (32, 10) && size(y.B) == (32, 15) +true +``` +""" +struct HeteroGraphConv + etypes::Vector{EType} + layers::Vector{<:GNNLayer} + aggr::Function +end + +Flux.@layer HeteroGraphConv + +HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr) +HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr) + +function HeteroGraphConv(itr; aggr = +) + etypes = [k[1] for k in itr] + layers = [k[2] for k in itr] + return HeteroGraphConv(etypes, layers, aggr) +end + +function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict}) + function forw(l, et) + sg = edge_type_subgraph(g, et) + node1_t, _, node2_t = et + return l(sg, (x[node1_t], x[node2_t])) + end + outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)] + dst_ntypes = [et[3] for et in hgc.etypes] + return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) +end + +function _reduceby_node_t(aggr, outs, ntypes) + function _reduce(node_t) + idxs = findall(x -> x == node_t, ntypes) + if length(idxs) == 0 + return nothing + elseif length(idxs) == 1 + return outs[idxs[1]] + else + return foldl(aggr, outs[i] for i in idxs) + end + end + # workaround to provide the aggregation once per unique node type, + # gradient is not needed + unique_ntypes = ChainRulesCore.ignore_derivatives() do + unique(ntypes) + end + vals = [_reduce(node_t) for node_t in unique_ntypes] + return NamedTuple{tuple(unique_ntypes...)}(vals) +end + +function Base.show(io::IO, hgc::HeteroGraphConv) + if get(io, :compact, false) + print(io, "HeteroGraphConv(aggr=$(hgc.aggr))") + else + println(io, "HeteroGraphConv(aggr=$(hgc.aggr)):") + for (i, (et,layer)) in enumerate(zip(hgc.etypes, hgc.layers)) + print(io, " $(et => layer)") + if i < length(hgc.etypes) + print(io, "\n") + end + end + end +end + +[.\src\layers\pool.jl] +@doc raw""" + GlobalPool(aggr) + +Global pooling layer for graph neural networks. +Takes a graph and feature nodes as inputs +and performs the operation + +```math +\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i +``` + +where ``V`` is the set of nodes of the input graph and +the type of aggregation represented by ``\square`` is selected by the `aggr` argument. +Commonly used aggregations are `mean`, `max`, and `+`. + +See also [`reduce_nodes`](@ref). + +# Examples + +```julia +using Flux, GraphNeuralNetworks, Graphs + +pool = GlobalPool(mean) + +g = GNNGraph(erdos_renyi(10, 4)) +X = rand(32, 10) +pool(g, X) # => 32x1 matrix + + +g = Flux.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5]) +X = rand(32, 50) +pool(g, X) # => 32x5 matrix +``` +""" +struct GlobalPool{F} <: GNNLayer + aggr::F +end + +(l::GlobalPool)(g::GNNGraph, x::AbstractArray) = GNNlib.global_pool(l, g, x) + +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) + +@doc raw""" + GlobalAttentionPool(fgate, ffeat=identity) + +Global soft attention layer from the [Gated Graph Sequence Neural +Networks](https://arxiv.org/abs/1511.05493) paper + +```math +\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) +``` + +where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref) +operation: + +```math +\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} + {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}. +``` + +# Arguments + +- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. + It is tipically expressed by a neural network. + +- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. + It is tipically expressed by a neural network. + +# Examples + +```julia +chin = 6 +chout = 5 + +fgate = Dense(chin, 1) +ffeat = Dense(chin, chout) +pool = GlobalAttentionPool(fgate, ffeat) + +g = Flux.batch([GNNGraph(random_regular_graph(10, 4), + ndata=rand(Float32, chin, 10)) + for i=1:3]) + +u = pool(g, g.ndata.x) + +@assert size(u) == (chout, g.num_graphs) +``` +""" +struct GlobalAttentionPool{G, F} + fgate::G + ffeat::F +end + +Flux.@layer GlobalAttentionPool + +GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) + +(l::GlobalAttentionPool)(g, x) = GNNlib.global_attention_pool(l, g, x) + +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) + +""" + TopKPool(adj, k, in_channel) + +Top-k pooling layer. + +# Arguments + +- `adj`: Adjacency matrix of a graph. +- `k`: Top-k nodes are selected to pool together. +- `in_channel`: The dimension of input channel. +""" +struct TopKPool{T, S} + A::AbstractMatrix{T} + k::Int + p::AbstractVector{S} + Ã::AbstractMatrix{T} +end + +function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform) + TopKPool(adj, k, init(in_channel), similar(adj, k, k)) +end + +(t::TopKPool)(x::AbstractArray) = topk_pool(t, x) + + +@doc raw""" + Set2Set(n_in, n_iters, n_layers = 1) + +Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391). + +For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times: +```math +\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*) +\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)} +\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i +\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}] +``` +where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, +input size `2*n_in` and output size `n_in`. + +Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`. +``` +""" +struct Set2Set{L} <: GNNLayer + lstm::L + num_iters::Int +end + +Flux.@layer Set2Set + +function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) + @assert n_layers >= 1 + n_out = 2 * n_in + + if n_layers == 1 + lstm = LSTM(n_out => n_in) + else + layers = [LSTM(n_out => n_in)] + for _ in 2:n_layers + push!(layers, LSTM(n_in => n_in)) + end + lstm = Chain(layers...) + end + + return Set2Set(lstm, n_iters) +end + +function (l::Set2Set)(g, x) + Flux.reset!(l.lstm) + return GNNlib.set2set_pool(l, g, x) +end + +(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) + +[.\src\layers\temporalconv.jl] +# Adapting Flux.Recur to work with GNNGraphs +function (m::Flux.Recur)(g::GNNGraph, x) + m.state, y = m.cell(m.state, g, x) + return y +end + +function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T + h = [m(g, x_t) for x_t in Flux.eachlastdim(x)] + sze = size(h[1]) + reshape(reduce(hcat, h), sze[1], sze[2], length(h)) +end + +struct TGCNCell <: GNNLayer + conv::GCNConv + gru::Flux.GRUv3Cell + state0 + in::Int + out::Int +end + +Flux.@layer TGCNCell + +function TGCNCell(ch::Pair{Int, Int}; + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32, + add_self_loops = false, + use_edge_weight = true) + in, out = ch + conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops, + use_edge_weight) + gru = Flux.GRUv3Cell(out, out) + state0 = init_state(out,1) + return TGCNCell(conv, gru, state0, in,out) +end + +function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) + x̃ = tgcn.conv(g, x) + h, x̃ = tgcn.gru(h, x̃) + return h, x̃ +end + +function Base.show(io::IO, tgcn::TGCNCell) + print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))") +end + +""" + TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight]) + +Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). + +Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). + If `add_self_loops=true` the new weights will be set to 1. + This option is ignored if the `edge_weight` is explicitly provided in the forward pass. + Default `false`. +# Examples + +```jldoctest +julia> tgcn = TGCN(2 => 6) +Recur( + TGCNCell( + GCNConv(2 => 6, σ), # 18 parameters + GRUv3Cell(6 => 6), # 240 parameters + Float32[0.0; 0.0; … ; 0.0; 0.0;;], # 6 parameters (all zero) + 2, + 6, + ), +) # Total: 8 trainable arrays, 264 parameters, + # plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB. + +julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> y = tgcn(g, x); + +julia> size(y) +(6, 5) + +julia> Flux.reset!(tgcn); + +julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20 +(6, 5, 20) +``` + +!!! warning "Batch size changes" + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. +""" +TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...)) + +Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0) + +# make TGCN compatible with GNNChain +(l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g) + + +""" + A3TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight]) + +Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph +Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). + +Performs a TGCN layer, followed by a soft attention layer. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). + If `add_self_loops=true` the new weights will be set to 1. + This option is ignored if the `edge_weight` is explicitly provided in the forward pass. + Default `false`. +# Examples + +```jldoctest +julia> a3tgcn = A3TGCN(2 => 6) +A3TGCN(2 => 6) + +julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> y = a3tgcn(g,x); + +julia> size(y) +(6, 5) + +julia> Flux.reset!(a3tgcn); + +julia> y = a3tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)); + +julia> size(y) +(6, 5) +``` + +!!! warning "Batch size changes" + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. +""" +struct A3TGCN <: GNNLayer + tgcn::Flux.Recur{TGCNCell} + dense1::Dense + dense2::Dense + in::Int + out::Int +end + +Flux.@layer A3TGCN + +function A3TGCN(ch::Pair{Int, Int}, + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32, + add_self_loops = false, + use_edge_weight = true) + in, out = ch + tgcn = TGCN(in => out; bias, init, init_state, add_self_loops, use_edge_weight) + dense1 = Dense(out, out) + dense2 = Dense(out, out) + return A3TGCN(tgcn, dense1, dense2, in, out) +end + +function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray) + h = a3tgcn.tgcn(g, x) + e = a3tgcn.dense1(h) + e = a3tgcn.dense2(e) + a = softmax(e, dims = 3) + c = sum(a .* h , dims = 3) + if length(size(c)) == 3 + c = dropdims(c, dims = 3) + end + return c +end + +function Base.show(io::IO, a3tgcn::A3TGCN) + print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") +end + +struct GConvGRUCell <: GNNLayer + conv_x_r::ChebConv + conv_h_r::ChebConv + conv_x_z::ChebConv + conv_h_z::ChebConv + conv_x_h::ChebConv + conv_h_h::ChebConv + k::Int + state0 + in::Int + out::Int +end + +Flux.@layer GConvGRUCell + +function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32) + in, out = ch + # reset gate + conv_x_r = ChebConv(in => out, k; bias, init) + conv_h_r = ChebConv(out => out, k; bias, init) + # update gate + conv_x_z = ChebConv(in => out, k; bias, init) + conv_h_z = ChebConv(out => out, k; bias, init) + # new gate + conv_x_h = ChebConv(in => out, k; bias, init) + conv_h_h = ChebConv(out => out, k; bias, init) + state0 = init_state(out, n) + return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out) +end + +function (ggru::GConvGRUCell)(h, g::GNNGraph, x) + r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) + r = Flux.sigmoid_fast(r) + z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) + z = Flux.sigmoid_fast(z) + h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) + h̃ = Flux.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end + +function Base.show(io::IO, ggru::GConvGRUCell) + print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") +end + +""" + GConvGRU(in => out, k, n; [bias, init, init_state]) + +Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). + +Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Chebyshev polynomial order. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); + +julia> y = ggru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = ggru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...)) +Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) + +(l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) + +struct GConvLSTMCell <: GNNLayer + conv_x_i::ChebConv + conv_h_i::ChebConv + w_i + b_i + conv_x_f::ChebConv + conv_h_f::ChebConv + w_f + b_f + conv_x_c::ChebConv + conv_h_c::ChebConv + w_c + b_c + conv_x_o::ChebConv + conv_h_o::ChebConv + w_o + b_o + k::Int + state0 + in::Int + out::Int +end + +Flux.@layer GConvLSTMCell + +function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32) + in, out = ch + # input gate + conv_x_i = ChebConv(in => out, k; bias, init) + conv_h_i = ChebConv(out => out, k; bias, init) + w_i = init(out, 1) + b_i = bias ? Flux.create_bias(w_i, true, out) : false + # forget gate + conv_x_f = ChebConv(in => out, k; bias, init) + conv_h_f = ChebConv(out => out, k; bias, init) + w_f = init(out, 1) + b_f = bias ? Flux.create_bias(w_f, true, out) : false + # cell state + conv_x_c = ChebConv(in => out, k; bias, init) + conv_h_c = ChebConv(out => out, k; bias, init) + w_c = init(out, 1) + b_c = bias ? Flux.create_bias(w_c, true, out) : false + # output gate + conv_x_o = ChebConv(in => out, k; bias, init) + conv_h_o = ChebConv(out => out, k; bias, init) + w_o = init(out, 1) + b_o = bias ? Flux.create_bias(w_o, true, out) : false + state0 = (init_state(out, n), init_state(out, n)) + return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, + conv_x_f, conv_h_f, w_f, b_f, + conv_x_c, conv_h_c, w_c, b_c, + conv_x_o, conv_h_o, w_o, b_o, + k, state0, in, out) +end + +function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) + # input gate + i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i + i = Flux.sigmoid_fast(i) + # forget gate + f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f + f = Flux.sigmoid_fast(f) + # cell state + c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) + # output gate + o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o + o = Flux.sigmoid_fast(o) + h = o .* Flux.tanh_fast(c) + return (h,c), h +end + +function Base.show(io::IO, gclstm::GConvLSTMCell) + print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") +end + +""" + GConvLSTM(in => out, k, n; [bias, init, init_state]) + +Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). + +Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Chebyshev polynomial order. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); + +julia> y = gclstm(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = gclstm(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) +Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) + +(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) + +struct DCGRUCell + in::Int + out::Int + state0 + k::Int + dconv_u::DConv + dconv_r::DConv + dconv_c::DConv +end + +Flux.@layer DCGRUCell + +function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + dconv_u = DConv((in + out) => out, k; bias=bias, init=init) + dconv_r = DConv((in + out) => out, k; bias=bias, init=init) + dconv_c = DConv((in + out) => out, k; bias=bias, init=init) + state0 = init_state(out, n) + return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) +end + +function (dcgru::DCGRUCell)(h, g::GNNGraph, x) + h̃ = vcat(x, h) + z = dcgru.dconv_u(g, h̃) + z = NNlib.sigmoid_fast.(z) + r = dcgru.dconv_r(g, h̃) + r = NNlib.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = dcgru.dconv_c(g, ĥ) + c = tanh.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + +function Base.show(io::IO, dcgru::DCGRUCell) + print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") +end + +""" + DCGRU(in => out, k, n; [bias, init, init_state]) + +Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural +Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). + +Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Diffusion step. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); + +julia> y = dcgru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = dcgru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) +Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) + +(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) + +function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::ChebConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::GATConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::GATv2Conv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::GatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::CGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::SGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::TransformerConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::GCNConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::ResGatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::SAGEConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +function (l::GraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) + return l.(tg.snapshots, x) +end + +[.\test\runtests.jl] +using CUDA +using GraphNeuralNetworks +using GNNGraphs: sort_edge_index +using GNNGraphs: getn, getdata +using Functors +using Flux +using Flux: gpu +using LinearAlgebra, Statistics, Random +using NNlib +import MLUtils +using SparseArrays +using Graphs +using Zygote +using Test +using MLDatasets +using InlineStrings # not used but with the import we test #98 and #104 + +CUDA.allowscalar(false) + +const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} + +ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets + +include("test_utils.jl") + +tests = [ + "layers/basic", + "layers/conv", + "layers/heteroconv", + "layers/temporalconv", + "layers/pool", + "examples/node_classification_cora", +] + +!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") + +# @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) +for graph_type in (:coo, :dense, :sparse) + + @info "Testing graph format :$graph_type" + global GRAPH_T = graph_type + global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) + + @testset "$t" for t in tests + startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI + include("$t.jl") + end +end + +[.\test\test_utils.jl] +using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA +CUDA.allowscalar(false) + +function ngradient(f, x...) + fdm = central_fdm(5, 1) + return FiniteDifferences.grad(fdm, f, x...) +end + +const rule_config = Zygote.ZygoteRuleConfig() + +# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed +function FiniteDifferences.to_vec(x::Integer) + Integer_from_vec(v) = x + return Int[x], Integer_from_vec +end + +# Test that forward pass on cpu and gpu are the same. +# Tests also gradient on cpu and gpu comparing with +# finite difference methods. +# Test gradients with respects to layer weights and to input. +# If `g` has edge features, it is assumed that the layer can +# use them in the forward pass as `l(g, x, e)`. +# Test also gradient with respect to `e`. +function test_layer(l, g::GNNGraph; atol = 1e-5, rtol = 1e-5, + exclude_grad_fields = [], + verbose = false, + test_gpu = TEST_GPU, + outsize = nothing, + outtype = :node) + + # TODO these give errors, probably some bugs in ChainRulesTestUtils + # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false) + # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false) + + isnothing(node_features(g)) && error("Plese add node data to the input graph") + fdm = central_fdm(5, 1) + + x = node_features(g) + e = edge_features(g) + use_edge_feat = !isnothing(e) + + x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad + xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g]) + + f(l, g::GNNGraph) = l(g) + f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x) + + loss(l, g::GNNGraph) = + if outtype == :node + sum(node_features(f(l, g))) + elseif outtype == :edge + sum(edge_features(f(l, g))) + elseif outtype == :graph + sum(graph_features(f(l, g))) + elseif outtype == :node_edge + gnew = f(l, g) + sum(node_features(gnew)) + sum(edge_features(gnew)) + end + + function loss(l, g::GNNGraph, x, e) + y = f(l, g, x, e) + if outtype == :node_edge + return sum(y[1]) + sum(y[2]) + else + return sum(y) + end + end + + # TEST OUTPUT + y = f(l, g, x, e) + if outtype == :node_edge + @assert y isa Tuple + @test eltype(y[1]) == eltype(x) + @test eltype(y[2]) == eltype(e) + @test all(isfinite, y[1]) + @test all(isfinite, y[2]) + if !isnothing(outsize) + @test size(y[1]) == outsize[1] + @test size(y[2]) == outsize[2] + end + else + @test eltype(y) == eltype(x) + @test all(isfinite, y) + if !isnothing(outsize) + @test size(y) == outsize + end + end + + # test same output on different graph formats + gcoo = GNNGraph(g, graph_type = :coo) + ycoo = f(l, gcoo, x, e) + if outtype == :node_edge + @test ycoo[1] ≈ y[1] + @test ycoo[2] ≈ y[2] + else + @test ycoo ≈ y + end + + g′ = f(l, g) + if outtype == :node + @test g′.ndata.x ≈ y + elseif outtype == :edge + @test g′.edata.e ≈ y + elseif outtype == :graph + @test g′.gdata.u ≈ y + elseif outtype == :node_edge + @test g′.ndata.x ≈ y[1] + @test g′.edata.e ≈ y[2] + else + @error "wrong outtype $outtype" + end + if test_gpu + ygpu = f(lgpu, ggpu, xgpu, egpu) + if outtype == :node_edge + @test ygpu[1] isa CuArray + @test eltype(ygpu[1]) == eltype(xgpu) + @test Array(ygpu[1]) ≈ y[1] + @test ygpu[2] isa CuArray + @test eltype(ygpu[2]) == eltype(xgpu) + @test Array(ygpu[2]) ≈ y[2] + else + @test ygpu isa CuArray + @test eltype(ygpu) == eltype(xgpu) + @test Array(ygpu) ≈ y + end + end + + # TEST x INPUT GRADIENT + x̄ = gradient(x -> loss(l, g, x, e), x)[1] + x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1] + @test eltype(x̄) == eltype(x) + @test x̄≈x̄_fd atol=atol rtol=rtol + + if test_gpu + x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1] + @test x̄gpu isa CuArray + @test eltype(x̄gpu) == eltype(x) + @test Array(x̄gpu)≈x̄ atol=atol rtol=rtol + end + + # TEST e INPUT GRADIENT + if e !== nothing + verbose && println("Test e gradient cpu") + ē = gradient(e -> loss(l, g, x, e), e)[1] + ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1] + @test eltype(ē) == eltype(e) + @test ē≈ē_fd atol=atol rtol=rtol + + if test_gpu + verbose && println("Test e gradient gpu") + ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1] + @test ēgpu isa CuArray + @test eltype(ēgpu) == eltype(ē) + @test Array(ēgpu)≈ē atol=atol rtol=rtol + end + end + + # TEST LAYER GRADIENT - l(g, x, e) + l̄ = gradient(l -> loss(l, g, x, e), l)[1] + l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1] + test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) + + if test_gpu + l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1] + test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose) + end + + # TEST LAYER GRADIENT - l(g) + l̄ = gradient(l -> loss(l, g), l)[1] + test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) + + return true +end + +function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5, + exclude_grad_fields = [], + verbose = false) + l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue + l̄fd = l̄fd isa Base.RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue + + for f in fieldnames(typeof(l)) + f ∈ exclude_grad_fields && continue + verbose && println("Test gradient of field $f...") + x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f) + test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose) + verbose && println("... field $f done!") + end + return true +end + +function test_approx_structs(x, g::Nothing, gfd; atol, rtol, kws...) + # finite diff gradients has to be zero if present + @test !(gfd isa AbstractArray) || isapprox(gfd, fill!(similar(gfd), 0); atol, rtol) +end + +function test_approx_structs(x::Union{AbstractArray, Number}, + g::Union{AbstractArray, Number}, gfd; atol, rtol, kws...) + @test eltype(g) == eltype(x) + if x isa CuArray + @test g isa CuArray + g = Array(g) + end + @test g≈gfd atol=atol rtol=rtol +end + +""" + to32(m) + +Convert the `eltype` of model's float parameters to `Float32`. +Preserves integer arrays. +""" +to32(m) = _paramtype(Float32, m) + +""" + to64(m) + +Convert the `eltype` of model's float parameters to `Float64`. +Preserves integer arrays. +""" +to64(m) = _paramtype(Float64, m) + +struct GNNEltypeAdaptor{T} end + +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where T = convert(AbstractArray{T}, x) +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Integer}) where T = x +Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x) + +_paramtype(::Type{T}, m) where T = fmap(adapt(GNNEltypeAdaptor{T}()), m) + +[.\test\examples\node_classification_cora.jl] +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitcrossentropy +using GraphNeuralNetworks +using MLDatasets: Cora +using Statistics, Random +using CUDA +CUDA.allowscalar(false) + +function eval_loss_accuracy(X, y, ids, model, g) + ŷ = model(g, X) + l = logitcrossentropy(ŷ[:, ids], y[:, ids]) + acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) + return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) +end + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 5.0f-3 # learning rate + epochs = 10 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = false # if true use cuda (if available) + nhidden = 64 # dimension of hidden features +end + +function train(Layer; verbose = false, kws...) + args = Args(; kws...) + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = Flux.gpu + args.seed > 0 && CUDA.seed!(args.seed) + else + device = Flux.cpu + end + + # LOAD DATA + dataset = Cora() + classes = dataset.metadata["classes"] + g = mldataset2gnngraph(dataset) |> device + X = g.ndata.features + y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged + train_mask = g.ndata.train_mask + test_mask = g.ndata.test_mask + ytrain = y[:, train_mask] + + nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) + + ## DEFINE MODEL + model = GNNChain(Layer(nin, nhidden), + # Dropout(0.5), + Layer(nhidden, nhidden), + Dense(nhidden, nout)) |> device + + opt = Flux.setup(Adam(args.η), model) + + ## TRAINING + function report(epoch) + train = eval_loss_accuracy(X, y, train_mask, model, g) + test = eval_loss_accuracy(X, y, test_mask, model, g) + println("Epoch: $epoch Train: $(train) Test: $(test)") + end + + verbose && report(0) + @time for epoch in 1:(args.epochs) + grad = Flux.gradient(model) do model + ŷ = model(g, X) + logitcrossentropy(ŷ[:, train_mask], ytrain) + end + Flux.update!(opt, model, grad[1]) + verbose && report(epoch) + end + + train_res = eval_loss_accuracy(X, y, train_mask, model, g) + test_res = eval_loss_accuracy(X, y, test_mask, model, g) + return train_res, test_res +end + +function train_many(; usecuda = false) + for (layer, Layer) in [ + ("GCNConv", (nin, nout) -> GCNConv(nin => nout, relu)), + ("ResGatedGraphConv", (nin, nout) -> ResGatedGraphConv(nin => nout, relu)), + ("GraphConv", (nin, nout) -> GraphConv(nin => nout, relu, aggr = mean)), + ("SAGEConv", (nin, nout) -> SAGEConv(nin => nout, relu)), + ("GATConv", (nin, nout) -> GATConv(nin => nout, relu)), + ("GINConv", (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr = mean)), + ("TransformerConv", + (nin, nout) -> TransformerConv(nin => nout, concat = false, + add_self_loops = true, root_weight = false, + heads = 2)), + ## ("ChebConv", (nin, nout) -> ChebConv(nin => nout, 2)), # not working on gpu + ## ("NNConv", (nin, nout) -> NNConv(nin => nout)), # needs edge features + ## ("GatedGraphConv", (nin, nout) -> GatedGraphConv(nout, 2)), # needs nin = nout + ## ("EdgeConv",(nin, nout) -> EdgeConv(Dense(2nin, nout, relu))), # Fits the training set but does not generalize well + ] + @show layer + @time train_res, test_res = train(Layer; usecuda, verbose = false) + # @show train_res, test_res + @test train_res.acc > 94 + @test test_res.acc > 69 + end +end + +train_many(usecuda = false) +if TEST_GPU + train_many(usecuda = true) +end + +[.\test\layers\basic.jl] +@testset "GNNChain" begin + n, din, d, dout = 10, 3, 4, 2 + deg = 4 + + g = GNNGraph(random_regular_graph(n, deg), + graph_type = GRAPH_T, + ndata = randn(Float32, din, n)) + x = g.ndata.x + + gnn = GNNChain(GCNConv(din => d), + LayerNorm(d), + x -> tanh.(x), + GraphConv(d => d, tanh), + Dropout(0.5), + Dense(d, dout)) + + testmode!(gnn) + + test_layer(gnn, g, rtol = 1e-5, exclude_grad_fields = [:μ, :σ²]) + + @testset "constructor with names" begin + m = GNNChain(GCNConv(din => d), + LayerNorm(d), + x -> tanh.(x), + Dense(d, dout)) + + m2 = GNNChain(enc = m, + dec = DotDecoder()) + + @test m2[:enc] === m + @test m2(g, x) == m2[:dec](g, m2[:enc](g, x)) + end + + @testset "constructor with vector" begin + m = GNNChain(GCNConv(din => d), + LayerNorm(d), + x -> tanh.(x), + Dense(d, dout)) + m2 = GNNChain([m.layers...]) + @test m2(g, x) == m(g, x) + end + + @testset "Parallel" begin + AddResidual(l) = Parallel(+, identity, l) + + gnn = GNNChain(GraphConv(din => d, tanh), + LayerNorm(d), + AddResidual(GraphConv(d => d, tanh)), + BatchNorm(d), + Dense(d, dout)) + + trainmode!(gnn) + + test_layer(gnn, g, rtol = 1e-4, atol=1e-4, exclude_grad_fields = [:μ, :σ²]) + end + + @testset "Only graph input" begin + nin, nout = 2, 4 + ndata = rand(Float32, nin, 3) + edata = rand(Float32, nin, 3) + g = GNNGraph([1, 1, 2], [2, 3, 3], ndata = ndata, edata = edata) + m = NNConv(nin => nout, Dense(2, nin * nout, tanh)) + chain = GNNChain(m) + y = m(g, g.ndata.x, g.edata.e) + @test m(g).ndata.x == y + @test chain(g).ndata.x == y + end +end + +@testset "WithGraph" begin + x = rand(Float32, 2, 3) + g = GNNGraph([1, 2, 3], [2, 3, 1], ndata = x) + model = SAGEConv(2 => 3) + wg = WithGraph(model, g) + # No need to feed the graph to `wg` + @test wg(x) == model(g, x) + @test Flux.params(wg) == Flux.params(model) + g2 = GNNGraph([1, 1, 2, 3], [2, 4, 1, 1]) + x2 = rand(Float32, 2, 4) + # WithGraph will ignore the internal graph if fed with a new one. + @test wg(g2, x2) == model(g2, x2) + + wg = WithGraph(model, g, traingraph = false) + @test length(Flux.params(wg)) == length(Flux.params(model)) + + wg = WithGraph(model, g, traingraph = true) + @test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g)) +end + +@testset "Flux restructure" begin + chain = GNNChain(GraphConv(2 => 2)) + params, restructure = Flux.destructure(chain) + @test restructure(params) isa GNNChain +end + +[.\test\layers\conv.jl] +RTOL_LOW = 1e-2 +RTOL_HIGH = 1e-5 +ATOL_LOW = 1e-3 + +in_channel = 3 +out_channel = 5 +N = 4 +T = Float32 + +adj1 = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + +g1 = GNNGraph(adj1, + ndata = rand(T, in_channel, N), + graph_type = GRAPH_T) + +adj_single_vertex = [0 0 0 1 + 0 0 0 0 + 0 0 0 1 + 1 0 1 0] + +g_single_vertex = GNNGraph(adj_single_vertex, + ndata = rand(T, in_channel, N), + graph_type = GRAPH_T) + +test_graphs = [g1, g_single_vertex] + +@testset "GCNConv" begin + l = GCNConv(in_channel => out_channel) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = GCNConv(in_channel => out_channel, tanh, bias = false) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = GCNConv(in_channel => out_channel, add_self_loops = false) + test_layer(l, g1, rtol = RTOL_HIGH, outsize = (out_channel, g1.num_nodes)) + + @testset "edge weights & custom normalization" begin + s = [2, 3, 1, 3, 1, 2] + t = [1, 1, 2, 2, 3, 3] + w = T[1, 2, 3, 4, 5, 6] + g = GNNGraph((s, t, w), ndata = ones(T, 1, 3), graph_type = GRAPH_T) + x = g.ndata.x + custom_norm_fn(d) = 1 ./ sqrt.(d) + l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true) + l.weight .= 1 + d = degree(g, dir = :in, edge_weight = true) + y = l(g, x) + @test y[1, 1] ≈ w[1] / √(d[1] * d[2]) + w[2] / √(d[1] * d[3]) + @test y[1, 2] ≈ w[3] / √(d[2] * d[1]) + w[4] / √(d[2] * d[3]) + @test y ≈ l(g, x, w; norm_fn = custom_norm_fn) # checking without custom + + # test gradient with respect to edge weights + w = rand(T, 6) + x = rand(T, 1, 3) + g = GNNGraph((s, t, w), ndata = x, graph_type = GRAPH_T, edata = w) + l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true) + @test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{T} # redundant test but more explicit + test_layer(l, g, rtol = RTOL_HIGH, outsize = (1, g.num_nodes), test_gpu = false) + end + + @testset "conv_weight" begin + l = GraphNeuralNetworks.GCNConv(in_channel => out_channel) + w = zeros(T, out_channel, in_channel) + g1 = GNNGraph(adj1, ndata = ones(T, in_channel, N)) + @test l(g1, g1.ndata.x, conv_weight = w) == zeros(T, out_channel, N) + a = rand(T, in_channel, N) + g2 = GNNGraph(adj1, ndata = a) + @test l(g2, g2.ndata.x, conv_weight = w) == w * a + end +end + +@testset "ChebConv" begin + k = 2 + l = ChebConv(in_channel => out_channel, k) + @test size(l.weight) == (out_channel, in_channel, k) + @test size(l.bias) == (out_channel,) + @test l.k == k + for g in test_graphs + g = add_self_loops(g) + test_layer(l, g, rtol = RTOL_HIGH, test_gpu = TEST_GPU, + outsize = (out_channel, g.num_nodes)) + end + + @testset "bias=false" begin + @test length(Flux.params(ChebConv(2 => 3, 3))) == 2 + @test length(Flux.params(ChebConv(2 => 3, 3, bias = false))) == 1 + end +end + +@testset "GraphConv" begin + l = GraphConv(in_channel => out_channel) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + @testset "bias=false" begin + @test length(Flux.params(GraphConv(2 => 3))) == 3 + @test length(Flux.params(GraphConv(2 => 3, bias = false))) == 2 + end +end + +@testset "GATConv" begin + for heads in (1, 2), concat in (true, false) + l = GATConv(in_channel => out_channel; heads, concat, dropout=0) + for g in test_graphs + test_layer(l, g, rtol = RTOL_LOW, + exclude_grad_fields = [:negative_slope, :dropout], + outsize = (concat ? heads * out_channel : out_channel, + g.num_nodes)) + end + end + + @testset "edge features" begin + ein = 3 + l = GATConv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0) + g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) + test_layer(l, g, rtol = RTOL_LOW, + exclude_grad_fields = [:negative_slope, :dropout], + outsize = (out_channel, g.num_nodes)) + end + + @testset "num params" begin + l = GATConv(2 => 3, add_self_loops = false) + @test length(Flux.params(l)) == 3 + l = GATConv((2, 4) => 3, add_self_loops = false) + @test length(Flux.params(l)) == 4 + l = GATConv((2, 4) => 3, add_self_loops = false, bias = false) + @test length(Flux.params(l)) == 3 + end +end + +@testset "GATv2Conv" begin + for heads in (1, 2), concat in (true, false) + l = GATv2Conv(in_channel => out_channel, tanh; heads, concat, dropout=0) + for g in test_graphs + test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW, + exclude_grad_fields = [:negative_slope, :dropout], + outsize = (concat ? heads * out_channel : out_channel, + g.num_nodes)) + end + end + + @testset "edge features" begin + ein = 3 + l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0) + g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) + test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW, + exclude_grad_fields = [:negative_slope, :dropout], + outsize = (out_channel, g.num_nodes)) + end + + @testset "num params" begin + l = GATv2Conv(2 => 3, add_self_loops = false) + @test length(Flux.params(l)) == 5 + l = GATv2Conv((2, 4) => 3, add_self_loops = false) + @test length(Flux.params(l)) == 6 + l = GATv2Conv((2, 4) => 3, add_self_loops = false, bias = false) + @test length(Flux.params(l)) == 4 + end +end + +@testset "GatedGraphConv" begin + num_layers = 3 + l = GatedGraphConv(out_channel, num_layers) + @test size(l.weight) == (out_channel, out_channel, num_layers) + + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "EdgeConv" begin + l = EdgeConv(Dense(2 * in_channel, out_channel), aggr = +) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "GINConv" begin + nn = Dense(in_channel, out_channel) + + l = GINConv(nn, 0.01f0, aggr = mean) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + @test !in(:eps, Flux.trainable(l)) +end + +@testset "NNConv" begin + edim = 10 + nn = Dense(edim, out_channel * in_channel) + + l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) + for g in test_graphs + g = GNNGraph(g, edata = rand(T, edim, g.num_edges)) + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "SAGEConv" begin + l = SAGEConv(in_channel => out_channel) + @test l.aggr == mean + + l = SAGEConv(in_channel => out_channel, tanh, bias = false, aggr = +) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "ResGatedGraphConv" begin + l = ResGatedGraphConv(in_channel => out_channel, tanh, bias = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "CGConv" begin + edim = 10 + l = CGConv((in_channel, edim) => out_channel, tanh, residual = false, bias = true) + for g in test_graphs + g = GNNGraph(g, edata = rand(T, edim, g.num_edges)) + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + # no edge features + l1 = CGConv(in_channel => out_channel, tanh, residual = false, bias = true) + @test l1(g1, g1.ndata.x) == l1(g1).ndata.x + @test l1(g1, g1.ndata.x, nothing) == l1(g1).ndata.x +end + +@testset "AGNNConv" begin + l = AGNNConv(trainable=false, add_self_loops=false) + @test l.β == [1.0f0] + @test l.add_self_loops == false + @test l.trainable == false + Flux.trainable(l) == (;) + + l = AGNNConv(init_beta=2.0f0) + @test l.β == [2.0f0] + @test l.add_self_loops == true + @test l.trainable == true + Flux.trainable(l) == (; β = [1f0]) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (in_channel, g.num_nodes)) + end +end + +@testset "MEGNetConv" begin + l = MEGNetConv(in_channel => out_channel, aggr = +) + for g in test_graphs + g = GNNGraph(g, edata = rand(T, in_channel, g.num_edges)) + test_layer(l, g, rtol = RTOL_LOW, + outtype = :node_edge, + outsize = ((out_channel, g.num_nodes), (out_channel, g.num_edges))) + end +end + +@testset "GMMConv" begin + ein_channel = 10 + K = 5 + l = GMMConv((in_channel, ein_channel) => out_channel, K = K) + for g in test_graphs + g = GNNGraph(g, edata = rand(Float32, ein_channel, g.num_edges)) + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end +end + +@testset "SGConv" begin + K = [1, 2, 3] # for different number of hops + for k in K + l = SGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = SGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + end +end + +@testset "TAGConv" begin + K = [1, 2, 3] + for k in K + l = TAGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = TAGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + end +end + +@testset "EGNNConv" begin + hin = 5 + hout = 5 + hidden = 5 + l = EGNNConv(hin => hout, hidden) + g = rand_graph(10, 20, graph_type = GRAPH_T) + x = rand(T, in_channel, g.num_nodes) + h = randn(T, hin, g.num_nodes) + hnew, xnew = l(g, h, x) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_channel, g.num_nodes) +end + +@testset "TransformerConv" begin + ein = 2 + heads = 3 + # used like in Kool et al., 2019 + l = TransformerConv(in_channel * heads => in_channel; heads, add_self_loops = true, + root_weight = false, ff_channels = 10, skip_connection = true, + batch_norm = false) + # batch_norm=false here for tests to pass; true in paper + for adj in [adj1, adj_single_vertex] + g = GNNGraph(adj, ndata = rand(T, in_channel * heads, size(adj, 1)), + graph_type = GRAPH_T) + test_layer(l, g, rtol = RTOL_LOW, + exclude_grad_fields = [:negative_slope], + outsize = (in_channel * heads, g.num_nodes)) + end + # used like in Shi et al., 2021 + l = TransformerConv((in_channel, ein) => in_channel; heads, gating = true, + bias_qkv = true) + for g in test_graphs + g = GNNGraph(g, edata = rand(T, ein, g.num_edges)) + test_layer(l, g, rtol = RTOL_LOW, + exclude_grad_fields = [:negative_slope], + outsize = (in_channel * heads, g.num_nodes)) + end + # test averaging heads + l = TransformerConv(in_channel => in_channel; heads, concat = false, + bias_root = false, + root_weight = false) + for g in test_graphs + test_layer(l, g, rtol = RTOL_LOW, + exclude_grad_fields = [:negative_slope], + outsize = (in_channel, g.num_nodes)) + end +end + +@testset "DConv" begin + K = [1, 2, 3] # for different number of hops + for k in K + l = DConv(in_channel => out_channel, k) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + end +end +[.\test\layers\heteroconv.jl] +@testset "HeteroGraphConv" begin + d, n = 3, 5 + g = rand_bipartite_heterograph((n, 2*n), 15) + hg = rand_bipartite_heterograph((2,3), 6) + + model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), + (:B,:to,:A) => GraphConv(d => d)]) + + for x in [ + (A = rand(Float32, d, n), B = rand(Float32, d, 2n)), + Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n)) + ] + # x = (A = rand(Float32, d, n), B = rand(Float32, d, 2n)) + x = Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n)) + + y = model(g, x) + + grad, dx = gradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x) + ngrad, ndx = ngradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x) + + @test grad.layers[1].weight1 ≈ ngrad.layers[1].weight1 rtol=1e-4 + @test grad.layers[1].weight2 ≈ ngrad.layers[1].weight2 rtol=1e-4 + @test grad.layers[1].bias ≈ ngrad.layers[1].bias rtol=1e-4 + @test grad.layers[2].weight1 ≈ ngrad.layers[2].weight1 rtol=1e-4 + @test grad.layers[2].weight2 ≈ ngrad.layers[2].weight2 rtol=1e-4 + @test grad.layers[2].bias ≈ ngrad.layers[2].bias rtol=1e-4 + + @test dx[:A] ≈ ndx[:A] rtol=1e-4 + @test dx[:B] ≈ ndx[:B] rtol=1e-4 + end + + @testset "Constructor from pairs" begin + layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh), + (:B, :to, :A) => GraphConv(64 => 32, tanh)); + @test length(layer.etypes) == 2 + end + + @testset "Destination node aggregation" begin + # deterministic setup to validate the aggregation + d, n = 3, 5 + g = GNNHeteroGraph(((:A, :to, :B) => ([1, 1, 2, 3], [1, 2, 2, 3]), + (:B, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3]), + (:C, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3])); num_nodes = Dict(:A => n, :B => n, :C => n)) + model = HeteroGraphConv([ + (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), + (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), + (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = +) + x = (A = rand(Float32, d, n), B = rand(Float32, d, n), C = rand(Float32, d, n)) + y = model(g, x) + weights = ones(Float32, d, d) + + ### Test default summation aggregation + # B2 has 2 edges from A and itself (sense check) + expected = sum(weights * x.A[:, [1, 2]]; dims = 2) .+ weights * x.B[:, [2]] + output = y.B[:, [2]] + @test expected ≈ output + + # B5 has only itself + @test weights * x.B[:, [5]] ≈ y.B[:, [5]] + + # A1 has 1 edge from B, 1 from C and twice itself + expected = sum(weights * x.B[:, [1]] + weights * x.C[:, [1]]; dims = 2) .+ + 2 * weights * x.A[:, [1]] + output = y.A[:, [1]] + @test expected ≈ output + + # A2 has 2 edges from B, 2 from C and twice itself + expected = sum(weights * x.B[:, [1, 2]] + weights * x.C[:, [1, 2]]; dims = 2) .+ + 2 * weights * x.A[:, [2]] + output = y.A[:, [2]] + @test expected ≈ output + + # A5 has only itself but twice + @test 2 * weights * x.A[:, [5]] ≈ y.A[:, [5]] + + #### Test different aggregation function + model2 = HeteroGraphConv([ + (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), + (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), + (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = -) + y2 = model2(g, x) + # B no change + @test y.B ≈ y2.B + + # A1 has 1 edge from B, 1 from C, itself cancels out + expected = sum(weights * x.B[:, [1]] - weights * x.C[:, [1]]; dims = 2) + output = y2.A[:, [1]] + @test expected ≈ output + + # A2 has 2 edges from B, 2 from C, itself cancels out + expected = sum(weights * x.B[:, [1, 2]] - weights * x.C[:, [1, 2]]; dims = 2) + output = y2.A[:, [2]] + @test expected ≈ output + end + + @testset "CGConv" begin + x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh), + (:B, :to, :A) => CGConv(4 => 2, tanh)); + y = layers(hg, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) + end + + @testset "EdgeConv" begin + x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), + (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); + y = layers(hg, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) + end + + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + + @testset "GATConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GATConv(4 => 2), + (:B, :to, :A) => GATConv(4 => 2)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + + @testset "GINConv" begin + x = (A = rand(4, 2), B = rand(4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4), + (:B, :to, :A) => GINConv(Dense(4, 2), 0.4)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + + @testset "ResGatedGraphConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2), + (:B, :to, :A) => ResGatedGraphConv(4 => 2)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + + @testset "GATv2Conv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GATv2Conv(4 => 2), + (:B, :to, :A) => GATv2Conv(4 => 2)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + + @testset "GCNConv" begin + g = rand_bipartite_heterograph((2,3), 6) + x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), + (:B, :to, :A) => GCNConv(4 => 2, tanh)); + y = layers(g, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) + end +end + +[.\test\layers\pool.jl] +@testset "GlobalPool" begin + p = GlobalPool(+) + n = 10 + chin = 6 + X = rand(Float32, 6, n) + g = GNNGraph(random_regular_graph(n, 4), ndata = X, graph_type = GRAPH_T) + u = p(g, X) + @test u ≈ sum(X, dims = 2) + + ng = 3 + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + ndata = rand(Float32, chin, n), + graph_type = GRAPH_T) + for i in 1:ng]) + u = p(g, g.ndata.x) + @test size(u) == (chin, ng) + @test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2) + @test p(g).gdata.u == u + + test_layer(p, g, rtol = 1e-5, exclude_grad_fields = [:aggr], outtype = :graph) +end + +@testset "GlobalAttentionPool" begin + n = 10 + chin = 6 + chout = 5 + ng = 3 + + fgate = Dense(chin, 1) + ffeat = Dense(chin, chout) + p = GlobalAttentionPool(fgate, ffeat) + @test length(Flux.params(p)) == 4 + + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + ndata = rand(Float32, chin, n), + graph_type = GRAPH_T) + for i in 1:ng]) + + test_layer(p, g, rtol = 1e-5, outtype = :graph, outsize = (chout, ng)) +end + +@testset "TopKPool" begin + N = 10 + k, in_channel = 4, 7 + X = rand(in_channel, N) + for T in [Bool, Float64] + adj = rand(T, N, N) + p = TopKPool(adj, k, in_channel) + @test eltype(p.p) === Float32 + @test size(p.p) == (in_channel,) + @test eltype(p.Ã) === T + @test size(p.Ã) == (k, k) + y = p(X) + @test size(y) == (in_channel, k) + end +end + +@testset "topk_index" begin + X = [8, 7, 6, 5, 4, 3, 2, 1] + @test topk_index(X, 4) == [1, 2, 3, 4] + @test topk_index(X', 4) == [1, 2, 3, 4] +end + +@testset "Set2Set" begin + n_in = 3 + n_iters = 2 + n_layers = 1 + g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5]) + g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes)) + l = Set2Set(n_in, n_iters, n_layers) + y = l(g, node_features(g)) + @test size(y) == (2 * n_in, g.num_graphs) + + ## TODO the numerical gradient seems to be 3 times smaller than zygote one + # test_layer(l, g, rtol = 1e-4, atol=1e-4, outtype = :graph, outsize = (2 * n_in, g.num_graphs), + # verbose=true, exclude_grad_fields = [:state0, :state]) +end +[.\test\layers\temporalconv.jl] +in_channel = 3 +out_channel = 5 +N = 4 +S = 5 +T = Float32 + +g1 = GNNGraph(rand_graph(N,8), + ndata = rand(T, in_channel, N), + graph_type = :sparse) + +tg = TemporalSnapshotsGNNGraph([g1 for _ in 1:S]) + +@testset "TGCNCell" begin + tgcn = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) + h, x̃ = tgcn(tgcn.state0, g1, g1.ndata.x) + @test size(h) == (out_channel, N) + @test size(x̃) == (out_channel, N) + @test h == x̃ +end + +@testset "TGCN" begin + tgcn = TGCN(in_channel => out_channel) + @test size(Flux.gradient(x -> sum(tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + +@testset "A3TGCN" begin + a3tgcn = A3TGCN(in_channel => out_channel) + @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + +@testset "GConvLSTMCell" begin + gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) + (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) + @test size(h) == (out_channel, N) + @test size(c) == (out_channel, N) +end + +@testset "GConvLSTM" begin + gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) +end + +@testset "GConvGRUCell" begin + gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) + h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) + @test size(h) == (out_channel, N) +end + +@testset "GConvGRU" begin + gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + +@testset "DCGRU" begin + dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + +@testset "GINConv" begin + ginconv = GINConv(Dense(in_channel => out_channel),0.3) + @test length(ginconv(tg, tg.ndata.x)) == S + @test size(ginconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "ChebConv" begin + chebconv = ChebConv(in_channel => out_channel, 5) + @test length(chebconv(tg, tg.ndata.x)) == S + @test size(chebconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(chebconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "GATConv" begin + gatconv = GATConv(in_channel => out_channel) + @test length(gatconv(tg, tg.ndata.x)) == S + @test size(gatconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(gatconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "GATv2Conv" begin + gatv2conv = GATv2Conv(in_channel => out_channel) + @test length(gatv2conv(tg, tg.ndata.x)) == S + @test size(gatv2conv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(gatv2conv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "GatedGraphConv" begin + gatedgraphconv = GatedGraphConv(5, 5) + @test length(gatedgraphconv(tg, tg.ndata.x)) == S + @test size(gatedgraphconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(gatedgraphconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "CGConv" begin + cgconv = CGConv(in_channel => out_channel) + @test length(cgconv(tg, tg.ndata.x)) == S + @test size(cgconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(cgconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "SGConv" begin + sgconv = SGConv(in_channel => out_channel) + @test length(sgconv(tg, tg.ndata.x)) == S + @test size(sgconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(sgconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "TransformerConv" begin + transformerconv = TransformerConv(in_channel => out_channel) + @test length(transformerconv(tg, tg.ndata.x)) == S + @test size(transformerconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(transformerconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "GCNConv" begin + gcnconv = GCNConv(in_channel => out_channel) + @test length(gcnconv(tg, tg.ndata.x)) == S + @test size(gcnconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(gcnconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "ResGatedGraphConv" begin + resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh) + @test length(resgatedconv(tg, tg.ndata.x)) == S + @test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "SAGEConv" begin + sageconv = SAGEConv(in_channel => out_channel) + @test length(sageconv(tg, tg.ndata.x)) == S + @test size(sageconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(sageconv(tg, x))), tg.ndata.x)[1]) == S +end + +@testset "GraphConv" begin + graphconv = GraphConv(in_channel => out_channel, tanh) + @test length(graphconv(tg, tg.ndata.x)) == S + @test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N) + @test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S +end + + diff --git a/sccript.py b/sccript.py new file mode 100644 index 000000000..9250023d2 --- /dev/null +++ b/sccript.py @@ -0,0 +1,16 @@ +import os + +def main(): + with open("data.txt", "w", encoding="utf-8") as outfile: + for root, _, files in os.walk("."): + for filename in files: + if filename.endswith(".jl"): + filepath = os.path.join(root, filename) + print(f"Processing: {filepath}") # Add print statement here + outfile.write(f"[{filepath}]\n") + with open(filepath, "r", encoding="utf-8") as infile: + outfile.write(infile.read()) + outfile.write("\n") + +if __name__ == "__main__": + main() \ No newline at end of file From 90fc120b241d8766c509050cf376a00bc6dce65b Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 13:41:00 +0530 Subject: [PATCH 12/41] fixing --- GNNLux/test/layers/conv_tests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9f3722b05..3af1c0430 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -97,6 +97,7 @@ @testset "NNConv" begin edim = 10 nn = Dense(edim, out_dims * in_dims) + g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) end From 0dae0bc83e2e65a2ab5d8a0194e1682ced65903b Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:25:18 +0530 Subject: [PATCH 13/41] Delete data.txt --- data.txt | 25145 ----------------------------------------------------- 1 file changed, 25145 deletions(-) delete mode 100644 data.txt diff --git a/data.txt b/data.txt deleted file mode 100644 index 5174930c5..000000000 --- a/data.txt +++ /dev/null @@ -1,25145 +0,0 @@ -[.\docs\make.jl] -using GraphNeuralNetworks -using GNNGraphs -using Flux -using NNlib -using Graphs -using SparseArrays -using Pluto, PlutoStaticHTML # for tutorials -using Documenter, DemoCards -using DocumenterInterLinks - - -tutorials, tutorials_cb, tutorial_assets = makedemos("tutorials") -assets = [] -isnothing(tutorial_assets) || push!(assets, tutorial_assets) - -interlinks = InterLinks( - "NNlib" => "https://fluxml.ai/NNlib.jl/stable/", - "Graphs" => "https://juliagraphs.org/Graphs.jl/stable/") - - -DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, - :(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux); - recursive = true) - -prettyurls = get(ENV, "CI", nothing) == "true" -mathengine = MathJax3() - -makedocs(; - modules = [GraphNeuralNetworks, GNNGraphs, GNNlib], - doctest = false, - clean = true, - plugins = [interlinks], - format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing), - sitename = "GraphNeuralNetworks.jl", - pages = ["Home" => "index.md", - "Graphs" => ["gnngraph.md", "heterograph.md", "temporalgraph.md"], - "Message Passing" => "messagepassing.md", - "Model Building" => "models.md", - "Datasets" => "datasets.md", - "Tutorials" => tutorials, - "API Reference" => [ - "GNNGraph" => "api/gnngraph.md", - "Basic Layers" => "api/basic.md", - "Convolutional Layers" => "api/conv.md", - "Pooling Layers" => "api/pool.md", - "Message Passing" => "api/messagepassing.md", - "Heterogeneous Graphs" => "api/heterograph.md", - "Temporal Graphs" => "api/temporalgraph.md", - "Utils" => "api/utils.md", - ], - "Developer Notes" => "dev.md", - "Summer Of Code" => "gsoc.md", - ]) - -tutorials_cb() - -deploydocs(repo = "github.com/CarloLucibello/GraphNeuralNetworks.jl.git") - -[.\docs\tutorials\introductory_tutorials\gnn_intro_pluto.jl] -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Carlo Lucibello](https://github.com/CarloLucibello)" -#> title = "Hands-on introduction to Graph Neural Networks" -#> date = "2022-05-22" -#> description = "A beginner level introduction to graph machine learning using GraphNeuralNetworks.jl" -#> cover = "assets/intro_1.png" - -using Markdown -using InteractiveUtils - -# ╔═╡ 42c84361-222a-46c4-b81f-d33eb41635c9 -begin - using Flux - using Flux: onecold, onehotbatch, logitcrossentropy - using MLDatasets - using LinearAlgebra, Random, Statistics - import GraphMakie - import CairoMakie as Makie - using Graphs - using PlutoUI - using GraphNeuralNetworks -end - -# ╔═╡ 03a9e023-e682-4ea3-a10b-14c4d101b291 -md""" -*This Pluto notebook is a Julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* - -Recently, deep learning on graphs has emerged to one of the hottest research fields in the deep learning community. -Here, **Graph Neural Networks (GNNs)** aim to generalize classical deep learning concepts to irregular structured data (in contrast to images or texts) and to enable neural networks to reason about objects and their relations. - -This is done by following a simple **neural message passing scheme**, where node features ``\mathbf{x}_i^{(\ell)}`` of all nodes ``i \in \mathcal{V}`` in a graph ``\mathcal{G} = (\mathcal{V}, \mathcal{E})`` are iteratively updated by aggregating localized information from their neighbors ``\mathcal{N}(i)``: - -```math -\mathbf{x}_i^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_i^{(\ell)}, \left\{ \mathbf{x}_j^{(\ell)} : j \in \mathcal{N}(i) \right\} \right) -``` - -This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/CarloLucibello/GraphNeuralNetworks.jl)**. -GraphNeuralNetworks.jl is an extension library to the popular deep learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/), and consists of various methods and utilities to ease the implementation of Graph Neural Networks. - -Let's first import the packages we need: -""" - -# ╔═╡ 361e0948-d91a-11ec-2d95-2db77435a0c1 -begin - ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation - Random.seed!(17) # for reproducibility -end; - -# ╔═╡ ef96f5ae-724d-4b8e-b7d7-c116ad1c3279 -md""" -Following [Kipf et al. (2017)](https://arxiv.org/abs/1609.02907), let's dive into the world of GNNs by looking at a simple graph-structured example, the well-known [**Zachary's karate club network**](https://en.wikipedia.org/wiki/Zachary%27s_karate_club). This graph describes a social network of 34 members of a karate club and documents links between members who interacted outside the club. Here, we are interested in detecting communities that arise from the member's interaction. - -GraphNeuralNetworks.jl provides utilities to convert [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl)'s datasets to its own type: -""" - -# ╔═╡ 4ba372d4-7a6a-41e0-92a0-9547a78e2898 -dataset = MLDatasets.KarateClub() - -# ╔═╡ 55aca2f0-4bbb-4d3a-9777-703896cfc548 -md""" -After initializing the `KarateClub` dataset, we first can inspect some of its properties. -For example, we can see that this dataset holds exactly **one graph**. -Furthermore, the graph holds exactly **4 classes**, which represent the community each node belongs to. -""" - -# ╔═╡ a1d35896-0f52-4c8b-b7dc-ec65649237c8 -karate = dataset[1] - -# ╔═╡ 48d7df25-9190-45c9-9829-140f452e5151 -karate.node_data.labels_comm - -# ╔═╡ 4598bf67-5448-4ce5-8be8-a473ab1a6a07 -md""" -Now we convert the single-graph dataset to a `GNNGraph`. Moreover, we add a an array of node features, a **34-dimensional feature vector** for each node which uniquely describes the members of the karate club. We also add a training mask selecting the nodes to be used for training in our semi-supervised node classification task. -""" - -# ╔═╡ 8d41a9fa-eefe-40c9-8cc3-cd503cf7434d -begin - # convert a MLDataset.jl's dataset to a GNNGraphs (or a collection of graphs) - g = mldataset2gnngraph(dataset) - - x = zeros(Float32, g.num_nodes, g.num_nodes) - x[diagind(x)] .= 1 - - train_mask = [true, false, false, false, true, false, false, false, true, - false, false, false, false, false, false, false, false, false, false, false, - false, false, false, false, true, false, false, false, false, false, - false, false, false, false] - - labels = g.ndata.labels_comm - y = onehotbatch(labels, 0:3) - - g = GNNGraph(g, ndata = (; x, y, train_mask)) -end - -# ╔═╡ c42c7f73-f84e-4e72-9af4-a6421af57f0d -md""" -Let's now look at the underlying graph in more detail: -""" - -# ╔═╡ a7ad9de3-3e18-4aff-b118-a4d798a2f4ec -with_terminal() do - # Gather some statistics about the graph. - println("Number of nodes: $(g.num_nodes)") - println("Number of edges: $(g.num_edges)") - println("Average node degree: $(g.num_edges / g.num_nodes)") - println("Number of training nodes: $(sum(g.ndata.train_mask))") - println("Training node label rate: $(mean(g.ndata.train_mask))") - # println("Has isolated nodes: $(has_isolated_nodes(g))") - println("Has self-loops: $(has_self_loops(g))") - println("Is undirected: $(is_bidirected(g))") -end - -# ╔═╡ 1e362709-a0d0-45d5-b2fd-a91c45fa317a -md""" -Each graph in GraphNeuralNetworks.jl is represented by a `GNNGraph` object, which holds all the information to describe its graph representation. -We can print the data object anytime via `print(g)` to receive a short summary about its attributes and their shapes. - -The `g` object holds 3 attributes: -- `g.ndata`: contains node-related information. -- `g.edata`: holds edge-related information. -- `g.gdata`: this stores the global data, therefore neither node nor edge-specific features. - -These attributes are `NamedTuples` that can store multiple feature arrays: we can access a specific set of features e.g. `x`, with `g.ndata.x`. - - -In our task, `g.ndata.train_mask` describes for which nodes we already know their community assignments. In total, we are only aware of the ground-truth labels of 4 nodes (one for each community), and the task is to infer the community assignment for the remaining nodes. - -The `g` object also provides some **utility functions** to infer some basic properties of the underlying graph. -For example, we can easily infer whether there exist isolated nodes in the graph (*i.e.* there exists no edge to any node), whether the graph contains self-loops (*i.e.*, ``(v, v) \in \mathcal{E}``), or whether the graph is bidirected (*i.e.*, for each edge ``(v, w) \in \mathcal{E}`` there also exists the edge ``(w, v) \in \mathcal{E}``). - -Let us now inspect the `edge_index` method: - -""" - -# ╔═╡ d627736a-fd5a-4cdc-bd4e-89ff8b8c55bd -edge_index(g) - -# ╔═╡ 98bb86d2-a7b9-4110-8851-8829a9f9b4d0 -md""" -By printing `edge_index(g)`, we can understand how GraphNeuralNetworks.jl represents graph connectivity internally. -We can see that for each edge, `edge_index` holds a tuple of two node indices, where the first value describes the node index of the source node and the second value describes the node index of the destination node of an edge. - -This representation is known as the **COO format (coordinate format)** commonly used for representing sparse matrices. -Instead of holding the adjacency information in a dense representation ``\mathbf{A} \in \{ 0, 1 \}^{|\mathcal{V}| \times |\mathcal{V}|}``, GraphNeuralNetworks.jl represents graphs sparsely, which refers to only holding the coordinates/values for which entries in ``\mathbf{A}`` are non-zero. - -Importantly, GraphNeuralNetworks.jl does not distinguish between directed and undirected graphs, and treats undirected graphs as a special case of directed graphs in which reverse edges exist for every entry in the `edge_index`. - -Since a `GNNGraph` is an `AbstractGraph` from the `Graphs.jl` library, it supports graph algorithms and visualization tools from the wider julia graph ecosystem: -""" - -# ╔═╡ 9820cc77-ae0a-454a-86b6-a23dbc56b6fd -GraphMakie.graphplot(g |> to_unidirected, node_size = 20, node_color = labels, - arrow_show = false) - -# ╔═╡ 86135c51-950c-4c08-b9e0-6c892234ff87 -md""" - -## Implementing Graph Neural Networks - -After learning about GraphNeuralNetworks.jl's data handling, it's time to implement our first Graph Neural Network! - -For this, we will use on of the most simple GNN operators, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)), which is defined as - -```math -\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} -``` - -where ``\mathbf{W}^{(\ell + 1)}`` denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and ``c_{w,v}`` refers to a fixed normalization coefficient for each edge. - -GraphNeuralNetworks.jl implements this layer via `GCNConv`, which can be executed by passing in the node feature representation `x` and the COO graph connectivity representation `edge_index`. - -With this, we are ready to create our first Graph Neural Network by defining our network architecture: -""" - -# ╔═╡ 88d1e59f-73d6-46ee-87e8-35beb7bc7674 -begin - struct GCN - layers::NamedTuple - end - - Flux.@layer GCN # provides parameter collection, gpu movement and more - - function GCN(num_features, num_classes) - layers = (conv1 = GCNConv(num_features => 4), - conv2 = GCNConv(4 => 4), - conv3 = GCNConv(4 => 2), - classifier = Dense(2, num_classes)) - return GCN(layers) - end - - function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix) - l = gcn.layers - x = l.conv1(g, x) - x = tanh.(x) - x = l.conv2(g, x) - x = tanh.(x) - x = l.conv3(g, x) - x = tanh.(x) # Final GNN embedding space. - out = l.classifier(x) - # Apply a final (linear) classifier. - return out, x - end -end - -# ╔═╡ 9838189c-5cf6-4f21-b58e-3bb905408ad3 -md""" - -Here, we first initialize all of our building blocks in the constructor and define the computation flow of our network in the call method. -We first define and stack **three graph convolution layers**, which corresponds to aggregating 3-hop neighborhood information around each node (all nodes up to 3 "hops" away). -In addition, the `GCNConv` layers reduce the node feature dimensionality to ``2``, *i.e.*, ``34 \rightarrow 4 \rightarrow 4 \rightarrow 2``. Each `GCNConv` layer is enhanced by a `tanh` non-linearity. - -After that, we apply a single linear transformation (`Flux.Dense` that acts as a classifier to map our nodes to 1 out of the 4 classes/communities. - -We return both the output of the final classifier as well as the final node embeddings produced by our GNN. -We proceed to initialize our final model via `GCN()`, and printing our model produces a summary of all its used sub-modules. - -### Embedding the Karate Club Network - -Let's take a look at the node embeddings produced by our GNN. -Here, we pass in the initial node features `x` and the graph information `g` to the model, and visualize its 2-dimensional embedding. -""" - -# ╔═╡ ad2c2e51-08ec-4ddc-9b5c-668a3688db12 -begin - num_features = 34 - num_classes = 4 - gcn = GCN(num_features, num_classes) -end - -# ╔═╡ ce26c963-0438-4ab2-b5c6-520272beef2b -_, h = gcn(g, g.ndata.x) - -# ╔═╡ e545e74f-0a3c-4d18-9cc7-557ca60be567 -function visualize_embeddings(h; colors = nothing) - xs = h[1, :] |> vec - ys = h[2, :] |> vec - Makie.scatter(xs, ys, color = labels, markersize = 20) -end - -# ╔═╡ 26138606-2e8d-435b-aa1a-b6159a0d2739 -visualize_embeddings(h, colors = labels) - -# ╔═╡ b9359c7d-b7fe-412d-8f5e-55ba6bccb4e9 -md""" -Remarkably, even before training the weights of our model, the model produces an embedding of nodes that closely resembles the community-structure of the graph. -Nodes of the same color (community) are already closely clustered together in the embedding space, although the weights of our model are initialized **completely at random** and we have not yet performed any training so far! -This leads to the conclusion that GNNs introduce a strong inductive bias, leading to similar embeddings for nodes that are close to each other in the input graph. - -### Training on the Karate Club Network - -But can we do better? Let's look at an example on how to train our network parameters based on the knowledge of the community assignments of 4 nodes in the graph (one for each community). - -Since everything in our model is differentiable and parameterized, we can add some labels, train the model and observe how the embeddings react. -Here, we make use of a semi-supervised or transductive learning procedure: we simply train against one node per class, but are allowed to make use of the complete input graph data. - -Training our model is very similar to any other Flux model. -In addition to defining our network architecture, we define a loss criterion (here, `logitcrossentropy`), and initialize a stochastic gradient optimizer (here, `Adam`). -After that, we perform multiple rounds of optimization, where each round consists of a forward and backward pass to compute the gradients of our model parameters w.r.t. to the loss derived from the forward pass. -If you are not new to Flux, this scheme should appear familiar to you. - -Note that our semi-supervised learning scenario is achieved by the following line: -``` -loss = logitcrossentropy(ŷ[:,train_mask], y[:,train_mask]) -``` -While we compute node embeddings for all of our nodes, we **only make use of the training nodes for computing the loss**. -Here, this is implemented by filtering the output of the classifier `out` and ground-truth labels `data.y` to only contain the nodes in the `train_mask`. - -Let us now start training and see how our node embeddings evolve over time (best experienced by explicitly running the code): -""" - -# ╔═╡ 912560a1-9c72-47bd-9fce-9702b346b603 -begin - model = GCN(num_features, num_classes) - opt = Flux.setup(Adam(1e-2), model) - epochs = 2000 - - emb = h - function report(epoch, loss, h) - # p = visualize_embeddings(h) - @info (; epoch, loss) - end - - report(0, 10.0, emb) - for epoch in 1:epochs - loss, grad = Flux.withgradient(model) do model - ŷ, emb = model(g, g.ndata.x) - logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) - end - - Flux.update!(opt, model, grad[1]) - if epoch % 200 == 0 - report(epoch, loss, emb) - end - end -end - -# ╔═╡ c8a217c9-0087-41f0-90c8-aac29bc1c996 -ŷ, emb_final = model(g, g.ndata.x) - -# ╔═╡ 727b24bc-0b1e-4ebd-b8ef-987015751e38 -# train accuracy -mean(onecold(ŷ[:, train_mask]) .== onecold(y[:, train_mask])) - -# ╔═╡ 8c60ec7e-46b0-40f7-bf6a-6228a31e1f66 -# test accuracy -mean(onecold(ŷ[:, .!train_mask]) .== onecold(y[:, .!train_mask])) - -# ╔═╡ 44d9f8cf-1023-48ad-a01f-07e59f4b4226 -visualize_embeddings(emb_final, colors = labels) - -# ╔═╡ a8841d35-97f9-431d-acab-abf478ce91a9 -md""" -As one can see, our 3-layer GCN model manages to linearly separating the communities and classifying most of the nodes correctly. - -Furthermore, we did this all with a few lines of code, thanks to the GraphNeuralNetworks.jl which helped us out with data handling and GNN implementations. -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[compat] -CairoMakie = "~0.12.5" -Flux = "~0.14.16" -GraphMakie = "~0.5.12" -GraphNeuralNetworks = "~0.6.19" -Graphs = "~1.11.2" -MLDatasets = "~0.7.16" -PlutoUI = "~0.7.59" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "0bbe321bcd3061714ce11e8a8428022b3809de5f" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.AbstractPlutoDingetjes]] -deps = ["Pkg"] -git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a" -uuid = "6e696c72-6542-2067-7265-42206c756150" -version = "1.3.2" - -[[deps.AbstractTrees]] -git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.5" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.AliasTables]] -deps = ["PtrArrays", "Random"] -git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" -uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" -version = "1.1.3" - -[[deps.Animations]] -deps = ["Colors"] -git-tree-sha1 = "e81c509d2c8e49592413bfb0bb3b08150056c79d" -uuid = "27a7e980-b3e6-11e9-2bcd-0b925532e340" -version = "0.4.1" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.Automa]] -deps = ["PrecompileTools", "TranscodingStreams"] -git-tree-sha1 = "014bc22d6c400a7703c0f5dc1fdc302440cf88be" -uuid = "67c07d97-cdcb-5c2c-af73-a7f9c32a568b" -version = "1.0.4" - -[[deps.AxisAlgorithms]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] -git-tree-sha1 = "01b8ccb13d68535d73d2b0c23e39bd23155fb712" -uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" -version = "1.1.0" - -[[deps.AxisArrays]] -deps = ["Dates", "IntervalSets", "IterTools", "RangeArrays"] -git-tree-sha1 = "16351be62963a67ac4083f748fdb3cca58bfd52f" -uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -version = "0.4.7" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CRC32c]] -uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" - -[[deps.CRlibm_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc" -uuid = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" -version = "1.0.1+0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.Cairo]] -deps = ["Cairo_jll", "Colors", "Glib_jll", "Graphics", "Libdl", "Pango_jll"] -git-tree-sha1 = "d0b3f8b4ad16cb0a2988c6788646a5e6a17b6b1b" -uuid = "159f3aea-2a34-519c-b102-8c37f9878175" -version = "1.0.5" - -[[deps.CairoMakie]] -deps = ["CRC32c", "Cairo", "Cairo_jll", "Colors", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools"] -git-tree-sha1 = "e4da5095557f24713bae4c9f50e34ff4d3b959c0" -uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.12.5" - -[[deps.Cairo_jll]] -deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" -uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorBrewer]] -deps = ["Colors", "JSON", "Test"] -git-tree-sha1 = "61c5334f33d91e570e1d0c3eb5465835242582c4" -uuid = "a2cac450-b92f-5266-8821-25eda20663c8" -version = "0.4.0" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" -weakdeps = ["IntervalSets", "StaticArrays"] - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Contour]] -git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelaunayTriangulation]] -deps = ["EnumX", "ExactPredicates", "Random"] -git-tree-sha1 = "078c716cbb032242df18b960e8b1fec6b1b0b9f9" -uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" -version = "1.0.5" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.109" - - [deps.Distributions.extensions] - DistributionsChainRulesCoreExt = "ChainRulesCore" - DistributionsDensityInterfaceExt = "DensityInterface" - DistributionsTestExt = "Test" - - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - -[[deps.EarCut_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" -uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" -version = "2.2.4+0" - -[[deps.EnumX]] -git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" -uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -version = "1.0.4" - -[[deps.ExactPredicates]] -deps = ["IntervalArithmetic", "Random", "StaticArrays"] -git-tree-sha1 = "b3f2ff58735b5f024c392fde763f29b057e4b025" -uuid = "429591f6-91af-11e9-00e2-59fbe8cec110" -version = "2.2.8" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.2+0" - -[[deps.Extents]] -git-tree-sha1 = "94997910aca72897524d2237c41eb852153b0f65" -uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" -version = "0.1.3" - -[[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "ab3f7e1819dba9434a3a5126510c8fda3a4e7000" -uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "6.1.1+0" - -[[deps.FFTW]] -deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "4820348781ae578893311153d69049a93d05f39d" -uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.8.0" - -[[deps.FFTW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" -uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.10+0" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePaths]] -deps = ["FilePathsBase", "MacroTools", "Reexport", "Requires"] -git-tree-sha1 = "919d9412dbf53a2e6fe74af62a73ceed0bce0629" -uuid = "8fc22ac5-c921-52a6-82fd-178b2807b824" -version = "0.8.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" -weakdeps = ["PDMats", "SparseArrays", "Statistics"] - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] -git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" -uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.96+0" - -[[deps.Format]] -git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" -uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -version = "1.3.7" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FreeType]] -deps = ["CEnum", "FreeType2_jll"] -git-tree-sha1 = "907369da0f8e80728ab49c1c7e09327bf0d6d999" -uuid = "b38be410-82b0-50bf-ab77-7b57e271db43" -version = "4.1.1" - -[[deps.FreeType2_jll]] -deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" -uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.2+0" - -[[deps.FreeTypeAbstraction]] -deps = ["ColorVectorSpace", "Colors", "FreeType", "GeometryBasics"] -git-tree-sha1 = "2493cdfd0740015955a8e46de4ef28f49460d8bc" -uuid = "663a7486-cb36-511b-a19d-713bb74d65c9" -version = "0.10.3" - -[[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" -uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.14+0" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.GeoInterface]] -deps = ["Extents"] -git-tree-sha1 = "9fff8990361d5127b770e3454488360443019bb3" -uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.5" - -[[deps.GeometryBasics]] -deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "b62f2b2d76cee0d61a2ef2b3118cd2a3215d3134" -uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" -version = "0.4.11" - -[[deps.Gettext_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" -uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" -version = "0.21.0+0" - -[[deps.Glib_jll]] -deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" -uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphMakie]] -deps = ["DataStructures", "GeometryBasics", "Graphs", "LinearAlgebra", "Makie", "NetworkLayout", "PolynomialRoots", "SimpleTraits", "StaticArrays"] -git-tree-sha1 = "c8c3ece1211905888da48e16f438af85e951ea55" -uuid = "1ecd5474-83a3-4783-bb4f-06765db800d2" -version = "0.5.12" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphics]] -deps = ["Colors", "LinearAlgebra", "NaNMath"] -git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd" -uuid = "a2bd30eb-e257-5431-a919-1863eab51364" -version = "1.1.2" - -[[deps.Graphite2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" -uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" -version = "1.3.14+0" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.GridLayoutBase]] -deps = ["GeometryBasics", "InteractiveUtils", "Observables"] -git-tree-sha1 = "fc713f007cff99ff9e50accba6373624ddd33588" -uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" -version = "0.11.0" - -[[deps.Grisu]] -git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" -uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" -version = "1.0.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" -uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" - -[[deps.Hyperscript]] -deps = ["Test"] -git-tree-sha1 = "179267cfa5e712760cd43dcae385d7ea90cc25a4" -uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" -version = "0.0.5" - -[[deps.HypertextLiteral]] -deps = ["Tricks"] -git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" -uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" -version = "0.9.5" - -[[deps.IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.5" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageAxes]] -deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] -git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" -uuid = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac" -version = "0.6.11" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageIO]] -deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] -git-tree-sha1 = "437abb322a41d527c197fa800455f79d414f0a3c" -uuid = "82e4d734-157c-48bb-816b-45c225c6df19" -version = "0.6.8" - -[[deps.ImageMetadata]] -deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] -git-tree-sha1 = "355e2b974f2e3212a75dfb60519de21361ad3cb7" -uuid = "bc367c6b-8a6b-528e-b4bd-a4b897500b49" -version = "0.9.9" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Imath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "0936ba688c6d201805a83da835b55c61a180db52" -uuid = "905a6f67-0a94-5f89-b386-d35d92009cd1" -version = "3.1.11+0" - -[[deps.IndirectArrays]] -git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" -uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" -version = "1.0.0" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.IntelOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "14eb2b542e748570b56446f4c50fbfb2306ebc45" -uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2024.2.0+0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.Interpolations]] -deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "88a101217d7cb38a7b481ccd50d21876e1d1b0e0" -uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.15.1" -weakdeps = ["Unitful"] - - [deps.Interpolations.extensions] - InterpolationsUnitfulExt = "Unitful" - -[[deps.IntervalArithmetic]] -deps = ["CRlibm_jll", "MacroTools", "RoundingEmulator"] -git-tree-sha1 = "433b0bb201cd76cb087b017e49244f10394ebe9c" -uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" -version = "0.22.14" - - [deps.IntervalArithmetic.extensions] - IntervalArithmeticDiffRulesExt = "DiffRules" - IntervalArithmeticForwardDiffExt = "ForwardDiff" - IntervalArithmeticRecipesBaseExt = "RecipesBase" - - [deps.IntervalArithmetic.weakdeps] - DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" - -[[deps.IntervalSets]] -git-tree-sha1 = "dba9ddf07f77f60450fe5d2e2beb9854d9a49bd0" -uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.10" - - [deps.IntervalSets.extensions] - IntervalSetsRandomExt = "Random" - IntervalSetsRecipesBaseExt = "RecipesBase" - IntervalSetsStatisticsExt = "Statistics" - - [deps.IntervalSets.weakdeps] - Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.Isoband]] -deps = ["isoband_jll"] -git-tree-sha1 = "f9b6d97355599074dc867318950adaa6f9946137" -uuid = "f1662d9f-8043-43de-a69a-05efc1cc6ff4" -version = "0.1.1" - -[[deps.IterTools]] -git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" -uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" -version = "1.10.0" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JpegTurbo]] -deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" -uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.5" - -[[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" -uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KernelDensity]] -deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] -git-tree-sha1 = "7d703202e65efa1369de1279c162b915e245eed1" -uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" -version = "0.6.9" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" -uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.2+0" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LLVMOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" -uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.7+0" - -[[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" -uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libffi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" -uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+1" - -[[deps.Libgcrypt_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] -git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" -uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.11+0" - -[[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" -uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.49.0+0" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.Libmount_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" -uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.1+0" - -[[deps.Libuuid_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" -uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.1+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MIMEs]] -git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" -uuid = "6c6e2e6c-3030-632d-7369-2d6c69616d65" -version = "0.1.4" - -[[deps.MKL_jll]] -deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "oneTBB_jll"] -git-tree-sha1 = "f046ccd0c6db2832a9f639e2c669c6fe867e5f4f" -uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2024.2.0+0" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.Makie]] -deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "Dates", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FilePaths", "FixedPointNumbers", "Format", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Scratch", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun", "Unitful"] -git-tree-sha1 = "863b9e666b5a099c8835e85476a5834f9d77c4c1" -uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.21.5" - -[[deps.MakieCore]] -deps = ["ColorTypes", "GeometryBasics", "IntervalSets", "Observables"] -git-tree-sha1 = "c1c950560397ee68ad7302ee0e3efa1b07466a2f" -uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -version = "0.8.4" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MathTeXEngine]] -deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "UnicodeFun"] -git-tree-sha1 = "e1641f32ae592e415e3dbae7f4a188b5316d4b62" -uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53" -version = "0.6.1" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.Netpbm]] -deps = ["FileIO", "ImageCore", "ImageMetadata"] -git-tree-sha1 = "d92b107dbb887293622df7697a2223f9f8176fcd" -uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" -version = "1.1.1" - -[[deps.NetworkLayout]] -deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "StaticArrays"] -git-tree-sha1 = "91bb2fedff8e43793650e7a677ccda6e6e6e166b" -uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" -version = "0.4.6" -weakdeps = ["Graphs"] - - [deps.NetworkLayout.extensions] - NetworkLayoutGraphsExt = "Graphs" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.Observables]] -git-tree-sha1 = "7438a59546cf62428fc9d1bc94729146d37a7225" -uuid = "510215fc-4207-5dde-b226-833fc4488ee2" -version = "0.5.5" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.Ogg_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" -uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" -version = "1.3.5+1" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenEXR]] -deps = ["Colors", "FileIO", "OpenEXR_jll"] -git-tree-sha1 = "327f53360fdb54df7ecd01e96ef1983536d1e633" -uuid = "52e1d378-f018-4a11-a4be-720524705ac7" -version = "0.3.2" - -[[deps.OpenEXR_jll]] -deps = ["Artifacts", "Imath_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "8292dd5c8a38257111ada2174000a33745b06d4e" -uuid = "18a262bb-aa17-5467-a713-aee519bc75cb" -version = "3.2.4+0" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" -uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.31" - -[[deps.PNGFiles]] -deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "67186a2bc9a90f9f85ff3cc8277868961fb57cbd" -uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.3" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.Packing]] -deps = ["GeometryBasics"] -git-tree-sha1 = "ec3edfe723df33528e085e632414499f26650501" -uuid = "19eb6ba3-879d-56ad-ad62-d5c202156566" -version = "0.5.0" - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Pango_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] -git-tree-sha1 = "cb5a2ab6763464ae0f19c86c56c63d4a2b0f5bda" -uuid = "36c8627f-9965-5494-a995-c6b170f724f3" -version = "1.52.2+0" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pixman_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" -uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.43.4+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PkgVersion]] -deps = ["Pkg"] -git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da" -uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" -version = "0.3.3" - -[[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" -uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" - -[[deps.PlutoUI]] -deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] -git-tree-sha1 = "ab55ee1510ad2af0ff674dbcced5e94921f867a9" -uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -version = "0.7.59" - -[[deps.PolygonOps]] -git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" -uuid = "647866c9-e3ac-4575-94e7-e3d426903924" -version = "0.1.2" - -[[deps.PolynomialRoots]] -git-tree-sha1 = "5f807b5345093487f733e520a1b7395ee9324825" -uuid = "3a141323-8675-5d76-9d11-e1df1406c778" -version = "1.0.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.2" - -[[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" -uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" - -[[deps.QOI]] -deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] -git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" -uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" -version = "1.0.0" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RangeArrays]] -git-tree-sha1 = "b9039e93773ddcfc828f12aadf7115b4b4d225f5" -uuid = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d" -version = "0.3.2" - -[[deps.Ratios]] -deps = ["Requires"] -git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b" -uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" -version = "0.4.5" -weakdeps = ["FixedPointNumbers"] - - [deps.Ratios.extensions] - RatiosFixedPointNumbersExt = "FixedPointNumbers" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.1" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.2+0" - -[[deps.RoundingEmulator]] -git-tree-sha1 = "40b9edad2e5287e05bd413a38f61a8ff55b9557b" -uuid = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705" -version = "0.2.1" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.SIMD]] -deps = ["PrecompileTools"] -git-tree-sha1 = "2803cab51702db743f3fda07dd1745aadfbf43bd" -uuid = "fdea26ae-647d-5447-a871-4b548cad5224" -version = "3.5.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.ShaderAbstractions]] -deps = ["ColorTypes", "FixedPointNumbers", "GeometryBasics", "LinearAlgebra", "Observables", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "79123bc60c5507f035e6d1d9e563bb2971954ec8" -uuid = "65257c39-d410-5151-9873-9b3e5be5013e" -version = "0.4.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.Showoff]] -deps = ["Dates", "Grisu"] -git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" -uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" -version = "1.0.3" - -[[deps.SignedDistanceFields]] -deps = ["Random", "Statistics", "Test"] -git-tree-sha1 = "d263a08ec505853a5ff1c1ebde2070419e3f28e9" -uuid = "73760f76-fbc4-59ce-8f25-708e95d2df96" -version = "0.4.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sixel]] -deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] -git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" -uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" -version = "0.1.3" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StatsFuns]] -deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" -weakdeps = ["ChainRulesCore", "InverseFunctions"] - - [deps.StatsFuns.extensions] - StatsFunsChainRulesCoreExt = "ChainRulesCore" - StatsFunsInverseFunctionsExt = "InverseFunctions" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TiffImages]] -deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "SIMD", "UUIDs"] -git-tree-sha1 = "bc7fd5c91041f44636b2c134041f7e5263ce58ae" -uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.10.0" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.Tricks]] -git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" -uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" -version = "0.1.8" - -[[deps.TriplotBase]] -git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" -uuid = "981d1d27-644d-49a2-9326-4793e63143c3" -version = "0.1.0" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnicodeFun]] -deps = ["REPL"] -git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" -uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" -version = "0.4.1" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WoodburyMatrices]] -deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "c1a7aa6219628fcd757dede0ca95e245c5cd9511" -uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" -version = "1.0.0" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.XML2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" -uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.1+0" - -[[deps.XSLT_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" -uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.41+0" - -[[deps.Xorg_libX11_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] -git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" -uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" -version = "1.8.6+0" - -[[deps.Xorg_libXau_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" -uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" -version = "1.0.11+0" - -[[deps.Xorg_libXdmcp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" -uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" -version = "1.1.4+0" - -[[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" -uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.6+0" - -[[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" -uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.11+0" - -[[deps.Xorg_libpthread_stubs_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" -uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" -version = "0.1.1+0" - -[[deps.Xorg_libxcb_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" -uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.17.0+0" - -[[deps.Xorg_xtrans_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" -uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.5.0+0" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.isoband_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51b5eeb3f98367157a7a12a1fb0aa5328946c03c" -uuid = "9a68df92-36a6-505f-a73e-abb412b6bfb4" -version = "0.2.3+0" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" -uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.9.0+0" - -[[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" -uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" -uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" - -[[deps.libpng_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" -uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" - -[[deps.libsixel_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "libpng_jll"] -git-tree-sha1 = "d4f63314c8aa1e48cd22aa0c17ed76cd1ae48c3c" -uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" -version = "1.10.3+0" - -[[deps.libvorbis_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" -uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+2" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.oneTBB_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "7d0ea0f4895ef2f5cb83645fa689e52cb55cf493" -uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e" -version = "2021.12.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" - -[[deps.x264_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" -uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" -version = "2021.5.5+0" - -[[deps.x265_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" -uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" -version = "3.5.0+0" -""" - -# ╔═╡ Cell order: -# ╟─03a9e023-e682-4ea3-a10b-14c4d101b291 -# ╠═42c84361-222a-46c4-b81f-d33eb41635c9 -# ╠═361e0948-d91a-11ec-2d95-2db77435a0c1 -# ╟─ef96f5ae-724d-4b8e-b7d7-c116ad1c3279 -# ╠═4ba372d4-7a6a-41e0-92a0-9547a78e2898 -# ╟─55aca2f0-4bbb-4d3a-9777-703896cfc548 -# ╠═a1d35896-0f52-4c8b-b7dc-ec65649237c8 -# ╠═48d7df25-9190-45c9-9829-140f452e5151 -# ╟─4598bf67-5448-4ce5-8be8-a473ab1a6a07 -# ╠═8d41a9fa-eefe-40c9-8cc3-cd503cf7434d -# ╟─c42c7f73-f84e-4e72-9af4-a6421af57f0d -# ╠═a7ad9de3-3e18-4aff-b118-a4d798a2f4ec -# ╟─1e362709-a0d0-45d5-b2fd-a91c45fa317a -# ╠═d627736a-fd5a-4cdc-bd4e-89ff8b8c55bd -# ╟─98bb86d2-a7b9-4110-8851-8829a9f9b4d0 -# ╠═9820cc77-ae0a-454a-86b6-a23dbc56b6fd -# ╟─86135c51-950c-4c08-b9e0-6c892234ff87 -# ╠═88d1e59f-73d6-46ee-87e8-35beb7bc7674 -# ╟─9838189c-5cf6-4f21-b58e-3bb905408ad3 -# ╠═ad2c2e51-08ec-4ddc-9b5c-668a3688db12 -# ╠═ce26c963-0438-4ab2-b5c6-520272beef2b -# ╠═e545e74f-0a3c-4d18-9cc7-557ca60be567 -# ╠═26138606-2e8d-435b-aa1a-b6159a0d2739 -# ╟─b9359c7d-b7fe-412d-8f5e-55ba6bccb4e9 -# ╠═912560a1-9c72-47bd-9fce-9702b346b603 -# ╠═c8a217c9-0087-41f0-90c8-aac29bc1c996 -# ╠═727b24bc-0b1e-4ebd-b8ef-987015751e38 -# ╠═8c60ec7e-46b0-40f7-bf6a-6228a31e1f66 -# ╠═44d9f8cf-1023-48ad-a01f-07e59f4b4226 -# ╟─a8841d35-97f9-431d-acab-abf478ce91a9 -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002 - -[.\docs\tutorials\introductory_tutorials\graph_classification_pluto.jl] -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Carlo Lucibello](https://github.com/CarloLucibello)" -#> title = "Graph Classification with Graph Neural Networks" -#> date = "2022-05-23" -#> description = "Tutorial for Graph Classification using GraphNeuralNetworks.jl" -#> cover = "assets/graph_classification.gif" - -using Markdown -using InteractiveUtils - -# ╔═╡ 361e0948-d91a-11ec-2d95-2db77435a0c1 -# ╠═╡ show_logs = false -begin - using Flux - using Flux: onecold, onehotbatch, logitcrossentropy - using Flux: DataLoader - using GraphNeuralNetworks - using MLDatasets - using MLUtils - using LinearAlgebra, Random, Statistics - - ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation - Random.seed!(17) # for reproducibility -end; - -# ╔═╡ 15136fd8-f9b2-4841-9a95-9de7b8969687 -md""" -*This Pluto notebook is a julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* - -In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. -Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties. -Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. - - -The most common task for graph classification is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. - -The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. -Let's load and inspect one of the smaller ones, the **MUTAG dataset**: -""" - -# ╔═╡ f6e86958-e96f-4c77-91fc-c72d8967575c -dataset = TUDataset("MUTAG") - -# ╔═╡ 24f76360-8599-46c8-a49f-4c31f02eb7d8 -dataset.graph_data.targets |> union - -# ╔═╡ 5d5e5152-c860-4158-8bc7-67ee1022f9f8 -g1, y1 = dataset[1] #get the first graph and target - -# ╔═╡ 33163dd2-cb35-45c7-ae5b-d4854d141773 -reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union - -# ╔═╡ a8d6a133-a828-4d51-83c4-fb44f9d5ede1 -reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union - -# ╔═╡ 3b3e0a79-264b-47d7-8bda-2a6db7290828 -md""" -This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. - -By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. -It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). -However, for the sake of simplicity, we will not make use of edge labels. -""" - -# ╔═╡ 7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb -md""" -We now convert the MLDatasets.jl graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): -""" - -# ╔═╡ 936c09f6-ee62-4bc2-a0c6-749a66080fd2 -begin - graphs = mldataset2gnngraph(dataset) - graphs = [GNNGraph(g, - ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), - edata = nothing) - for g in graphs] - y = onehotbatch(dataset.graph_data.targets, [-1, 1]) -end - -# ╔═╡ 2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4 -md"""We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: -""" - -# ╔═╡ 519477b2-8323-4ece-a7eb-141e9841117c -train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs - -# ╔═╡ 3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b -begin - train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) - test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) -end - -# ╔═╡ f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4 -md""" -Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all ``4 \cdot 32+22 = 150`` graphs. -""" - -# ╔═╡ 2a1c501e-811b-4ddd-887b-91e8c929c8b7 -md""" -## Mini-batching of graphs - -Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. -In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. -The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. - -However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. -Therefore, GraphNeuralNetworks.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). - -This procedure has some crucial advantages over other batching procedures: - -1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. - -2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. - -GraphNeuralNetworks.jl can **batch multiple graphs into a single giant graph**: -""" - -# ╔═╡ a142610a-d862-42a9-88af-c8d8b6825650 -vec_gs, _ = first(train_loader) - -# ╔═╡ 6faaf637-a0ff-468c-86b5-b0a7250258d6 -MLUtils.batch(vec_gs) - -# ╔═╡ e314b25f-e904-4c39-bf60-24cddf91fe9d -md""" -Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: - -```math -\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] -``` -""" - -# ╔═╡ ac69571a-998b-4630-afd6-f3d405618bc5 -md""" -## Training a Graph Neural Network (GNN) - -Training a GNN for graph classification usually follows a simple recipe: - -1. Embed each node by performing multiple rounds of message passing -2. Aggregate node embeddings into a unified graph embedding (**readout layer**) -3. Train a final classifier on the graph embedding - -There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: - -```math -\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v -``` - -GraphNeuralNetworks.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. - -The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: -""" - -# ╔═╡ 04402032-18a4-42b5-ad04-19b286bd29b7 -function create_model(nin, nh, nout) - GNNChain(GCNConv(nin => nh, relu), - GCNConv(nh => nh, relu), - GCNConv(nh => nh), - GlobalPool(mean), - Dropout(0.5), - Dense(nh, nout)) -end - -# ╔═╡ 2313fd8d-6e84-4bde-bacc-fb697dc33cbb -md""" -Here, we again make use of the `GCNConv` with ``\mathrm{ReLU}(x) = \max(x, 0)`` activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. - -Let's train our network for a few epochs to see how well it performs on the training as well as test set: -""" - -# ╔═╡ c956ed97-fa5c-45c6-84dd-39f3e37d8070 -function eval_loss_accuracy(model, data_loader, device) - loss = 0.0 - acc = 0.0 - ntot = 0 - for (g, y) in data_loader - g, y = MLUtils.batch(g) |> device, y |> device - n = length(y) - ŷ = model(g, g.ndata.x) - loss += logitcrossentropy(ŷ, y) * n - acc += mean((ŷ .> 0) .== y) * n - ntot += n - end - return (loss = round(loss / ntot, digits = 4), - acc = round(acc * 100 / ntot, digits = 2)) -end - -# ╔═╡ 968c7087-7637-4844-9509-dd838cf99a8c -function train!(model; epochs = 200, η = 1e-2, infotime = 10) - # device = Flux.gpu # uncomment this for GPU training - device = Flux.cpu - model = model |> device - opt = Flux.setup(Adam(1e-3), model) - - function report(epoch) - train = eval_loss_accuracy(model, train_loader, device) - test = eval_loss_accuracy(model, test_loader, device) - @info (; epoch, train, test) - end - - report(0) - for epoch in 1:epochs - for (g, y) in train_loader - g, y = MLUtils.batch(g) |> device, y |> device - grad = Flux.gradient(model) do model - ŷ = model(g, g.ndata.x) - logitcrossentropy(ŷ, y) - end - Flux.update!(opt, model, grad[1]) - end - epoch % infotime == 0 && report(epoch) - end -end - -# ╔═╡ dedf18d8-4281-49fa-adaf-bd57fc15095d -begin - nin = 7 - nh = 64 - nout = 2 - model = create_model(nin, nh, nout) - train!(model) -end - -# ╔═╡ 3454b311-9545-411d-b47a-b43724b84c36 -md""" -As one can see, our model reaches around **74% test accuracy**. -Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. - -## (Optional) Exercise - -Can we do better than this? -As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. -An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: - -```math -\mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} -``` - -This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. - -As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. -This should bring you close to **82% test accuracy**. -""" - -# ╔═╡ 93e08871-2929-4279-9f8a-587168617365 -md""" -## Conclusion - -In this chapter, you have learned how to apply GNNs to the task of graph classification. -You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[compat] -Flux = "~0.14.16" -GraphNeuralNetworks = "~0.6.19" -MLDatasets = "~0.7.16" -MLUtils = "~0.4.4" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "4d31565cd40e53ce5e158a179486a694e9c7da67" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" -""" - -# ╔═╡ Cell order: -# ╠═361e0948-d91a-11ec-2d95-2db77435a0c1 -# ╟─15136fd8-f9b2-4841-9a95-9de7b8969687 -# ╠═f6e86958-e96f-4c77-91fc-c72d8967575c -# ╠═24f76360-8599-46c8-a49f-4c31f02eb7d8 -# ╠═5d5e5152-c860-4158-8bc7-67ee1022f9f8 -# ╠═33163dd2-cb35-45c7-ae5b-d4854d141773 -# ╠═a8d6a133-a828-4d51-83c4-fb44f9d5ede1 -# ╟─3b3e0a79-264b-47d7-8bda-2a6db7290828 -# ╟─7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb -# ╠═936c09f6-ee62-4bc2-a0c6-749a66080fd2 -# ╟─2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4 -# ╠═519477b2-8323-4ece-a7eb-141e9841117c -# ╠═3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b -# ╟─f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4 -# ╟─2a1c501e-811b-4ddd-887b-91e8c929c8b7 -# ╠═a142610a-d862-42a9-88af-c8d8b6825650 -# ╠═6faaf637-a0ff-468c-86b5-b0a7250258d6 -# ╟─e314b25f-e904-4c39-bf60-24cddf91fe9d -# ╟─ac69571a-998b-4630-afd6-f3d405618bc5 -# ╠═04402032-18a4-42b5-ad04-19b286bd29b7 -# ╟─2313fd8d-6e84-4bde-bacc-fb697dc33cbb -# ╠═c956ed97-fa5c-45c6-84dd-39f3e37d8070 -# ╠═968c7087-7637-4844-9509-dd838cf99a8c -# ╠═dedf18d8-4281-49fa-adaf-bd57fc15095d -# ╟─3454b311-9545-411d-b47a-b43724b84c36 -# ╟─93e08871-2929-4279-9f8a-587168617365 -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002 - -[.\docs\tutorials\introductory_tutorials\node_classification_pluto.jl] -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Deeptendu Santra](https://github.com/Dsantra92)" -#> title = "Node Classification with Graph Neural Networks" -#> date = "2022-09-25" -#> description = "Tutorial for Node classification using GraphNeuralNetworks.jl" -#> cover = "assets/node_classsification.gif" - -using Markdown -using InteractiveUtils - -# ╔═╡ 5463330a-0161-11ed-1b18-936030a32bbf -# ╠═╡ show_logs = false -begin - using MLDatasets - using GraphNeuralNetworks - using Flux - using Flux: onecold, onehotbatch, logitcrossentropy - using Plots - using PlutoUI - using TSne - using Random - using Statistics - - ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" - Random.seed!(17) # for reproducibility -end; - -# ╔═╡ ca2f0293-7eac-4d9a-9a2f-fda47fd95a99 -md""" -In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). -""" - -# ╔═╡ 4455f18c-2bd9-42ed-bce3-cfe6561eab23 -md""" -## Import -Let us start off by importing some libraries. We will be using Flux.jl and `GraphNeuralNetworks.jl` for our tutorial. -""" - -# ╔═╡ 0d556a7c-d4b6-4cef-806c-3e1712de0791 -md""" -## Visualize -We want to visualize the the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane. -""" - -# ╔═╡ 997b5387-3811-4998-a9d1-7981b58b9e09 -function visualize_tsne(out, targets) - z = tsne(out, 2) - scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) -end - -# ╔═╡ 4b6fa18d-7ccd-4c07-8dc3-ded4d7da8562 -md""" -## Dataset: Cora - -For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represent articles/documents and the edges between these nodes if one of them cite each other. - -Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. - -This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. -""" - -# ╔═╡ edab1e3a-31f6-471f-9835-5b1f97e5cf3f -dataset = Cora() - -# ╔═╡ d73a2db5-9417-4b2c-a9f5-b7d499a53fcb -md""" -Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. -""" - -# ╔═╡ 32bb90c1-c802-4c0c-a620-5d3b8f3f2477 -dataset.metadata - -# ╔═╡ 3438ee7f-bfca-465d-85df-13379622d415 -md""" -The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph. -""" - -# ╔═╡ eec6fb60-0774-4f2a-bcb7-dbc28ab747a6 -dataset.graphs - -# ╔═╡ bd2fd04d-7fb0-4b31-959b-bddabe681754 -md""" -There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. -""" - -# ╔═╡ b29c3a02-c21b-4b10-aa04-b90bcc2931d8 -g = mldataset2gnngraph(dataset) - -# ╔═╡ 16d9fbad-d4dc-4b51-9576-1736d228e2b3 -with_terminal() do - # Gather some statistics about the graph. - println("Number of nodes: $(g.num_nodes)") - println("Number of edges: $(g.num_edges)") - println("Average node degree: $(g.num_edges / g.num_nodes)") - println("Number of training nodes: $(sum(g.ndata.train_mask))") - println("Training node label rate: $(mean(g.ndata.train_mask))") - # println("Has isolated nodes: $(has_isolated_nodes(g))") - println("Has self-loops: $(has_self_loops(g))") - println("Is undirected: $(is_bidirected(g))") -end - -# ╔═╡ 923d061c-25c3-4826-8147-9afa3dbd5bac -md""" -Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. -We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. -For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). -This results in a training node label rate of only 5%. - -We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). -""" - -# ╔═╡ 28e00b95-56db-4d36-a205-fd24d3c54e17 -begin - x = g.ndata.features - # we onehot encode both the node labels (what we want to predict): - y = onehotbatch(g.ndata.targets, 1:7) - train_mask = g.ndata.train_mask - num_features = size(x)[1] - hidden_channels = 16 - num_classes = dataset.metadata["num_classes"] -end; - -# ╔═╡ fa743000-604f-4d28-99f1-46ab2f884b8e -md""" -## Multi-layer Perception Network (MLP) - -In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. - -Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): -""" - -# ╔═╡ f972f61b-2001-409b-9190-ac2c0652829a -begin - struct MLP - layers::NamedTuple - end - - Flux.@layer :expand MLP - - function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5) - layers = (hidden = Dense(num_features => hidden_channels), - drop = Dropout(drop_rate), - classifier = Dense(hidden_channels => num_classes)) - return MLP(layers) - end - - function (model::MLP)(x::AbstractMatrix) - l = model.layers - x = l.hidden(x) - x = relu(x) - x = l.drop(x) - x = l.classifier(x) - return x - end -end - -# ╔═╡ 4dade64a-e28e-42c7-8ad5-93fc04724d4d -md""" -### Training a Multilayer Perceptron - -Our MLP is defined by two linear layers and enhanced by [ReLU](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib.relu) non-linearity and [Dropout](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Dropout). -Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. - -Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/tutorials/introductory_tutorials/gnn_intro_pluto/#Hands-on-introduction-to-Graph-Neural-Networks). -We again make use of the **cross entropy loss** and **Adam optimizer**. -This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). -""" - -# ╔═╡ 05979cfe-439c-4abc-90cd-6ca2a05f6e0f -function train(model::MLP, data::AbstractMatrix, epochs::Int, opt) - Flux.trainmode!(model) - - for epoch in 1:epochs - loss, grad = Flux.withgradient(model) do model - ŷ = model(data) - logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) - end - - Flux.update!(opt, model, grad[1]) - if epoch % 200 == 0 - @show epoch, loss - end - end -end - -# ╔═╡ a3f420e1-7521-4df9-b6d5-fc0a1fd05095 -function accuracy(model::MLP, x::AbstractMatrix, y::Flux.OneHotArray, mask::BitVector) - Flux.testmode!(model) - mean(onecold(model(x))[mask] .== onecold(y)[mask]) -end - -# ╔═╡ b18384fe-b8ae-4f51-bd73-d129d5e70f98 -md""" -After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. -Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: -""" - -# ╔═╡ 54a2972e-b107-47c8-bf7e-eb51b4ccbe02 -md""" -As one can see, our MLP performs rather bad with only about 47% test accuracy. -But why does the MLP do not perform better? -The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. - -It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. -That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. -""" - -# ╔═╡ 623e7b53-046c-4858-89d9-13caae45255d -md""" -## Training a Graph Convolutional Neural Network (GNN) - -Following-up on [the first part of this tutorial](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/tutorials/introductory_tutorials/node_classification_pluto/#Multi-layer-Perception-Network-(MLP)), we replace the `Dense` linear layers by the [`GCNConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GCNConv) module. -To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as - -```math -\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} -``` - -where ``\mathbf{W}^{(\ell + 1)}`` denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. -In contrast, a single `Linear` layer is defined as - -```math -\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} -``` - -which does not make use of neighboring node information. -""" - -# ╔═╡ eb36a46c-f139-425e-8a93-207bc4a16f89 -begin - struct GCN - layers::NamedTuple - end - - Flux.@layer GCN # provides parameter collection, gpu movement and more - - function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5) - layers = (conv1 = GCNConv(num_features => hidden_channels), - drop = Dropout(drop_rate), - conv2 = GCNConv(hidden_channels => num_classes)) - return GCN(layers) - end - - function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix) - l = gcn.layers - x = l.conv1(g, x) - x = relu.(x) - x = l.drop(x) - x = l.conv2(g, x) - return x - end -end - -# ╔═╡ 20b5f802-abce-49e1-a442-f381e80c0f85 -md""" -Now let's visualize the node embeddings of our **untrained** GCN network. -""" - -# ╔═╡ b295adce-b37e-45f3-963a-3699d714e36d -# ╠═╡ show_logs = false -begin - gcn = GCN(num_features, num_classes, hidden_channels) - h_untrained = gcn(g, x) |> transpose - visualize_tsne(h_untrained, g.ndata.targets) -end - -# ╔═╡ 5538970f-b273-4122-9d50-7deb049e6934 -md""" -We certainly can do better by training our model. -The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. -""" - -# ╔═╡ 901d9478-9a12-4122-905d-6cfc6d80e84c -function train(model::GCN, g::GNNGraph, x::AbstractMatrix, epochs::Int, opt) - Flux.trainmode!(model) - - for epoch in 1:epochs - loss, grad = Flux.withgradient(model) do model - ŷ = model(g, x) - logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) - end - - Flux.update!(opt, model, grad[1]) - if epoch % 200 == 0 - @show epoch, loss - end - end -end - -# ╔═╡ 026911dd-6a27-49ce-9d41-21e01646c10a -# ╠═╡ show_logs = false -begin - mlp = MLP(num_features, num_classes, hidden_channels) - opt_mlp = Flux.setup(Adam(1e-3), mlp) - epochs = 2000 - train(mlp, g.ndata.features, epochs, opt_mlp) -end - -# ╔═╡ 65d9fd3d-1649-4b95-a106-f26fa4ab9bce -function accuracy(model::GCN, g::GNNGraph, x::AbstractMatrix, y::Flux.OneHotArray, - mask::BitVector) - Flux.testmode!(model) - mean(onecold(model(g, x))[mask] .== onecold(y)[mask]) -end - -# ╔═╡ b2302697-1e20-4721-ae93-0b121ff9ce8f -accuracy(mlp, g.ndata.features, y, .!train_mask) - -# ╔═╡ 20be52b1-1c33-4f54-b5c0-fecc4e24fbb5 -# ╠═╡ show_logs = false -begin - opt_gcn = Flux.setup(Adam(1e-2), gcn) - train(gcn, g, x, epochs, opt_gcn) -end - -# ╔═╡ 5aa99aff-b5ed-40ec-a7ec-0ba53385e6bd -md""" -Now let's evaluate the loss of our trained GCN. -""" - -# ╔═╡ 2163d0d8-0661-4d11-a09e-708769011d35 -with_terminal() do - train_accuracy = accuracy(gcn, g, g.ndata.features, y, train_mask) - test_accuracy = accuracy(gcn, g, g.ndata.features, y, .!train_mask) - - println("Train accuracy: $(train_accuracy)") - println("Test accuracy: $(test_accuracy)") -end - -# ╔═╡ 6cd49f3f-a415-4b6a-9323-4d6aa6b87f18 -md""" -**There it is!** -By simply swapping the linear layers with GNN layers, we can reach **75.77% of test accuracy**! -This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. - -We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. -""" - -# ╔═╡ 7a93a802-6774-42f9-b6da-7ae614464e72 -# ╠═╡ show_logs = false -begin - Flux.testmode!(gcn) # inference mode - - out_trained = gcn(g, x) |> transpose - visualize_tsne(out_trained, g.ndata.targets) -end - -# ╔═╡ 50a409fd-d80b-4c48-a51b-173c39a6dcb4 -md""" -## (Optional) Exercises - -1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. - -2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? - -3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. -""" - -# ╔═╡ c343419f-a1d7-45a0-b600-2c868588b33a -md""" -## Conclusion -In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -TSne = "24678dba-d5e9-5843-a4c6-250288b04835" - -[compat] -Flux = "~0.14.16" -GraphNeuralNetworks = "~0.6.19" -MLDatasets = "~0.7.16" -Plots = "~1.40.5" -PlutoUI = "~0.7.59" -TSne = "~1.3.0" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "fb2b669c9e43473fabf01e07c834a510ae36fa5e" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.AbstractPlutoDingetjes]] -deps = ["Pkg"] -git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a" -uuid = "6e696c72-6542-2067-7265-42206c756150" -version = "1.3.2" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.Cairo_jll]] -deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" -uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Contour]] -git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.EpollShim_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" -uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" -version = "0.0.20230411+0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.2+0" - -[[deps.FFMPEG]] -deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" -uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" - -[[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" -uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "4.4.4+1" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] -git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" -uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.96+0" - -[[deps.Format]] -git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" -uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -version = "1.3.7" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FreeType2_jll]] -deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" -uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.2+0" - -[[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" -uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.14+0" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" -uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.4.0+0" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GR]] -deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" -uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" - -[[deps.GR_jll]] -deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" -uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.Gettext_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" -uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" -version = "0.21.0+0" - -[[deps.Glib_jll]] -deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" -uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphite2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" -uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" -version = "1.3.14+0" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.Grisu]] -git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" -uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" -version = "1.0.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.2+1" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" -uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.Hyperscript]] -deps = ["Test"] -git-tree-sha1 = "179267cfa5e712760cd43dcae385d7ea90cc25a4" -uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" -version = "0.0.5" - -[[deps.HypertextLiteral]] -deps = ["Tricks"] -git-tree-sha1 = "7134810b1afce04bbc1045ca1985fbe81ce17653" -uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" -version = "0.9.5" - -[[deps.IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.5" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLFzf]] -deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" -uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.7" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" -uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" -uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.2+0" - -[[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" -uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LLVMOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" -uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.7+0" - -[[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" -uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.Latexify]] -deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] -git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" -uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.16.4" - - [deps.Latexify.extensions] - DataFramesExt = "DataFrames" - SymEngineExt = "SymEngine" - - [deps.Latexify.weakdeps] - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libffi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" -uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+1" - -[[deps.Libgcrypt_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] -git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" -uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.11+0" - -[[deps.Libglvnd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] -git-tree-sha1 = "6f73d1dd803986947b2c750138528a999a6c7733" -uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" -version = "1.6.0+0" - -[[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" -uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.49.0+0" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.Libmount_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" -uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.1+0" - -[[deps.Libtiff_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" -uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" - -[[deps.Libuuid_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" -uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.1+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MIMEs]] -git-tree-sha1 = "65f28ad4b594aebe22157d6fac869786a255b7eb" -uuid = "6c6e2e6c-3030-632d-7369-2d6c69616d65" -version = "0.1.4" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.Measures]] -git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" -uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" -version = "0.3.2" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.Ogg_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" -uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" -version = "1.3.5+1" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] -git-tree-sha1 = "2f0a1d8c79bc385ec3fcda12830c9d0e72b30e71" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "5.0.4+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" -uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pipe]] -git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" -uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" -version = "1.3.0" - -[[deps.Pixman_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" -uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.43.4+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PlotThemes]] -deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" -uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.2.0" - -[[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" -uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" - -[[deps.Plots]] -deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] -git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" -uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.40.5" - - [deps.Plots.extensions] - FileIOExt = "FileIO" - GeometryBasicsExt = "GeometryBasics" - IJuliaExt = "IJulia" - ImageInTerminalExt = "ImageInTerminal" - UnitfulExt = "Unitful" - - [deps.Plots.weakdeps] - FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" - GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" - IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" - ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.PlutoUI]] -deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"] -git-tree-sha1 = "ab55ee1510ad2af0ff674dbcced5e94921f867a9" -uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" -version = "0.7.59" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.2" - -[[deps.Qt6Base_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] -git-tree-sha1 = "492601870742dcd38f233b23c3ec629628c1d724" -uuid = "c0090381-4147-56d7-9ebc-da0b1113ec56" -version = "6.7.1+1" - -[[deps.Qt6Declarative_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6ShaderTools_jll"] -git-tree-sha1 = "e5dd466bf2569fe08c91a2cc29c1003f4797ac3b" -uuid = "629bc702-f1f5-5709-abd5-49b8460ea067" -version = "6.7.1+2" - -[[deps.Qt6ShaderTools_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll"] -git-tree-sha1 = "1a180aeced866700d4bebc3120ea1451201f16bc" -uuid = "ce943373-25bb-56aa-8eca-768745ed7b5a" -version = "6.7.1+1" - -[[deps.Qt6Wayland_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6Declarative_jll"] -git-tree-sha1 = "729927532d48cf79f49070341e1d918a65aba6b0" -uuid = "e99dba38-086e-5de3-a5b1-6e4c66e897c3" -version = "6.7.1+1" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.RecipesPipeline]] -deps = ["Dates", "NaNMath", "PlotUtils", "PrecompileTools", "RecipesBase"] -git-tree-sha1 = "45cf9fd0ca5839d06ef333c8201714e888486342" -uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" -version = "0.6.12" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.1" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.Showoff]] -deps = ["Dates", "Grisu"] -git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" -uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" -version = "1.0.3" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TSne]] -deps = ["Distances", "LinearAlgebra", "Printf", "ProgressMeter", "Statistics"] -git-tree-sha1 = "6f1dfbf9dad6958439816fa9c5fa20898203fdf4" -uuid = "24678dba-d5e9-5843-a4c6-250288b04835" -version = "1.3.0" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.Tricks]] -git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" -uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" -version = "0.1.8" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnicodeFun]] -deps = ["REPL"] -git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" -uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" -version = "0.4.1" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnitfulLatexify]] -deps = ["LaTeXStrings", "Latexify", "Unitful"] -git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" -uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" -version = "1.6.4" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.Unzip]] -git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" -uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" -version = "0.2.0" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.Vulkan_Loader_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Wayland_jll", "Xorg_libX11_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "2f0486047a07670caad3a81a075d2e518acc5c59" -uuid = "a44049a8-05dd-5a78-86c9-5fde0876e88c" -version = "1.3.243+0" - -[[deps.Wayland_jll]] -deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "7558e29847e99bc3f04d6569e82d0f5c54460703" -uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" -version = "1.21.0+1" - -[[deps.Wayland_protocols_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "93f43ab61b16ddfb2fd3bb13b3ce241cafb0e6c9" -uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" -version = "1.31.0+0" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.XML2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" -uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.1+0" - -[[deps.XSLT_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" -uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.41+0" - -[[deps.XZ_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" -uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" -version = "5.4.6+0" - -[[deps.Xorg_libICE_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" -uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" -version = "1.1.1+0" - -[[deps.Xorg_libSM_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] -git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" -uuid = "c834827a-8449-5923-a945-d239c165b7dd" -version = "1.2.4+0" - -[[deps.Xorg_libX11_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] -git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" -uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" -version = "1.8.6+0" - -[[deps.Xorg_libXau_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" -uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" -version = "1.0.11+0" - -[[deps.Xorg_libXcursor_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXfixes_jll", "Xorg_libXrender_jll"] -git-tree-sha1 = "12e0eb3bc634fa2080c1c37fccf56f7c22989afd" -uuid = "935fb764-8cf2-53bf-bb30-45bb1f8bf724" -version = "1.2.0+4" - -[[deps.Xorg_libXdmcp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" -uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" -version = "1.1.4+0" - -[[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" -uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.6+0" - -[[deps.Xorg_libXfixes_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "0e0dc7431e7a0587559f9294aeec269471c991a4" -uuid = "d091e8ba-531a-589c-9de9-94069b037ed8" -version = "5.0.3+4" - -[[deps.Xorg_libXi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXfixes_jll"] -git-tree-sha1 = "89b52bc2160aadc84d707093930ef0bffa641246" -uuid = "a51aa0fd-4e3c-5386-b890-e753decda492" -version = "1.7.10+4" - -[[deps.Xorg_libXinerama_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll"] -git-tree-sha1 = "26be8b1c342929259317d8b9f7b53bf2bb73b123" -uuid = "d1454406-59df-5ea1-beac-c340f2130bc3" -version = "1.1.4+4" - -[[deps.Xorg_libXrandr_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll"] -git-tree-sha1 = "34cea83cb726fb58f325887bf0612c6b3fb17631" -uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" -version = "1.5.2+4" - -[[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" -uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.11+0" - -[[deps.Xorg_libpthread_stubs_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" -uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" -version = "0.1.1+0" - -[[deps.Xorg_libxcb_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" -uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.17.0+0" - -[[deps.Xorg_libxkbfile_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "730eeca102434283c50ccf7d1ecdadf521a765a4" -uuid = "cc61e674-0454-545c-8b26-ed2c68acab7a" -version = "1.1.2+0" - -[[deps.Xorg_xcb_util_cursor_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_jll", "Xorg_xcb_util_renderutil_jll"] -git-tree-sha1 = "04341cb870f29dcd5e39055f895c39d016e18ccd" -uuid = "e920d4aa-a673-5f3a-b3d7-f755a4d47c43" -version = "0.1.4+0" - -[[deps.Xorg_xcb_util_image_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "0fab0a40349ba1cba2c1da699243396ff8e94b97" -uuid = "12413925-8142-5f55-bb0e-6d7ca50bb09b" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll"] -git-tree-sha1 = "e7fd7b2881fa2eaa72717420894d3938177862d1" -uuid = "2def613f-5ad1-5310-b15b-b15d46f528f5" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_keysyms_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "d1151e2c45a544f32441a567d1690e701ec89b00" -uuid = "975044d2-76e6-5fbe-bf08-97ce7c6574c7" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_renderutil_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "dfd7a8f38d4613b6a575253b3174dd991ca6183e" -uuid = "0d47668e-0667-5a69-a72c-f761630bfb7e" -version = "0.3.9+1" - -[[deps.Xorg_xcb_util_wm_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "e78d10aab01a4a154142c5006ed44fd9e8e31b67" -uuid = "c22f9ab0-d5fe-5066-847c-f4bb1cd4e361" -version = "0.4.1+1" - -[[deps.Xorg_xkbcomp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxkbfile_jll"] -git-tree-sha1 = "330f955bc41bb8f5270a369c473fc4a5a4e4d3cb" -uuid = "35661453-b289-5fab-8a00-3d9160c6a3a4" -version = "1.4.6+0" - -[[deps.Xorg_xkeyboard_config_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xkbcomp_jll"] -git-tree-sha1 = "691634e5453ad362044e2ad653e79f3ee3bb98c3" -uuid = "33bec58e-1273-512f-9401-5d533626f822" -version = "2.39.0+0" - -[[deps.Xorg_xtrans_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" -uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.5.0+0" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zstd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" -uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.eudev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] -git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" -uuid = "35ca27e7-8b34-5b7f-bca9-bdc33f59eb06" -version = "3.2.9+0" - -[[deps.fzf_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" -uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.43.0+0" - -[[deps.gperf_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3516a5630f741c9eecb3720b1ec9d8edc3ecc033" -uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" -version = "3.1.1+0" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" -uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.9.0+0" - -[[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" -uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.libevdev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" -uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" -version = "1.11.0+0" - -[[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" -uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" - -[[deps.libinput_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] -git-tree-sha1 = "ad50e5b90f222cfe78aa3d5183a20a12de1322ce" -uuid = "36db933b-70db-51c0-b978-0f229ee0e533" -version = "1.18.0+0" - -[[deps.libpng_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" -uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" - -[[deps.libvorbis_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" -uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+2" - -[[deps.mtdev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "814e154bdb7be91d78b6802843f76b6ece642f11" -uuid = "009596ad-96f7-51b1-9f1b-5ce2d5e8a71e" -version = "1.1.6+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" - -[[deps.x264_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" -uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" -version = "2021.5.5+0" - -[[deps.x265_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" -uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" -version = "3.5.0+0" - -[[deps.xkbcommon_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] -git-tree-sha1 = "9c304562909ab2bab0262639bd4f444d7bc2be37" -uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" -version = "1.4.1+1" -""" - -# ╔═╡ Cell order: -# ╟─ca2f0293-7eac-4d9a-9a2f-fda47fd95a99 -# ╟─4455f18c-2bd9-42ed-bce3-cfe6561eab23 -# ╠═5463330a-0161-11ed-1b18-936030a32bbf -# ╟─0d556a7c-d4b6-4cef-806c-3e1712de0791 -# ╠═997b5387-3811-4998-a9d1-7981b58b9e09 -# ╟─4b6fa18d-7ccd-4c07-8dc3-ded4d7da8562 -# ╠═edab1e3a-31f6-471f-9835-5b1f97e5cf3f -# ╟─d73a2db5-9417-4b2c-a9f5-b7d499a53fcb -# ╠═32bb90c1-c802-4c0c-a620-5d3b8f3f2477 -# ╟─3438ee7f-bfca-465d-85df-13379622d415 -# ╠═eec6fb60-0774-4f2a-bcb7-dbc28ab747a6 -# ╟─bd2fd04d-7fb0-4b31-959b-bddabe681754 -# ╠═b29c3a02-c21b-4b10-aa04-b90bcc2931d8 -# ╠═16d9fbad-d4dc-4b51-9576-1736d228e2b3 -# ╟─923d061c-25c3-4826-8147-9afa3dbd5bac -# ╠═28e00b95-56db-4d36-a205-fd24d3c54e17 -# ╟─fa743000-604f-4d28-99f1-46ab2f884b8e -# ╠═f972f61b-2001-409b-9190-ac2c0652829a -# ╟─4dade64a-e28e-42c7-8ad5-93fc04724d4d -# ╠═05979cfe-439c-4abc-90cd-6ca2a05f6e0f -# ╠═a3f420e1-7521-4df9-b6d5-fc0a1fd05095 -# ╠═026911dd-6a27-49ce-9d41-21e01646c10a -# ╟─b18384fe-b8ae-4f51-bd73-d129d5e70f98 -# ╠═b2302697-1e20-4721-ae93-0b121ff9ce8f -# ╟─54a2972e-b107-47c8-bf7e-eb51b4ccbe02 -# ╟─623e7b53-046c-4858-89d9-13caae45255d -# ╠═eb36a46c-f139-425e-8a93-207bc4a16f89 -# ╟─20b5f802-abce-49e1-a442-f381e80c0f85 -# ╠═b295adce-b37e-45f3-963a-3699d714e36d -# ╟─5538970f-b273-4122-9d50-7deb049e6934 -# ╠═901d9478-9a12-4122-905d-6cfc6d80e84c -# ╠═65d9fd3d-1649-4b95-a106-f26fa4ab9bce -# ╠═20be52b1-1c33-4f54-b5c0-fecc4e24fbb5 -# ╟─5aa99aff-b5ed-40ec-a7ec-0ba53385e6bd -# ╠═2163d0d8-0661-4d11-a09e-708769011d35 -# ╟─6cd49f3f-a415-4b6a-9323-4d6aa6b87f18 -# ╠═7a93a802-6774-42f9-b6da-7ae614464e72 -# ╟─50a409fd-d80b-4c48-a51b-173c39a6dcb4 -# ╟─c343419f-a1d7-45a0-b600-2c868588b33a -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002 - -[.\docs\tutorials_broken\temporal_graph_classification_pluto.jl] -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Aurora Rossi](https://github.com/aurorarossi)" -#> title = "Temporal Graph classification with Graph Neural Networks" -#> date = "2024-03-06" -#> description = "Temporal Graph classification with GraphNeuralNetworks.jl" -#> cover = "assets/brain_gnn.gif" - -using Markdown -using InteractiveUtils - -# ╔═╡ b8df1800-c69d-4e18-8a0a-097381b62a4c -begin - using Flux - using GraphNeuralNetworks - using Statistics, Random - using LinearAlgebra - using MLDatasets: TemporalBrains - using CUDA - using cuDNN -end - -# ╔═╡ 69d00ec8-da47-11ee-1bba-13a14e8a6db2 -md"In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying. - -We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU. -" - -# ╔═╡ ef8406e4-117a-4cc6-9fa5-5028695b1a4f -md" -## Import - -We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. -" - -# ╔═╡ 2544d468-1430-4986-88a9-be4df2a7cf27 -md" -## Dataset: TemporalBrains -The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation). -Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions. - -The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time. -For each snapshot, the feature of a node represents the average activation of the node during that snapshot. -Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+). -The network's edge weights are binarized, and the threshold is set to 0.6 by default. -" - -# ╔═╡ f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 -brain_dataset = TemporalBrains() - -# ╔═╡ d9e4722d-6f02-4d41-955c-8bb3e411e404 -md"After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format. -So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. -" - -# ╔═╡ bb36237a-5545-47d0-a873-7ddff3efe8ba -function data_loader(brain_dataset) - graphs = brain_dataset.graphs - dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) - for i in 1:length(graphs) - graph = graphs[i] - dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots)) - # Add graph and node features - for t in 1:27 - s = dataset[i].snapshots[t] - s.ndata.x = [I(102); s.ndata.x'] - end - dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"])) - end - # Split the dataset into a 80% training set and a 20% test set - train_loader = dataset[1:200] - test_loader = dataset[201:250] - return train_loader, test_loader -end; - -# ╔═╡ d4732340-9179-4ada-b82e-a04291d745c2 -md" -The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`. - -The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender. - -The last part splits the dataset. -" - - -# ╔═╡ ec088a59-2fc2-426a-a406-f8f8d6784128 -md" -## Model - -We now implement a simple model that takes a `TemporalSnapshotsGNNGraph` as input. -It consists of a `GINConv` applied independently to each snapshot, a `GlobalPool` to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a `Dense` layer. - -First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`. -" - -# ╔═╡ 5ea98df9-4920-4c94-9472-3ef475af89fd -function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) - h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], length(h)) -end - -# ╔═╡ cfda2cf4-d08b-4f46-bd39-02ae3ed53369 -md" -Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass. -" - -# ╔═╡ 2eedd408-67ee-47b2-be6f-2caec94e95b5 -begin - struct GenderPredictionModel - gin::GINConv - mlp::Chain - globalpool::GlobalPool - f::Function - dense::Dense - end - - Flux.@layer GenderPredictionModel - - function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) - mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) - gin = GINConv(mlp, 0.5) - globalpool = GlobalPool(mean) - f = x -> mean(x, dims = 2) - dense = Dense(nhidden, 2) - GenderPredictionModel(gin, mlp, globalpool, f, dense) - end - - function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) - h = m.gin(g, g.ndata.x) - h = m.globalpool(g, h) - h = m.f(h) - m.dense(h) - end - -end - -# ╔═╡ 76780020-406d-4803-9af0-d928e54fc18c -md" -## Training - -We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format. -The accuracy expresses the number of correct classifications. -" - -# ╔═╡ 0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 -lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y); - -# ╔═╡ cc2ebdcf-72de-4a3b-af46-5bddab6689cc -function eval_loss_accuracy(model, data_loader) - error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader]) - acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader]) - return (loss = error, acc = acc) -end; - -# ╔═╡ d64be72e-8c1f-4551-b4f2-28c8b78466c0 -function train(dataset; usecuda::Bool, kws...) - - if usecuda && CUDA.functional() #check if GPU is available - my_device = gpu - @info "Training on GPU" - else - my_device = cpu - @info "Training on CPU" - end - - function report(epoch) - train_loss, train_acc = eval_loss_accuracy(model, train_loader) - test_loss, test_acc = eval_loss_accuracy(model, test_loader) - println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") - return (train_loss, train_acc, test_loss, test_acc) - end - - model = GenderPredictionModel() |> my_device - - opt = Flux.setup(Adam(1.0f-3), model) - - train_loader, test_loader = data_loader(dataset) - train_loader = train_loader |> my_device - test_loader = test_loader |> my_device - - report(0) - for epoch in 1:100 - for g in train_loader - grads = Flux.gradient(model) do model - ŷ = model(g) - lossfunction(vec(ŷ), g.tgdata.g) - end - Flux.update!(opt, model, grads[1]) - end - if epoch % 10 == 0 - report(epoch) - end - end - return model -end; - - -# ╔═╡ 483f17ba-871c-4769-88bd-8ec781d1909d -train(brain_dataset; usecuda = true) - -# ╔═╡ b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 -md" -We set up the training on the GPU because training takes a lot of time, especially when working on the CPU. -" - -# ╔═╡ cb4eed19-2658-411d-886c-e0c9c2b44219 -md" -## Conclusions - -In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data. -" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[compat] -CUDA = "~5.4.3" -Flux = "~0.14.16" -GraphNeuralNetworks = "~0.6.19" -MLDatasets = "~0.7.16" -cuDNN = "~1.3.2" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "25724970092e282d6cd2d6ea9e021d61f3714205" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.4.3" - - [deps.CUDA.extensions] - ChainRulesCoreExt = "ChainRulesCore" - EnzymeCoreExt = "EnzymeCore" - SpecialFunctionsExt = "SpecialFunctions" - - [deps.CUDA.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - -[[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "97df9d4d6be8ac6270cb8fd3b8fc413690820cbd" -uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.9.1+1" - -[[deps.CUDA_Runtime_Discovery]] -deps = ["Libdl"] -git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030" -uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.3.4" - -[[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" -uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.14.1+0" - -[[deps.CUDNN_jll]] -deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" -uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" -version = "9.0.0+1" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.7" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JuliaNVTXCallbacks_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" -uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" -version = "0.2.1+0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LLVMLoopInfo]] -git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" -uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" -version = "1.0.0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NVTX]] -deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] -git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" -uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" -version = "0.3.4" - -[[deps.NVTX_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" -uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" -version = "3.1.0+2" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.7.0" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" -weakdeps = ["CUDA"] - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.24" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.cuDNN]] -deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" -uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.3.2" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" -""" - -# ╔═╡ Cell order: -# ╟─69d00ec8-da47-11ee-1bba-13a14e8a6db2 -# ╟─ef8406e4-117a-4cc6-9fa5-5028695b1a4f -# ╠═b8df1800-c69d-4e18-8a0a-097381b62a4c -# ╟─2544d468-1430-4986-88a9-be4df2a7cf27 -# ╠═f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 -# ╟─d9e4722d-6f02-4d41-955c-8bb3e411e404 -# ╠═bb36237a-5545-47d0-a873-7ddff3efe8ba -# ╟─d4732340-9179-4ada-b82e-a04291d745c2 -# ╟─ec088a59-2fc2-426a-a406-f8f8d6784128 -# ╠═5ea98df9-4920-4c94-9472-3ef475af89fd -# ╟─cfda2cf4-d08b-4f46-bd39-02ae3ed53369 -# ╠═2eedd408-67ee-47b2-be6f-2caec94e95b5 -# ╟─76780020-406d-4803-9af0-d928e54fc18c -# ╠═0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 -# ╠═cc2ebdcf-72de-4a3b-af46-5bddab6689cc -# ╠═d64be72e-8c1f-4551-b4f2-28c8b78466c0 -# ╠═483f17ba-871c-4769-88bd-8ec781d1909d -# ╟─b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 -# ╟─cb4eed19-2658-411d-886c-e0c9c2b44219 -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002 - -[.\docs\tutorials_broken\traffic_prediction.jl] -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Aurora Rossi](https://github.com/aurorarossi)" -#> title = "Traffic Prediction using recurrent Temporal Graph Convolutional Network" -#> date = "2023-08-21" -#> description = "Traffic Prediction using GraphNeuralNetworks.jl" -#> cover = "assets/traffic.gif" - -using Markdown -using InteractiveUtils - -# ╔═╡ 1f95ad97-a007-4724-84db-392b0026e1a4 -begin - using GraphNeuralNetworks - using Flux - using Flux.Losses: mae - using MLDatasets: METRLA - using Statistics - using Plots -end - -# ╔═╡ 5fdab668-4003-11ee-33f5-3953225b0c0f -md" -In this tutorial, we will learn how to use a recurrent Temporal Graph Convolutional Network (TGCN) to predict traffic in a spatio-temporal setting. Traffic forecasting is the problem of predicting future traffic trends on a road network given historical traffic data, such as, in our case, traffic speed and time of day. -" - -# ╔═╡ 3dd0ce32-2339-4d5a-9a6f-1f662bc5500b -md" -## Import - -We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. -" - -# ╔═╡ ec5caeb6-1f95-4cb9-8739-8cadba29a22d -md" -## Dataset: METR-LA - -We use the `METR-LA` dataset from the paper [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926.pdf), which contains traffic data from loop detectors in the highway of Los Angeles County. The dataset contains traffic speed data from March 1, 2012 to June 30, 2012. The data is collected every 5 minutes, resulting in 12 observations per hour, from 207 sensors. Each sensor is a node in the graph, and the edges represent the distances between the sensors. -" - -# ╔═╡ f531e39c-6842-494a-b4ac-8904321098c9 -dataset_metrla = METRLA(; num_timesteps = 3) - -# ╔═╡ d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad - g = dataset_metrla[1] - -# ╔═╡ dc2d5e98-2201-4754-bfc6-8ed2bbb82153 -md" -`edge_data` contains the weights of the edges of the graph and -`node_data` contains a node feature vector and a target vector. The latter vectors contain batches of dimension `num_timesteps`, which means that they contain vectors with the node features and targets of `num_timesteps` time steps. Two consecutive batches are shifted by one-time step. -The node features are the traffic speed of the sensors and the time of the day, and the targets are the traffic speed of the sensors in the next time step. -Let's see some examples: -" - -# ╔═╡ 0dde5fd3-72d0-4b15-afb3-9a5b102327c9 -size(g.node_data.features[1]) - -# ╔═╡ f7a6d572-28cf-4d69-a9be-d49f367eca37 -md" -The first dimension correspond to the two features (first line the speed value and the second line the time of the day), the second to the nodes and the third to the number of timestep `num_timesteps`. -" - -# ╔═╡ 3d5503bc-bb97-422e-9465-becc7d3dbe07 -size(g.node_data.targets[1]) - -# ╔═╡ 3569715d-08f5-4605-b946-9ef7ccd86ae5 -md" -In the case of the targets the first dimension is 1 because they store just the speed value. -" - -# ╔═╡ aa4eb172-2a42-4c01-a6ef-c6c95208d5b2 -g.node_data.features[1][:,1,:] - -# ╔═╡ 367ed417-4f53-44d4-8135-0c91c842a75f -g.node_data.features[2][:,1,:] - -# ╔═╡ 7c084eaa-655c-4251-a342-6b6f4df76ddb -g.node_data.targets[1][:,1,:] - -# ╔═╡ bf0d820d-32c0-4731-8053-53d5d499e009 -function plot_data(data,sensor) - p = plot(legend=false, xlabel="Time (h)", ylabel="Normalized speed") - plotdata = [] - for i in 1:3:length(data) - push!(plotdata,data[i][1,sensor,:]) - end - plotdata = reduce(vcat,plotdata) - plot!(p, collect(1:length(data)), plotdata, color = :green, xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24])) - return p -end - -# ╔═╡ cb89d1a3-b4ff-421a-8717-a0b7f21dea1a -plot_data(g.node_data.features[1:288],1) - -# ╔═╡ 3b49a612-3a04-4eb5-bfbc-360614f4581a -md" -Now let's construct the static graph, the temporal features and targets from the dataset. -" - -# ╔═╡ 95d8bd24-a40d-409f-a1e7-4174428ef860 -begin - graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes) - features = g.node_data.features - targets = g.node_data.targets -end; - -# ╔═╡ fde2ac9e-b121-4105-8428-1820b9c17a43 -md" -Now let's construct the `train_loader` and `data_loader`. -" - - -# ╔═╡ 111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3 -begin - train_loader = zip(features[1:200], targets[1:200]) - test_loader = zip(features[2001:2288], targets[2001:2288]) -end; - -# ╔═╡ 572a6633-875b-4d7e-9afc-543b442948fb -md" -## Model: T-GCN - -We use the T-GCN model from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction] (https://arxiv.org/pdf/1811.05320.pdf), which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). The GCN is used to capture spatial features from the graph, and the GRU is used to capture temporal features from the feature time series. -" - -# ╔═╡ 5502f4fa-3201-4980-b766-2ab88b175b11 -model = GNNChain(TGCN(2 => 100), Dense(100, 1)) - -# ╔═╡ 4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e -md" -![](https://www.researchgate.net/profile/Haifeng-Li-3/publication/335353434/figure/fig4/AS:851870352437249@1580113127759/The-architecture-of-the-Gated-Recurrent-Unit-model.jpg) -" - -# ╔═╡ 755a88c2-c2e5-46d1-9582-af4b2c5a6bbd -md" -## Training - -We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the mean absolute error (MAE) as the loss function. -" - -# ╔═╡ e83253b2-9f3a-44e2-a747-cce1661657c4 -function train(graph, train_loader, model) - - opt = Flux.setup(Adam(0.001), model) - - for epoch in 1:100 - for (x, y) in train_loader - x, y = (x, y) - grads = Flux.gradient(model) do model - ŷ = model(graph, x) - Flux.mae(ŷ, y) - end - Flux.update!(opt, model, grads[1]) - end - - if epoch % 10 == 0 - loss = mean([Flux.mae(model(graph,x), y) for (x, y) in train_loader]) - @show epoch, loss - end - end - return model -end - -# ╔═╡ 85a923da-3027-4f71-8db6-96852c115c03 -train(graph, train_loader, model) - -# ╔═╡ 39c82234-97ea-48d6-98dd-915f072b7f85 -function plot_predicted_data(graph,features,targets, sensor) - p = plot(xlabel="Time (h)", ylabel="Normalized speed") - prediction = [] - grand_truth = [] - for i in 1:3:length(features) - push!(grand_truth,targets[i][1,sensor,:]) - push!(prediction, model(graph, features[i])[1,sensor,:]) - end - prediction = reduce(vcat,prediction) - grand_truth = reduce(vcat, grand_truth) - plot!(p, collect(1:length(features)), grand_truth, color = :blue, label = "Grand Truth", xticks =([i for i in 0:50:250], ["$(i)" for i in 0:4:24])) - plot!(p, collect(1:length(features)), prediction, color = :red, label= "Prediction") - return p -end - -# ╔═╡ 8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4 -plot_predicted_data(graph,features[301:588],targets[301:588], 1) - -# ╔═╡ 2c5f6250-ee7a-41b1-9551-bcfeba83ca8b -accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y) - -# ╔═╡ 1008dad4-d784-4c38-a7cf-d9b64728e28d -mean([accuracy(model(graph,x), y) for (x, y) in test_loader]) - -# ╔═╡ 8d0e8b9f-226f-4bff-9deb-046e6a897b71 -md"The accuracy is not very good but can be improved by training using more data. We used a small subset of the dataset for this tutorial because of the computational cost of training the model. From the plot of the predictions, we can see that the model is able to capture the general trend of the traffic speed, but it is not able to capture the peaks of the traffic." - -# ╔═╡ a7e4bb23-6687-476a-a0c2-1b2736873d9d -md" -## Conclusion - -In this tutorial, we learned how to use a recurrent temporal graph convolutional network to predict traffic in a spatio-temporal setting. We used the TGCN model, which consists of a graph convolutional network (GCN) and a gated recurrent unit (GRU). We then trained the model for 100 epochs on a small subset of the METR-LA dataset. The accuracy of the model is not very good, but it can be improved by training on more data. -" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[compat] -Flux = "~0.14.16" -GraphNeuralNetworks = "~0.6.19" -MLDatasets = "~0.7.16" -Plots = "~1.40.5" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "8742c1fb8ae152ad31b34471cf90f234c1b8b06c" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.Cairo_jll]] -deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" -uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Contour]] -git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.EpollShim_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" -uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" -version = "0.0.20230411+0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.2+0" - -[[deps.FFMPEG]] -deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" -uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" - -[[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" -uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "4.4.4+1" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] -git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" -uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.96+0" - -[[deps.Format]] -git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" -uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -version = "1.3.7" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FreeType2_jll]] -deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" -uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.2+0" - -[[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" -uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.14+0" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" -uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.4.0+0" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GR]] -deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" -uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" - -[[deps.GR_jll]] -deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" -uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.Gettext_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" -uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" -version = "0.21.0+0" - -[[deps.Glib_jll]] -deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" -uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphite2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" -uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" -version = "1.3.14+0" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.Grisu]] -git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" -uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" -version = "1.0.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.2+1" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" -uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLFzf]] -deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" -uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.7" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" -uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" -uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.2+0" - -[[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" -uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LLVMOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" -uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.7+0" - -[[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" -uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.Latexify]] -deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] -git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" -uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.16.4" - - [deps.Latexify.extensions] - DataFramesExt = "DataFrames" - SymEngineExt = "SymEngine" - - [deps.Latexify.weakdeps] - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libffi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" -uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+1" - -[[deps.Libgcrypt_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] -git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" -uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.11+0" - -[[deps.Libglvnd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] -git-tree-sha1 = "6f73d1dd803986947b2c750138528a999a6c7733" -uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" -version = "1.6.0+0" - -[[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" -uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.49.0+0" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.Libmount_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" -uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.1+0" - -[[deps.Libtiff_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" -uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" - -[[deps.Libuuid_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" -uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.1+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.Measures]] -git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" -uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" -version = "0.3.2" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.Ogg_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" -uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" -version = "1.3.5+1" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] -git-tree-sha1 = "2f0a1d8c79bc385ec3fcda12830c9d0e72b30e71" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "5.0.4+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" -uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pipe]] -git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" -uuid = "b98c9c47-44ae-5843-9183-064241ee97a0" -version = "1.3.0" - -[[deps.Pixman_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" -uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.43.4+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PlotThemes]] -deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" -uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.2.0" - -[[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" -uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" - -[[deps.Plots]] -deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] -git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" -uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.40.5" - - [deps.Plots.extensions] - FileIOExt = "FileIO" - GeometryBasicsExt = "GeometryBasics" - IJuliaExt = "IJulia" - ImageInTerminalExt = "ImageInTerminal" - UnitfulExt = "Unitful" - - [deps.Plots.weakdeps] - FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" - GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" - IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" - ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.Qt6Base_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] -git-tree-sha1 = "492601870742dcd38f233b23c3ec629628c1d724" -uuid = "c0090381-4147-56d7-9ebc-da0b1113ec56" -version = "6.7.1+1" - -[[deps.Qt6Declarative_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6ShaderTools_jll"] -git-tree-sha1 = "e5dd466bf2569fe08c91a2cc29c1003f4797ac3b" -uuid = "629bc702-f1f5-5709-abd5-49b8460ea067" -version = "6.7.1+2" - -[[deps.Qt6ShaderTools_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll"] -git-tree-sha1 = "1a180aeced866700d4bebc3120ea1451201f16bc" -uuid = "ce943373-25bb-56aa-8eca-768745ed7b5a" -version = "6.7.1+1" - -[[deps.Qt6Wayland_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Qt6Base_jll", "Qt6Declarative_jll"] -git-tree-sha1 = "729927532d48cf79f49070341e1d918a65aba6b0" -uuid = "e99dba38-086e-5de3-a5b1-6e4c66e897c3" -version = "6.7.1+1" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.RecipesPipeline]] -deps = ["Dates", "NaNMath", "PlotUtils", "PrecompileTools", "RecipesBase"] -git-tree-sha1 = "45cf9fd0ca5839d06ef333c8201714e888486342" -uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" -version = "0.6.12" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.1" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.Showoff]] -deps = ["Dates", "Grisu"] -git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" -uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" -version = "1.0.3" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - - [deps.StridedViews.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnicodeFun]] -deps = ["REPL"] -git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" -uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" -version = "0.4.1" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnitfulLatexify]] -deps = ["LaTeXStrings", "Latexify", "Unitful"] -git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" -uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" -version = "1.6.4" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.Unzip]] -git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" -uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" -version = "0.2.0" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.Vulkan_Loader_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Wayland_jll", "Xorg_libX11_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "2f0486047a07670caad3a81a075d2e518acc5c59" -uuid = "a44049a8-05dd-5a78-86c9-5fde0876e88c" -version = "1.3.243+0" - -[[deps.Wayland_jll]] -deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "7558e29847e99bc3f04d6569e82d0f5c54460703" -uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89" -version = "1.21.0+1" - -[[deps.Wayland_protocols_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "93f43ab61b16ddfb2fd3bb13b3ce241cafb0e6c9" -uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" -version = "1.31.0+0" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.XML2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" -uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.1+0" - -[[deps.XSLT_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" -uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.41+0" - -[[deps.XZ_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" -uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" -version = "5.4.6+0" - -[[deps.Xorg_libICE_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" -uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" -version = "1.1.1+0" - -[[deps.Xorg_libSM_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] -git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" -uuid = "c834827a-8449-5923-a945-d239c165b7dd" -version = "1.2.4+0" - -[[deps.Xorg_libX11_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] -git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" -uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" -version = "1.8.6+0" - -[[deps.Xorg_libXau_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" -uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" -version = "1.0.11+0" - -[[deps.Xorg_libXcursor_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXfixes_jll", "Xorg_libXrender_jll"] -git-tree-sha1 = "12e0eb3bc634fa2080c1c37fccf56f7c22989afd" -uuid = "935fb764-8cf2-53bf-bb30-45bb1f8bf724" -version = "1.2.0+4" - -[[deps.Xorg_libXdmcp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" -uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" -version = "1.1.4+0" - -[[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" -uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.6+0" - -[[deps.Xorg_libXfixes_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "0e0dc7431e7a0587559f9294aeec269471c991a4" -uuid = "d091e8ba-531a-589c-9de9-94069b037ed8" -version = "5.0.3+4" - -[[deps.Xorg_libXi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXfixes_jll"] -git-tree-sha1 = "89b52bc2160aadc84d707093930ef0bffa641246" -uuid = "a51aa0fd-4e3c-5386-b890-e753decda492" -version = "1.7.10+4" - -[[deps.Xorg_libXinerama_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll"] -git-tree-sha1 = "26be8b1c342929259317d8b9f7b53bf2bb73b123" -uuid = "d1454406-59df-5ea1-beac-c340f2130bc3" -version = "1.1.4+4" - -[[deps.Xorg_libXrandr_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll"] -git-tree-sha1 = "34cea83cb726fb58f325887bf0612c6b3fb17631" -uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" -version = "1.5.2+4" - -[[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" -uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.11+0" - -[[deps.Xorg_libpthread_stubs_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" -uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" -version = "0.1.1+0" - -[[deps.Xorg_libxcb_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" -uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.17.0+0" - -[[deps.Xorg_libxkbfile_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] -git-tree-sha1 = "730eeca102434283c50ccf7d1ecdadf521a765a4" -uuid = "cc61e674-0454-545c-8b26-ed2c68acab7a" -version = "1.1.2+0" - -[[deps.Xorg_xcb_util_cursor_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_jll", "Xorg_xcb_util_renderutil_jll"] -git-tree-sha1 = "04341cb870f29dcd5e39055f895c39d016e18ccd" -uuid = "e920d4aa-a673-5f3a-b3d7-f755a4d47c43" -version = "0.1.4+0" - -[[deps.Xorg_xcb_util_image_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "0fab0a40349ba1cba2c1da699243396ff8e94b97" -uuid = "12413925-8142-5f55-bb0e-6d7ca50bb09b" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll"] -git-tree-sha1 = "e7fd7b2881fa2eaa72717420894d3938177862d1" -uuid = "2def613f-5ad1-5310-b15b-b15d46f528f5" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_keysyms_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "d1151e2c45a544f32441a567d1690e701ec89b00" -uuid = "975044d2-76e6-5fbe-bf08-97ce7c6574c7" -version = "0.4.0+1" - -[[deps.Xorg_xcb_util_renderutil_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "dfd7a8f38d4613b6a575253b3174dd991ca6183e" -uuid = "0d47668e-0667-5a69-a72c-f761630bfb7e" -version = "0.3.9+1" - -[[deps.Xorg_xcb_util_wm_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_xcb_util_jll"] -git-tree-sha1 = "e78d10aab01a4a154142c5006ed44fd9e8e31b67" -uuid = "c22f9ab0-d5fe-5066-847c-f4bb1cd4e361" -version = "0.4.1+1" - -[[deps.Xorg_xkbcomp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxkbfile_jll"] -git-tree-sha1 = "330f955bc41bb8f5270a369c473fc4a5a4e4d3cb" -uuid = "35661453-b289-5fab-8a00-3d9160c6a3a4" -version = "1.4.6+0" - -[[deps.Xorg_xkeyboard_config_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_xkbcomp_jll"] -git-tree-sha1 = "691634e5453ad362044e2ad653e79f3ee3bb98c3" -uuid = "33bec58e-1273-512f-9401-5d533626f822" -version = "2.39.0+0" - -[[deps.Xorg_xtrans_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" -uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.5.0+0" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zstd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" -uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.eudev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] -git-tree-sha1 = "431b678a28ebb559d224c0b6b6d01afce87c51ba" -uuid = "35ca27e7-8b34-5b7f-bca9-bdc33f59eb06" -version = "3.2.9+0" - -[[deps.fzf_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" -uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.43.0+0" - -[[deps.gperf_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3516a5630f741c9eecb3720b1ec9d8edc3ecc033" -uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" -version = "3.1.1+0" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" -uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.9.0+0" - -[[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" -uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.libevdev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" -uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" -version = "1.11.0+0" - -[[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" -uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" - -[[deps.libinput_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] -git-tree-sha1 = "ad50e5b90f222cfe78aa3d5183a20a12de1322ce" -uuid = "36db933b-70db-51c0-b978-0f229ee0e533" -version = "1.18.0+0" - -[[deps.libpng_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" -uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" - -[[deps.libvorbis_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" -uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+2" - -[[deps.mtdev_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "814e154bdb7be91d78b6802843f76b6ece642f11" -uuid = "009596ad-96f7-51b1-9f1b-5ce2d5e8a71e" -version = "1.1.6+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" - -[[deps.x264_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" -uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" -version = "2021.5.5+0" - -[[deps.x265_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" -uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" -version = "3.5.0+0" - -[[deps.xkbcommon_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] -git-tree-sha1 = "9c304562909ab2bab0262639bd4f444d7bc2be37" -uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" -version = "1.4.1+1" -""" - -# ╔═╡ Cell order: -# ╟─5fdab668-4003-11ee-33f5-3953225b0c0f -# ╟─3dd0ce32-2339-4d5a-9a6f-1f662bc5500b -# ╠═1f95ad97-a007-4724-84db-392b0026e1a4 -# ╟─ec5caeb6-1f95-4cb9-8739-8cadba29a22d -# ╠═f531e39c-6842-494a-b4ac-8904321098c9 -# ╠═d5ebf9aa-cec8-4417-baaf-f2e8e19f1cad -# ╟─dc2d5e98-2201-4754-bfc6-8ed2bbb82153 -# ╠═0dde5fd3-72d0-4b15-afb3-9a5b102327c9 -# ╟─f7a6d572-28cf-4d69-a9be-d49f367eca37 -# ╠═3d5503bc-bb97-422e-9465-becc7d3dbe07 -# ╟─3569715d-08f5-4605-b946-9ef7ccd86ae5 -# ╠═aa4eb172-2a42-4c01-a6ef-c6c95208d5b2 -# ╠═367ed417-4f53-44d4-8135-0c91c842a75f -# ╠═7c084eaa-655c-4251-a342-6b6f4df76ddb -# ╠═bf0d820d-32c0-4731-8053-53d5d499e009 -# ╠═cb89d1a3-b4ff-421a-8717-a0b7f21dea1a -# ╟─3b49a612-3a04-4eb5-bfbc-360614f4581a -# ╠═95d8bd24-a40d-409f-a1e7-4174428ef860 -# ╟─fde2ac9e-b121-4105-8428-1820b9c17a43 -# ╠═111b7d5d-c7e3-44c0-9e5e-2ed1a86854d3 -# ╟─572a6633-875b-4d7e-9afc-543b442948fb -# ╠═5502f4fa-3201-4980-b766-2ab88b175b11 -# ╟─4a1ec34a-1092-4b4a-b8a8-bd91939ffd9e -# ╟─755a88c2-c2e5-46d1-9582-af4b2c5a6bbd -# ╠═e83253b2-9f3a-44e2-a747-cce1661657c4 -# ╠═85a923da-3027-4f71-8db6-96852c115c03 -# ╠═39c82234-97ea-48d6-98dd-915f072b7f85 -# ╠═8c3a903b-2c8a-4d4f-8eef-74d5611f2ce4 -# ╠═2c5f6250-ee7a-41b1-9551-bcfeba83ca8b -# ╠═1008dad4-d784-4c38-a7cf-d9b64728e28d -# ╟─8d0e8b9f-226f-4bff-9deb-046e6a897b71 -# ╟─a7e4bb23-6687-476a-a0c2-1b2736873d9d -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002 - -[.\examples\graph_classification_temporalbrains.jl] -# Example of graph classification when graphs are temporal and modeled as `TemporalSnapshotsGNNGraphs'. -# In this code, we train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. -# The dataset used is the TemporalBrains dataset from the MLDataset.jl package, and the accuracy achieved with the model reaches 65-70% (it can be improved by fine-tuning the parameters of the model). -# Author: Aurora Rossi - -# Load packages -using Flux -using Flux.Losses: mae -using GraphNeuralNetworks -using CUDA -using Statistics, Random -using LinearAlgebra -using MLDatasets -CUDA.allowscalar(false) - -# Load data -MLdataset = TemporalBrains() -graphs = MLdataset.graphs - -# Function to transform the graphs from the MLDatasets format to the TemporalSnapshotsGNNGraph format -# and split the dataset into a training and a test set -function data_loader(graphs) - dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) - for i in 1:length(graphs) - gr = graphs[i] - dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(gr.snapshots)) - for t in 1:27 - dataset[i].snapshots[t].ndata.x = reduce( - vcat, [I(102), dataset[i].snapshots[t].ndata.x']) - end - dataset[i].tgdata.g = Float32.(Array(Flux.onehot(gr.graph_data.g, ["F", "M"]))) - end - # Split the dataset into a 80% training set and a 20% test set - train_loader = dataset[1:800] - test_loader = dataset[801:1000] - return train_loader, test_loader -end - -# Arguments for the train function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - epochs = 200 # number of epochs - seed = -5 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 128 # dimension of hidden features - infotime = 10 # report every `infotime` epochs -end - -# Adapt GlobalPool to work with TemporalSnapshotsGNNGraph -function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) - h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], length(h)) -end - -# Define the model -struct GenderPredictionModel - gin::GINConv - mlp::Chain - globalpool::GlobalPool - f::Function - dense::Dense -end - -Flux.@layer GenderPredictionModel - -function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) - mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) - gin = GINConv(mlp, 0.5) - globalpool = GlobalPool(mean) - f = x -> mean(x, dims = 2) - dense = Dense(nhidden, 2) - GenderPredictionModel(gin, mlp, globalpool, f, dense) -end - -function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) - h = m.gin(g, g.ndata.x) - h = m.globalpool(g, h) - h = m.f(h) - m.dense(h) -end - -# Train the model - -function train(graphs; kws...) - args = Args(; kws...) - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - my_device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - my_device = cpu - @info "Training on CPU" - end - - lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y) |> my_device - - function eval_loss_accuracy(model, data_loader) - error = mean([lossfunction(model(g), gpu(g.tgdata.g)) for g in data_loader]) - acc = mean([round( - 100 * - mean(Flux.onecold(model(g)) .== Flux.onecold(gpu(g.tgdata.g))); - digits = 2) for g in data_loader]) - return (loss = error, acc = acc) - end - - function report(epoch) - train_loss, train_acc = eval_loss_accuracy(model, train_loader) - test_loss, test_acc = eval_loss_accuracy(model, test_loader) - println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") - return (train_loss, train_acc, test_loss, test_acc) - end - - model = GenderPredictionModel() |> my_device - - opt = Flux.setup(Adam(args.η), model) - - train_loader, test_loader = data_loader(graphs) # it takes a while to load the data - - train_loader = train_loader |> my_device - test_loader = test_loader |> my_device - - report(0) - for epoch in 1:(args.epochs) - for g in train_loader - grads = Flux.gradient(model) do model - ŷ = model(g) - lossfunction(vec(ŷ), g.tgdata.g) - end - Flux.update!(opt, model, grads[1]) - end - if args.infotime > 0 && epoch % args.infotime == 0 - report(epoch) - end - end - return model -end - -model = train(graphs) -[.\examples\graph_classification_tudataset.jl] -# An example of graph classification - -using Flux -using Flux: onecold, onehotbatch -using Flux.Losses: logitbinarycrossentropy -using Flux: DataLoader -using GraphNeuralNetworks -using MLDatasets: TUDataset -using Statistics, Random -using MLUtils -using CUDA -CUDA.allowscalar(false) - -function eval_loss_accuracy(model, data_loader, device) - loss = 0.0 - acc = 0.0 - ntot = 0 - for (g, y) in data_loader - g, y = (g, y) |> device - n = length(y) - ŷ = model(g, g.ndata.x) |> vec - loss += logitbinarycrossentropy(ŷ, y) * n - acc += mean((ŷ .> 0) .== y) * n - ntot += n - end - return (loss = round(loss / ntot, digits = 4), - acc = round(acc * 100 / ntot, digits = 2)) -end - -function getdataset() - tudata = TUDataset("MUTAG") - display(tudata) - graphs = mldataset2gnngraph(tudata) - oh(x) = Float32.(onehotbatch(x, 0:6)) - graphs = [GNNGraph(g, ndata = oh(g.ndata.targets)) for g in graphs] - y = (1 .+ Float32.(tudata.graph_data.targets)) ./ 2 - @assert all(∈([0, 1]), y) # binary classification - return graphs, y -end - -# arguments for the `train` function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - batchsize = 32 # batch size (number of graphs in each batch) - epochs = 200 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 128 # dimension of hidden features - infotime = 10 # report every `infotime` epochs -end - -function train(; kws...) - args = Args(; kws...) - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end - - # LOAD DATA - NUM_TRAIN = 150 - - dataset = getdataset() - train_data, test_data = splitobs(dataset, at = NUM_TRAIN, shuffle = true) - - train_loader = DataLoader(train_data; args.batchsize, shuffle = true, collate = true) - test_loader = DataLoader(test_data; args.batchsize, shuffle = false, collate = true) - - # DEFINE MODEL - - nin = size(dataset[1][1].ndata.x, 1) - nhidden = args.nhidden - - model = GNNChain(GraphConv(nin => nhidden, relu), - GraphConv(nhidden => nhidden, relu), - GlobalPool(mean), - Dense(nhidden, 1)) |> device - - opt = Flux.setup(Adam(args.η), model) - - # LOGGING FUNCTION - - function report(epoch) - train = eval_loss_accuracy(model, train_loader, device) - test = eval_loss_accuracy(model, test_loader, device) - println("Epoch: $epoch Train: $(train) Test: $(test)") - end - - # TRAIN - - report(0) - for epoch in 1:(args.epochs) - for (g, y) in train_loader - g, y = (g, y) |> device - grads = Flux.gradient(model) do model - ŷ = model(g, g.ndata.x) |> vec - logitbinarycrossentropy(ŷ, y) - end - Flux.update!(opt, model, grads[1]) - end - epoch % args.infotime == 0 && report(epoch) - end -end - -train() - -[.\examples\link_prediction_pubmed.jl] -# An example of link prediction using negative and positive samples. -# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py -# See the comparison paper https://arxiv.org/pdf/2102.12557.pdf for more details - -using Flux -using Flux: onecold, onehotbatch -using Flux.Losses: logitbinarycrossentropy -using GraphNeuralNetworks -using MLDatasets: PubMed -using Statistics, Random, LinearAlgebra -using CUDA -CUDA.allowscalar(false) - -# arguments for the `train` function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - epochs = 200 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 64 # dimension of hidden features - infotime = 10 # report every `infotime` epochs -end - -# We define our own edge prediction layer but could also -# use GraphNeuralNetworks.DotDecoder instead. -struct DotPredictor end - -function (::DotPredictor)(g, x) - z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims = 1), g, xi = x, xj = x) - # z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with built-in method - return vec(z) -end - -function train(; kws...) - args = Args(; kws...) - - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end - - ### LOAD DATA - g = mldataset2gnngraph(PubMed()) - - # Print some info - display(g) - @show is_bidirected(g) - @show has_self_loops(g) - @show has_multi_edges(g) - @show mean(degree(g)) - isbidir = is_bidirected(g) - - # Move to device - g = g |> device - X = g.ndata.features - - #### TRAIN/TEST splits - # With bidirected graph, we make sure that an edge and its reverse - # are in the same split - train_pos_g, test_pos_g = rand_edge_split(g, 0.9, bidirected = isbidir) - test_neg_g = negative_sample(g, num_neg_edges = test_pos_g.num_edges, - bidirected = isbidir) - - ### DEFINE MODEL ######### - nin, nhidden = size(X, 1), args.nhidden - - # We embed the graph with positive training edges in the model - model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu), - GCNConv(nhidden => nhidden)), - train_pos_g) |> device - - pred = DotPredictor() - - opt = Flux.setup(Adam(args.η), model) - - ### LOSS FUNCTION ############ - - function loss(model, pos_g, neg_g = nothing; with_accuracy = false) - h = model(X) - if neg_g === nothing - # We sample a negative graph at each training step - neg_g = negative_sample(pos_g, bidirected = isbidir) - end - pos_score = pred(pos_g, h) - neg_score = pred(neg_g, h) - scores = [pos_score; neg_score] - labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)] - l = logitbinarycrossentropy(scores, labels) - if with_accuracy - acc = 0.5 * mean(pos_score .>= 0) + 0.5 * mean(neg_score .< 0) - return l, acc - else - return l - end - end - - ### LOGGING FUNCTION - function report(epoch) - train_loss, train_acc = loss(model, train_pos_g, with_accuracy = true) - test_loss, test_acc = loss(model, test_pos_g, test_neg_g, with_accuracy = true) - println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") - end - - ### TRAINING - report(0) - for epoch in 1:(args.epochs) - grads = Flux.gradient(model -> loss(model, train_pos_g), model) - Flux.update!(opt, model, grads[1]) - epoch % args.infotime == 0 && report(epoch) - end -end - -train() - -[.\examples\neural_ode_cora.jl] -# Load the packages -using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations -using Flux: onehotbatch, onecold -using Flux.Losses: logitcrossentropy -using Flux -using Statistics: mean -using MLDatasets: Cora -using CUDA -# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux - -# device = cpu # `gpu` not working yet -device = CUDA.functional() ? gpu : cpu - -# LOAD DATA -dataset = Cora() -classes = dataset.metadata["classes"] -g = mldataset2gnngraph(dataset) |> device -X = g.ndata.features -y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged -(; train_mask, val_mask, test_mask) = g.ndata -ytrain = y[:, train_mask] - -# Model and Data Configuration -nin = size(X, 1) -nhidden = 16 -nout = length(classes) -epochs = 40 - -# Define the Neural GDE -diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) - -node_chain = GNNChain(GCNConv(nhidden => nhidden, relu), - GCNConv(nhidden => nhidden, relu)) |> device - -node = NeuralODE(WithGraph(node_chain, g), - (0.0f0, 1.0f0), Tsit5(), save_everystep = false, - reltol = 1e-3, abstol = 1e-3, save_start = false) |> device - -model = GNNChain(GCNConv(nin => nhidden, relu), - node, - diffeqsol_to_array, - Dense(nhidden, nout)) |> device - -# # Training - -opt = Flux.setup(Adam(0.01), model) - -function eval_loss_accuracy(X, y, mask) - ŷ = model(g, X) - l = logitcrossentropy(ŷ[:, mask], y[:, mask]) - acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask])) - return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) -end - -# ## Training Loop -for epoch in 1:epochs - grad = gradient(model) do model - ŷ = model(g, X) - logitcrossentropy(ŷ[:, train_mask], ytrain) - end - Flux.update!(opt, model, grad[1]) - @show eval_loss_accuracy(X, y, train_mask) -end - -[.\examples\node_classification_cora.jl] -# An example of semi-supervised node classification - -using Flux -using Flux: onecold, onehotbatch -using Flux.Losses: logitcrossentropy -using GraphNeuralNetworks -using MLDatasets: Cora -using Statistics, Random -using CUDA -CUDA.allowscalar(false) - -function eval_loss_accuracy(X, y, mask, model, g) - ŷ = model(g, X) - l = logitcrossentropy(ŷ[:, mask], y[:, mask]) - acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask])) - return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) -end - -# arguments for the `train` function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - epochs = 100 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 128 # dimension of hidden features - infotime = 10 # report every `infotime` epochs -end - -function train(; kws...) - args = Args(; kws...) - - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end - - # LOAD DATA - dataset = Cora() - classes = dataset.metadata["classes"] - g = mldataset2gnngraph(dataset) |> device - X = g.features - y = onehotbatch(g.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged - ytrain = y[:, g.train_mask] - - nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) - - ## DEFINE MODEL - model = GNNChain(GCNConv(nin => nhidden, relu), - GCNConv(nhidden => nhidden, relu), - Dense(nhidden, nout)) |> device - - opt = Flux.setup(Adam(args.η), model) - - display(g) - - ## LOGGING FUNCTION - function report(epoch) - train = eval_loss_accuracy(X, y, g.train_mask, model, g) - test = eval_loss_accuracy(X, y, g.test_mask, model, g) - println("Epoch: $epoch Train: $(train) Test: $(test)") - end - - ## TRAINING - report(0) - for epoch in 1:(args.epochs) - grad = Flux.gradient(model) do model - ŷ = model(g, X) - logitcrossentropy(ŷ[:, g.train_mask], ytrain) - end - - Flux.update!(opt, model, grad[1]) - - epoch % args.infotime == 0 && report(epoch) - end -end - -train() - -[.\examples\traffic_prediction.jl] -# Example of using TGCN, a recurrent temporal graph convolutional network of the paper https://arxiv.org/pdf/1811.05320.pdf, for traffic prediction by training it on the METRLA dataset - -# Load packages -using Flux -using Flux.Losses: mae -using GraphNeuralNetworks -using MLDatasets: METRLA -using CUDA -using Statistics, Random -CUDA.allowscalar(false) - -# Import dataset function -function getdataset() - metrla = METRLA(; num_timesteps = 3) - g = metrla[1] - graph = GNNGraph(g.edge_index; edata = g.edge_data, g.num_nodes) - features = g.node_data.features - targets = g.node_data.targets - train_loader = zip(features[1:2000], targets[1:2000]) - test_loader = zip(features[2001:2288], targets[2001:2288]) - return graph, train_loader, test_loader -end - -# Loss and accuracy functions -lossfunction(ŷ, y) = Flux.mae(ŷ, y) -accuracy(ŷ, y) = 1 - Statistics.norm(y-ŷ)/Statistics.norm(y) - -function eval_loss_accuracy(model, graph, data_loader) - error = mean([lossfunction(model(graph,x), y) for (x, y) in data_loader]) - acc = mean([accuracy(model(graph,x), y) for (x, y) in data_loader]) - return (loss = round(error, digits = 4), acc = round(acc , digits = 4)) -end - -# Arguments for the train function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - epochs = 100 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 100 # dimension of hidden features - infotime = 20 # report every `infotime` epochs -end - -# Train function -function train(; kws...) - args = Args(; kws...) - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end - - # Define model - model = GNNChain(TGCN(2 => args.nhidden), Dense(args.nhidden, 1)) |> device - - opt = Flux.setup(Adam(args.η), model) - - graph, train_loader, test_loader = getdataset() - graph = graph |> device - train_loader = train_loader |> device - test_loader = test_loader |> device - - function report(epoch) - train_loss, train_acc = eval_loss_accuracy(model, graph, train_loader) - test_loss, test_acc = eval_loss_accuracy(model, graph, test_loader) - println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") - end - - report(0) - for epoch in 1:(args.epochs) - for (x, y) in train_loader - x, y = (x, y) - grads = Flux.gradient(model) do model - ŷ = model(graph, x) - lossfunction(y,ŷ) - end - Flux.update!(opt, model, grads[1]) - end - - args.infotime > 0 && epoch % args.infotime == 0 && report(epoch) - - end - return model -end - -train() - - -[.\GNNGraphs\ext\GNNGraphsCUDAExt.jl] -module GNNGraphsCUDAExt - -using CUDA -using Random, Statistics, LinearAlgebra -using GNNGraphs -using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T - -const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} - -# Query - -GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1)) - -# Transform - -GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz) - - -# Utils - -GNNGraphs.iscuarray(x::AnyCuArray) = true - - -function sort_edge_index(u::AnyCuArray, v::AnyCuArray) - dev = get_device(u) - cdev = cpu_device() - u, v = u |> cdev, v |> cdev - #TODO proper cuda friendly implementation - sort_edge_index(u, v) |> dev -end - - -end #module - -[.\GNNGraphs\ext\GNNGraphsSimpleWeightedGraphsExt.jl] -module GNNGraphsSimpleWeightedGraphsExt - -using Graphs -using GNNGraphs -using SimpleWeightedGraphs - -function GNNGraphs.GNNGraph(g::T; kws...) where - {T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}} - return GNNGraph(g.weights, kws...) -end - -end #module -[.\GNNGraphs\src\abstracttypes.jl] - -const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}} -const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} -const ADJMAT_T = AbstractMatrix -const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T - -const AVecI = AbstractVector{<:Integer} - -# All concrete graph types should be subtypes of AbstractGNNGraph{T}. -# GNNGraph and GNNHeteroGraph are the two concrete types. -abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end - -[.\GNNGraphs\src\chainrules.jl] -# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648 -# Remove when merged - -function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict} - ks = map(first, ps) - project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps) - function Dict_pullback(ȳ) - dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v - dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent())) - Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) - end - return (NoTangent(), dps...) - end - return T(ps...), Dict_pullback -end - -[.\GNNGraphs\src\convert.jl] -### CONVERT_TO_COO REPRESENTATION ######## - -function to_coo(data::EDict; num_nodes = nothing, kws...) - graph = EDict{COO_T}() - _num_nodes = NDict{Int}() - num_edges = EDict{Int}() - for k in keys(data) - d = data[k] - @assert d isa Tuple - if length(d) == 2 - d = (d..., nothing) - end - if num_nodes !== nothing - n1 = get(num_nodes, k[1], nothing) - n2 = get(num_nodes, k[3], nothing) - else - n1 = nothing - n2 = nothing - end - g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) - graph[k] = g - num_edges[k] = nedges - _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) - _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) - end - return graph, _num_nodes, num_edges -end - -function to_coo(coo::COO_T; dir = :out, num_nodes = nothing, weighted = true, - hetero = false) - s, t, val = coo - - if isnothing(num_nodes) - ns = maximum(s) - nt = maximum(t) - num_nodes = hetero ? (ns, nt) : max(ns, nt) - elseif num_nodes isa Integer - ns = num_nodes - nt = num_nodes - elseif num_nodes isa Tuple - ns = isnothing(num_nodes[1]) ? maximum(s) : num_nodes[1] - nt = isnothing(num_nodes[2]) ? maximum(t) : num_nodes[2] - num_nodes = (ns, nt) - else - error("Invalid num_nodes $num_nodes") - end - @assert isnothing(val) || length(val) == length(s) - @assert length(s) == length(t) - if !isempty(s) - @assert minimum(s) >= 1 - @assert minimum(t) >= 1 - @assert maximum(s) <= ns - @assert maximum(t) <= nt - end - num_edges = length(s) - if !weighted - coo = (s, t, nothing) - end - return coo, num_nodes, num_edges -end - -function to_coo(A::SPARSE_T; dir = :out, num_nodes = nothing, weighted = true) - s, t, v = findnz(A) - if dir == :in - s, t = t, s - end - num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes - num_edges = length(s) - if !weighted - v = nothing - end - return (s, t, v), num_nodes, num_edges -end - -function _findnz_idx(A) - nz = findall(!=(0), A) # vec of cartesian indexes - s, t = ntuple(i -> map(t -> t[i], nz), 2) - return s, t, nz -end - -@non_differentiable _findnz_idx(A) - -function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true) - s, t, nz = _findnz_idx(A) - v = A[nz] - if dir == :in - s, t = t, s - end - num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes - num_edges = length(s) - if !weighted - v = nothing - end - return (s, t, v), num_nodes, num_edges -end - -function to_coo(adj_list::ADJLIST_T; dir = :out, num_nodes = nothing, weighted = true) - @assert dir ∈ [:out, :in] - num_nodes = length(adj_list) - num_edges = sum(length.(adj_list)) - @assert num_nodes > 0 - s = similar(adj_list[1], eltype(adj_list[1]), num_edges) - t = similar(adj_list[1], eltype(adj_list[1]), num_edges) - e = 0 - for i in 1:num_nodes - for j in adj_list[i] - e += 1 - s[e] = i - t[e] = j - end - end - @assert e == num_edges - if dir == :in - s, t = t, s - end - (s, t, nothing), num_nodes, num_edges -end - -### CONVERT TO ADJACENCY MATRIX ################ - -### DENSE #################### - -to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...) - -function to_dense(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, - weighted = true) - @assert dir ∈ [:out, :in] - T = T === nothing ? eltype(A) : T - num_nodes = size(A, 1) - @assert num_nodes == size(A, 2) - # @assert all(x -> (x == 1) || (x == 0), A) - num_edges = numnonzeros(A) - if dir == :in - A = A' - end - if T != eltype(A) - A = T.(A) - end - if !weighted - A = map(x -> ifelse(x > 0, T(1), T(0)), A) - end - return A, num_nodes, num_edges -end - -function to_dense(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, - weighted = true) - @assert dir ∈ [:out, :in] - num_nodes = length(adj_list) - num_edges = sum(length.(adj_list)) - @assert num_nodes > 0 - T = T === nothing ? eltype(adj_list[1]) : T - A = fill!(similar(adj_list[1], T, (num_nodes, num_nodes)), 0) - if dir == :out - for (i, neigs) in enumerate(adj_list) - A[i, neigs] .= 1 - end - else - for (i, neigs) in enumerate(adj_list) - A[neigs, i] .= 1 - end - end - A, num_nodes, num_edges -end - -function to_dense(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, weighted = true) - # `dir` will be ignored since the input `coo` is always in source -> target format. - # The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j) - s, t, val = coo - n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes - if T === nothing - T = isnothing(val) ? eltype(s) : eltype(val) - end - if val === nothing || !weighted - val = ones_like(s, T) - end - if eltype(val) != T - val = T.(val) - end - - idxs = s .+ n .* (t .- 1) - - ## using scatter instead of indexing since there could be multiple edges - # A = fill!(similar(s, T, (n, n)), 0) - # v = vec(A) # vec view of A - # A[idxs] .= val # exploiting linear indexing - v = NNlib.scatter(+, val, idxs, dstsize = n^2) - A = reshape(v, (n, n)) - return A, n, length(s) -end - -### SPARSE ############# - -function to_sparse(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, - weighted = true) - @assert dir ∈ [:out, :in] - num_nodes = size(A, 1) - @assert num_nodes == size(A, 2) - T = T === nothing ? eltype(A) : T - num_edges = A isa AbstractSparseMatrix ? nnz(A) : count(!=(0), A) - if dir == :in - A = A' - end - if T != eltype(A) - A = T.(A) - end - if !(A isa AbstractSparseMatrix) - A = sparse(A) - end - if !weighted - A = map(x -> ifelse(x > 0, T(1), T(0)), A) - end - return A, num_nodes, num_edges -end - -function to_sparse(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, - weighted = true) - coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes) - return to_sparse(coo; num_nodes) -end - -function to_sparse(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, - weighted = true) - s, t, eweight = coo - T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T - - if eweight === nothing || !weighted - eweight = fill!(similar(s, T), 1) - end - - num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes - A = sparse(s, t, eweight, num_nodes, num_nodes) - num_edges::Int = nnz(A) - if eltype(A) != T - A = T.(A) - end - return A, num_nodes, num_edges -end - -[.\GNNGraphs\src\datastore.jl] -""" - DataStore([n, data]) - DataStore([n,] k1 = x1, k2 = x2, ...) - -A container for feature arrays. The optional argument `n` enforces that -`numobs(x) == n` for each array contained in the datastore. - -At construction time, the `data` can be provided as any iterables of pairs -of symbols and arrays or as keyword arguments: - -```jldoctest -julia> ds = DataStore(3, x = rand(Float32, 2, 3), y = rand(Float32, 3)) -DataStore(3) with 2 elements: - y = 3-element Vector{Float32} - x = 2×3 Matrix{Float32} - -julia> ds = DataStore(3, Dict(:x => rand(Float32, 2, 3), :y => rand(Float32, 3))); # equivalent to above - -julia> ds = DataStore(3, (x = rand(Float32, 2, 3), y = rand(Float32, 30))) -ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3 -Stacktrace: - [1] DataStore(n::Int64, data::Dict{Symbol, Any}) - @ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:54 - [2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float32}, Vector{Float32}}}) - @ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:73 - [3] top-level scope - @ REPL[13]:1 - -julia> ds = DataStore(x = randFloat32, 2, 3), y = rand(Float32, 30)) # no checks -DataStore() with 2 elements: - y = 30-element Vector{Float32} - x = 2×3 Matrix{Float32} - y = 30-element Vector{Float64} - x = 2×3 Matrix{Float64} -``` - -The `DataStore` has an interface similar to both dictionaries and named tuples. -Arrays can be accessed and added using either the indexing or the property syntax: - -```jldoctest -julia> ds = DataStore(x = ones(Float32, 2, 3), y = zeros(Float32, 3)) -DataStore() with 2 elements: - y = 3-element Vector{Float32} - x = 2×3 Matrix{Float32} - -julia> ds.x # same as `ds[:x]` -2×3 Matrix{Float32}: - 1.0 1.0 1.0 - 1.0 1.0 1.0 - -julia> ds.z = zeros(Float32, 3) # Add new feature array `z`. Same as `ds[:z] = rand(Float32, 3)` -3-element Vector{Float64}: -0.0 -0.0 -0.0 -``` - -The `DataStore` can be iterated over, and the keys and values can be accessed -using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f` -to each feature array: - -```jldoctest -julia> ds = DataStore(a = zeros(2), b = zeros(2)); - -julia> ds2 = map(x -> x .+ 1, ds) - -julia> ds2.a -2-element Vector{Float64}: - 1.0 - 1.0 -``` -""" -struct DataStore - _n::Int # either -1 or numobs(data) - _data::Dict{Symbol, Any} - - function DataStore(n::Int, data::Dict{Symbol, Any}) - if n >= 0 - for (k, v) in data - @assert numobs(v)==n "DataStore: data[$k] has $(numobs(v)) observations, but n = $n" - end - end - return new(n, data) - end -end - -@functor DataStore - -DataStore(data) = DataStore(-1, data) -DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol, Any}(pairs(data))) -DataStore(n::Int, data) = DataStore(n, Dict{Symbol, Any}(data)) - -DataStore(; kws...) = DataStore(-1; kws...) -DataStore(n::Int; kws...) = DataStore(n, Dict{Symbol, Any}(kws...)) - -getdata(ds::DataStore) = getfield(ds, :_data) -getn(ds::DataStore) = getfield(ds, :_n) -# setn!(ds::DataStore, n::Int) = setfield!(ds, :n, n) - -function Base.getproperty(ds::DataStore, s::Symbol) - if s === :_n - return getn(ds) - elseif s === :_data - return getdata(ds) - else - return getdata(ds)[s] - end -end - -function Base.getproperty(vds::Vector{DataStore}, s::Symbol) - if s === :_n - return [getn(ds) for ds in vds] - elseif s === :_data - return [getdata(ds) for ds in vds] - else - return [getdata(ds)[s] for ds in vds] - end -end - -function Base.setproperty!(ds::DataStore, s::Symbol, x) - @assert s != :_n "cannot set _n directly" - @assert s != :_data "cannot set _data directly" - if getn(ds) >= 0 - numobs(x) == getn(ds) || throw(DimensionMismatch("expected $(getn(ds)) object features but got $(numobs(x)).")) - end - return getdata(ds)[s] = x -end - -Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s) -Base.setindex!(ds::DataStore, x, s::Symbol) = setproperty!(ds, s, x) - -function Base.show(io::IO, ds::DataStore) - len = length(ds) - n = getn(ds) - if n < 0 - print(io, "DataStore()") - else - print(io, "DataStore($(getn(ds)))") - end - if len > 0 - print(io, " with $(length(getdata(ds))) element") - len > 1 && print(io, "s") - print(io, ":") - for (k, v) in getdata(ds) - print(io, "\n $(k) = $(summary(v))") - end - else - print(io, " with no elements") - end -end - -Base.iterate(ds::DataStore) = iterate(getdata(ds)) -Base.iterate(ds::DataStore, state) = iterate(getdata(ds), state) -Base.keys(ds::DataStore) = keys(getdata(ds)) -Base.values(ds::DataStore) = values(getdata(ds)) -Base.length(ds::DataStore) = length(getdata(ds)) -Base.haskey(ds::DataStore, k) = haskey(getdata(ds), k) -Base.get(ds::DataStore, k, default) = get(getdata(ds), k, default) -Base.pairs(ds::DataStore) = pairs(getdata(ds)) -Base.:(==)(ds1::DataStore, ds2::DataStore) = getdata(ds1) == getdata(ds2) -Base.isempty(ds::DataStore) = isempty(getdata(ds)) -Base.delete!(ds::DataStore, k) = delete!(getdata(ds), k) - -function Base.map(f, ds::DataStore) - d = getdata(ds) - newd = Dict{Symbol, Any}(k => f(v) for (k, v) in d) - return DataStore(getn(ds), newd) -end - -MLUtils.numobs(ds::DataStore) = numobs(getdata(ds)) - -function MLUtils.getobs(ds::DataStore, i::Int) - newdata = getobs(getdata(ds), i) - return DataStore(-1, newdata) -end - -function MLUtils.getobs(ds::DataStore, - i::AbstractVector{T}) where {T <: Union{Integer, Bool}} - newdata = getobs(getdata(ds), i) - n = getn(ds) - if n >= 0 - if length(ds) > 0 - n = numobs(newdata) - else - # if newdata is empty, then we can't get the number of observations from it - n = T == Bool ? sum(i) : length(i) - end - end - if !(newdata isa Dict{Symbol, Any}) - newdata = Dict{Symbol, Any}(newdata) - end - return DataStore(n, newdata) -end - -function cat_features(ds1::DataStore, ds2::DataStore) - n1, n2 = getn(ds1), getn(ds2) - n1 = n1 >= 0 ? n1 : 1 - n2 = n2 >= 0 ? n2 : 1 - return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2))) -end - -function cat_features(dss::AbstractVector{DataStore}; kws...) - ns = getn.(dss) - ns = map(n -> n >= 0 ? n : 1, ns) - return DataStore(sum(ns), cat_features(getdata.(dss); kws...)) -end - -# DataStore is always already normalized -normalize_graphdata(ds::DataStore; kws...) = ds - -_gather(x::DataStore, i) = map(x -> _gather(x, i), x) - -function _scatter(aggr, src::DataStore, idx, n) - newdata = _scatter(aggr, getdata(src), idx, n) - if !(newdata isa Dict{Symbol, Any}) - newdata = Dict{Symbol, Any}(newdata) - end - return DataStore(n, newdata) -end - -function Base.hash(ds::D, h::UInt) where {D <: DataStore} - fs = (getfield(ds, k) for k in fieldnames(D)) - return foldl((h, f) -> hash(f, h), fs, init = hash(D, h)) -end - -[.\GNNGraphs\src\gatherscatter.jl] -_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x) -_gather(x::Dict, i) = Dict([k => _gather(v, i) for (k, v) in x]...) -_gather(x::Tuple, i) = map(x -> _gather(x, i), x) -_gather(x::AbstractArray, i) = NNlib.gather(x, i) -_gather(x::Nothing, i) = nothing - -_scatter(aggr, src::Nothing, idx, n) = nothing -_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) -_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) -_scatter(aggr, src::Dict, idx, n) = Dict([k => _scatter(aggr, v, idx, n) for (k, v) in src]...) - -function _scatter(aggr, - src::AbstractArray, - idx::AbstractVector{<:Integer}, - n::Integer) - dstsize = (size(src)[1:(end - 1)]..., n) - return NNlib.scatter(aggr, src, idx; dstsize) -end - -[.\GNNGraphs\src\generate.jl] -""" - rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...) - -Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. - -If `bidirected=true` the reverse edge of each edge will be present. -If `bidirected=false` instead, `m` unrelated edges are generated. -In any case, the output graph will contain no self-loops or multi-edges. - -A vector can be passed as `edge_weight`. Its length has to be equal to `m` -in the directed case, and `m÷2` in the bidirected one. - -Pass a random number generator as the first argument to make the generation reproducible. - -Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. - -# Examples - -```jldoctest -julia> g = rand_graph(5, 4, bidirected=false) -GNNGraph: - num_nodes = 5 - num_edges = 4 - -julia> edge_index(g) -([1, 3, 3, 4], [5, 4, 5, 2]) - -# In the bidirected case, edge data will be duplicated on the reverse edges if needed. -julia> g = rand_graph(5, 4, edata=rand(Float32, 16, 2)) -GNNGraph: - num_nodes = 5 - num_edges = 4 - edata: - e => (16, 4) - -# Each edge has a reverse -julia> edge_index(g) -([1, 3, 3, 4], [3, 4, 1, 3]) -``` -""" -function rand_graph(n::Integer, m::Integer; seed=-1, kws...) - if seed != -1 - Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph) - rng = MersenneTwister(seed) - else - rng = Random.default_rng() - end - return rand_graph(rng, n, m; kws...) -end - -function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; - bidirected::Bool = true, - edge_weight::Union{AbstractVector, Nothing} = nothing, kws...) - if bidirected - @assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m." - s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false) - s, t = vcat(s, t), vcat(t, s) - if edge_weight !== nothing - edge_weight = vcat(edge_weight, edge_weight) - end - else - s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false) - end - return GNNGraph((s, t, edge_weight); num_nodes=n, kws...) -end - -""" - rand_heterograph([rng,] n, m; bidirected=false, kws...) - -Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges -specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs -specifing node/edge types and their numbers. - -Pass a random number generator as a first argument to make the generation reproducible. - -Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. -Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` -will be generated. - -Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. - -# Examples - -```jldoctest -julia> g = rand_heterograph((:user => 10, :movie => 20), - (:user, :rate, :movie) => 30) -GNNHeteroGraph: - num_nodes: (:user => 10, :movie => 20) - num_edges: ((:user, :rate, :movie) => 30,) -``` -""" -function rand_heterograph end - -# for generic iterators of pairs -rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) -rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...) - -function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...) - if seed != -1 - Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph) - rng = MersenneTwister(seed) - else - rng = Random.default_rng() - end - return rand_heterograph(rng, n, m; kws...) -end - -function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...) - if bidirected - return _rand_bidirected_heterograph(rng, n, m; kws...) - end - graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m)) - return GNNHeteroGraph(graphs; num_nodes = n, kws...) -end - -function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...) - for k in keys(m) - if reverse(k) ∈ keys(m) - @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." - else - m[reverse(k)] = m[k] - end - end - graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}() - for k in keys(m) - reverse(k) ∈ keys(graphs) && continue - s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) - graphs[k] = s, t, val - graphs[reverse(k)] = t, s, val - end - return GNNHeteroGraph(graphs; num_nodes = n, kws...) -end - - -""" - rand_bipartite_heterograph([rng,] - (n1, n2), (m12, m21); - bidirected = true, - node_t = (:A, :B), - edge_t = :to, - kws...) - -Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph. -The graph will have two types of nodes, and edges will only connect nodes of different types. - -The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type. -The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2` -and vice versa. - -The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments, -which default to `(:A, :B)` and `:to` respectively. - -If `bidirected=true` (default), the reverse edge of each edge will be present. In this case -`m12 == m21` is required. - -A random number generator can be passed as the first argument to make the generation reproducible. - -Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. - -See [`rand_heterograph`](@ref) for a more general version. - -# Examples - -```julia-repl -julia> g = rand_bipartite_heterograph((10, 15), 20) -GNNHeteroGraph: - num_nodes: (:A => 10, :B => 15) - num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20) - -julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false) -GNNHeteroGraph: - num_nodes: Dict(:item => 15, :user => 10) - num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20) -``` -""" -rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...) - -function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true, - node_t = (:A, :B), edge_t::Symbol = :to, kws...) - if m isa Integer - m12 = m21 = m - else - m12, m21 = m - end - - return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2), - Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21); - bidirected, kws...) -end - -""" - knn_graph(points::AbstractMatrix, - k::Int; - graph_indicator = nothing, - self_loops = false, - dir = :in, - kws...) - -Create a `k`-nearest neighbor graph where each node is linked -to its `k` closest `points`. - -# Arguments - -- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. -- `k`: The number of neighbors considered in the kNN algorithm. -- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, - in which case the returned graph will be a batch of graphs. -- `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which - case the graph will contain self-loops. -- `dir`: The direction of the edges. If `dir=:in` edges go from the `k` - neighbors to the central node. If `dir=:out` we have the opposite - direction. -- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. - -# Examples - -```jldoctest -julia> n, k = 10, 3; - -julia> x = rand(Float32, 3, n); - -julia> g = knn_graph(x, k) -GNNGraph: - num_nodes = 10 - num_edges = 30 - -julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; - -julia> g = knn_graph(x, k; graph_indicator) -GNNGraph: - num_nodes = 10 - num_edges = 30 - num_graphs = 2 - -``` -""" -function knn_graph(points::AbstractMatrix, k::Int; - graph_indicator = nothing, - self_loops = false, - dir = :in, - kws...) - if graph_indicator !== nothing - d, n = size(points) - @assert graph_indicator isa AbstractVector{<:Integer} - @assert length(graph_indicator) == n - # All graphs in the batch must have at least k nodes. - cm = StatsBase.countmap(graph_indicator) - @assert all(values(cm) .>= k) - - # Make sure that the distance between points in different graphs - # is always larger than any distance within the same graph. - points = points .- minimum(points) - points = points ./ maximum(points) - dummy_feature = 2d .* reshape(graph_indicator, 1, n) - points = vcat(points, dummy_feature) - end - - kdtree = NearestNeighbors.KDTree(points) - if !self_loops - k += 1 - end - sortres = false - idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres) - - g = GNNGraph(idxs; dir, graph_indicator, kws...) - if !self_loops - g = remove_self_loops(g) - end - return g -end - -""" - radius_graph(points::AbstractMatrix, - r::AbstractFloat; - graph_indicator = nothing, - self_loops = false, - dir = :in, - kws...) - -Create a graph where each node is linked -to its neighbors within a given distance `r`. - -# Arguments - -- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. -- `r`: The radius. -- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, - in which case the returned graph will be a batch of graphs. -- `self_loops`: If `true`, consider the node itself among its neighbors, in which - case the graph will contain self-loops. -- `dir`: The direction of the edges. If `dir=:in` edges go from the - neighbors to the central node. If `dir=:out` we have the opposite - direction. -- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. - -# Examples - -```jldoctest -julia> n, r = 10, 0.75; - -julia> x = rand(Float32, 3, n); - -julia> g = radius_graph(x, r) -GNNGraph: - num_nodes = 10 - num_edges = 46 - -julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; - -julia> g = radius_graph(x, r; graph_indicator) -GNNGraph: - num_nodes = 10 - num_edges = 20 - num_graphs = 2 - -``` -# References -Section B paragraphs 1 and 2 of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) -""" -function radius_graph(points::AbstractMatrix, r::AbstractFloat; - graph_indicator = nothing, - self_loops = false, - dir = :in, - kws...) - if graph_indicator !== nothing - d, n = size(points) - @assert graph_indicator isa AbstractVector{<:Integer} - @assert length(graph_indicator) == n - - # Make sure that the distance between points in different graphs - # is always larger than r. - dummy_feature = 2r .* reshape(graph_indicator, 1, n) - points = vcat(points, dummy_feature) - end - - balltree = NearestNeighbors.BallTree(points) - - sortres = false - idxs = NearestNeighbors.inrange(balltree, points, r, sortres) - - g = GNNGraph(idxs; dir, graph_indicator, kws...) - if !self_loops - g = remove_self_loops(g) - end - return g -end - -""" - rand_temporal_radius_graph(number_nodes::Int, - number_snapshots::Int, - speed::AbstractFloat, - r::AbstractFloat; - self_loops = false, - dir = :in, - kws...) - -Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. -First, the positions of the nodes are randomly generated in the unit square. Two nodes are connected if their distance is less than a given radius `r`. -Each following snapshot is obtained by applying the same construction to new positions obtained as follows. -For each snapshot, the new positions of the points are determined by applying random independent displacement vectors to the previous positions. The direction of the displacement is chosen uniformly at random and its length is chosen uniformly in `[0, speed]`. Then the connections are recomputed. -If a point happens to move outside the boundary, its position is updated as if it had bounced off the boundary. - -# Arguments - -- `number_nodes`: The number of nodes of each snapshot. -- `number_snapshots`: The number of snapshots. -- `speed`: The speed to update the nodes. -- `r`: The radius of connection. -- `self_loops`: If `true`, consider the node itself among its neighbors, in which - case the graph will contain self-loops. -- `dir`: The direction of the edges. If `dir=:in` edges go from the - neighbors to the central node. If `dir=:out` we have the opposite - direction. -- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. - -# Example - -```jldoctest -julia> n, snaps, s, r = 10, 5, 0.1, 1.5; - -julia> tg = rand_temporal_radius_graph(n,snaps,s,r) # complete graph at each snapshot -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [90, 90, 90, 90, 90] - num_snapshots: 5 -``` - -""" -function rand_temporal_radius_graph(number_nodes::Int, - number_snapshots::Int, - speed::AbstractFloat, - r::AbstractFloat; - self_loops = false, - dir = :in, - kws...) - points=rand(2, number_nodes) - tg = Vector{GNNGraph}(undef, number_snapshots) - for t in 1:number_snapshots - tg[t] = radius_graph(points, r; graph_indicator = nothing, self_loops, dir, kws...) - for i in 1:number_nodes - ρ = 2 * speed * rand() - speed - theta=2*pi*rand() - points[1,i]=1-abs(1-(abs(points[1,i]+ρ*cos(theta)))) - points[2,i]=1-abs(1-(abs(points[2,i]+ρ*sin(theta)))) - end - end - return TemporalSnapshotsGNNGraph(tg) -end - - -function _hyperbolic_distance(nodeA::Array{Float64, 1},nodeB::Array{Float64, 1}; ζ::Real) - if nodeA != nodeB - a = cosh(ζ * nodeA[1]) * cosh(ζ * nodeB[1]) - b = sinh(ζ * nodeA[1]) * sinh(ζ * nodeB[1]) - c = cos(pi - abs(pi - abs(nodeA[2] - nodeB[2]))) - d = acosh(a - (b * c)) / ζ - else - d = 0.0 - end - return d -end - -""" - rand_temporal_hyperbolic_graph(number_nodes::Int, - number_snapshots::Int; - α::Real, - R::Real, - speed::Real, - ζ::Real=1, - self_loop = false, - kws...) - -Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. -First, the positions of the nodes are generated with a quasi-uniform distribution (depending on the parameter `α`) in hyperbolic space within a disk of radius `R`. Two nodes are connected if their hyperbolic distance is less than `R`. Each following snapshot is created in order to keep the same initial distribution. - -# Arguments - -- `number_nodes`: The number of nodes of each snapshot. -- `number_snapshots`: The number of snapshots. -- `α`: The parameter that controls the position of the points. If `α=ζ`, the points are uniformly distributed on the disk of radius `R`. If `α>ζ`, the points are more concentrated in the center of the disk. If `α<ζ`, the points are more concentrated at the boundary of the disk. -- `R`: The radius of the disk and of connection. -- `speed`: The speed to update the nodes. -- `ζ`: The parameter that controls the curvature of the disk. -- `self_loops`: If `true`, consider the node itself among its neighbors, in which - case the graph will contain self-loops. -- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. - -# Example - -```jldoctest -julia> n, snaps, α, R, speed, ζ = 10, 5, 1.0, 4.0, 0.1, 1.0; - -julia> thg = rand_temporal_hyperbolic_graph(n, snaps; α, R, speed, ζ) -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [44, 46, 48, 42, 38] - num_snapshots: 5 -``` - -# References -Section D of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) and the paper -[Hyperbolic Geometry of Complex Networks](https://arxiv.org/pdf/1006.5169.pdf) -""" -function rand_temporal_hyperbolic_graph(number_nodes::Int, - number_snapshots::Int; - α::Real, - R::Real, - speed::Real, - ζ::Real=1, - self_loop = false, - kws...) - @assert number_snapshots > 1 "The number of snapshots must be greater than 1" - @assert α > 0 "α must be greater than 0" - - probabilities = rand(number_nodes) - - points = Array{Float64}(undef,2,number_nodes) - points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) - points[2,:].= 2 * pi * rand(number_nodes) - - tg = Vector{GNNGraph}(undef, number_snapshots) - - for time in 1:number_snapshots - adj = zeros(number_nodes,number_nodes) - for i in 1:number_nodes - for j in 1:number_nodes - if !self_loop && i==j - continue - elseif _hyperbolic_distance(points[:,i],points[:,j]; ζ) <= R - adj[i,j] = adj[j,i] = 1 - end - end - end - tg[time] = GNNGraph(adj) - - probabilities .= probabilities .+ (2 * speed * rand(number_nodes) .- speed) - probabilities[probabilities.>1] .= 1 .- (probabilities[probabilities .> 1] .% 1) - probabilities[probabilities.<0] .= abs.(probabilities[probabilities .< 0]) - - points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) - points[2,:].= points[2,:] .+ (2 * speed * rand(number_nodes) .- speed) - end - return TemporalSnapshotsGNNGraph(tg) -end - -[.\GNNGraphs\src\gnngraph.jl] -#=================================== -Define GNNGraph type as a subtype of Graphs.AbstractGraph. -For the core methods to be implemented by any AbstractGraph, see -https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type -https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types -=============================================# - -""" - GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) - GNNGraph(g::GNNGraph; [ndata, edata, gdata]) - -A type representing a graph structure that also stores -feature arrays associated to nodes, edges, and the graph itself. - -The feature arrays are stored in the fields `ndata`, `edata`, and `gdata` -as [`DataStore`](@ref) objects offering a convenient dictionary-like -and namedtuple-like interface. The features can be passed at construction -time or added later. - -A `GNNGraph` can be constructed out of different `data` objects -expressing the connections inside the graph. The internal representation type -is determined by `graph_type`. - -When constructed from another `GNNGraph`, the internal graph representation -is preserved and shared. The node/edge/graph features are retained -as well, unless explicitely set by the keyword arguments -`ndata`, `edata`, and `gdata`. - -A `GNNGraph` can also represent multiple graphs batched togheter -(see [`MLUtils.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). -The field `g.graph_indicator` contains the graph membership -of each node. - -`GNNGraph`s are always directed graphs, therefore each edge is defined -by a source node and a target node (see [`edge_index`](@ref)). -Self loops (edges connecting a node to itself) and multiple edges -(more than one edge between the same pair of nodes) are supported. - -A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most -functionality from that library. - -# Arguments - -- `data`: Some data representing the graph topology. Possible type are - - An adjacency matrix - - An adjacency list. - - A tuple containing the source and target vectors (COO representation) - - A Graphs.jl' graph. -- `graph_type`: A keyword argument that specifies - the underlying representation used by the GNNGraph. - Currently supported values are - - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge - connects the node `source[k]` to node `target[k]`. - Optionally, also edge weights can be given: `(source, target, weights)`. - - `:sparse`. A sparse adjacency matrix representation. - - `:dense`. A dense adjacency matrix representation. - Defaults to `:coo`, currently the most supported type. -- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. - Possible values are `:out` and `:in`. Default `:out`. -- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. -- `graph_indicator`: For batched graphs, a vector containing the graph assignment of each node. Default `nothing`. -- `ndata`: Node features. An array or named tuple of arrays whose last dimension has size `num_nodes`. -- `edata`: Edge features. An array or named tuple of arrays whose last dimension has size `num_edges`. -- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. - -# Examples - -```julia -using GraphNeuralNetworks - -# Construct from adjacency list representation -data = [[2,3], [1,4,5], [1], [2,5], [2,4]] -g = GNNGraph(data) - -# Number of nodes, edges, and batched graphs -g.num_nodes # 5 -g.num_edges # 10 -g.num_graphs # 1 - -# Same graph in COO representation -s = [1,1,2,2,2,3,4,4,5,5] -t = [2,3,1,4,5,3,2,5,2,4] -g = GNNGraph(s, t) - -# From a Graphs' graph -g = GNNGraph(erdos_renyi(100, 20)) - -# Add 2 node feature arrays at creation time -g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes))) - -# Add 1 edge feature array, after the graph creation -g.edata.z = rand(16, g.num_edges) - -# Add node features and edge features with default names `x` and `e` -g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges)) - -g.ndata.x # or just g.x -g.edata.e # or just g.e - -# Collect edges' source and target nodes. -# Both source and target are vectors of length num_edges -source, target = edge_index(g) -``` -A `GNNGraph` can be sent to the GPU using e.g. Flux's `gpu` function: -``` -# Send to gpu -using Flux, CUDA -g = g |> Flux.gpu -``` -""" -struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} - graph::T - num_nodes::Int - num_edges::Int - num_graphs::Int - graph_indicator::Union{Nothing, AVecI} # vector of ints or nothing - ndata::DataStore - edata::DataStore - gdata::DataStore -end - -@functor GNNGraph - -function GNNGraph(data::D; - num_nodes = nothing, - graph_indicator = nothing, - graph_type = :coo, - dir = :out, - ndata = nothing, - edata = nothing, - gdata = nothing) where {D <: Union{COO_T, ADJMAT_T, ADJLIST_T}} - @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" - @assert dir ∈ [:in, :out] - - if graph_type == :coo - graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) - elseif graph_type == :dense - graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) - elseif graph_type == :sparse - graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) - end - - num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 - - ndata = normalize_graphdata(ndata, default_name = :x, n = num_nodes) - edata = normalize_graphdata(edata, default_name = :e, n = num_edges, - duplicate_if_needed = true) - - # don't force the shape of the data when there is only one graph - gdata = normalize_graphdata(gdata, default_name = :u, - n = num_graphs > 1 ? num_graphs : -1) - - GNNGraph(graph, - num_nodes, num_edges, num_graphs, - graph_indicator, - ndata, edata, gdata) -end - -GNNGraph(; kws...) = GNNGraph(0; kws...) - -function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer} - s, t = T[], T[] - return GNNGraph(s, t; num_nodes, kws...) -end - -Base.zero(::Type{G}) where {G <: GNNGraph} = G(0) - -# COO convenience constructors -function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) - GNNGraph((s, t, v); kws...) -end -GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) - -# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...) - -function GNNGraph(g::AbstractGraph; edge_weight = nothing, kws...) - s = Graphs.src.(Graphs.edges(g)) - t = Graphs.dst.(Graphs.edges(g)) - w = edge_weight - if !Graphs.is_directed(g) - # add reverse edges since GNNGraph is directed - s, t = [s; t], [t; s] - if !isnothing(w) - @assert length(w) == Graphs.ne(g) "edge_weight must have length equal to the number of undirected edges" - w = [w; w] - end - end - num_nodes::Int = Graphs.nv(g) - GNNGraph((s, t, w); num_nodes = num_nodes, kws...) -end - -function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata, - graph_type = nothing) - ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes) - edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges, - duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs) - - if !isnothing(graph_type) - if graph_type == :coo - graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes) - elseif graph_type == :dense - graph, num_nodes, num_edges = to_dense(g.graph; g.num_nodes) - elseif graph_type == :sparse - graph, num_nodes, num_edges = to_sparse(g.graph; g.num_nodes) - end - @assert num_nodes == g.num_nodes - @assert num_edges == g.num_edges - else - graph = g.graph - end - return GNNGraph(graph, - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - ndata, edata, gdata) -end - -""" - copy(g::GNNGraph; deep=false) - -Create a copy of `g`. If `deep` is `true`, then copy will be a deep copy (equivalent to `deepcopy(g)`), -otherwise it will be a shallow copy with the same underlying graph data. -""" -function Base.copy(g::GNNGraph; deep = false) - if deep - GNNGraph(deepcopy(g.graph), - g.num_nodes, g.num_edges, g.num_graphs, - deepcopy(g.graph_indicator), - deepcopy(g.ndata), deepcopy(g.edata), deepcopy(g.gdata)) - else - GNNGraph(g.graph, - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) - end -end - -function print_feature(io::IO, feature) - if !isempty(feature) - if length(keys(feature)) == 1 - k = first(keys(feature)) - v = first(values(feature)) - print(io, "$(k): $(dims2string(size(v)))") - else - print(io, "(") - for (i, (k, v)) in enumerate(pairs(feature)) - print(io, "$k: $(dims2string(size(v)))") - if i == length(feature) - print(io, ")") - else - print(io, ", ") - end - end - end - end -end - -function print_all_features(io::IO, feat1, feat2, feat3) - n1 = length(feat1) - n2 = length(feat2) - n3 = length(feat3) - if n1 == 0 && n2 == 0 && n3 == 0 - print(io, "no") - elseif n1 != 0 && (n2 != 0 || n3 != 0) - print_feature(io, feat1) - print(io, ", ") - elseif n2 == 0 && n3 == 0 - print_feature(io, feat1) - end - if n2 != 0 && n3 != 0 - print_feature(io, feat2) - print(io, ", ") - elseif n2 != 0 && n3 == 0 - print_feature(io, feat2) - end - print_feature(io, feat3) -end - -function Base.show(io::IO, g::GNNGraph) - print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") - print_all_features(io, g.ndata, g.edata, g.gdata) - print(io, " data") -end - -function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph) - if get(io, :compact, false) - print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") - print_all_features(io, g.ndata, g.edata, g.gdata) - print(io, " data") - else - print(io, - "GNNGraph:\n num_nodes: $(g.num_nodes)\n num_edges: $(g.num_edges)") - g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") - if !isempty(g.ndata) - print(io, "\n ndata:") - for k in keys(g.ndata) - print(io, "\n\t$k = $(shortsummary(g.ndata[k]))") - end - end - if !isempty(g.edata) - print(io, "\n edata:") - for k in keys(g.edata) - print(io, "\n\t$k = $(shortsummary(g.edata[k]))") - end - end - if !isempty(g.gdata) - print(io, "\n gdata:") - for k in keys(g.gdata) - print(io, "\n\t$k = $(shortsummary(g.gdata[k]))") - end - end - end -end - -MLUtils.numobs(g::GNNGraph) = g.num_graphs -MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i) - -######################### - -function Base.:(==)(g1::GNNGraph, g2::GNNGraph) - g1 === g2 && return true - for k in fieldnames(typeof(g1)) - k === :graph_indicator && continue - getfield(g1, k) != getfield(g2, k) && return false - end - return true -end - -function Base.hash(g::T, h::UInt) where {T <: GNNGraph} - fs = (getfield(g, k) for k in fieldnames(T) if k !== :graph_indicator) - return foldl((h, f) -> hash(f, h), fs, init = hash(T, h)) -end - -function Base.getproperty(g::GNNGraph, s::Symbol) - if s in fieldnames(GNNGraph) - return getfield(g, s) - end - if (s in keys(g.ndata)) + (s in keys(g.edata)) + (s in keys(g.gdata)) > 1 - throw(ArgumentError("Ambiguous property name $s")) - end - if s in keys(g.ndata) - return g.ndata[s] - elseif s in keys(g.edata) - return g.edata[s] - elseif s in keys(g.gdata) - return g.gdata[s] - else - throw(ArgumentError("$(s) is not a field of GNNGraph")) - end -end - -[.\GNNGraphs\src\GNNGraphs.jl] -module GNNGraphs - -using SparseArrays -using Functors: @functor -import Graphs -using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, - has_self_loops, is_directed -import NearestNeighbors -import NNlib -import StatsBase -import KrylovKit -using ChainRulesCore -using LinearAlgebra, Random, Statistics -import MLUtils -using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like -import Functors -using MLDataDevices: get_device, cpu_device, CPUDevice - -include("chainrules.jl") # hacks for differentiability - -include("datastore.jl") -export DataStore - -include("abstracttypes.jl") -export AbstractGNNGraph - -include("gnngraph.jl") -export GNNGraph, - node_features, - edge_features, - graph_features - -include("gnnheterograph.jl") -export GNNHeteroGraph, - num_edge_types, - num_node_types, - edge_type_subgraph - -include("temporalsnapshotsgnngraph.jl") -export TemporalSnapshotsGNNGraph, - add_snapshot, - # add_snapshot!, - remove_snapshot - # remove_snapshot! - -include("query.jl") -export adjacency_list, - edge_index, - get_edge_weight, - graph_indicator, - has_multi_edges, - is_directed, - is_bidirected, - normalized_laplacian, - scaled_laplacian, - laplacian_lambda_max, -# from Graphs - adjacency_matrix, - degree, - has_self_loops, - has_isolated_nodes, - inneighbors, - outneighbors, - khop_adj - -include("transform.jl") -export add_nodes, - add_edges, - add_self_loops, - getgraph, - negative_sample, - rand_edge_split, - remove_self_loops, - remove_edges, - remove_multi_edges, - set_edge_weight, - to_bidirected, - to_unidirected, - random_walk_pe, - perturb_edges, - remove_nodes, - ppr_diffusion, -# from MLUtils - batch, - unbatch, -# from SparseArrays - blockdiag - -include("generate.jl") -export rand_graph, - rand_heterograph, - rand_bipartite_heterograph, - knn_graph, - radius_graph, - rand_temporal_radius_graph, - rand_temporal_hyperbolic_graph - -include("sampling.jl") -export sample_neighbors - -include("operators.jl") -# Base.intersect - -include("convert.jl") -include("utils.jl") -export sort_edge_index, color_refinement - -include("gatherscatter.jl") -# _gather, _scatter - -include("mldatasets.jl") -export mldataset2gnngraph - -end #module - -[.\GNNGraphs\src\gnnheterograph.jl] - -const EType = Tuple{Symbol, Symbol, Symbol} -const NType = Symbol -const EDict{T} = Dict{EType, T} -const NDict{T} = Dict{NType, T} - -""" - GNNHeteroGraph(data; [ndata, edata, gdata, num_nodes]) - GNNHeteroGraph(pairs...; [ndata, edata, gdata, num_nodes]) - -A type representing a heterogeneous graph structure. -It is similar to [`GNNGraph`](@ref) but nodes and edges are of different types. - -# Constructor Arguments - -- `data`: A dictionary or an iterable object that maps `(source_type, edge_type, target_type)` - triples to `(source, target)` index vectors (or to `(source, target, weight)` if also edge weights are present). -- `pairs`: Passing multiple relations as pairs is equivalent to passing `data=Dict(pairs...)`. -- `ndata`: Node features. A dictionary of arrays or named tuple of arrays. - The size of the last dimension of each array must be given by `g.num_nodes`. -- `edata`: Edge features. A dictionary of arrays or named tuple of arrays. Default `nothing`. - The size of the last dimension of each array must be given by `g.num_edges`. Default `nothing`. -- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. Default `nothing`. -- `num_nodes`: The number of nodes for each type. If not specified, inferred from `data`. Default `nothing`. - -# Fields - -- `graph`: A dictionary that maps (source_type, edge_type, target_type) triples to (source, target) index vectors. -- `num_nodes`: The number of nodes for each type. -- `num_edges`: The number of edges for each type. -- `ndata`: Node features. -- `edata`: Edge features. -- `gdata`: Graph features. -- `ntypes`: The node types. -- `etypes`: The edge types. - -# Examples - -```julia -julia> using GraphNeuralNetworks - -julia> nA, nB = 10, 20; - -julia> num_nodes = Dict(:A => nA, :B => nB); - -julia> edges1 = (rand(1:nA, 20), rand(1:nB, 20)) -([4, 8, 6, 3, 4, 7, 2, 7, 3, 2, 3, 4, 9, 4, 2, 9, 10, 1, 3, 9], [6, 4, 20, 8, 16, 7, 12, 16, 5, 4, 6, 20, 11, 19, 17, 9, 12, 2, 18, 12]) - -julia> edges2 = (rand(1:nB, 30), rand(1:nA, 30)) -([17, 5, 2, 4, 5, 3, 8, 7, 9, 7 … 19, 8, 20, 7, 16, 2, 9, 15, 8, 13], [1, 1, 3, 1, 1, 3, 2, 7, 4, 4 … 7, 10, 6, 3, 4, 9, 1, 5, 8, 5]) - -julia> data = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2); - -julia> hg = GNNHeteroGraph(data; num_nodes) -GNNHeteroGraph: - num_nodes: (:A => 10, :B => 20) - num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) - -julia> hg.num_edges -Dict{Tuple{Symbol, Symbol, Symbol}, Int64} with 2 entries: -(:A, :rel1, :B) => 20 -(:B, :rel2, :A) => 30 - -# Let's add some node features -julia> ndata = Dict(:A => (x = rand(2, nA), y = rand(3, num_nodes[:A])), - :B => rand(10, nB)); - -julia> hg = GNNHeteroGraph(data; num_nodes, ndata) -GNNHeteroGraph: - num_nodes: (:A => 10, :B => 20) - num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) - ndata: - :A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64}) - :B => x = 10×20 Matrix{Float64} - -# Access features of nodes of type :A -julia> hg.ndata[:A].x -2×10 Matrix{Float64}: - 0.825882 0.0797502 0.245813 0.142281 0.231253 0.685025 0.821457 0.888838 0.571347 0.53165 - 0.631286 0.316292 0.705325 0.239211 0.533007 0.249233 0.473736 0.595475 0.0623298 0.159307 -``` - -See also [`GNNGraph`](@ref) for a homogeneous graph type and [`rand_heterograph`](@ref) for a function to generate random heterographs. -""" -struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} - graph::EDict{T} - num_nodes::NDict{Int} - num_edges::EDict{Int} - num_graphs::Int - graph_indicator::Union{Nothing, NDict} - ndata::NDict{DataStore} - edata::EDict{DataStore} - gdata::DataStore - ntypes::Vector{NType} - etypes::Vector{EType} -end - -@functor GNNHeteroGraph - -GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...) -GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...) - -GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}()) - -function GNNHeteroGraph(data::Dict; kws...) - all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`")) - return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...) -end - -function GNNHeteroGraph(data::EDict; - num_nodes = nothing, - graph_indicator = nothing, - graph_type = :coo, - dir = :out, - ndata = nothing, - edata = nothing, - gdata = (;)) - @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" - @assert dir ∈ [:in, :out] - @assert graph_type==:coo "only :coo graph_type is supported for now" - - if num_nodes !== nothing - num_nodes = Dict(num_nodes) - end - - ntypes = union([[k[1] for k in keys(data)]; [k[3] for k in keys(data)]]) - etypes = collect(keys(data)) - - if graph_type == :coo - graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) - elseif graph_type == :dense - graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) - elseif graph_type == :sparse - graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) - end - - num_graphs = !isnothing(graph_indicator) ? - maximum([maximum(gi) for gi in values(graph_indicator)]) : 1 - - - if length(keys(graph)) == 0 - ndata = Dict{Symbol, DataStore}() - edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}() - gdata = DataStore() - else - ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) - edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, - duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) - end - - return GNNHeteroGraph(graph, - num_nodes, num_edges, num_graphs, - graph_indicator, - ndata, edata, gdata, - ntypes, etypes) -end - -function show_sorted_dict(io::IO, d::Dict, compact::Bool) - # if compact - print(io, "Dict") - # end - print(io, "(") - if !isempty(d) - _keys = sort!(collect(keys(d))) - for key in _keys[1:end-1] - print(io, "$(_str(key)) => $(d[key]), ") - end - print(io, "$(_str(_keys[end])) => $(d[_keys[end]])") - end - # if length(d) == 1 - # print(io, ",") - # end - print(io, ")") -end - -function Base.show(io::IO, g::GNNHeteroGraph) - print(io, "GNNHeteroGraph(") - show_sorted_dict(io, g.num_nodes, true) - print(io, ", ") - show_sorted_dict(io, g.num_edges, true) - print(io, ")") -end - -function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph) - if get(io, :compact, false) - print(io, "GNNHeteroGraph(") - show_sorted_dict(io, g.num_nodes, true) - print(io, ", ") - show_sorted_dict(io, g.num_edges, true) - print(io, ")") - else - print(io, "GNNHeteroGraph:\n num_nodes: ") - show_sorted_dict(io, g.num_nodes, false) - print(io, "\n num_edges: ") - show_sorted_dict(io, g.num_edges, false) - g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") - if !isempty(g.ndata) && !all(isempty, values(g.ndata)) - print(io, "\n ndata:") - for k in sort(collect(keys(g.ndata))) - isempty(g.ndata[k]) && continue - print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))") - end - end - if !isempty(g.edata) && !all(isempty, values(g.edata)) - print(io, "\n edata:") - for k in sort(collect(keys(g.edata))) - isempty(g.edata[k]) && continue - print(io, "\n\t$k => $(shortsummary(g.edata[k]))") - end - end - if !isempty(g.gdata) - print(io, "\n gdata:\n\t") - shortsummary(io, g.gdata) - end - end -end - -_str(s::Symbol) = ":$s" -_str(s) = "$s" - -MLUtils.numobs(g::GNNHeteroGraph) = g.num_graphs -# MLUtils.getobs(g::GNNHeteroGraph, i) = getgraph(g, i) - - -""" - num_edge_types(g) - -Return the number of edge types in the graph. For [`GNNGraph`](@ref)s, this is always 1. -For [`GNNHeteroGraph`](@ref)s, this is the number of unique edge types. -""" -num_edge_types(g::GNNGraph) = 1 - -num_edge_types(g::GNNHeteroGraph) = length(g.etypes) - -""" - num_node_types(g) - -Return the number of node types in the graph. For [`GNNGraph`](@ref)s, this is always 1. -For [`GNNHeteroGraph`](@ref)s, this is the number of unique node types. -""" -num_node_types(g::GNNGraph) = 1 - -num_node_types(g::GNNHeteroGraph) = length(g.ntypes) - -""" - edge_type_subgraph(g::GNNHeteroGraph, edge_ts) - -Return a subgraph of `g` that contains only the edges of type `edge_ts`. -Edge types can be specified as a single edge type (i.e. a tuple containing 3 symbols) or a vector of edge types. -""" -edge_type_subgraph(g::GNNHeteroGraph, edge_t::EType) = edge_type_subgraph(g, [edge_t]) - -function edge_type_subgraph(g::GNNHeteroGraph, edge_ts::AbstractVector{<:EType}) - for edge_t in edge_ts - @assert edge_t in g.etypes "Edge type $(edge_t) not found in graph" - end - node_ts = _ntypes_from_edges(edge_ts) - graph = Dict([edge_t => g.graph[edge_t] for edge_t in edge_ts]...) - num_nodes = Dict([node_t => g.num_nodes[node_t] for node_t in node_ts]...) - num_edges = Dict([edge_t => g.num_edges[edge_t] for edge_t in edge_ts]...) - if g.graph_indicator === nothing - graph_indicator = nothing - else - graph_indicator = Dict([node_t => g.graph_indicator[node_t] for node_t in node_ts]...) - end - ndata = Dict([node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata)]...) - edata = Dict([edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata)]...) - - return GNNHeteroGraph(graph, num_nodes, num_edges, g.num_graphs, - graph_indicator, ndata, edata, g.gdata, - node_ts, edge_ts) -end - -# TODO this is not correct but Zygote cannot differentiate -# through dictionary generation -# @non_differentiable edge_type_subgraph(::Any...) - -function _ntypes_from_edges(edge_ts::AbstractVector{<:EType}) - ntypes = Symbol[] - for edge_t in edge_ts - node1_t, _, node2_t = edge_t - !in(node1_t, ntypes) && push!(ntypes, node1_t) - !in(node2_t, ntypes) && push!(ntypes, node2_t) - end - return ntypes -end - -@non_differentiable _ntypes_from_edges(::Any...) - -function Base.getindex(g::GNNHeteroGraph, node_t::NType) - return g.ndata[node_t] -end - -Base.getindex(g::GNNHeteroGraph, n1_t::Symbol, rel::Symbol, n2_t::Symbol) = g[(n1_t, rel, n2_t)] - -function Base.getindex(g::GNNHeteroGraph, edge_t::EType) - return g.edata[edge_t] -end - -[.\GNNGraphs\src\mldatasets.jl] -# We load a Graph Dataset from MLDatasets without explicitly depending on it - -""" - mldataset2gnngraph(dataset) - -Convert a graph dataset from the package MLDatasets.jl into one or many [`GNNGraph`](@ref)s. - -# Examples - -```jldoctest -julia> using MLDatasets, GraphNeuralNetworks - -julia> mldataset2gnngraph(Cora()) -GNNGraph: - num_nodes = 2708 - num_edges = 10556 - ndata: - features => 1433×2708 Matrix{Float32} - targets => 2708-element Vector{Int64} - train_mask => 2708-element BitVector - val_mask => 2708-element BitVector - test_mask => 2708-element BitVector -``` -""" -function mldataset2gnngraph(dataset::D) where {D} - @assert hasproperty(dataset, :graphs) - graphs = mlgraph2gnngraph.(dataset.graphs) - if length(graphs) == 1 - return graphs[1] - else - return graphs - end -end - -function mlgraph2gnngraph(g::G) where {G} - @assert hasproperty(g, :num_nodes) - @assert hasproperty(g, :edge_index) - @assert hasproperty(g, :node_data) - @assert hasproperty(g, :edge_data) - return GNNGraph(g.edge_index; ndata = g.node_data, edata = g.edge_data, g.num_nodes) -end - -[.\GNNGraphs\src\operators.jl] -# 2 or more args graph operators -"""" - intersect(g1::GNNGraph, g2::GNNGraph) - -Intersect two graphs by keeping only the common edges. -""" -function Base.intersect(g1::GNNGraph, g2::GNNGraph) - @assert g1.num_nodes == g2.num_nodes - @assert graph_type_symbol(g1) == graph_type_symbol(g2) - graph_type = graph_type_symbol(g1) - num_nodes = g1.num_nodes - - idx1, _ = edge_encoding(edge_index(g1)..., num_nodes) - idx2, _ = edge_encoding(edge_index(g2)..., num_nodes) - idx = intersect(idx1, idx2) - s, t = edge_decoding(idx, num_nodes) - return GNNGraph(s, t; num_nodes, graph_type) -end - -[.\GNNGraphs\src\query.jl] - -""" - edge_index(g::GNNGraph) - -Return a tuple containing two vectors, respectively storing -the source and target nodes for each edges in `g`. - -```julia -s, t = edge_index(g) -``` -""" -edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2] - -edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2] - -""" - edge_index(g::GNNHeteroGraph, [edge_t]) - -Return a tuple containing two vectors, respectively storing the source and target nodes -for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`. - -If `edge_t` is not provided, it will error if `g` has more than one edge type. -""" -edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2] -edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2] - -get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] - -get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][3] - -get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3] - -Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...) - -Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)} - -# """ -# eltype(g::GNNGraph) -# -# Type of nodes in `g`, -# an integer type like `Int`, `Int32`, `Uint16`, .... -# """ -function Base.eltype(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - w = get_edge_weight(g) - return w !== nothing ? eltype(w) : eltype(s) -end - -Base.eltype(g::GNNGraph{<:ADJMAT_T}) = eltype(g.graph) - -function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer) - s, t = edge_index(g) - return any((s .== i) .& (t .== j)) -end - -Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i, j] != 0 - -""" - has_edge(g::GNNHeteroGraph, edge_t, i, j) - -Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`. - -# Examples - -```jldoctest -julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) -GNNHeteroGraph: - num_nodes: (:A => 2, :B => 2) - num_edges: ((:A, :to, :B) => 4, (:B, :to, :A) => 0) - -julia> has_edge(g, (:A,:to,:B), 1, 1) -true - -julia> has_edge(g, (:B,:to,:A), 1, 1) -false -``` -""" -function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer) - s, t = edge_index(g, edge_t) - return any((s .== i) .& (t .== j)) -end - -graph_type_symbol(::GNNGraph{<:COO_T}) = :coo -graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse -graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense - -Graphs.nv(g::GNNGraph) = g.num_nodes -Graphs.ne(g::GNNGraph) = g.num_edges -Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes -Graphs.vertices(g::GNNGraph) = 1:(g.num_nodes) - - -""" - neighbors(g::GNNGraph, i::Integer; dir=:out) - -Return the neighbors of node `i` in the graph `g`. -If `dir=:out`, return the neighbors through outgoing edges. -If `dir=:in`, return the neighbors through incoming edges. - -See also [`outneighbors`](@ref Graphs.outneighbors), [`inneighbors`](@ref Graphs.inneighbors). -""" -function Graphs.neighbors(g::GNNGraph, i::Integer; dir::Symbol = :out) - @assert dir ∈ (:in, :out) - if dir == :out - outneighbors(g, i) - else - inneighbors(g, i) - end -end - -""" - outneighbors(g::GNNGraph, i::Integer) - -Return the neighbors of node `i` in the graph `g` through outgoing edges. - -See also [`neighbors`](@ref Graphs.neighbors) and [`inneighbors`](@ref Graphs.inneighbors). -""" -function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer) - s, t = edge_index(g) - return t[s .== i] -end - -function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) - A = g.graph - return findall(!=(0), A[i, :]) -end - -""" - inneighbors(g::GNNGraph, i::Integer) - -Return the neighbors of node `i` in the graph `g` through incoming edges. - -See also [`neighbors`](@ref Graphs.neighbors) and [`outneighbors`](@ref Graphs.outneighbors). -""" -function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer) - s, t = edge_index(g) - return s[t .== i] -end - -function Graphs.inneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) - A = g.graph - return findall(!=(0), A[:, i]) -end - -Graphs.is_directed(::GNNGraph) = true -Graphs.is_directed(::Type{<:GNNGraph}) = true - -""" - adjacency_list(g; dir=:out) - adjacency_list(g, nodes; dir=:out) - -Return the adjacency list representation (a vector of vectors) -of the graph `g`. - -Calling `a` the adjacency list, if `dir=:out` than -`a[i]` will contain the neighbors of node `i` through -outgoing edges. If `dir=:in`, it will contain neighbors from -incoming edges instead. - -If `nodes` is given, return the neighborhood of the nodes in `nodes` only. -""" -function adjacency_list(g::GNNGraph, nodes; dir = :out, with_eid = false) - @assert dir ∈ [:out, :in] - s, t = edge_index(g) - if dir == :in - s, t = t, s - end - T = eltype(s) - idict = 0 - dmap = Dict(n => (idict += 1) for n in nodes) - adjlist = [T[] for _ in 1:length(dmap)] - eidlist = [T[] for _ in 1:length(dmap)] - for (eid, (i, j)) in enumerate(zip(s, t)) - inew = get(dmap, i, 0) - inew == 0 && continue - push!(adjlist[inew], j) - push!(eidlist[inew], eid) - end - if with_eid - return adjlist, eidlist - else - return adjlist - end -end - -# function adjacency_list(g::GNNGraph, nodes; dir=:out) -# @assert dir ∈ [:out, :in] -# fneighs = dir == :out ? outneighbors : inneighbors -# return [fneighs(g, i) for i in nodes] -# end - -adjacency_list(g::GNNGraph; dir = :out) = adjacency_list(g, 1:(g.num_nodes); dir) - -""" - adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true) - -Return the adjacency matrix `A` for the graph `g`. - -If `dir=:out`, `A[i,j] > 0` denotes the presence of an edge from node `i` to node `j`. -If `dir=:in` instead, `A[i,j] > 0` denotes the presence of an edge from node `j` to node `i`. - -User may specify the eltype `T` of the returned matrix. - -If `weighted=true`, the `A` will contain the edge weights if any, otherwise the elements of `A` will be either 0 or 1. -""" -function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out, - weighted = true) - if iscuarray(g.graph[1]) - # Revisit after - # https://github.com/JuliaGPU/CUDA.jl/issues/1113 - A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted) - else - A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted) - end - @assert size(A) == (n, n) - return dir == :out ? A : A' -end - -function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g); - dir = :out, weighted = true) - @assert dir ∈ [:in, :out] - A = g.graph - if !weighted - A = binarize(A) - end - A = T != eltype(A) ? T.(A) : A - return dir == :out ? A : A' -end - -function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; - dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}} - A = adjacency_matrix(g, T; dir, weighted) - if !weighted - function adjacency_matrix_pullback_noweight(Δ) - return (NoTangent(), ZeroTangent(), NoTangent()) - end - return A, adjacency_matrix_pullback_noweight - else - function adjacency_matrix_pullback_weighted(Δ) - dg = Tangent{G}(; graph = Δ .* binarize(A)) - return (NoTangent(), dg, NoTangent()) - end - return A, adjacency_matrix_pullback_weighted - end -end - -function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; - dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}} - A = adjacency_matrix(g, T; dir, weighted) - w = get_edge_weight(g) - if !weighted || w === nothing - function adjacency_matrix_pullback_noweight(Δ) - return (NoTangent(), ZeroTangent(), NoTangent()) - end - return A, adjacency_matrix_pullback_noweight - else - function adjacency_matrix_pullback_weighted(Δ) - s, t = edge_index(g) - dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t))) - return (NoTangent(), dg, NoTangent()) - end - return A, adjacency_matrix_pullback_weighted - end -end - -function _get_edge_weight(g, edge_weight::Bool) - if edge_weight === true - return get_edge_weight(g) - elseif edge_weight === false - return nothing - end -end - -_get_edge_weight(g, edge_weight::AbstractVector) = edge_weight - -""" - degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true) - -Return a vector containing the degrees of the nodes in `g`. - -The gradient is propagated through this function only if `edge_weight` is `true` -or a vector. - -# Arguments - -- `g`: A graph. -- `T`: Element type of the returned vector. If `nothing`, is - chosen based on the graph type and will be an integer - if `edge_weight = false`. Default `nothing`. -- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. - For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. -- `edge_weight`: If `true` and the graph contains weighted edges, the degree will - be weighted. Set to `false` instead to just count the number of - outgoing/ingoing edges. - Finally, you can also pass a vector of weights to be used - instead of the graph's own weights. - Default `true`. - -""" -function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT = nothing; dir = :out, - edge_weight = true) where { - TT <: Union{Nothing, Type{<:Number}}} - s, t = edge_index(g) - - ew = _get_edge_weight(g, edge_weight) - - T = if isnothing(T) - if !isnothing(ew) - eltype(ew) - else - eltype(s) - end - else - T - end - return _degree((s, t), T, dir, ew, g.num_nodes) -end - -# TODO:: Make efficient -Graphs.degree(g::GNNGraph, i::Union{Int, AbstractVector}; dir = :out) = degree(g; dir)[i] - -function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out, - edge_weight = true) where {TT<:Union{Nothing, Type{<:Number}}} - - # edge_weight=true or edge_weight=nothing act the same here - @assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations" - @assert dir ∈ (:in, :out, :both) - if T === nothing - Nt = eltype(g) - if edge_weight === false && !(Nt <: Integer) - T = Nt == Float32 ? Int32 : - Nt == Float16 ? Int16 : Int - else - T = Nt - end - end - A = adjacency_matrix(g) - return _degree(A, T, dir, edge_weight, g.num_nodes) -end - -""" - degree(g::GNNHeteroGraph, edge_type::EType; dir = :in) - -Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph -given `edge_type`. - -# Arguments - -- `g`: A graph. -- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type. -- `T`: Element type of the returned vector. If `nothing`, is - chosen based on the graph type. Default `nothing`. -- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. - For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. - Default `dir = :out`. - -""" -function Graphs.degree(g::GNNHeteroGraph, edge::EType, - T::TT = nothing; dir = :out) where { - TT <: Union{Nothing, Type{<:Number}}} - - s, t = edge_index(g, edge) - - T = isnothing(T) ? eltype(s) : T - - n_type = dir == :in ? g.ntypes[2] : g.ntypes[1] - - return _degree((s, t), T, dir, nothing, g.num_nodes[n_type]) -end - -function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int) - _degree((s, t), T, dir, ones_like(s, T), num_nodes) -end - -function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::AbstractVector, num_nodes::Int) - degs = zeros_like(s, T, num_nodes) - - if dir ∈ [:out, :both] - degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (num_nodes,)) - end - if dir ∈ [:in, :both] - degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize = (num_nodes,)) - end - return degs -end - -function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int) - if edge_weight === false - A = binarize(A) - end - A = eltype(A) != T ? T.(A) : A - return dir == :out ? vec(sum(A, dims = 2)) : - dir == :in ? vec(sum(A, dims = 1)) : - vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2)) -end - -function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes) - degs = _degree(graph, T, dir, edge_weight, num_nodes) - function _degree_pullback(Δ) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) - end - return degs, _degree_pullback -end - -function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes) - degs = _degree(A, T, dir, edge_weight, num_nodes) - if edge_weight === false - function _degree_pullback_noweights(Δ) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) - end - return degs, _degree_pullback_noweights - else - function _degree_pullback_weights(Δ) - # We propagate the gradient only to the non-zero elements - # of the adjacency matrix. - bA = binarize(A) - if dir == :in - dA = bA .* Δ' - elseif dir == :out - dA = Δ .* bA - else # dir == :both - dA = Δ .* bA + Δ' .* bA - end - return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent()) - end - return degs, _degree_pullback_weights - end -end - -""" - has_isolated_nodes(g::GNNGraph; dir=:out) - -Return true if the graph `g` contains nodes with out-degree (if `dir=:out`) -or in-degree (if `dir = :in`) equal to zero. -""" -function has_isolated_nodes(g::GNNGraph; dir = :out) - return any(iszero, degree(g; dir)) -end - -function Graphs.laplacian_matrix(g::GNNGraph, T::DataType = eltype(g); dir::Symbol = :out) - A = adjacency_matrix(g, T; dir = dir) - D = Diagonal(vec(sum(A; dims = 2))) - return D - A -end - -""" - normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out) - -Normalized Laplacian matrix of graph `g`. - -# Arguments - -- `g`: A `GNNGraph`. -- `T`: result element type. -- `add_self_loops`: add self-loops while calculating the matrix. -- `dir`: the edge directionality considered (:out, :in, :both). -""" -function normalized_laplacian(g::GNNGraph, T::DataType = Float32; - add_self_loops::Bool = false, dir::Symbol = :out) - Ã = normalized_adjacency(g, T; dir, add_self_loops) - return I - Ã -end - -function normalized_adjacency(g::GNNGraph, T::DataType = Float32; - add_self_loops::Bool = false, dir::Symbol = :out) - A = adjacency_matrix(g, T; dir = dir) - if add_self_loops - A = A + I - end - degs = vec(sum(A; dims = 2)) - ChainRulesCore.ignore_derivatives() do - @assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`." - end - inv_sqrtD = Diagonal(inv.(sqrt.(degs))) - return inv_sqrtD * A * inv_sqrtD -end - -@doc raw""" - scaled_laplacian(g, T=Float32; dir=:out) - -Scaled Laplacian matrix of graph `g`, -defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. - -# Arguments - -- `g`: A `GNNGraph`. -- `T`: result element type. -- `dir`: the edge directionality considered (:out, :in, :both). -""" -function scaled_laplacian(g::GNNGraph, T::DataType = Float32; dir = :out) - L = normalized_laplacian(g, T) - # @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" - λmax = _eigmax(L) - return 2 / λmax * L - I -end - -# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays -function _eigmax(A) - x0 = _rand_dense_vector(A) - KrylovKit.eigsolve(Symmetric(A), x0, 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available -end - -_rand_dense_vector(A::AbstractMatrix{T}) where {T} = randn(float(T), size(A, 1)) - -# Eigenvalues for cuarray don't seem to be well supported. -# https://github.com/JuliaGPU/CUDA.jl/issues/154 -# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 - -""" - graph_indicator(g::GNNGraph; edges=false) - -Return a vector containing the graph membership -(an integer from `1` to `g.num_graphs`) of each node in the graph. -If `edges=true`, return the graph membership of each edge instead. -""" -function graph_indicator(g::GNNGraph; edges = false) - if isnothing(g.graph_indicator) - gi = ones_like(edge_index(g)[1], Int, g.num_nodes) - else - gi = g.graph_indicator - end - if edges - s, t = edge_index(g) - return gi[s] - else - return gi - end -end - -""" - graph_indicator(g::GNNHeteroGraph, [node_t]) - -Return a Dict of vectors containing the graph membership -(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type. -If `node_t` is provided, return the graph membership of each node of type `node_t` instead. - -See also [`batch`](@ref). -""" -function graph_indicator(g::GNNHeteroGraph) - return g.graph_indicator -end - -function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) - @assert node_t ∈ g.ntypes - if isnothing(g.graph_indicator) - gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t]) - else - gi = g.graph_indicator[node_t] - end - return gi -end - -function node_features(g::GNNGraph) - if isempty(g.ndata) - return nothing - elseif length(g.ndata) > 1 - @error "Multiple feature arrays, access directly through `g.ndata`" - else - return first(values(g.ndata)) - end -end - -function edge_features(g::GNNGraph) - if isempty(g.edata) - return nothing - elseif length(g.edata) > 1 - @error "Multiple feature arrays, access directly through `g.edata`" - else - return first(values(g.edata)) - end -end - -function graph_features(g::GNNGraph) - if isempty(g.gdata) - return nothing - elseif length(g.gdata) > 1 - @error "Multiple feature arrays, access directly through `g.gdata`" - else - return first(values(g.gdata)) - end -end - -""" - is_bidirected(g::GNNGraph) - -Check if the directed graph `g` essentially corresponds -to an undirected graph, i.e. if for each edge it also contains the -reverse edge. -""" -function is_bidirected(g::GNNGraph) - s, t = edge_index(g) - s1, t1 = sort_edge_index(s, t) - s2, t2 = sort_edge_index(t, s) - all((s1 .== s2) .& (t1 .== t2)) -end - -""" - has_self_loops(g::GNNGraph) - -Return `true` if `g` has any self loops. -""" -function Graphs.has_self_loops(g::GNNGraph) - s, t = edge_index(g) - any(s .== t) -end - -""" - has_multi_edges(g::GNNGraph) - -Return `true` if `g` has any multiple edges. -""" -function has_multi_edges(g::GNNGraph) - s, t = edge_index(g) - idxs, _ = edge_encoding(s, t, g.num_nodes) - length(union(idxs)) < length(idxs) -end - -""" - khop_adj(g::GNNGraph,k::Int,T::DataType=eltype(g); dir=:out, weighted=true) - -Return ``A^k`` where ``A`` is the adjacency matrix of the graph 'g'. - -""" -function khop_adj(g::GNNGraph, k::Int, T::DataType = eltype(g); dir = :out, weighted = true) - return (adjacency_matrix(g, T; dir, weighted))^k -end - -""" - laplacian_lambda_max(g::GNNGraph, T=Float32; add_self_loops=false, dir=:out) - -Return the largest eigenvalue of the normalized symmetric Laplacian of the graph `g`. - -If the graph is batched from multiple graphs, return the list of the largest eigenvalue for each graph. -""" -function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32; - add_self_loops::Bool = false, dir::Symbol = :out) - if g.num_graphs == 1 - return _eigmax(normalized_laplacian(g, T; add_self_loops, dir)) - else - eigenvalues = zeros(g.num_graphs) - for i in 1:(g.num_graphs) - eigenvalues[i] = _eigmax(normalized_laplacian(getgraph(g, i), T; add_self_loops, - dir)) - end - return eigenvalues - end -end - -@non_differentiable edge_index(x...) -@non_differentiable adjacency_list(x...) -@non_differentiable graph_indicator(x...) -@non_differentiable has_multi_edges(x...) -@non_differentiable Graphs.has_self_loops(x...) -@non_differentiable is_bidirected(x...) -@non_differentiable normalized_adjacency(x...) # TODO remove this in the future -@non_differentiable normalized_laplacian(x...) # TODO remove this in the future -@non_differentiable scaled_laplacian(x...) # TODO remove this in the future - -[.\GNNGraphs\src\sampling.jl] -""" - sample_neighbors(g, nodes, K=-1; dir=:in, replace=false, dropnodes=false) - -Sample neighboring edges of the given nodes and return the induced subgraph. -For each node, a number of inbound (or outbound when `dir = :out``) edges will be randomly chosen. -If `dropnodes=false`, the graph returned will then contain all the nodes in the original graph, -but only the sampled edges. - -The returned graph will contain an edge feature `EID` corresponding to the id of the edge -in the original graph. If `dropnodes=true`, it will also contain a node feature `NID` with -the node ids in the original graph. - -# Arguments - -- `g`. The graph. -- `nodes`. A list of node IDs to sample neighbors from. -- `K`. The maximum number of edges to be sampled for each node. - If -1, all the neighboring edges will be selected. -- `dir`. Determines whether to sample inbound (`:in`) or outbound (``:out`) edges (Default `:in`). -- `replace`. If `true`, sample with replacement. -- `dropnodes`. If `true`, the resulting subgraph will contain only the nodes involved in the sampled edges. - -# Examples - -```julia -julia> g = rand_graph(20, 100) -GNNGraph: - num_nodes = 20 - num_edges = 100 - -julia> sample_neighbors(g, 2:3) -GNNGraph: - num_nodes = 20 - num_edges = 9 - edata: - EID => (9,) - -julia> sg = sample_neighbors(g, 2:3, dropnodes=true) -GNNGraph: - num_nodes = 10 - num_edges = 9 - ndata: - NID => (10,) - edata: - EID => (9,) - -julia> sg.ndata.NID -10-element Vector{Int64}: - 2 - 3 - 17 - 14 - 18 - 15 - 16 - 20 - 7 - 10 - -julia> sample_neighbors(g, 2:3, 5, replace=true) -GNNGraph: - num_nodes = 20 - num_edges = 10 - edata: - EID => (10,) -``` -""" -function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; - dir = :in, replace = false, dropnodes = false) - @assert dir ∈ (:in, :out) - _, eidlist = adjacency_list(g, nodes; dir, with_eid = true) - for i in 1:length(eidlist) - if replace - k = K > 0 ? K : length(eidlist[i]) - else - k = K > 0 ? min(length(eidlist[i]), K) : length(eidlist[i]) - end - eidlist[i] = StatsBase.sample(eidlist[i], k; replace) - end - eids = reduce(vcat, eidlist) - s, t = edge_index(g) - w = get_edge_weight(g) - s = s[eids] - t = t[eids] - w = isnothing(w) ? nothing : w[eids] - - edata = getobs(g.edata, eids) - edata.EID = eids - - num_edges = length(eids) - - if !dropnodes - graph = (s, t, w) - - gnew = GNNGraph(graph, - g.num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) - else - nodes_other = dir == :in ? setdiff(s, nodes) : setdiff(t, nodes) - nodes_all = [nodes; nodes_other] - nodemap = Dict(n => i for (i, n) in enumerate(nodes_all)) - s = [nodemap[s] for s in s] - t = [nodemap[t] for t in t] - graph = (s, t, w) - graph_indicator = g.graph_indicator !== nothing ? g.graph_indicator[nodes_all] : - nothing - num_nodes = length(nodes_all) - ndata = getobs(g.ndata, nodes_all) - ndata.NID = nodes_all - - gnew = GNNGraph(graph, - num_nodes, num_edges, g.num_graphs, - graph_indicator, - ndata, edata, g.gdata) - end - return gnew -end - -[.\GNNGraphs\src\temporalsnapshotsgnngraph.jl] -""" - TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) - -A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref). - -`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object, -and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features. -The features can be passed at construction time or added later. - -# Constructor Arguments - -- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes. - -# Examples - -```julia -julia> using GraphNeuralNetworks - -julia> snapshots = [rand_graph(10,20) for i in 1:5]; - -julia> tg = TemporalSnapshotsGNNGraph(snapshots) -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [20, 20, 20, 20, 20] - num_snapshots: 5 - -julia> tg.tgdata.x = rand(4); # add temporal graph feature - -julia> tg # show temporal graph with new feature -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [20, 20, 20, 20, 20] - num_snapshots: 5 - tgdata: - x = 4-element Vector{Float64} -``` -""" -struct TemporalSnapshotsGNNGraph - num_nodes::AbstractVector{Int} - num_edges::AbstractVector{Int} - num_snapshots::Int - snapshots::AbstractVector{<:GNNGraph} - tgdata::DataStore -end - -function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) - @assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes" - return TemporalSnapshotsGNNGraph( - [s.num_nodes for s in snapshots], - [s.num_edges for s in snapshots], - length(snapshots), - snapshots, - DataStore() - ) -end - -function Base.:(==)(tsg1::TemporalSnapshotsGNNGraph, tsg2::TemporalSnapshotsGNNGraph) - tsg1 === tsg2 && return true - for k in fieldnames(typeof(tsg1)) - getfield(tsg1, k) != getfield(tsg2, k) && return false - end - return true -end - -function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int) - return tg.snapshots[t] -end - -function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector) - return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata) -end - -""" - add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) - -Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the snapshot `g` at time index `t`. - -# Examples - -```jldoctest -julia> using GraphNeuralNetworks - -julia> snapshots = [rand_graph(10, 20) for i in 1:5]; - -julia> tg = TemporalSnapshotsGNNGraph(snapshots) -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [20, 20, 20, 20, 20] - num_snapshots: 5 - -julia> new_tg = add_snapshot(tg, 3, rand_graph(10, 16)) # add a new snapshot at time 3 -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10, 10] - num_edges: [20, 20, 16, 20, 20, 20] - num_snapshots: 6 -``` -""" -function add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) - if tg.num_snapshots > 0 - @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" - end - @assert t <= tg.num_snapshots + 1 "cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots" - num_nodes = tg.num_nodes |> copy - num_edges = tg.num_edges |> copy - snapshots = tg.snapshots |> copy - num_snapshots = tg.num_snapshots + 1 - insert!(num_nodes, t, g.num_nodes) - insert!(num_edges, t, g.num_edges) - insert!(snapshots, t, g) - return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) -end - -# """ -# add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) - -# Add to `tg` the snapshot `g` at time index `t`. - -# See also [`add_snapshot`](@ref) for a non-mutating version. -# """ -# function add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) -# if t > tg.num_snapshots + 1 -# error("cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots") -# end -# if tg.num_snapshots > 0 -# @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" -# end -# insert!(tg.num_nodes, t, g.num_nodes) -# insert!(tg.num_edges, t, g.num_edges) -# insert!(tg.snapshots, t, g) -# return tg -# end - -""" - remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) - -Return a [`TemporalSnapshotsGNNGraph`](@ref) created starting from `tg` by removing the snapshot at time index `t`. - -# Examples - -```jldoctest -julia> using GraphNeuralNetworks - -julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)]; - -julia> tg = TemporalSnapshotsGNNGraph(snapshots) -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10] - num_edges: [20, 14, 22] - num_snapshots: 3 - -julia> new_tg = remove_snapshot(tg, 2) # remove snapshot at time 2 -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10] - num_edges: [20, 22] - num_snapshots: 2 -``` -""" -function remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) - num_nodes = tg.num_nodes |> copy - num_edges = tg.num_edges |> copy - snapshots = tg.snapshots |> copy - num_snapshots = tg.num_snapshots - 1 - deleteat!(num_nodes, t) - deleteat!(num_edges, t) - deleteat!(snapshots, t) - return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) -end - -# """ -# remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) - -# Remove the snapshot at time index `t` from `tg` and return `tg`. - -# See [`remove_snapshot`](@ref) for a non-mutating version. -# """ -# function remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) -# @assert t <= tg.num_snapshots "snapshot index $t out of bounds" -# tg.num_snapshots -= 1 -# deleteat!(tg.num_nodes, t) -# deleteat!(tg.num_edges, t) -# deleteat!(tg.snapshots, t) -# return tg -# end - -function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol) - if prop ∈ fieldnames(TemporalSnapshotsGNNGraph) - return getfield(tg, prop) - elseif prop == :ndata - return [s.ndata for s in tg.snapshots] - elseif prop == :edata - return [s.edata for s in tg.snapshots] - elseif prop == :gdata - return [s.gdata for s in tg.snapshots] - else - return [getproperty(s,prop) for s in tg.snapshots] - end -end - -function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph) - print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") - print_feature_t(io, tsg.tgdata) - print(io, " data") -end - -function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph) - if get(io, :compact, false) - print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") - print_feature_t(io, tsg.tgdata) - print(io, " data") - else - print(io, - "TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)") - if !isempty(tsg.tgdata) - print(io, "\n tgdata:") - for k in keys(tsg.tgdata) - print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))") - end - end - end -end - -function print_feature_t(io::IO, feature) - if !isempty(feature) - if length(keys(feature)) == 1 - k = first(keys(feature)) - v = first(values(feature)) - print(io, "$(k): $(dims2string(size(v)))") - else - print(io, "(") - for (i, (k, v)) in enumerate(pairs(feature)) - print(io, "$k: $(dims2string(size(v)))") - if i == length(feature) - print(io, ")") - else - print(io, ", ") - end - end - end - else - print(io, "no") - end -end - -@functor TemporalSnapshotsGNNGraph - -[.\GNNGraphs\src\transform.jl] - -""" - add_self_loops(g::GNNGraph) - -Return a graph with the same features as `g` -but also adding edges connecting the nodes to themselves. - -Nodes with already existing self-loops will obtain a second self-loop. - -If the graphs has edge weights, the new edges will have weight 1. -""" -function add_self_loops(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - @assert isempty(g.edata) - ew = get_edge_weight(g) - n = g.num_nodes - nodes = convert(typeof(s), [1:n;]) - s = [s; nodes] - t = [t; nodes] - if ew !== nothing - ew = [ew; fill!(similar(ew, n), 1)] - end - - return GNNGraph((s, t, ew), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -function add_self_loops(g::GNNGraph{<:ADJMAT_T}) - A = g.graph - @assert isempty(g.edata) - num_edges = g.num_edges + g.num_nodes - A = A + I - return GNNGraph(A, - g.num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -""" - add_self_loops(g::GNNHeteroGraph, edge_t::EType) - add_self_loops(g::GNNHeteroGraph) - -If the source node type is the same as the destination node type in `edge_t`, -return a graph with the same features as `g` but also add self-loops -of the specified type, `edge_t`. Otherwise, it returns `g` unchanged. - -Nodes with already existing self-loops of type `edge_t` will obtain -a second set of self-loops of the same type. - -If the graph has edge weights for edges of type `edge_t`, the new edges will have weight 1. - -If no edges of type `edge_t` exist, or all existing edges have no weight, -then all new self loops will have no weight. - -If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. -This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. -""" -function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) - - function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) - get(g.graph, edge_t, (nothing, nothing, nothing))[3] - end - - src_t, _, tgt_t = edge_t - (src_t === tgt_t) || - return g - - n = get(g.num_nodes, src_t, 0) - - if haskey(g.graph, edge_t) - s, t = g.graph[edge_t][1:2] - nodes = convert(typeof(s), [1:n;]) - s = [s; nodes] - t = [t; nodes] - else - if !isempty(g.graph) - T = typeof(first(values(g.graph))[1]) - nodes = convert(T, [1:n;]) - else - nodes = [1:n;] - end - s = nodes - t = nodes - end - - graph = g.graph |> copy - ew = get(g.graph, edge_t, (nothing, nothing, nothing))[3] - - if ew !== nothing - ew = [ew; fill!(similar(ew, n), 1)] - end - - graph[edge_t] = (s, t, ew) - edata = g.edata |> copy - ndata = g.ndata |> copy - ntypes = g.ntypes |> copy - etypes = g.etypes |> copy - num_nodes = g.num_nodes |> copy - num_edges = g.num_edges |> copy - num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1]) - - return GNNHeteroGraph(graph, - num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - ndata, edata, g.gdata, - ntypes, etypes) -end - -function add_self_loops(g::GNNHeteroGraph) - for edge_t in keys(g.graph) - g = add_self_loops(g, edge_t) - end - return g -end - -""" - remove_self_loops(g::GNNGraph) - -Return a graph constructed from `g` where self-loops (edges from a node to itself) -are removed. - -See also [`add_self_loops`](@ref) and [`remove_multi_edges`](@ref). -""" -function remove_self_loops(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - - mask_old_loops = s .!= t - s = s[mask_old_loops] - t = t[mask_old_loops] - edata = getobs(edata, mask_old_loops) - w = isnothing(w) ? nothing : getobs(w, mask_old_loops) - - GNNGraph((s, t, w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) -end - -function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) - @assert isempty(g.edata) - A = g.graph - A[diagind(A)] .= 0 - if A isa AbstractSparseMatrix - dropzeros!(A) - end - num_edges = numnonzeros(A) - return GNNGraph(A, - g.num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -""" - remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) - remove_edges(g::GNNGraph, p=0.5) - -Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability. - -# Arguments -- `g`: The input graph from which edges will be removed. -- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method. -- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5. - -# Returns -A new GNNGraph with the specified edges removed. - -# Example -```julia -julia> using GraphNeuralNetworks - -# Construct a GNNGraph -julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) -GNNGraph: - num_nodes: 3 - num_edges: 5 - -# Remove the second edge -julia> g_new = remove_edges(g, [2]); - -julia> g_new -GNNGraph: - num_nodes: 3 - num_edges: 4 - -# Remove edges with a probability of 0.5 -julia> g_new = remove_edges(g, 0.5); - -julia> g_new -GNNGraph: - num_nodes: 3 - num_edges: 2 -``` -""" -function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - - mask_to_keep = trues(length(s)) - - mask_to_keep[edges_to_remove] .= false - - s = s[mask_to_keep] - t = t[mask_to_keep] - edata = getobs(edata, mask_to_keep) - w = isnothing(w) ? nothing : getobs(w, mask_to_keep) - - return GNNGraph((s, t, w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) -end - - -function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) - num_edges = g.num_edges - edges_to_remove = filter(_ -> rand() < p, 1:num_edges) - return remove_edges(g, edges_to_remove) -end - -""" - remove_multi_edges(g::GNNGraph; aggr=+) - -Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. -Possible edge features are aggregated according to `aggr`, that can take value -`+`,`min`, `max` or `mean`. - -See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). -""" -function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - num_edges = g.num_edges - idxs, idxmax = edge_encoding(s, t, g.num_nodes) - - perm = sortperm(idxs) - idxs = idxs[perm] - s, t = s[perm], t[perm] - edata = getobs(edata, perm) - w = isnothing(w) ? nothing : getobs(w, perm) - idxs = [-1; idxs] - mask = idxs[2:end] .> idxs[1:(end - 1)] - if !all(mask) - s, t = s[mask], t[mask] - idxs = similar(s, num_edges) - idxs .= 1:num_edges - idxs .= idxs .- cumsum(.!mask) - num_edges = length(s) - w = _scatter(aggr, w, idxs, num_edges) - edata = _scatter(aggr, edata, idxs, num_edges) - end - - return GNNGraph((s, t, w), - g.num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) -end - -""" - remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector) - -Remove specified nodes, and their associated edges, from a GNNGraph. This operation reindexes the remaining nodes to maintain a continuous sequence of node indices, starting from 1. Similarly, edges are reindexed to account for the removal of edges connected to the removed nodes. - -# Arguments -- `g`: The input graph from which nodes (and their edges) will be removed. -- `nodes_to_remove`: Vector of node indices to be removed. - -# Returns -A new GNNGraph with the specified nodes and all edges associated with these nodes removed. - -# Example -```julia -using GraphNeuralNetworks - -g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) - -# Remove nodes with indices 2 and 3, for example -g_new = remove_nodes(g, [2, 3]) - -# g_new now does not contain nodes 2 and 3, and any edges that were connected to these nodes. -println(g_new) -``` -""" -function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector) - nodes_to_remove = sort(union(nodes_to_remove)) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - ndata = g.ndata - - function find_edges_to_remove(nodes, nodes_to_remove) - return findall(node_id -> begin - idx = searchsortedlast(nodes_to_remove, node_id) - idx >= 1 && idx <= length(nodes_to_remove) && nodes_to_remove[idx] == node_id - end, nodes) - end - - edges_to_remove_s = find_edges_to_remove(s, nodes_to_remove) - edges_to_remove_t = find_edges_to_remove(t, nodes_to_remove) - edges_to_remove = union(edges_to_remove_s, edges_to_remove_t) - - mask_edges_to_keep = trues(length(s)) - mask_edges_to_keep[edges_to_remove] .= false - s = s[mask_edges_to_keep] - t = t[mask_edges_to_keep] - - w = isnothing(w) ? nothing : getobs(w, mask_edges_to_keep) - - for node in sort(nodes_to_remove, rev=true) - s[s .> node] .-= 1 - t[t .> node] .-= 1 - end - - nodes_to_keep = setdiff(1:g.num_nodes, nodes_to_remove) - ndata = getobs(ndata, nodes_to_keep) - edata = getobs(edata, mask_edges_to_keep) - - num_nodes = g.num_nodes - length(nodes_to_remove) - - return GNNGraph((s, t, w), - num_nodes, length(s), g.num_graphs, - g.graph_indicator, - ndata, edata, g.gdata) -end - -""" - remove_nodes(g::GNNGraph, p) - -Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`. - -# Examples - -```julia -julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1]) -GNNGraph: - num_nodes: 4 - num_edges: 6 - -julia> g_new = remove_nodes(g, 0.5) -GNNGraph: - num_nodes: 2 - num_edges: 2 -``` -""" -function remove_nodes(g::GNNGraph, p::AbstractFloat) - nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes) - return remove_nodes(g, nodes_to_remove) -end - -""" - add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata]) - add_edges(g::GNNGraph, (s, t); [edata]) - add_edges(g::GNNGraph, (s, t, w); [edata]) - -Add to graph `g` the edges with source nodes `s` and target nodes `t`. -Optionally, pass the edge weight `w` and the features `edata` for the new edges. -Returns a new graph sharing part of the underlying data with `g`. - -If the `s` or `t` contain nodes that are not already present in the graph, -they are added to the graph as well. - -# Examples - -```jldoctest -julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; - -julia> w = Float32[1.0, 2.0, 3.0, 4.0, 5.0]; - -julia> g = GNNGraph((s, t, w)) -GNNGraph: - num_nodes: 4 - num_edges: 5 - -julia> add_edges(g, ([2, 3], [4, 1], [10.0, 20.0])) -GNNGraph: - num_nodes: 4 - num_edges: 7 -``` -```jldoctest -julia> g = GNNGraph() -GNNGraph: - num_nodes: 0 - num_edges: 0 - -julia> add_edges(g, [1,2], [2,3]) -GNNGraph: - num_nodes: 3 - num_edges: 2 -``` -""" -add_edges(g::GNNGraph{<:COO_T}, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, (snew, tnew, nothing); kws...) -add_edges(g, data::Tuple{<:AbstractVector, <:AbstractVector}; kws...) = add_edges(g, (data..., nothing); kws...) - -function add_edges(g::GNNGraph{<:COO_T}, data::COO_T; edata = nothing) - snew, tnew, wnew = data - @assert length(snew) == length(tnew) - @assert isnothing(wnew) || length(wnew) == length(snew) - if length(snew) == 0 - return g - end - @assert minimum(snew) >= 1 - @assert minimum(tnew) >= 1 - num_new = length(snew) - edata = normalize_graphdata(edata, default_name = :e, n = num_new) - edata = cat_features(g.edata, edata) - - s, t = edge_index(g) - s = [s; snew] - t = [t; tnew] - w = get_edge_weight(g) - w = cat_features(w, wnew, g.num_edges, num_new) - - num_nodes = max(maximum(snew), maximum(tnew), g.num_nodes) - if num_nodes > g.num_nodes - ndata_new = normalize_graphdata((;), default_name = :x, n = num_nodes - g.num_nodes) - ndata = cat_features(g.ndata, ndata_new) - else - ndata = g.ndata - end - - return GNNGraph((s, t, w), - num_nodes, length(s), g.num_graphs, - g.graph_indicator, - ndata, edata, g.gdata) -end - -""" - add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes]) - add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes]) - add_edges(g::GNNHeteroGraph, edge_t => (s, t, w); [edata, num_nodes]) - -Add to heterograph `g` edges of type `edge_t` with source node vector `s` and target node vector `t`. -Optionally, pass the edge weights `w` or the features `edata` for the new edges. -`edge_t` is a triplet of symbols `(src_t, rel_t, dst_t)`. - -If the edge type is not already present in the graph, it is added. -If it involves new node types, they are added to the graph as well. -In this case, a dictionary or named tuple of `num_nodes` can be passed to specify the number of nodes of the new types, -otherwise the number of nodes is inferred from the maximum node id in `s` and `t`. -""" -add_edges(g::GNNHeteroGraph{<:COO_T}, edge_t::EType, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, edge_t => (snew, tnew, nothing); kws...) -add_edges(g::GNNHeteroGraph{<:COO_T}, data::Pair{EType, <:Tuple{<:AbstractVector, <:AbstractVector}}; kws...) = add_edges(g, data.first => (data.second..., nothing); kws...) - -function add_edges(g::GNNHeteroGraph{<:COO_T}, - data::Pair{EType, <:COO_T}; - edata = nothing, - num_nodes = Dict{Symbol,Int}()) - edge_t, (snew, tnew, wnew) = data - @assert length(snew) == length(tnew) - if length(snew) == 0 - return g - end - @assert minimum(snew) >= 1 - @assert minimum(tnew) >= 1 - - is_existing_rel = haskey(g.graph, edge_t) - - edata = normalize_graphdata(edata, default_name = :e, n = length(snew)) - _edata = g.edata |> copy - if haskey(_edata, edge_t) - _edata[edge_t] = cat_features(g.edata[edge_t], edata) - else - _edata[edge_t] = edata - end - - graph = g.graph |> copy - etypes = g.etypes |> copy - ntypes = g.ntypes |> copy - _num_nodes = g.num_nodes |> copy - ndata = g.ndata |> copy - if !is_existing_rel - for (node_t, st) in [(edge_t[1], snew), (edge_t[3], tnew)] - if node_t ∉ ntypes - push!(ntypes, node_t) - if haskey(num_nodes, node_t) - _num_nodes[node_t] = num_nodes[node_t] - else - _num_nodes[node_t] = maximum(st) - end - ndata[node_t] = DataStore(_num_nodes[node_t]) - end - end - push!(etypes, edge_t) - else - s, t = edge_index(g, edge_t) - snew = [s; snew] - tnew = [t; tnew] - w = get_edge_weight(g, edge_t) - wnew = cat_features(w, wnew, length(s), length(snew)) - end - - if maximum(snew) > _num_nodes[edge_t[1]] - ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(snew) - _num_nodes[edge_t[1]]) - ndata[edge_t[1]] = cat_features(ndata[edge_t[1]], ndata_new) - _num_nodes[edge_t[1]] = maximum(snew) - end - if maximum(tnew) > _num_nodes[edge_t[3]] - ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(tnew) - _num_nodes[edge_t[3]]) - ndata[edge_t[3]] = cat_features(ndata[edge_t[3]], ndata_new) - _num_nodes[edge_t[3]] = maximum(tnew) - end - - graph[edge_t] = (snew, tnew, wnew) - num_edges = g.num_edges |> copy - num_edges[edge_t] = length(graph[edge_t][1]) - - return GNNHeteroGraph(graph, - _num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - ndata, _edata, g.gdata, - ntypes, etypes) -end - - -""" - perturb_edges([rng], g::GNNGraph, perturb_ratio) - -Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`. -The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. -These new edges are added without creating self-loops. - -The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. -The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. - -# Arguments - -- `g::GNNGraph`: The graph to be perturbed. -- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. -- `rng`: An optionalrandom number generator to ensure reproducible results. - -# Examples - -```julia -julia> g = GNNGraph((s, t, w)) -GNNGraph: - num_nodes: 4 - num_edges: 5 - -julia> perturbed_g = perturb_edges(g, 0.2) -GNNGraph: - num_nodes: 4 - num_edges: 6 -``` -""" -perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat) = - perturb_edges(Random.default_rng(), g, perturb_ratio) - -function perturb_edges(rng::AbstractRNG, g::GNNGraph{<:COO_T}, perturb_ratio::AbstractFloat) - @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" - - num_current_edges = g.num_edges - num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) - - if num_edges_to_add == 0 - return g - end - - num_nodes = g.num_nodes - @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - - snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) - tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) - - mask_loops = snew .!= tnew - snew = snew[mask_loops] - tnew = tnew[mask_loops] - - while length(snew) < num_edges_to_add - n = num_edges_to_add - length(snew) - snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) - tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) - mask_new_loops = snewnew .!= tnewnew - snewnew = snewnew[mask_new_loops] - tnewnew = tnewnew[mask_new_loops] - snew = [snew; snewnew] - tnew = [tnew; tnewnew] - end - - return add_edges(g, (snew, tnew, nothing)) -end - - -### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable -# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector} -# s, t = edge_index(g) -# @assert length(snew) == length(tnew) -# # TODO remove this constraint -# @assert get_edge_weight(g) === nothing - -# edata = normalize_graphdata(edata, default_name=:e, n=length(snew)) -# edata = cat_features(g.edata, edata) - -# s, t = edge_index(g) -# append!(s, snew) -# append!(t, tnew) -# g.num_edges += length(snew) -# return true -# end - -""" - to_bidirected(g) - -Adds a reverse edge for each edge in the graph, then calls -[`remove_multi_edges`](@ref) with `mean` aggregation to simplify the graph. - -See also [`is_bidirected`](@ref). - -# Examples - -```jldoctest -julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; - -julia> w = [1.0, 2.0, 3.0, 4.0, 5.0]; - -julia> e = [10.0, 20.0, 30.0, 40.0, 50.0]; - -julia> g = GNNGraph(s, t, w, edata = e) -GNNGraph: - num_nodes = 4 - num_edges = 5 - edata: - e => (5,) - -julia> g2 = to_bidirected(g) -GNNGraph: - num_nodes = 4 - num_edges = 7 - edata: - e => (7,) - -julia> edge_index(g2) -([1, 2, 2, 3, 3, 4, 4], [2, 1, 3, 2, 4, 3, 4]) - -julia> get_edge_weight(g2) -7-element Vector{Float64}: - 1.0 - 1.0 - 2.0 - 2.0 - 3.5 - 3.5 - 5.0 - -julia> g2.edata.e -7-element Vector{Float64}: - 10.0 - 10.0 - 20.0 - 20.0 - 35.0 - 35.0 - 50.0 -``` -""" -function to_bidirected(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - w = get_edge_weight(g) - snew = [s; t] - tnew = [t; s] - w = cat_features(w, w) - edata = cat_features(g.edata, g.edata) - - g = GNNGraph((snew, tnew, w), - g.num_nodes, length(snew), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) - - return remove_multi_edges(g; aggr = mean) -end - -""" - to_unidirected(g::GNNGraph) - -Return a graph that for each multiple edge between two nodes in `g` -keeps only an edge in one direction. -""" -function to_unidirected(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - w = get_edge_weight(g) - idxs, _ = edge_encoding(s, t, g.num_nodes, directed = false) - snew, tnew = edge_decoding(idxs, g.num_nodes, directed = false) - - g = GNNGraph((snew, tnew, w), - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) - - return remove_multi_edges(g; aggr = mean) -end - -function Graphs.SimpleGraph(g::GNNGraph) - G = Graphs.SimpleGraph(g.num_nodes) - for e in Graphs.edges(g) - Graphs.add_edge!(G, e) - end - return G -end -function Graphs.SimpleDiGraph(g::GNNGraph) - G = Graphs.SimpleDiGraph(g.num_nodes) - for e in Graphs.edges(g) - Graphs.add_edge!(G, e) - end - return G -end - -""" - add_nodes(g::GNNGraph, n; [ndata]) - -Add `n` new nodes to graph `g`. In the -new graph, these nodes will have indexes from `g.num_nodes + 1` -to `g.num_nodes + n`. -""" -function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata = (;)) - ndata = normalize_graphdata(ndata, default_name = :x, n = n) - ndata = cat_features(g.ndata, ndata) - - GNNGraph(g.graph, - g.num_nodes + n, g.num_edges, g.num_graphs, - g.graph_indicator, - ndata, g.edata, g.gdata) -end - -""" - set_edge_weight(g::GNNGraph, w::AbstractVector) - -Set `w` as edge weights in the returned graph. -""" -function set_edge_weight(g::GNNGraph, w::AbstractVector) - s, t = edge_index(g) - @assert length(w) == length(s) - - return GNNGraph((s, t, w), - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) - nv1, nv2 = g1.num_nodes, g2.num_nodes - if g1.graph isa COO_T - s1, t1 = edge_index(g1) - s2, t2 = edge_index(g2) - s = vcat(s1, nv1 .+ s2) - t = vcat(t1, nv1 .+ t2) - w = cat_features(get_edge_weight(g1), get_edge_weight(g2)) - graph = (s, t, w) - ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, nv1) : g1.graph_indicator - ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, nv2) : g2.graph_indicator - elseif g1.graph isa ADJMAT_T - graph = blockdiag(g1.graph, g2.graph) - ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, nv1) : g1.graph_indicator - ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, nv2) : g2.graph_indicator - end - graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) - - GNNGraph(graph, - nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs, - graph_indicator, - cat_features(g1.ndata, g2.ndata), - cat_features(g1.edata, g2.edata), - cat_features(g1.gdata, g2.gdata)) -end - -# PIRACY -function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) - m1, n1 = size(A1) - @assert m1 == n1 - m2, n2 = size(A2) - @assert m2 == n2 - O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0) - O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0) - return [A1 O1 - O2 A2] -end - -""" - blockdiag(xs::GNNGraph...) - -Equivalent to [`MLUtils.batch`](@ref). -""" -function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) - g = g1 - for go in gothers - g = blockdiag(g, go) - end - return g -end - -""" - batch(gs::Vector{<:GNNGraph}) - -Batch together multiple `GNNGraph`s into a single one -containing the total number of original nodes and edges. - -Equivalent to [`SparseArrays.blockdiag`](@ref). -See also [`MLUtils.unbatch`](@ref). - -# Examples - -```jldoctest -julia> g1 = rand_graph(4, 6, ndata=ones(8, 4)) -GNNGraph: - num_nodes = 4 - num_edges = 6 - ndata: - x => (8, 4) - -julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7)) -GNNGraph: - num_nodes = 7 - num_edges = 4 - ndata: - x => (8, 7) - -julia> g12 = MLUtils.batch([g1, g2]) -GNNGraph: - num_nodes = 11 - num_edges = 10 - num_graphs = 2 - ndata: - x => (8, 11) - -julia> g12.ndata.x -8×11 Matrix{Float64}: - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 - 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -``` -""" -function MLUtils.batch(gs::AbstractVector{<:GNNGraph}) - Told = eltype(gs) - # try to restrict the eltype - gs = [g for g in gs] - if eltype(gs) != Told - return MLUtils.batch(gs) - else - return blockdiag(gs...) - end -end - -function MLUtils.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T} - v_num_nodes = [g.num_nodes for g in gs] - edge_indices = [edge_index(g) for g in gs] - nodesum = cumsum([0; v_num_nodes])[1:(end - 1)] - s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) - t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) - w = cat_features([get_edge_weight(g) for g in gs]) - graph = (s, t, w) - - function materialize_graph_indicator(g) - g.graph_indicator === nothing ? ones_like(s, g.num_nodes) : g.graph_indicator - end - - v_gi = materialize_graph_indicator.(gs) - v_num_graphs = [g.num_graphs for g in gs] - graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] - v_gi = [ng .+ gi for (ng, gi) in zip(graphsum, v_gi)] - graph_indicator = cat_features(v_gi) - - GNNGraph(graph, - sum(v_num_nodes), - sum([g.num_edges for g in gs]), - sum(v_num_graphs), - graph_indicator, - cat_features([g.ndata for g in gs]), - cat_features([g.edata for g in gs]), - cat_features([g.gdata for g in gs])) -end - -function MLUtils.batch(g::GNNGraph) - throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead.")) -end - - -function MLUtils.batch(gs::AbstractVector{<:GNNHeteroGraph}) - function edge_index_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) - if haskey(g.graph, edge_t) - g.graph[edge_t][1:2] - else - nothing - end - end - - function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) - get(g.graph, edge_t, (nothing, nothing, nothing))[3] - end - - @assert length(gs) > 0 - ntypes = union([g.ntypes for g in gs]...) - etypes = union([g.etypes for g in gs]...) - - v_num_nodes = Dict(node_t => [get(g.num_nodes, node_t, 0) for g in gs] for node_t in ntypes) - num_nodes = Dict(node_t => sum(v_num_nodes[node_t]) for node_t in ntypes) - num_edges = Dict(edge_t => sum(get(g.num_edges, edge_t, 0) for g in gs) for edge_t in etypes) - edge_indices = edge_indices = Dict(edge_t => [edge_index_nullable(g, edge_t) for g in gs] for edge_t in etypes) - nodesum = Dict(node_t => cumsum([0; v_num_nodes[node_t]])[1:(end - 1)] for node_t in ntypes) - graphs = [] - for edge_t in etypes - src_t, _, dst_t = edge_t - # @show edge_t edge_indices[edge_t] first(edge_indices[edge_t]) - # for ei in edge_indices[edge_t] - # @show ei[1] - # end - # # [ei[1] for (ii, ei) in enumerate(edge_indices[edge_t])] - s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) - t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) - w = cat_features(filter(x -> x !== nothing, [get_edge_weight_nullable(g, edge_t) for g in gs])) - push!(graphs, edge_t => (s, t, w)) - end - graph = Dict(graphs...) - - #TODO relax this restriction - @assert all(g -> g.num_graphs == 1, gs) - - s = edge_index(gs[1], gs[1].etypes[1])[1] # grab any source vector - - function materialize_graph_indicator(g, node_t) - n = get(g.num_nodes, node_t, 0) - return ones_like(s, n) - end - v_gi = Dict(node_t => [materialize_graph_indicator(g, node_t) for g in gs] for node_t in ntypes) - v_num_graphs = [g.num_graphs for g in gs] - graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] - v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes) - graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes) - - function data_or_else(data, types) - Dict(type => get(data, type, DataStore(0)) for type in types) - end - - return GNNHeteroGraph(graph, - num_nodes, - num_edges, - sum(v_num_graphs), - graph_indicator, - cat_features([data_or_else(g.ndata, ntypes) for g in gs]), - cat_features([data_or_else(g.edata, etypes) for g in gs]), - cat_features([g.gdata for g in gs]), - ntypes, etypes) -end - -""" - unbatch(g::GNNGraph) - -Opposite of the [`MLUtils.batch`](@ref) operation, returns -an array of the individual graphs batched together in `g`. - -See also [`MLUtils.batch`](@ref) and [`getgraph`](@ref). - -# Examples - -```jldoctest -julia> gbatched = MLUtils.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)]) -GNNGraph: - num_nodes = 19 - num_edges = 16 - num_graphs = 3 - -julia> MLUtils.unbatch(gbatched) -3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: - GNNGraph: - num_nodes = 5 - num_edges = 6 - - GNNGraph: - num_nodes = 10 - num_edges = 8 - - GNNGraph: - num_nodes = 4 - num_edges = 2 -``` -""" -function MLUtils.unbatch(g::GNNGraph{T}) where {T <: COO_T} - g.num_graphs == 1 && return [g] - - nodemasks = _unbatch_nodemasks(g.graph_indicator, g.num_graphs) - num_nodes = length.(nodemasks) - cumnum_nodes = [0; cumsum(num_nodes)] - - s, t = edge_index(g) - w = get_edge_weight(g) - - edgemasks = _unbatch_edgemasks(s, t, g.num_graphs, cumnum_nodes) - num_edges = length.(edgemasks) - @assert sum(num_edges)==g.num_edges "Error in unbatching, likely the edges are not sorted (first edges belong to the first graphs, then edges in the second graph and so on)" - - function build_graph(i) - node_mask = nodemasks[i] - edge_mask = edgemasks[i] - snew = s[edge_mask] .- cumnum_nodes[i] - tnew = t[edge_mask] .- cumnum_nodes[i] - wnew = w === nothing ? nothing : w[edge_mask] - graph = (snew, tnew, wnew) - graph_indicator = nothing - ndata = getobs(g.ndata, node_mask) - edata = getobs(g.edata, edge_mask) - gdata = getobs(g.gdata, i) - - nedges = num_edges[i] - nnodes = num_nodes[i] - ngraphs = 1 - - return GNNGraph(graph, - nnodes, nedges, ngraphs, - graph_indicator, - ndata, edata, gdata) - end - - return [build_graph(i) for i in 1:(g.num_graphs)] -end - -function MLUtils.unbatch(g::GNNGraph) - return [getgraph(g, i) for i in 1:(g.num_graphs)] -end - -function _unbatch_nodemasks(graph_indicator, num_graphs) - @assert issorted(graph_indicator) "The graph_indicator vector must be sorted." - idxslast = [searchsortedlast(graph_indicator, i) for i in 1:num_graphs] - - nodemasks = [1:idxslast[1]] - for i in 2:num_graphs - push!(nodemasks, (idxslast[i - 1] + 1):idxslast[i]) - end - return nodemasks -end - -function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes) - edgemasks = [] - for i in 1:(num_graphs - 1) - lastedgeid = findfirst(s) do x - x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2] - end - firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1 - # if nothing make empty range - lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1 - - push!(edgemasks, firstedgeid:lastedgeid) - end - push!(edgemasks, (last(edgemasks[end]) + 1):length(s)) - return edgemasks -end - -@non_differentiable _unbatch_nodemasks(::Any...) -@non_differentiable _unbatch_edgemasks(::Any...) - -""" - getgraph(g::GNNGraph, i; nmap=false) - -Return the subgraph of `g` induced by those nodes `j` -for which `g.graph_indicator[j] == i` or, -if `i` is a collection, `g.graph_indicator[j] ∈ i`. -In other words, it extract the component graphs from a batched graph. - -If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones. -The node `i` in the subgraph will correspond to the node `v[i]` in `g`. -""" -getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...) - -function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap = false) - if g.graph_indicator === nothing - @assert i == [1] - if nmap - return g, 1:(g.num_nodes) - else - return g - end - end - - node_mask = g.graph_indicator .∈ Ref(i) - - nodes = (1:(g.num_nodes))[node_mask] - nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes)) - - graphmap = Dict(i => inew for (inew, i) in enumerate(i)) - graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]] - - s, t = edge_index(g) - w = get_edge_weight(g) - edge_mask = s .∈ Ref(nodes) - - if g.graph isa COO_T - s = [nodemap[i] for i in s[edge_mask]] - t = [nodemap[i] for i in t[edge_mask]] - w = isnothing(w) ? nothing : w[edge_mask] - graph = (s, t, w) - elseif g.graph isa ADJMAT_T - graph = g.graph[nodes, nodes] - end - - ndata = getobs(g.ndata, node_mask) - edata = getobs(g.edata, edge_mask) - gdata = getobs(g.gdata, i) - - num_edges = sum(edge_mask) - num_nodes = length(graph_indicator) - num_graphs = length(i) - - gnew = GNNGraph(graph, - num_nodes, num_edges, num_graphs, - graph_indicator, - ndata, edata, gdata) - - if nmap - return gnew, nodes - else - return gnew - end -end - -""" - negative_sample(g::GNNGraph; - num_neg_edges = g.num_edges, - bidirected = is_bidirected(g)) - -Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges. - -If `bidirected=true`, the output graph will be bidirected and there will be no -leakage from the origin graph. - -See also [`is_bidirected`](@ref). -""" -function negative_sample(g::GNNGraph; - max_trials = 3, - num_neg_edges = g.num_edges, - bidirected = is_bidirected(g)) - @assert g.num_graphs == 1 - # Consider self-loops as positive edges - # Construct new graph dropping features - g = add_self_loops(GNNGraph(edge_index(g), num_nodes = g.num_nodes)) - - s, t = edge_index(g) - n = g.num_nodes - dev = get_device(s) - cdev = cpu_device() - s, t = s |> cdev, t |> cdev - idx_pos, maxid = edge_encoding(s, t, n) - if bidirected - num_neg_edges = num_neg_edges ÷ 2 - pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge - else - pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge - end - # pneg * sample_prob * maxid == num_neg_edges - sample_prob = min(1, num_neg_edges / (pneg * maxid) * 1.1) - idx_neg = Int[] - for _ in 1:max_trials - rnd = randsubseq(1:maxid, sample_prob) - setdiff!(rnd, idx_pos) - union!(idx_neg, rnd) - if length(idx_neg) >= num_neg_edges - idx_neg = idx_neg[1:num_neg_edges] - break - end - end - s_neg, t_neg = edge_decoding(idx_neg, n) - if bidirected - s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg] - end - s_neg, t_neg = s_neg |> dev, t_neg |> dev - return GNNGraph(s_neg, t_neg, num_nodes = n) -end - -""" - rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2 - -Randomly partition the edges in `g` to form two graphs, `g1` -and `g2`. Both will have the same number of nodes as `g`. -`g1` will contain a fraction `frac` of the original edges, -while `g2` wil contain the rest. - -If `bidirected = true` makes sure that an edge and its reverse go into the same split. -This option is supported only for bidirected graphs with no self-loops -and multi-edges. - -`rand_edge_split` is tipically used to create train/test splits in link prediction tasks. -""" -function rand_edge_split(g::GNNGraph, frac; bidirected = is_bidirected(g)) - s, t = edge_index(g) - ne = bidirected ? g.num_edges ÷ 2 : g.num_edges - eids = randperm(ne) - size1 = round(Int, ne * frac) - - if !bidirected - s1, t1 = s[eids[1:size1]], t[eids[1:size1]] - s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] - else - # @assert is_bidirected(g) - # @assert !has_self_loops(g) - # @assert !has_multi_edges(g) - mask = s .< t - s, t = s[mask], t[mask] - s1, t1 = s[eids[1:size1]], t[eids[1:size1]] - s1, t1 = [s1; t1], [t1; s1] - s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] - s2, t2 = [s2; t2], [t2; s2] - end - g1 = GNNGraph(s1, t1, num_nodes = g.num_nodes) - g2 = GNNGraph(s2, t2, num_nodes = g.num_nodes) - return g1, g2 -end - -""" - random_walk_pe(g, walk_length) - -Return the random walk positional encoding from the paper [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) of the given graph `g` and the length of the walk `walk_length` as a matrix of size `(walk_length, g.num_nodes)`. -""" -function random_walk_pe(g::GNNGraph, walk_length::Int) - matrix = zeros(walk_length, g.num_nodes) - adj = adjacency_matrix(g, Float32; dir = :out) - matrix = dense_zeros_like(adj, Float32, (walk_length, g.num_nodes)) - deg = sum(adj, dims = 2) |> vec - deg_inv = inv.(deg) - deg_inv[isinf.(deg_inv)] .= 0 - RW = adj * Diagonal(deg_inv) - out = RW - matrix[1, :] .= diag(RW) - for i in 2:walk_length - out = out * RW - matrix[i, :] .= diag(out) - end - return matrix -end - -dense_zeros_like(a::SparseMatrixCSC, T::Type, sz = size(a)) = zeros(T, sz) -dense_zeros_like(a::AbstractArray, T::Type, sz = size(a)) = fill!(similar(a, T, sz), 0) -dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz) - -# """ -# Transform vector of cartesian indexes into a tuple of vectors containing integers. -# """ -ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) - -@non_differentiable negative_sample(x...) -@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule -@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule -@non_differentiable dense_zeros_like(x...) - -""" - ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph - -Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix. -References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422) - - -The function performs the following steps: -1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix. -2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities. -3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix. -4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix. - -# Arguments -- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available. -- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`. - -# Returns -- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation. -""" -function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0) - s, t = edge_index(g) - w = get_edge_weight(g) - if isnothing(w) - w = ones(Float32, g.num_edges) - end - - N = g.num_nodes - - initial_A = sparse(t, s, w, N, N) - scaled_A = (Float32(alpha) - 1) * initial_A - - I_sparse = sparse(Diagonal(ones(Float32, N))) - A_sparse = I_sparse + scaled_A - - A_dense = Matrix(A_sparse) - - PPR = alpha * inv(A_dense) - - new_w = [PPR[dst, src] for (src, dst) in zip(s, t)] - - return GNNGraph((s, t, new_w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -[.\GNNGraphs\src\utils.jl] -function check_num_nodes(g::GNNGraph, x::AbstractArray) - @assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)" - return true -end -function check_num_nodes(g::GNNGraph, x::Union{Tuple, NamedTuple}) - map(x -> check_num_nodes(g, x), x) - return true -end - -check_num_nodes(::GNNGraph, ::Nothing) = true - -function check_num_nodes(g::GNNGraph, x::Tuple) - @assert length(x) == 2 - check_num_nodes(g, x[1]) - check_num_nodes(g, x[2]) - return true -end - -# x = (Xsrc, Xdst) = (Xj, Xi) -function check_num_nodes(g::GNNHeteroGraph, x::Tuple) - @assert length(x) == 2 - @assert length(g.etypes) == 1 - nt1, _, nt2 = only(g.etypes) - if x[1] isa AbstractArray - @assert size(x[1], ndims(x[1])) == g.num_nodes[nt1] - end - if x[2] isa AbstractArray - @assert size(x[2], ndims(x[2])) == g.num_nodes[nt2] - end - return true -end - -function check_num_edges(g::GNNGraph, e::AbstractArray) - @assert g.num_edges==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(g.num_edges)" - return true -end -function check_num_edges(g::AbstractGNNGraph, x::Union{Tuple, NamedTuple}) - map(x -> check_num_edges(g, x), x) - return true -end - -check_num_edges(::AbstractGNNGraph, ::Nothing) = true - -function check_num_edges(g::GNNHeteroGraph, e::AbstractArray) - num_edgs = only(g.num_edges)[2] - @assert only(num_edgs)==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(num_edgs)" - return true -end - -sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) - -""" - sort_edge_index(ei::Tuple) -> u', v' - sort_edge_index(u, v) -> u', v' - -Return a sorted version of the tuple of vectors `ei = (u, v)`, -applying a common permutation to `u` and `v`. -The sorting is lexycographic, that is the pairs `(ui, vi)` -are sorted first according to the `ui` and then according to `vi`. -""" -function sort_edge_index(u, v) - uv = collect(zip(u, v)) - p = sortperm(uv) # isless lexicographically defined for tuples - return u[p], v[p] -end - - -cat_features(x1::Nothing, x2::Nothing) = nothing -cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1)) -function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) - cat(x1, x2, dims = 1) -end - -# workaround for issue #98 #104 -# See https://github.com/JuliaStrings/InlineStrings.jl/issues/21 -# Remove when minimum supported version is julia v1.8 -cat_features(x1::NamedTuple{(), Tuple{}}, x2::NamedTuple{(), Tuple{}}) = (;) -cat_features(xs::AbstractVector{NamedTuple{(), Tuple{}}}) = (;) - -function cat_features(x1::NamedTuple, x2::NamedTuple) - sort(collect(keys(x1))) == sort(collect(keys(x2))) || - @error "cannot concatenate feature data with different keys" - - return NamedTuple(k => cat_features(x1[k], x2[k]) for k in keys(x1)) -end - -function cat_features(x1::Dict{Symbol, T}, x2::Dict{Symbol, T}) where {T} - sort(collect(keys(x1))) == sort(collect(keys(x2))) || - @error "cannot concatenate feature data with different keys" - - return Dict{Symbol, T}([k => cat_features(x1[k], x2[k]) for k in keys(x1)]...) -end - -function cat_features(x::Dict) - return Dict([k => cat_features(v) for (k, v) in pairs(x)]...) -end - - -function cat_features(xs::AbstractVector{<:AbstractArray{T, N}}) where {T <: Number, N} - cat(xs...; dims = N) -end - -cat_features(xs::AbstractVector{Nothing}) = nothing -cat_features(xs::AbstractVector{<:Number}) = xs - -function cat_features(xs::AbstractVector{<:NamedTuple}) - symbols = [sort(collect(keys(x))) for x in xs] - all(y -> y == symbols[1], symbols) || - @error "cannot concatenate feature data with different keys" - length(xs) == 1 && return xs[1] - - # concatenate - syms = symbols[1] - NamedTuple(k => cat_features([x[k] for x in xs]) for k in syms) -end - -# function cat_features(xs::AbstractVector{Dict{Symbol, T}}) where {T} -# symbols = [sort(collect(keys(x))) for x in xs] -# all(y -> y == symbols[1], symbols) || -# @error "cannot concatenate feature data with different keys" -# length(xs) == 1 && return xs[1] - -# # concatenate -# syms = symbols[1] -# return Dict{Symbol, T}([k => cat_features([x[k] for x in xs]) for k in syms]...) -# end - -function cat_features(xs::AbstractVector{<:Dict}) - _allkeys = [sort(collect(keys(x))) for x in xs] - _keys = union(_allkeys...) - length(xs) == 1 && return xs[1] - - # concatenate - return Dict([k => cat_features([x[k] for x in xs if haskey(x, k)]) for k in _keys]...) -end - - -# Used to concatenate edge weights -cat_features(w1::Nothing, w2::Nothing, n1::Int, n2::Int) = nothing -cat_features(w1::AbstractVector, w2::Nothing, n1::Int, n2::Int) = cat_features(w1, ones_like(w1, n2)) -cat_features(w1::Nothing, w2::AbstractVector, n1::Int, n2::Int) = cat_features(ones_like(w2, n1), w2) -cat_features(w1::AbstractVector, w2::AbstractVector, n1::Int, n2::Int) = cat_features(w1, w2) - - -# Turns generic type into named tuple -normalize_graphdata(data::Nothing; n, kws...) = DataStore(n) - -function normalize_graphdata(data; default_name::Symbol, kws...) - normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) -end - -function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false) - # This had to workaround two Zygote bugs with NamedTuples - # https://github.com/FluxML/Zygote.jl/issues/1071 - # https://github.com/FluxML/Zygote.jl/issues/1072 - - if n > 1 - @assert all(x -> x isa AbstractArray, data) "Non-array features provided." - end - - if n <= 1 - # If last array dimension is not 1, add a new dimension. - # This is mostly useful to reshape global feature vectors - # of size D to Dx1 matrices. - unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v - unsqz_last(v) = v - - data = map(unsqz_last, data) - end - - if n > 0 - if duplicate_if_needed - function duplicate(v) - if v isa AbstractArray && size(v)[end] == n ÷ 2 - v = cat(v, v, dims = ndims(v)) - end - return v - end - data = map(duplicate, data) - end - - for x in data - if x isa AbstractArray - @assert size(x)[end]==n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])." - end - end - end - - return DataStore(n, data) -end - -# For heterogeneous graphs -function normalize_heterographdata(data::Nothing; default_name::Symbol, ns::Dict, kws...) - Dict([k => normalize_graphdata(nothing; default_name = default_name, n, kws...) - for (k, n) in ns]...) -end - -normalize_heterographdata(data; kws...) = normalize_heterographdata(Dict(data); kws...) - -function normalize_heterographdata(data::Dict; default_name::Symbol, ns::Dict, kws...) - Dict([k => normalize_graphdata(get(data, k, nothing); default_name = default_name, n, kws...) - for (k, n) in ns]...) -end - -numnonzeros(a::AbstractSparseMatrix) = nnz(a) -numnonzeros(a::AbstractMatrix) = count(!=(0), a) - -## Map edges into a contiguous range of integers -function edge_encoding(s, t, n; directed = true, self_loops = true) - if directed && self_loops - maxid = n^2 - idx = (s .- 1) .* n .+ t - elseif !directed && self_loops - maxid = n * (n + 1) ÷ 2 - mask = s .> t - snew = copy(s) - tnew = copy(t) - snew[mask] .= t[mask] - tnew[mask] .= s[mask] - s, t = snew, tnew - - # idx = ∑_{i',i'=i'}^n 1 + ∑_{j',i<=j'<=j} 1 - # = ∑_{i',i'=i'}^n 1 + (j - i + 1) - # = ∑_{i',i' s) - elseif !directed && !self_loops - @assert all(s .!= t) - maxid = n * (n - 1) ÷ 2 - mask = s .> t - snew = copy(s) - tnew = copy(t) - snew[mask] .= t[mask] - tnew[mask] .= s[mask] - s, t = snew, tnew - - # idx(s,t) = ∑_{s',1<= s'= s) - elseif !directed && !self_loops - # Considering t = s + 1 in - # idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s) - # and inverting for s we have - s = @. floor(Int, 1/2 + n - 1/2 * sqrt(9 - 4n + 4n^2 - 8*idx)) - # now we can find t - t = @. idx - (s - 1) * n + s * (s - 1) ÷ 2 + s - end - return s, t -end - -# for bipartite graphs -function edge_decoding(idx, n1, n2) - @assert all(1 .<= idx .<= n1 * n2) - s = (idx .- 1) .÷ n2 .+ 1 - t = (idx .- 1) .% n2 .+ 1 - return s, t -end - -function _rand_edges(rng, n::Int, m::Int; directed = true, self_loops = true) - idmax = if directed && self_loops - n^2 - elseif !directed && self_loops - n * (n + 1) ÷ 2 - elseif directed && !self_loops - n * (n - 1) - elseif !directed && !self_loops - n * (n - 1) ÷ 2 - end - idx = StatsBase.sample(rng, 1:idmax, m, replace = false) - s, t = edge_decoding(idx, n; directed, self_loops) - val = nothing - return s, t, val -end - -function _rand_edges(rng, (n1, n2), m) - idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) - s, t = edge_decoding(idx, n1, n2) - val = nothing - return s, t, val -end - -binarize(x) = map(>(0), x) - -@non_differentiable binarize(x...) -@non_differentiable edge_encoding(x...) -@non_differentiable edge_decoding(x...) - -### PRINTING ##### - -function shortsummary(io::IO, x) - s = shortsummary(x) - s === nothing && return - print(io, s) -end - -shortsummary(x) = summary(x) -shortsummary(x::Number) = "$x" - -function shortsummary(x::NamedTuple) - if length(x) == 0 - return nothing - elseif length(x) === 1 - return "$(keys(x)[1]) = $(shortsummary(x[1]))" - else - "(" * join(("$k = $(shortsummary(x[k]))" for k in keys(x)), ", ") * ")" - end -end - -function shortsummary(x::DataStore) - length(x) == 0 && return nothing - return "DataStore(" * join(("$k = [$(shortsummary(x[k]))]" for k in keys(x)), ", ") * - ")" -end - -# from (2,2,3) output of size function to a string "2×2×3" -function dims2string(d) - isempty(d) ? "0-dimensional" : - length(d) == 1 ? "$(d[1])-element" : - join(map(string, d), '×') -end - -@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}}) -@non_differentiable normalize_graphdata(::Nothing) - -iscuarray(x::AbstractArray) = false -@non_differentiable iscuarray(::Any) - - -@doc raw""" - color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters - -The color refinement algorithm for graph coloring. -Given a graph `g` and an initial coloring `x0`, the algorithm -iteratively refines the coloring until a fixed point is reached. - -At each iteration the algorithm computes a hash of the coloring and the sorted list of colors -of the neighbors of each node. This hash is used to determine if the coloring has changed. - -```math -x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))). -```` - -This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing. - -# Arguments -- `g::GNNGraph`: The graph to color. -- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1. - -# Returns -- `x::AbstractVector{<:Integer}`: The final coloring. -- `num_colors::Int`: The number of colors used. -- `niters::Int`: The number of iterations until convergence. -""" -color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes)) - -function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer}) - @assert length(x0) == g.num_nodes - s, t = edge_index(g) - t, s = sort_edge_index(t, s) # sort by target - degs = degree(g, dir=:in) - x = x0 - - hashmap = Dict{UInt64, Int}() - x′ = zeros(Int, length(x0)) - niters = 0 - while true - xneigs = chunk(x[s], size=degs) - for (i, (xi, xineigs)) in enumerate(zip(x, xneigs)) - idx = hash((xi, sort(xineigs))) - x′[i] = get!(hashmap, idx, length(hashmap) + 1) - end - niters += 1 - x == x′ && break - x = x′ - end - num_colors = length(union(x)) - return x, num_colors, niters -end -[.\GNNGraphs\test\chainrules.jl] -@testset "dict constructor" begin - grad = gradient(1.) do x - d = Dict([:x => x, :y => 5]...) - return sum(d[:x].^2) - end[1] - - @test grad == 2 - - ## BROKEN Constructors - # grad = gradient(1.) do x - # d = Dict([(:x => x), (:y => 5)]) - # return sum(d[:x].^2) - # end[1] - - # @test grad == 2 - - - # grad = gradient(1.) do x - # d = Dict([(:x => x), (:y => 5)]) - # return sum(d[:x].^2) - # end[1] - - # @test grad == 2 -end - -[.\GNNGraphs\test\convert.jl] -if TEST_GPU - @testset "to_coo(dense) on gpu" begin - get_st(A) = GNNGraphs.to_coo(A)[1][1:2] - get_val(A) = GNNGraphs.to_coo(A)[1][3] - - A = cu([0 2 2; 2.0 0 2; 2 2 0]) - - y = get_val(A) - @test y isa CuVector{Float32} - @test Array(y) ≈ [2, 2, 2, 2, 2, 2] - - s, t = get_st(A) - @test s isa CuVector{<:Integer} - @test t isa CuVector{<:Integer} - @test Array(s) == [2, 3, 1, 3, 1, 2] - @test Array(t) == [1, 1, 2, 2, 3, 3] - - @test gradient(A -> sum(get_val(A)), A)[1] isa CuMatrix{Float32} - end -end - -[.\GNNGraphs\test\datastore.jl] - -@testset "constructor" begin - @test_throws AssertionError DataStore(10, (:x => rand(10), :y => rand(2, 4))) - - @testset "keyword args" begin - ds = DataStore(10, x = rand(10), y = rand(2, 10)) - @test size(ds.x) == (10,) - @test size(ds.y) == (2, 10) - - ds = DataStore(x = rand(10), y = rand(2, 10)) - @test size(ds.x) == (10,) - @test size(ds.y) == (2, 10) - end -end - -@testset "getproperty / setproperty!" begin - x = rand(10) - ds = DataStore(10, (:x => x, :y => rand(2, 10))) - @test ds.x == ds[:x] == x - @test_throws DimensionMismatch ds.z=rand(12) - ds.z = [1:10;] - @test ds.z == [1:10;] - vec = [DataStore(10, (:x => x,)), DataStore(10, (:x => x, :y => rand(2, 10)))] - @test vec.x == [x, x] - @test_throws KeyError vec.z - @test vec._n == [10, 10] - @test vec._data == [Dict(:x => x), Dict(:x => x, :y => vec[2].y)] -end - -@testset "setindex!" begin - ds = DataStore(10) - x = rand(10) - @test (ds[:x] = x) == x # Tests setindex! - @test ds.x == ds[:x] == x -end - -@testset "map" begin - ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) - ds2 = map(x -> x .+ 1, ds) - @test ds2.x == ds.x .+ 1 - @test ds2.y == ds.y .+ 1 - - @test_throws AssertionError ds2=map(x -> [x; x], ds) -end - -@testset "getdata / getn" begin - ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) - @test getdata(ds) == getfield(ds, :_data) - @test_throws KeyError ds.data - @test getn(ds) == getfield(ds, :_n) - @test_throws KeyError ds.n -end - -@testset "cat empty" begin - ds1 = DataStore(2, (:x => rand(2))) - ds2 = DataStore(1, (:x => rand(1))) - dsempty = DataStore(0, (:x => rand(0))) - - ds = GNNGraphs.cat_features(ds1, ds2) - @test getn(ds) == 3 - ds = GNNGraphs.cat_features(ds1, dsempty) - @test getn(ds) == 2 - - # issue #280 - g = GNNGraph([1], [2]) - h = add_edges(g, Int[], Int[]) # adds no edges - @test getn(g.edata) == 1 - @test getn(h.edata) == 1 -end - - -@testset "gradient" begin - ds = DataStore(10, (:x => rand(10), :y => rand(2, 10))) - - f1(ds) = sum(ds.x) - grad = gradient(f1, ds)[1] - @test grad._data[:x] ≈ ngradient(f1, ds)[1][:x] - - g = rand_graph(5, 2) - x = rand(2, 5) - grad = gradient(x -> sum(exp, GNNGraph(g, ndata = x).ndata.x), x)[1] - @test grad == exp.(x) -end - -@testset "functor" begin - ds = DataStore(10, (:x => zeros(10), :y => ones(2, 10))) - p, re = Functors.functor(ds) - @test p[1] === getn(ds) - @test p[2] === getdata(ds) - @test ds == re(p) - - ds2 = Functors.fmap(ds) do x - if x isa AbstractArray - x .+ 1 - else - x - end - end - @test ds isa DataStore - @test ds2.x == ds.x .+ 1 -end - -[.\GNNGraphs\test\generate.jl] -@testset "rand_graph" begin - n, m = 10, 20 - m2 = m ÷ 2 - x = rand(3, n) - e = rand(4, m2) - - g = rand_graph(n, m, ndata = x, edata = e, graph_type = GRAPH_T) - @test g.num_nodes == n - @test g.num_edges == m - @test g.ndata.x === x - if GRAPH_T == :coo - s, t = edge_index(g) - @test s[1:m2] == t[(m2 + 1):end] - @test t[1:m2] == s[(m2 + 1):end] - @test g.edata.e[:, 1:m2] == e - @test g.edata.e[:, (m2 + 1):end] == e - end - - rng = MersenneTwister(17) - g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) - @test g.num_nodes == n - @test g.num_edges == m - - rng = MersenneTwister(17) - g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) - @test edge_index(g2) == edge_index(g) - - ew = rand(m2) - rng = MersenneTwister(17) - g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew) - @test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo) - - ew = rand(m) - rng = MersenneTwister(17) - g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew) - @test get_edge_weight(g) == ew broken=(GRAPH_T != :coo) -end - -@testset "knn_graph" begin - n, k = 10, 3 - x = rand(3, n) - g = knn_graph(x, k; graph_type = GRAPH_T) - @test g.num_nodes == 10 - @test g.num_edges == n * k - @test degree(g, dir = :in) == fill(k, n) - @test has_self_loops(g) == false - - g = knn_graph(x, k; dir = :out, self_loops = true, graph_type = GRAPH_T) - @test g.num_nodes == 10 - @test g.num_edges == n * k - @test degree(g, dir = :out) == fill(k, n) - @test has_self_loops(g) == true - - graph_indicator = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2] - g = knn_graph(x, k; graph_indicator, graph_type = GRAPH_T) - @test g.num_graphs == 2 - s, t = edge_index(g) - ne = n * k ÷ 2 - @test all(1 .<= s[1:ne] .<= 5) - @test all(1 .<= t[1:ne] .<= 5) - @test all(6 .<= s[(ne + 1):end] .<= 10) - @test all(6 .<= t[(ne + 1):end] .<= 10) -end - -@testset "radius_graph" begin - n, r = 10, 0.5 - x = rand(3, n) - g = radius_graph(x, r; graph_type = GRAPH_T) - @test g.num_nodes == 10 - @test has_self_loops(g) == false - - g = radius_graph(x, r; dir = :out, self_loops = true, graph_type = GRAPH_T) - @test g.num_nodes == 10 - @test has_self_loops(g) == true - - graph_indicator = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2] - g = radius_graph(x, r; graph_indicator, graph_type = GRAPH_T) - @test g.num_graphs == 2 - s, t = edge_index(g) - @test (s .> 5) == (t .> 5) -end - -@testset "rand_bipartite_heterograph" begin - g = rand_bipartite_heterograph((10, 15), (20, 20)) - @test g.num_nodes == Dict(:A => 10, :B => 15) - @test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) - sA, tB = edge_index(g, (:A, :to, :B)) - for (s, t) in zip(sA, tB) - @test 1 <= s <= 10 - @test 1 <= t <= 15 - @test has_edge(g, (:A,:to,:B), s, t) - @test has_edge(g, (:B,:to,:A), t, s) - end - - g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) - @test has_edge(g, (:A,:to,:B), 1, 1) - @test !has_edge(g, (:B,:to,:A), 1, 1) -end - -@testset "rand_temporal_radius_graph" begin - number_nodes = 30 - number_snapshots = 5 - r = 0.1 - speed = 0.1 - tg = rand_temporal_radius_graph(number_nodes, number_snapshots, speed, r) - @test tg.num_nodes == [number_nodes for i in 1:number_snapshots] - @test tg.num_snapshots == number_snapshots - r2 = 0.95 - tg2 = rand_temporal_radius_graph(number_nodes, number_snapshots, speed, r2) - @test mean(mean(degree.(tg.snapshots)))<=mean(mean(degree.(tg2.snapshots))) -end - -@testset "rand_temporal_hyperbolic_graph" begin - @test GNNGraphs._hyperbolic_distance([1.0,1.0],[1.0,1.0];ζ=1)==0 - @test GNNGraphs._hyperbolic_distance([0.23,0.11],[0.98,0.55];ζ=1) == GNNGraphs._hyperbolic_distance([0.98,0.55],[0.23,0.11];ζ=1) - number_nodes = 30 - number_snapshots = 5 - α, R, speed, ζ = 1, 1, 0.1, 1 - - tg = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ) - @test tg.num_nodes == [number_nodes for i in 1:number_snapshots] - @test tg.num_snapshots == number_snapshots - R = 10 - tg1 = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ) - @test mean(mean(degree.(tg1.snapshots)))<=mean(mean(degree.(tg.snapshots))) -end - -[.\GNNGraphs\test\gnngraph.jl] -@testset "Constructor: adjacency matrix" begin - A = sprand(10, 10, 0.5) - sA, tA, vA = findnz(A) - - g = GNNGraph(A, graph_type = GRAPH_T) - s, t = edge_index(g) - v = get_edge_weight(g) - @test s == sA - @test t == tA - @test v == vA - - g = GNNGraph(Matrix(A), graph_type = GRAPH_T) - s, t = edge_index(g) - v = get_edge_weight(g) - @test s == sA - @test t == tA - @test v == vA - - g = GNNGraph([0 0 0 - 0 0 1 - 0 1 0], graph_type = GRAPH_T) - @test g.num_nodes == 3 - @test g.num_edges == 2 - - g = GNNGraph([0 1 0 - 1 0 0 - 0 0 0], graph_type = GRAPH_T) - @test g.num_nodes == 3 - @test g.num_edges == 2 -end - -@testset "Constructor: integer" begin - g = GNNGraph(10, graph_type = GRAPH_T) - @test g.num_nodes == 10 - @test g.num_edges == 0 - - g2 = rand_graph(10, 30, graph_type = GRAPH_T) - G = typeof(g2) - g = G(10) - @test g.num_nodes == 10 - @test g.num_edges == 0 - - g = GNNGraph(graph_type = GRAPH_T) - @test g.num_nodes == 0 -end - -@testset "symmetric graph" begin - s = [1, 1, 2, 2, 3, 3, 4, 4] - t = [2, 4, 1, 3, 2, 4, 1, 3] - adj_mat = [0 1 0 1 - 1 0 1 0 - 0 1 0 1 - 1 0 1 0] - adj_list_out = [[2, 4], [1, 3], [2, 4], [1, 3]] - adj_list_in = [[2, 4], [1, 3], [2, 4], [1, 3]] - - # core functionality - g = GNNGraph(s, t; graph_type = GRAPH_T) - if TEST_GPU - dev = CUDADevice() - g_gpu = g |> dev - end - - @test g.num_edges == 8 - @test g.num_nodes == 4 - @test nv(g) == g.num_nodes - @test ne(g) == g.num_edges - @test Tuple.(collect(edges(g))) |> sort == collect(zip(s, t)) |> sort - @test sort(outneighbors(g, 1)) == [2, 4] - @test sort(inneighbors(g, 1)) == [2, 4] - @test is_directed(g) == true - s1, t1 = sort_edge_index(edge_index(g)) - @test s1 == s - @test t1 == t - @test vertices(g) == 1:(g.num_nodes) - - @test sort.(adjacency_list(g; dir = :in)) == adj_list_in - @test sort.(adjacency_list(g; dir = :out)) == adj_list_out - - @testset "adjacency_matrix" begin - @test adjacency_matrix(g) == adj_mat - @test adjacency_matrix(g; dir = :in) == adj_mat - @test adjacency_matrix(g; dir = :out) == adj_mat - - if TEST_GPU - # See https://github.com/JuliaGPU/CUDA.jl/pull/1093 - mat_gpu = adjacency_matrix(g_gpu) - @test mat_gpu isa ACUMatrix{Int} - @test Array(mat_gpu) == adj_mat - end - end - - @testset "normalized_laplacian" begin - mat = normalized_laplacian(g) - if TEST_GPU - mat_gpu = normalized_laplacian(g_gpu) - @test mat_gpu isa ACUMatrix{Float32} - @test Array(mat_gpu) == mat - end - end - - @testset "scaled_laplacian" begin if TEST_GPU - mat = scaled_laplacian(g) - mat_gpu = scaled_laplacian(g_gpu) - @test mat_gpu isa ACUMatrix{Float32} - @test Array(mat_gpu) ≈ mat - end end - - @testset "constructors" begin - adjacency_matrix(g; dir = :out) == adj_mat - adjacency_matrix(g; dir = :in) == adj_mat - end - - if TEST_GPU - @testset "functor" begin - s_cpu, t_cpu = edge_index(g) - s_gpu, t_gpu = edge_index(g_gpu) - @test s_gpu isa CuVector{Int} - @test Array(s_gpu) == s_cpu - @test t_gpu isa CuVector{Int} - @test Array(t_gpu) == t_cpu - end - end -end - -@testset "asymmetric graph" begin - s = [1, 2, 3, 4] - t = [2, 3, 4, 1] - adj_mat_out = [0 1 0 0 - 0 0 1 0 - 0 0 0 1 - 1 0 0 0] - adj_list_out = [[2], [3], [4], [1]] - - adj_mat_in = [0 0 0 1 - 1 0 0 0 - 0 1 0 0 - 0 0 1 0] - adj_list_in = [[4], [1], [2], [3]] - - # core functionality - g = GNNGraph(s, t; graph_type = GRAPH_T) - if TEST_GPU - dev = CUDADevice() #TODO replace with `gpu_device()` - g_gpu = g |> dev - end - - @test g.num_edges == 4 - @test g.num_nodes == 4 - @test length(edges(g)) == 4 - @test sort(outneighbors(g, 1)) == [2] - @test sort(inneighbors(g, 1)) == [4] - @test is_directed(g) == true - @test is_directed(typeof(g)) == true - s1, t1 = sort_edge_index(edge_index(g)) - @test s1 == s - @test t1 == t - - # adjacency - @test adjacency_matrix(g) == adj_mat_out - @test adjacency_list(g) == adj_list_out - @test adjacency_matrix(g, dir = :out) == adj_mat_out - @test adjacency_list(g, dir = :out) == adj_list_out - @test adjacency_matrix(g, dir = :in) == adj_mat_in - @test adjacency_list(g, dir = :in) == adj_list_in -end - -@testset "zero" begin - g = rand_graph(4, 6, graph_type = GRAPH_T) - G = typeof(g) - @test zero(G) == G(0) -end - -@testset "Graphs.jl constructor" begin - lg = random_regular_graph(10, 4) - @test !Graphs.is_directed(lg) - g = GNNGraph(lg) - @test g.num_edges == 2 * ne(lg) # g in undirected - @test Graphs.is_directed(g) - for e in Graphs.edges(lg) - i, j = src(e), dst(e) - @test has_edge(g, i, j) - @test has_edge(g, j, i) - end - - @testset "SimpleGraph{Int32}" begin - g = GNNGraph(SimpleGraph{Int32}(6), graph_type = GRAPH_T) - @test g.num_nodes == 6 - end -end - -@testset "Features" begin - g = GNNGraph(sprand(10, 10, 0.3), graph_type = GRAPH_T) - - # default names - X = rand(10, g.num_nodes) - E = rand(10, g.num_edges) - U = rand(10, g.num_graphs) - - g = GNNGraph(g, ndata = X, edata = E, gdata = U) - @test g.ndata.x === X - @test g.edata.e === E - @test g.gdata.u === U - @test g.x === g.ndata.x - @test g.e === g.edata.e - @test g.u === g.gdata.u - - # Check no args - g = GNNGraph(g) - @test g.ndata.x === X - @test g.edata.e === E - @test g.gdata.u === U - - # multiple features names - g = GNNGraph(g, ndata = (x2 = 2X, g.ndata...), edata = (e2 = 2E, g.edata...), - gdata = (u2 = 2U, g.gdata...)) - @test g.ndata.x === X - @test g.edata.e === E - @test g.gdata.u === U - @test g.ndata.x2 ≈ 2X - @test g.edata.e2 ≈ 2E - @test g.gdata.u2 ≈ 2U - @test g.x === g.ndata.x - @test g.e === g.edata.e - @test g.u === g.gdata.u - @test g.x2 === g.ndata.x2 - @test g.e2 === g.edata.e2 - @test g.u2 === g.gdata.u2 - - # Dimension checks - @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata = rand(29), - graph_type = GRAPH_T) - @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata = rand(2, 29), - graph_type = GRAPH_T) - @test_throws AssertionError GNNGraph(erdos_renyi(10, 30), - edata = (; x = rand(30), y = rand(29)), - graph_type = GRAPH_T) - - # Copy features on reverse edge - e = rand(30) - g = GNNGraph(erdos_renyi(10, 30), edata = e, graph_type = GRAPH_T) - @test g.edata.e == [e; e] - - # non-array global - g = rand_graph(10, 30, gdata = "ciao", graph_type = GRAPH_T) - @test g.gdata.u == "ciao" - - # vectors stays vectors - g = rand_graph(10, 30, ndata = rand(10), - edata = rand(30), - gdata = (u = rand(2), z = rand(1), q = 1), - graph_type = GRAPH_T) - @test size(g.ndata.x) == (10,) - @test size(g.edata.e) == (30,) - @test size(g.gdata.u) == (2, 1) - @test size(g.gdata.z) == (1,) - @test g.gdata.q === 1 - - # Error for non-array ndata - @test_throws AssertionError rand_graph(10, 30, ndata = "ciao", graph_type = GRAPH_T) - @test_throws AssertionError rand_graph(10, 30, ndata = 1, graph_type = GRAPH_T) - - # Error for Ambiguous getproperty - g = rand_graph(10, 20, ndata = rand(2, 10), edata = (; x = rand(3, 20)), - graph_type = GRAPH_T) - @test size(g.ndata.x) == (2, 10) - @test size(g.edata.x) == (3, 20) - @test_throws ArgumentError g.x -end - -@testset "MLUtils and DataLoader compat" begin - n, m, num_graphs = 10, 30, 50 - X = rand(10, n) - E = rand(10, m) - U = rand(10, 1) - data = [rand_graph(n, m, ndata = X, edata = E, gdata = U, graph_type = GRAPH_T) - for _ in 1:num_graphs] - g = MLUtils.batch(data) - - @testset "batch then pass to dataloader" begin - @test MLUtils.getobs(g, 3) == getgraph(g, 3) - @test MLUtils.getobs(g, 3:5) == getgraph(g, 3:5) - @test MLUtils.numobs(g) == g.num_graphs - - d = MLUtils.DataLoader(g, batchsize = 2, shuffle = false) - @test first(d) == getgraph(g, 1:2) - end - - @testset "pass to dataloader and no automatic collation" begin - @test MLUtils.getobs(data, 3) == data[3] - @test MLUtils.getobs(data, 3:5) isa Vector{<:GNNGraph} - @test MLUtils.getobs(data, 3:5) == [data[3], data[4], data[5]] - @test MLUtils.numobs(data) == g.num_graphs - - d = MLUtils.DataLoader(data, batchsize = 2, shuffle = false) - @test first(d) == [data[1], data[2]] - end -end - -@testset "Graphs.jl integration" begin - g = GNNGraph(erdos_renyi(10, 20), graph_type = GRAPH_T) - @test g isa Graphs.AbstractGraph -end - -@testset "==" begin - g1 = rand_graph(5, 6, ndata = rand(5), edata = rand(6), graph_type = GRAPH_T) - @test g1 == g1 - @test g1 == deepcopy(g1) - @test g1 !== deepcopy(g1) - - g2 = GNNGraph(g1, graph_type = GRAPH_T) - @test g1 == g2 - @test g1 === g2 # this is true since GNNGraph is immutable - - g2 = GNNGraph(g1, ndata = rand(5), graph_type = GRAPH_T) - @test g1 != g2 - @test g1 !== g2 - - g2 = GNNGraph(g1, edata = rand(6), graph_type = GRAPH_T) - @test g1 != g2 - @test g1 !== g2 -end - -@testset "hash" begin - g1 = rand_graph(5, 6, ndata = rand(5), edata = rand(6), graph_type = GRAPH_T) - @test hash(g1) == hash(g1) - @test hash(g1) == hash(deepcopy(g1)) - @test hash(g1) == hash(GNNGraph(g1, ndata = g1.ndata, graph_type = GRAPH_T)) - @test hash(g1) == hash(GNNGraph(g1, ndata = g1.ndata, graph_type = GRAPH_T)) - @test hash(g1) != hash(GNNGraph(g1, ndata = rand(5), graph_type = GRAPH_T)) - @test hash(g1) != hash(GNNGraph(g1, edata = rand(6), graph_type = GRAPH_T)) -end - -@testset "copy" begin - g1 = rand_graph(10, 4, ndata = rand(2, 10), graph_type = GRAPH_T) - g2 = copy(g1) - @test g1 === g2 # shallow copies are identical for immutable objects - - g2 = copy(g1, deep = true) - @test g1 == g2 - @test g1 !== g2 -end - -## Cannot test this because DataStore is not an ordered collection -## Uncomment when/if it will be based on OrderedDict -# @testset "show" begin -# @test sprint(show, rand_graph(10, 20)) == "GNNGraph(10, 20) with no data" -# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10))) == "GNNGraph(10, 20) with x: 5×10 data" -# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1)))) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20, (q: 1×1, p: 3×1) data" -# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10),))) == "GNNGraph(10, 20) with a: 5×10 data" -# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20))) == "GNNGraph(10, 20) with x: 5×10, e: 2×20 data" -# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1))) == "GNNGraph(10, 20) with x: 5×10, u: 1×1 data" -# @test sprint(show, rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1))) == "GNNGraph(10, 20) with x: 5×10, (e: 2×20, f: 2×20, h: 3×20), u: 1×1 data" -# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20))) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20 data" -# @test sprint(show, rand_graph(10, 20, ndata=(a=rand(5,5, 10), b=rand(3,2, 10)), edata=rand(2, 20))) == "GNNGraph(10, 20) with (a: 5×5×10, b: 3×2×10), e: 2×20 data" -# end - -# @testset "show plain/text compact true" begin -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => true) == "GNNGraph(10, 20) with no data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1))); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20, (q: 1×1, p: 3×1) data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10),)); context=:compact => true) == "GNNGraph(10, 20) with a: 5×10 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, e: 2×20 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, u: 1×1 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1)); context=:compact => true) == "GNNGraph(10, 20) with x: 5×10, (e: 2×20, f: 2×20, h: 3×20), u: 1×1 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×10, b: 3×10), e: 2×20 data" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5,5, 10), b=rand(3,2, 10)), edata=rand(2, 20)); context=:compact => true) == "GNNGraph(10, 20) with (a: 5×5×10, b: 3×2×10), e: 2×20 data" -# end - -# @testset "show plain/text compact false" begin -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20), gdata=(q=rand(1, 1), p=rand(3, 1))); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}\n\tb = 3×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}\n gdata:\n\tq = 1×1 Matrix{Float64}\n\tp = 3×1 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10),)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), gdata=rand(1, 1)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n gdata:\n\tu = 1×1 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=rand(5, 10), edata=(e=rand(2, 20), f=rand(2, 20), h=rand(3, 20)), gdata=rand(1, 1)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\tx = 5×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}\n\tf = 2×20 Matrix{Float64}\n\th = 3×20 Matrix{Float64}\n gdata:\n\tu = 1×1 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 10), b=rand(3, 10)), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×10 Matrix{Float64}\n\tb = 3×10 Matrix{Float64}\n edata:\n\te = 2×20 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), rand_graph(10, 20, ndata=(a=rand(5, 5, 10), b=rand(3, 2, 10)), edata=rand(2, 20)); context=:compact => false) == "GNNGraph:\n num_nodes: 10\n num_edges: 20\n ndata:\n\ta = 5×5×10 Array{Float64, 3}\n\tb = 3×2×10 Array{Float64, 3}\n edata:\n\te = 2×20 Matrix{Float64}" -# end - -[.\GNNGraphs\test\gnnheterograph.jl] - - -@testset "Empty constructor" begin - g = GNNHeteroGraph() - @test isempty(g.num_nodes) - g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4])) - @test g.num_nodes[:user] == 3 - @test g.num_nodes[:actor] == 9 - @test g.num_edges[(:user, :like, :actor)] == 5 -end - -@testset "Constructor from pairs" begin - hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3,4], [3,2,1,5])) - @test hg.num_nodes == Dict(:A => 4, :B => 5) - @test hg.num_edges == Dict((:A, :e1, :B) => 4) - - hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3], [3,2,1]), - (:A, :e2, :C) => ([1,2,3], [4,5,6])) - @test hg.num_nodes == Dict(:A => 3, :B => 3, :C => 6) - @test hg.num_edges == Dict((:A, :e1, :B) => 3, (:A, :e2, :C) => 3) -end - -@testset "Generation" begin - hg = rand_heterograph(Dict(:A => 10, :B => 20), - Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10)) - - @test hg.num_nodes == Dict(:A => 10, :B => 20) - @test hg.num_edges == Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10) - @test hg.graph_indicator === nothing - @test hg.num_graphs == 1 - @test hg.ndata isa Dict{Symbol, DataStore} - @test hg.edata isa Dict{Tuple{Symbol, Symbol, Symbol}, DataStore} - @test isempty(hg.gdata) - @test sort(hg.ntypes) == [:A, :B] - @test sort(hg.etypes) == [(:A, :rel1, :B), (:B, :rel2, :A)] - -end - -@testset "features" begin - hg = rand_heterograph(Dict(:A => 10, :B => 20), - Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), - ndata = Dict(:A => rand(2, 10), - :B => (x = rand(3, 20), y = rand(4, 20))), - edata = Dict((:A, :rel1, :B) => rand(5, 30)), - gdata = 1) - - @test size(hg.ndata[:A].x) == (2, 10) - @test size(hg.ndata[:B].x) == (3, 20) - @test size(hg.ndata[:B].y) == (4, 20) - @test size(hg.edata[(:A, :rel1, :B)].e) == (5, 30) - @test hg.gdata == DataStore(u = 1) - -end - -@testset "indexing syntax" begin - g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7])) - g[:movie].z = rand(Float32, 64, 13); - g[:user, :rate, :movie].e = rand(Float32, 64, 4); - g[:user].x = rand(Float32, 64, 3); - @test size(g.ndata[:user].x) == (64, 3) - @test size(g.ndata[:movie].z) == (64, 13) - @test size(g.edata[(:user, :rate, :movie)].e) == (64, 4) -end - - -@testset "simplified constructor" begin - hg = rand_heterograph((:A => 10, :B => 20), - ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), - ndata = (:A => rand(2, 10), - :B => (x = rand(3, 20), y = rand(4, 20))), - edata = (:A, :rel1, :B) => rand(5, 30), - gdata = 1) - - @test hg.num_nodes == Dict(:A => 10, :B => 20) - @test hg.num_edges == Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10) - @test hg.graph_indicator === nothing - @test hg.num_graphs == 1 - @test size(hg.ndata[:A].x) == (2, 10) - @test size(hg.ndata[:B].x) == (3, 20) - @test size(hg.ndata[:B].y) == (4, 20) - @test size(hg.edata[(:A, :rel1, :B)].e) == (5, 30) - @test hg.gdata == DataStore(u = 1) - - nA, nB = 10, 20 - edges1 = rand(1:nA, 20), rand(1:nB, 20) - edges2 = rand(1:nB, 30), rand(1:nA, 30) - hg = GNNHeteroGraph(((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2)) - @test hg.num_edges == Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) - - nA, nB = 10, 20 - edges1 = rand(1:nA, 20), rand(1:nB, 20) - edges2 = rand(1:nB, 30), rand(1:nA, 30) - hg = GNNHeteroGraph(((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2); - num_nodes = (:A => nA, :B => nB)) - @test hg.num_nodes == Dict(:A => 10, :B => 20) - @test hg.num_edges == Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) -end - -@testset "num_edge_types / num_node_types" begin - hg = rand_heterograph((:A => 10, :B => 20), - ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), - ndata = (:A => rand(2, 10), - :B => (x = rand(3, 20), y = rand(4, 20))), - edata = (:A, :rel1, :B) => rand(5, 30), - gdata = 1) - @test num_edge_types(hg) == 2 - @test num_node_types(hg) == 2 - - g = rand_graph(10, 20) - @test num_edge_types(g) == 1 - @test num_node_types(g) == 1 -end - -@testset "numobs" begin - hg = rand_heterograph((:A => 10, :B => 20), - ((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10), - ndata = (:A => rand(2, 10), - :B => (x = rand(3, 20), y = rand(4, 20))), - edata = (:A, :rel1, :B) => rand(5, 30), - gdata = 1) - @test MLUtils.numobs(hg) == 1 -end - -@testset "get/set node features" begin - d, n = 3, 5 - g = rand_bipartite_heterograph((n, 2*n), 15) - g[:A].x = rand(Float32, d, n) - g[:B].y = rand(Float32, d, 2*n) - - @test size(g[:A].x) == (d, n) - @test size(g[:B].y) == (d, 2*n) -end - -@testset "add_edges" begin - d, n = 3, 5 - g = rand_bipartite_heterograph((n, 2 * n), 15) - s, t = [1, 2, 3], [3, 2, 1] - ## Keep the same ntypes - construct with args - g1 = add_edges(g, (:A, :rel1, :B), s, t) - @test num_node_types(g1) == 2 - @test num_edge_types(g1) == 3 - for i in eachindex(s, t) - @test has_edge(g1, (:A, :rel1, :B), s[i], t[i]) - end - # no change to num_nodes - @test g1.num_nodes[:A] == n - @test g1.num_nodes[:B] == 2n - - ## Keep the same ntypes - construct with a pair - g2 = add_edges(g, (:A, :rel1, :B) => (s, t)) - @test num_node_types(g2) == 2 - @test num_edge_types(g2) == 3 - for i in eachindex(s, t) - @test has_edge(g2, (:A, :rel1, :B), s[i], t[i]) - end - # no change to num_nodes - @test g2.num_nodes[:A] == n - @test g2.num_nodes[:B] == 2n - - ## New ntype with num_nodes (applies only to the new ntype) and edata - edata = rand(Float32, d, length(s)) - g3 = add_edges(g, - (:A, :rel1, :C) => (s, t); - num_nodes = Dict(:A => 1, :B => 1, :C => 10), - edata) - @test num_node_types(g3) == 3 - @test num_edge_types(g3) == 3 - for i in eachindex(s, t) - @test has_edge(g3, (:A, :rel1, :C), s[i], t[i]) - end - # added edata - @test g3.edata[(:A, :rel1, :C)].e == edata - # no change to existing num_nodes - @test g3.num_nodes[:A] == n - @test g3.num_nodes[:B] == 2n - # new num_nodes added as per kwarg - @test g3.num_nodes[:C] == 10 -end - -@testset "add self loops" begin - g1 = GNNHeteroGraph((:A, :to, :B) => ([1,2,3,4], [3,2,1,5])) - g2 = add_self_loops(g1, (:A, :to, :B)) - @test g2.num_edges[(:A, :to, :B)] === g1.num_edges[(:A, :to, :B)] - g1 = GNNHeteroGraph((:A, :to, :A) => ([1,2,3,4], [3,2,1,5])) - g2 = add_self_loops(g1, (:A, :to, :A)) - @test g2.num_edges[(:A, :to, :A)] === g1.num_edges[(:A, :to, :A)] + g1.num_nodes[(:A)] -end - -## Cannot test this because DataStore is not an ordered collection -## Uncomment when/if it will be based on OrderedDict -# @testset "show" begin -# num_nodes = Dict(:A => 10, :B => 20); -# edges1 = rand(1:num_nodes[:A], 20), rand(1:num_nodes[:B], 20) -# edges2 = rand(1:num_nodes[:B], 30), rand(1:num_nodes[:A], 30) -# eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2) -# ndata = Dict(:A => (x = rand(2, num_nodes[:A]), y = rand(3, num_nodes[:A])),:B => rand(10, num_nodes[:B])) -# edata= Dict((:A, :rel1, :B) => (x = rand(2, 20), y = rand(3, 20)),(:B, :rel2, :A) => rand(10, 30)) -# hg1 = GNNHeteroGraph(eindex; num_nodes) -# hg2 = GNNHeteroGraph(eindex; num_nodes, ndata,edata) -# hg3 = GNNHeteroGraph(eindex; num_nodes, ndata) -# @test sprint(show, hg1) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))" -# @test sprint(show, hg2) == sprint(show, hg1) -# @test sprint(show, MIME("text/plain"), hg1; context=:compact => true) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))" -# @test sprint(show, MIME("text/plain"), hg2; context=:compact => true) == sprint(show, MIME("text/plain"), hg1;context=:compact => true) -# @test sprint(show, MIME("text/plain"), hg1; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)" -# @test sprint(show, MIME("text/plain"), hg2; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}\n edata:\n\t(:A, :rel1, :B) => (x = 2×20 Matrix{Float64}, y = 3×20 Matrix{Float64})\n\t(:B, :rel2, :A) => e = 10×30 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), hg3; context=:compact => false) =="GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}" -# @test sprint(show, MIME("text/plain"), hg2; context=:compact => false) != sprint(show, MIME("text/plain"), hg3; context=:compact => false) -# end - -[.\GNNGraphs\test\mldatasets.jl] -dataset = Cora() -classes = dataset.metadata["classes"] -gml = dataset[1] -g = mldataset2gnngraph(dataset) -@test g isa GNNGraph -@test g.num_nodes == gml.num_nodes -@test g.num_edges == gml.num_edges -@test edge_index(g) === gml.edge_index - -[.\GNNGraphs\test\operators.jl] -@testset "intersect" begin - g = rand_graph(10, 20, graph_type = GRAPH_T) - @test intersect(g, g).num_edges == 20 -end - -[.\GNNGraphs\test\query.jl] -@testset "is_bidirected" begin - g = rand_graph(10, 20, bidirected = true, graph_type = GRAPH_T) - @test is_bidirected(g) - - g = rand_graph(10, 20, bidirected = false, graph_type = GRAPH_T) - @test !is_bidirected(g) -end - -@testset "has_multi_edges" begin if GRAPH_T == :coo - s = [1, 1, 2, 3] - t = [2, 2, 2, 4] - g = GNNGraph(s, t, graph_type = GRAPH_T) - @test has_multi_edges(g) - - s = [1, 2, 2, 3] - t = [2, 1, 2, 4] - g = GNNGraph(s, t, graph_type = GRAPH_T) - @test !has_multi_edges(g) -end end - -@testset "edges" begin - g = rand_graph(4, 10, graph_type = GRAPH_T) - @test edgetype(g) <: Graphs.Edge - for e in edges(g) - @test e isa Graphs.Edge - end -end - -@testset "has_isolated_nodes" begin - s = [1, 2, 3] - t = [2, 3, 2] - g = GNNGraph(s, t, graph_type = GRAPH_T) - @test has_isolated_nodes(g) == false - @test has_isolated_nodes(g, dir = :in) == true -end - -@testset "has_self_loops" begin - s = [1, 1, 2, 3] - t = [2, 2, 2, 4] - g = GNNGraph(s, t, graph_type = GRAPH_T) - @test has_self_loops(g) - - s = [1, 1, 2, 3] - t = [2, 2, 3, 4] - g = GNNGraph(s, t, graph_type = GRAPH_T) - @test !has_self_loops(g) -end - -@testset "degree" begin - @testset "unweighted" begin - s = [1, 1, 2, 3] - t = [2, 2, 2, 4] - g = GNNGraph(s, t, graph_type = GRAPH_T) - - @test degree(g) isa Vector{Int} - @test degree(g) == degree(g; dir = :out) == [2, 1, 1, 0] # default is outdegree - @test degree(g; dir = :in) == [0, 3, 0, 1] - @test degree(g; dir = :both) == [2, 4, 1, 1] - @test eltype(degree(g, Float32)) == Float32 - - if TEST_GPU - dev = CUDADevice() #TODO replace with `gpu_device()` - g_gpu = g |> dev - d = degree(g) - d_gpu = degree(g_gpu) - @test d_gpu isa CuVector{Int} - @test Array(d_gpu) == d - end - end - - @testset "weighted" begin - # weighted degree - s = [1, 1, 2, 3] - t = [2, 2, 2, 4] - eweight = Float32[0.1, 2.1, 1.2, 1] - g = GNNGraph((s, t, eweight), graph_type = GRAPH_T) - @test degree(g) ≈ [2.2, 1.2, 1.0, 0.0] - d = degree(g, edge_weight = false) - if GRAPH_T == :coo - @test d == [2, 1, 1, 0] - else - # Adjacency matrix representation cannot disambiguate multiple edges - # and edge weights - @test d == [1, 1, 1, 0] - end - @test eltype(d) <: Integer - @test degree(g, edge_weight = 2 * eweight) ≈ [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo) - - if TEST_GPU - dev = CUDADevice() #TODO replace with `gpu_device()` - g_gpu = g |> dev - d = degree(g) - d_gpu = degree(g_gpu) - @test d_gpu isa CuVector{Float32} - @test Array(d_gpu) ≈ d - end - @testset "gradient" begin - gw = gradient(eweight) do w - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - sum(degree(g, edge_weight = false)) - end[1] - - @test gw === nothing - - gw = gradient(eweight) do w - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - sum(degree(g, edge_weight = true)) - end[1] - - @test gw isa AbstractVector{Float32} - @test gw isa Vector{Float32} broken = (GRAPH_T == :sparse) - @test gw ≈ ones(Float32, length(gw)) - - gw = gradient(eweight) do w - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - sum(degree(g, dir=:both, edge_weight=true)) - end[1] - - @test gw isa AbstractVector{Float32} - @test gw isa Vector{Float32} broken = (GRAPH_T == :sparse) - @test gw ≈ 2 * ones(Float32, length(gw)) - - grad = gradient(g) do g - sum(degree(g, edge_weight=false)) - end[1] - @test grad === nothing - - grad = gradient(g) do g - sum(degree(g, edge_weight=true)) - end[1] - - if GRAPH_T == :coo - @test grad.graph[3] isa Vector{Float32} - @test grad.graph[3] ≈ ones(Float32, length(gw)) - else - if GRAPH_T == :sparse - @test grad.graph isa AbstractSparseMatrix{Float32} - end - @test grad.graph isa AbstractMatrix{Float32} - - @test grad.graph ≈ [0.0 1.0 0.0 0.0 - 0.0 1.0 0.0 0.0 - 0.0 0.0 0.0 1.0 - 0.0 0.0 0.0 0.0] - end - - @testset "directed, degree dir=$dir" for dir in [:in, :out, :both] - g = rand_graph(10, 30, bidirected=false) - w = rand(Float32, 30) - s, t = edge_index(g) - - grad = gradient(w) do w - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - sum(tanh.(degree(g; dir, edge_weight=true))) - end[1] - - ngrad = ngradient(w) do w - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - sum(tanh.(degree(g; dir, edge_weight=true))) - end[1] - - @test grad ≈ ngrad - end - - @testset "heterognn, degree" begin - g = GNNHeteroGraph((:A, :to, :B) => ([1,1,2,3], [7,13,5,7])) - @test degree(g, (:A, :to, :B), dir = :out) == [2, 1, 1] - @test degree(g, (:A, :to, :B), dir = :in) == [0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1] - @test degree(g, (:A, :to, :B)) == [2, 1, 1] - end - end - end -end - -@testset "laplacian_matrix" begin - g = rand_graph(10, 30, graph_type = GRAPH_T) - A = adjacency_matrix(g) - D = Diagonal(vec(sum(A, dims = 2))) - L = laplacian_matrix(g) - @test eltype(L) == eltype(g) - @test L ≈ D - A -end - -@testset "laplacian_lambda_max" begin - s = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] - t = [2, 3, 4, 5, 1, 5, 1, 2, 3, 4] - g = GNNGraph(s, t) - @test laplacian_lambda_max(g) ≈ Float32(1.809017) - data1 = [g for i in 1:5] - gall1 = MLUtils.batch(data1) - @test laplacian_lambda_max(gall1) ≈ [Float32(1.809017) for i in 1:5] - data2 = [rand_graph(10, 20) for i in 1:3] - gall2 = MLUtils.batch(data2) - @test length(laplacian_lambda_max(gall2, add_self_loops=true)) == 3 -end - -@testset "adjacency_matrix" begin - a = sprand(5, 5, 0.5) - abin = map(x -> x > 0 ? 1 : 0, a) - - g = GNNGraph(a, graph_type = GRAPH_T) - A = adjacency_matrix(g, Float32) - @test A ≈ a - @test eltype(A) == Float32 - - Abin = adjacency_matrix(g, Float32, weighted = false) - @test Abin ≈ abin - @test eltype(Abin) == Float32 - - @testset "gradient" begin - s = [1, 2, 3] - t = [2, 3, 1] - w = [0.1, 0.1, 0.2] - gw = gradient(w) do w - g = GNNGraph(s, t, w, graph_type = GRAPH_T) - A = adjacency_matrix(g, weighted = false) - sum(A) - end[1] - @test gw === nothing - - gw = gradient(w) do w - g = GNNGraph(s, t, w, graph_type = GRAPH_T) - A = adjacency_matrix(g, weighted = true) - sum(A) - end[1] - - @test gw == [1, 1, 1] - end - - @testset "khop_adj" begin - s = [1, 2, 3] - t = [2, 3, 1] - w = [0.1, 0.1, 0.2] - g = GNNGraph(s, t, w) - @test khop_adj(g, 2) == adjacency_matrix(g) * adjacency_matrix(g) - @test khop_adj(g, 2, Int8; weighted = false) == sparse([0 0 1; 1 0 0; 0 1 0]) - @test khop_adj(g, 2, Int8; dir = in, weighted = false) == - sparse([0 0 1; 1 0 0; 0 1 0]') - @test khop_adj(g, 1) == adjacency_matrix(g) - @test eltype(khop_adj(g, 4)) == Float64 - @test eltype(khop_adj(g, 10, Float32)) == Float32 - end -end - -if GRAPH_T == :coo - @testset "HeteroGraph" begin - @testset "graph_indicator" begin - gs = [rand_heterograph(Dict(:user => 10, :movie => 20, :actor => 30), - Dict((:user,:like,:movie) => 10, - (:actor,:rate,:movie)=>20)) for _ in 1:3] - g = MLUtils.batch(gs) - @test graph_indicator(g) == Dict(:user => [repeat([1], 10); repeat([2], 10); repeat([3], 10)], - :movie => [repeat([1], 20); repeat([2], 20); repeat([3], 20)], - :actor => [repeat([1], 30); repeat([2], 30); repeat([3], 30)]) - @test graph_indicator(g, :movie) == [repeat([1], 20); repeat([2], 20); repeat([3], 20)] - end - end -end - - -[.\GNNGraphs\test\runtests.jl] -using CUDA, cuDNN -using GNNGraphs -using GNNGraphs: getn, getdata -using Functors -using LinearAlgebra, Statistics, Random -using NNlib -import MLUtils -import StatsBase -using SparseArrays -using Graphs -using Zygote -using Test -using MLDatasets -using InlineStrings # not used but with the import we test #98 and #104 -using SimpleWeightedGraphs -using MLDataDevices: gpu_device, cpu_device, get_device -using MLDataDevices: CUDADevice - -CUDA.allowscalar(false) - -const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} - -ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets - -include("test_utils.jl") - -tests = [ - "chainrules", - "datastore", - "gnngraph", - "convert", - "transform", - "operators", - "generate", - "query", - "sampling", - "gnnheterograph", - "temporalsnapshotsgnngraph", - "mldatasets", - "ext/SimpleWeightedGraphs" -] - -!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") - -for graph_type in (:coo, :dense, :sparse) - @info "Testing graph format :$graph_type" - global GRAPH_T = graph_type - global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) - # global GRAPH_T = :sparse - # global TEST_GPU = false - - @testset "$t" for t in tests - include("$t.jl") - end -end - -[.\GNNGraphs\test\sampling.jl] -if GRAPH_T == :coo - @testset "sample_neighbors" begin - # replace = false - dir = :in - nodes = 2:3 - g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) - sg = sample_neighbors(g, nodes; dir) - @test sg.num_nodes == 10 - @test sg.num_edges == sum(degree(g, i; dir) for i in nodes) - @test size(sg.edata.EID) == (sg.num_edges,) - @test length(union(sg.edata.EID)) == length(sg.edata.EID) - adjlist = adjacency_list(g; dir) - s, t = edge_index(sg) - @test all(t .∈ Ref(nodes)) - for i in nodes - @test sort(neighbors(sg, i; dir)) == sort(neighbors(g, i; dir)) - end - - # replace = true - dir = :out - nodes = 2:3 - K = 2 - g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) - sg = sample_neighbors(g, nodes, K; dir, replace = true) - @test sg.num_nodes == 10 - @test sg.num_edges == sum(K for i in nodes) - @test size(sg.edata.EID) == (sg.num_edges,) - adjlist = adjacency_list(g; dir) - s, t = edge_index(sg) - @test all(s .∈ Ref(nodes)) - for i in nodes - @test issubset(neighbors(sg, i; dir), adjlist[i]) - end - - # dropnodes = true - dir = :in - nodes = 2:3 - g = rand_graph(10, 40, bidirected = false, graph_type = GRAPH_T) - g = GNNGraph(g, ndata = (x1 = rand(10),), edata = (e1 = rand(40),)) - sg = sample_neighbors(g, nodes; dir, dropnodes = true) - @test sg.num_edges == sum(degree(g, i; dir) for i in nodes) - @test size(sg.edata.EID) == (sg.num_edges,) - @test size(sg.ndata.NID) == (sg.num_nodes,) - @test sg.edata.e1 == g.edata.e1[sg.edata.EID] - @test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID] - @test length(union(sg.ndata.NID)) == length(sg.ndata.NID) - end -end -[.\GNNGraphs\test\temporalsnapshotsgnngraph.jl] -@testset "Constructor array TemporalSnapshotsGNNGraph" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - @test tsg.num_nodes == [10 for i in 1:5] - @test tsg.num_edges == [20 for i in 1:5] - wrsnapshots = [rand_graph(10,20), rand_graph(12,22)] - @test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots) -end - -@testset "==" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg1 = TemporalSnapshotsGNNGraph(snapshots) - tsg2 = TemporalSnapshotsGNNGraph(snapshots) - @test tsg1 == tsg2 - tsg3 = TemporalSnapshotsGNNGraph(snapshots[1:3]) - @test tsg1 != tsg3 - @test tsg1 !== tsg3 -end - -@testset "getindex" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - @test tsg[3] == snapshots[3] - @test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata) -end - -@testset "getproperty" begin - x = rand(10) - snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - @test tsg.tgdata == DataStore() - @test tsg.x == tsg.ndata.x == [x for i in 1:5] - @test_throws KeyError tsg.ndata.w - @test_throws ArgumentError tsg.w -end - -@testset "add/remove_snapshot" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - g = rand_graph(10, 20) - tsg = add_snapshot(tsg, 3, g) - @test tsg.num_nodes == [10 for i in 1:6] - @test tsg.num_edges == [20 for i in 1:6] - @test tsg.snapshots[3] == g - tsg = remove_snapshot(tsg, 3) - @test tsg.num_nodes == [10 for i in 1:5] - @test tsg.num_edges == [20 for i in 1:5] - @test tsg.snapshots == snapshots -end - -@testset "add/remove_snapshot" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - g = rand_graph(10, 20) - tsg2 = add_snapshot(tsg, 3, g) - @test tsg2.num_nodes == [10 for i in 1:6] - @test tsg2.num_edges == [20 for i in 1:6] - @test tsg2.snapshots[3] == g - @test tsg2.num_snapshots == 6 - @test tsg.num_nodes == [10 for i in 1:5] - @test tsg.num_edges == [20 for i in 1:5] - @test tsg.snapshots[2] === tsg2.snapshots[2] - @test tsg.snapshots[3] === tsg2.snapshots[4] - @test length(tsg.snapshots) == 5 - @test tsg.num_snapshots == 5 - - tsg21 = add_snapshot(tsg2, 7, g) - @test tsg21.num_snapshots == 7 - - tsg3 = remove_snapshot(tsg, 3) - @test tsg3.num_nodes == [10 for i in 1:4] - @test tsg3.num_edges == [20 for i in 1:4] - @test tsg3.snapshots == snapshots[[1,2,4,5]] -end - - -# @testset "add/remove_snapshot!" begin -# snapshots = [rand_graph(10, 20) for i in 1:5] -# tsg = TemporalSnapshotsGNNGraph(snapshots) -# g = rand_graph(10, 20) -# tsg2 = add_snapshot!(tsg, 3, g) -# @test tsg2.num_nodes == [10 for i in 1:6] -# @test tsg2.num_edges == [20 for i in 1:6] -# @test tsg2.snapshots[3] == g -# @test tsg2.num_snapshots == 6 -# @test tsg2 === tsg - -# tsg3 = remove_snapshot!(tsg, 3) -# @test tsg3.num_nodes == [10 for i in 1:4] -# @test tsg3.num_edges == [20 for i in 1:4] -# @test length(tsg3.snapshots) === 4 -# @test tsg3 === tsg -# end - -@testset "show" begin - snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data" - @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data" - @test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5" - tsg.tgdata.x=rand(4) - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data" -end - -if TEST_GPU - @testset "gpu" begin - snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - tsg.tgdata.x = rand(5) - dev = CUDADevice() #TODO replace with `gpu_device()` - tsg = tsg |> dev - @test tsg.snapshots[1].ndata.x isa CuArray - @test tsg.snapshots[end].ndata.x isa CuArray - @test tsg.tgdata.x isa CuArray - @test tsg.num_nodes isa CuArray - @test tsg.num_edges isa CuArray - end -end - -[.\GNNGraphs\test\test_utils.jl] -using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA -CUDA.allowscalar(false) - -function ngradient(f, x...) - fdm = central_fdm(5, 1) - return FiniteDifferences.grad(fdm, f, x...) -end - -const rule_config = Zygote.ZygoteRuleConfig() - -# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed -function FiniteDifferences.to_vec(x::Integer) - Integer_from_vec(v) = x - return Int[x], Integer_from_vec -end - -# Test that forward pass on cpu and gpu are the same. -# Tests also gradient on cpu and gpu comparing with -# finite difference methods. -# Test gradients with respects to layer weights and to input. -# If `g` has edge features, it is assumed that the layer can -# use them in the forward pass as `l(g, x, e)`. -# Test also gradient with respect to `e`. -function test_layer(l, g::GNNGraph; atol = 1e-5, rtol = 1e-5, - exclude_grad_fields = [], - verbose = false, - test_gpu = TEST_GPU, - outsize = nothing, - outtype = :node) - - # TODO these give errors, probably some bugs in ChainRulesTestUtils - # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false) - # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false) - - isnothing(node_features(g)) && error("Plese add node data to the input graph") - fdm = central_fdm(5, 1) - - x = node_features(g) - e = edge_features(g) - use_edge_feat = !isnothing(e) - - x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad - xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g]) - - f(l, g::GNNGraph) = l(g) - f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x) - - loss(l, g::GNNGraph) = - if outtype == :node - sum(node_features(f(l, g))) - elseif outtype == :edge - sum(edge_features(f(l, g))) - elseif outtype == :graph - sum(graph_features(f(l, g))) - elseif outtype == :node_edge - gnew = f(l, g) - sum(node_features(gnew)) + sum(edge_features(gnew)) - end - - function loss(l, g::GNNGraph, x, e) - y = f(l, g, x, e) - if outtype == :node_edge - return sum(y[1]) + sum(y[2]) - else - return sum(y) - end - end - - # TEST OUTPUT - y = f(l, g, x, e) - if outtype == :node_edge - @assert y isa Tuple - @test eltype(y[1]) == eltype(x) - @test eltype(y[2]) == eltype(e) - @test all(isfinite, y[1]) - @test all(isfinite, y[2]) - if !isnothing(outsize) - @test size(y[1]) == outsize[1] - @test size(y[2]) == outsize[2] - end - else - @test eltype(y) == eltype(x) - @test all(isfinite, y) - if !isnothing(outsize) - @test size(y) == outsize - end - end - - # test same output on different graph formats - gcoo = GNNGraph(g, graph_type = :coo) - ycoo = f(l, gcoo, x, e) - if outtype == :node_edge - @test ycoo[1] ≈ y[1] - @test ycoo[2] ≈ y[2] - else - @test ycoo ≈ y - end - - g′ = f(l, g) - if outtype == :node - @test g′.ndata.x ≈ y - elseif outtype == :edge - @test g′.edata.e ≈ y - elseif outtype == :graph - @test g′.gdata.u ≈ y - elseif outtype == :node_edge - @test g′.ndata.x ≈ y[1] - @test g′.edata.e ≈ y[2] - else - @error "wrong outtype $outtype" - end - if test_gpu - ygpu = f(lgpu, ggpu, xgpu, egpu) - if outtype == :node_edge - @test ygpu[1] isa CuArray - @test eltype(ygpu[1]) == eltype(xgpu) - @test Array(ygpu[1]) ≈ y[1] - @test ygpu[2] isa CuArray - @test eltype(ygpu[2]) == eltype(xgpu) - @test Array(ygpu[2]) ≈ y[2] - else - @test ygpu isa CuArray - @test eltype(ygpu) == eltype(xgpu) - @test Array(ygpu) ≈ y - end - end - - # TEST x INPUT GRADIENT - x̄ = gradient(x -> loss(l, g, x, e), x)[1] - x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1] - @test eltype(x̄) == eltype(x) - @test x̄≈x̄_fd atol=atol rtol=rtol - - if test_gpu - x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1] - @test x̄gpu isa CuArray - @test eltype(x̄gpu) == eltype(x) - @test Array(x̄gpu)≈x̄ atol=atol rtol=rtol - end - - # TEST e INPUT GRADIENT - if e !== nothing - verbose && println("Test e gradient cpu") - ē = gradient(e -> loss(l, g, x, e), e)[1] - ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1] - @test eltype(ē) == eltype(e) - @test ē≈ē_fd atol=atol rtol=rtol - - if test_gpu - verbose && println("Test e gradient gpu") - ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1] - @test ēgpu isa CuArray - @test eltype(ēgpu) == eltype(ē) - @test Array(ēgpu)≈ē atol=atol rtol=rtol - end - end - - # TEST LAYER GRADIENT - l(g, x, e) - l̄ = gradient(l -> loss(l, g, x, e), l)[1] - l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1] - test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) - - if test_gpu - l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1] - test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose) - end - - # TEST LAYER GRADIENT - l(g) - l̄ = gradient(l -> loss(l, g), l)[1] - test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) - - return true -end - -function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5, - exclude_grad_fields = [], - verbose = false) - l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue - l̄fd = l̄fd isa Base.RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue - - for f in fieldnames(typeof(l)) - f ∈ exclude_grad_fields && continue - verbose && println("Test gradient of field $f...") - x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f) - test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose) - verbose && println("... field $f done!") - end - return true -end - -function test_approx_structs(x, g::Nothing, gfd; atol, rtol, kws...) - # finite diff gradients has to be zero if present - @test !(gfd isa AbstractArray) || isapprox(gfd, fill!(similar(gfd), 0); atol, rtol) -end - -function test_approx_structs(x::Union{AbstractArray, Number}, - g::Union{AbstractArray, Number}, gfd; atol, rtol, kws...) - @test eltype(g) == eltype(x) - if x isa CuArray - @test g isa CuArray - g = Array(g) - end - @test g≈gfd atol=atol rtol=rtol -end - -""" - to32(m) - -Convert the `eltype` of model's float parameters to `Float32`. -Preserves integer arrays. -""" -to32(m) = _paramtype(Float32, m) - -""" - to64(m) - -Convert the `eltype` of model's float parameters to `Float64`. -Preserves integer arrays. -""" -to64(m) = _paramtype(Float64, m) - -struct GNNEltypeAdaptor{T} end - -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where T = convert(AbstractArray{T}, x) -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Integer}) where T = x -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x) - -_paramtype(::Type{T}, m) where T = fmap(adapt(GNNEltypeAdaptor{T}()), m) - -[.\GNNGraphs\test\transform.jl] -@testset "add self-loops" begin - A = [1 1 0 0 - 0 0 1 0 - 0 0 0 1 - 1 0 0 0] - A2 = [2 1 0 0 - 0 1 1 0 - 0 0 1 1 - 1 0 0 1] - - g = GNNGraph(A; graph_type = GRAPH_T) - fg2 = add_self_loops(g) - @test adjacency_matrix(g) == A - @test g.num_edges == sum(A) - @test adjacency_matrix(fg2) == A2 - @test fg2.num_edges == sum(A2) -end - -@testset "batch" begin - g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), - graph_type = GRAPH_T) - g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) - g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) - - g12 = MLUtils.batch([g1, g2]) - g12b = blockdiag(g1, g2) - @test g12 == g12b - - g123 = MLUtils.batch([g1, g2, g3]) - @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] - - # Allow wider eltype - g123 = MLUtils.batch(GNNGraph[g1, g2, g3]) - @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] - - - s, t = edge_index(g123) - @test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]] - @test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]] - @test node_features(g123)[:, 11:14] ≈ node_features(g2) - - # scalar graph features - g1 = GNNGraph(g1, gdata = rand()) - g2 = GNNGraph(g2, gdata = rand()) - g3 = GNNGraph(g3, gdata = rand()) - g123 = MLUtils.batch([g1, g2, g3]) - @test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u] - - # Batch of batches - g123123 = MLUtils.batch([g123, g123]) - @test g123123.graph_indicator == - [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] - @test g123123.num_graphs == 6 -end - -@testset "unbatch" begin - g1 = rand_graph(10, 20, graph_type = GRAPH_T) - g2 = rand_graph(5, 10, graph_type = GRAPH_T) - g12 = MLUtils.batch([g1, g2]) - gs = MLUtils.unbatch([g1, g2]) - @test length(gs) == 2 - @test gs[1].num_nodes == 10 - @test gs[1].num_edges == 20 - @test gs[1].num_graphs == 1 - @test gs[2].num_nodes == 5 - @test gs[2].num_edges == 10 - @test gs[2].num_graphs == 1 -end - -@testset "batch/unbatch roundtrip" begin - n = 20 - c = 3 - ngraphs = 10 - gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n), - graph_type = GRAPH_T) - for _ in 1:ngraphs] - gall = MLUtils.batch(gs) - gs2 = MLUtils.unbatch(gall) - @test gs2[1] == gs[1] - @test gs2[end] == gs[end] -end - -@testset "getgraph" begin - g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), - graph_type = GRAPH_T) - g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) - g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) - g = MLUtils.batch([g1, g2, g3]) - - g2b, nodemap = getgraph(g, 2, nmap = true) - s, t = edge_index(g2b) - @test s == edge_index(g2)[1] - @test t == edge_index(g2)[2] - @test node_features(g2b) ≈ node_features(g2) - - g2c = getgraph(g, 2) - @test g2c isa GNNGraph{typeof(g.graph)} - - g1b, nodemap = getgraph(g1, 1, nmap = true) - @test g1b === g1 - @test nodemap == 1:(g1.num_nodes) -end - -@testset "remove_edges" begin - if GRAPH_T == :coo - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - w = [0.1, 0.2, 0.3, 0.4] - edata = ['a', 'b', 'c', 'd'] - g = GNNGraph(s, t, w, edata = edata, graph_type = GRAPH_T) - - # single edge removal - gnew = remove_edges(g, [1]) - new_s, new_t = edge_index(gnew) - @test gnew.num_edges == 3 - @test new_s == s[2:end] - @test new_t == t[2:end] - - # multiple edge removal - gnew = remove_edges(g, [1,2,4]) - new_s, new_t = edge_index(gnew) - new_w = get_edge_weight(gnew) - new_edata = gnew.edata.e - @test gnew.num_edges == 1 - @test new_s == [2] - @test new_t == [4] - @test new_w == [0.3] - @test new_edata == ['c'] - - # drop with probability - gnew = remove_edges(g, Float32(1.0)) - @test gnew.num_edges == 0 - - gnew = remove_edges(g, Float32(0.0)) - @test gnew.num_edges == g.num_edges - end -end - -@testset "add_edges" begin - if GRAPH_T == :coo - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - g = GNNGraph(s, t, graph_type = GRAPH_T) - snew = [1] - tnew = [4] - gnew = add_edges(g, snew, tnew) - @test gnew.num_edges == 5 - @test sort(inneighbors(gnew, 4)) == [1, 2] - - gnew2 = add_edges(g, (snew, tnew)) - @test gnew2 == gnew - @test get_edge_weight(gnew2) === nothing - - g = GNNGraph(s, t, edata = (e1 = rand(2, 4), e2 = rand(3, 4)), graph_type = GRAPH_T) - # @test_throws ErrorException add_edges(g, snew, tnew) - gnew = add_edges(g, snew, tnew, edata = (e1 = ones(2, 1), e2 = zeros(3, 1))) - @test all(gnew.edata.e1[:, 5] .== 1) - @test all(gnew.edata.e2[:, 5] .== 0) - - @testset "adding new nodes" begin - g = GNNGraph() - g = add_edges(g, ([1,3], [2, 1])) - @test g.num_nodes == 3 - @test g.num_edges == 2 - @test sort(inneighbors(g, 1)) == [3] - @test sort(outneighbors(g, 1)) == [2] - end - @testset "also add weights" begin - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - w = [1.0, 2.0, 3.0, 4.0] - snew = [1] - tnew = [4] - wnew = [5.] - - g = GNNGraph((s, t), graph_type = GRAPH_T) - gnew = add_edges(g, (snew, tnew, wnew)) - @test get_edge_weight(gnew) == [ones(length(s)); wnew] - - g = GNNGraph((s, t, w), graph_type = GRAPH_T) - gnew = add_edges(g, (snew, tnew, wnew)) - @test get_edge_weight(gnew) == [w; wnew] - end - end -end - -@testset "perturb_edges" begin if GRAPH_T == :coo - s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1] - g = GNNGraph((s, t)) - rng = MersenneTwister(42) - g_per = perturb_edges(rng, g, 0.5) - @test g_per.num_edges == 8 -end end - -@testset "remove_nodes" begin if GRAPH_T == :coo - #single node - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - eweights = [0.1, 0.2, 0.3, 0.4] - ndata = [1.0, 2.0, 3.0, 4.0, 5.0] - edata = ['a', 'b', 'c', 'd'] - - g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T) - - gnew = remove_nodes(g, [1]) - - snew = [1, 2] - tnew = [3, 4] - eweights_new = [0.3, 0.4] - ndata_new = [2.0, 3.0, 4.0, 5.0] - edata_new = ['c', 'd'] - - stest, ttest = edge_index(gnew) - eweightstest = get_edge_weight(gnew) - ndatatest = gnew.ndata.x - edatatest = gnew.edata.e - - - @test gnew.num_edges == 2 - @test gnew.num_nodes == 4 - @test snew == stest - @test tnew == ttest - @test eweights_new == eweightstest - @test ndata_new == ndatatest - @test edata_new == edatatest - - # multiple nodes - s = [1, 5, 2, 3] - t = [2, 3, 4, 5] - eweights = [0.1, 0.2, 0.3, 0.4] - ndata = [1.0, 2.0, 3.0, 4.0, 5.0] - edata = ['a', 'b', 'c', 'd'] - - g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T) - - gnew = remove_nodes(g, [1,4]) - snew = [3,2] - tnew = [2,3] - eweights_new = [0.2,0.4] - ndata_new = [2.0,3.0,5.0] - edata_new = ['b','d'] - - stest, ttest = edge_index(gnew) - eweightstest = get_edge_weight(gnew) - ndatatest = gnew.ndata.x - edatatest = gnew.edata.e - - @test gnew.num_edges == 2 - @test gnew.num_nodes == 3 - @test snew == stest - @test tnew == ttest - @test eweights_new == eweightstest - @test ndata_new == ndatatest - @test edata_new == edatatest -end end - -@testset "remove_nodes(g, p)" begin - if GRAPH_T == :coo - Random.seed!(42) - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - g = GNNGraph(s, t, graph_type = GRAPH_T) - - gnew = remove_nodes(g, 0.5) - @test gnew.num_nodes == 3 - - gnew = remove_nodes(g, 1.0) - @test gnew.num_nodes == 0 - - gnew = remove_nodes(g, 0.0) - @test gnew.num_nodes == 5 - end -end - -@testset "add_nodes" begin if GRAPH_T == :coo - g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T) - gnew = add_nodes(g, 5, ndata = ones(2, 5)) - @test gnew.num_nodes == g.num_nodes + 5 - @test gnew.num_edges == g.num_edges - @test gnew.num_graphs == g.num_graphs - @test all(gnew.ndata.x[:, 7:11] .== 1) -end end - -@testset "remove_self_loops" begin if GRAPH_T == :coo # add_edges and set_edge_weight only implemented for coo - g = rand_graph(10, 20, graph_type = GRAPH_T) - g1 = add_edges(g, [1:5;], [1:5;]) - @test g1.num_edges == g.num_edges + 5 - g2 = remove_self_loops(g1) - @test g2.num_edges == g.num_edges - @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) - - # with edge features and weights - g1 = GNNGraph(g1, edata = (e1 = ones(3, g1.num_edges), e2 = 2 * ones(g1.num_edges))) - g1 = set_edge_weight(g1, 3 * ones(g1.num_edges)) - g2 = remove_self_loops(g1) - @test g2.num_edges == g.num_edges - @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) - @test size(get_edge_weight(g2)) == (g2.num_edges,) - @test size(g2.edata.e1) == (3, g2.num_edges) - @test size(g2.edata.e2) == (g2.num_edges,) -end end - -@testset "remove_multi_edges" begin if GRAPH_T == :coo - g = rand_graph(10, 20, graph_type = GRAPH_T) - s, t = edge_index(g) - g1 = add_edges(g, s[1:5], t[1:5]) - @test g1.num_edges == g.num_edges + 5 - g2 = remove_multi_edges(g1, aggr = +) - @test g2.num_edges == g.num_edges - @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) - - # Default aggregation is + - g1 = GNNGraph(g1, edata = (e1 = ones(3, g1.num_edges), e2 = 2 * ones(g1.num_edges))) - g1 = set_edge_weight(g1, 3 * ones(g1.num_edges)) - g2 = remove_multi_edges(g1) - @test g2.num_edges == g.num_edges - @test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g)) - @test count(g2.edata.e1[:, i] == 2 * ones(3) for i in 1:(g2.num_edges)) == 5 - @test count(g2.edata.e2[i] == 4 for i in 1:(g2.num_edges)) == 5 - w2 = get_edge_weight(g2) - @test count(w2[i] == 6 for i in 1:(g2.num_edges)) == 5 -end end - -@testset "negative_sample" begin if GRAPH_T == :coo - n, m = 10, 30 - g = rand_graph(n, m, bidirected = true, graph_type = GRAPH_T) - - # check bidirected=is_bidirected(g) default - gneg = negative_sample(g, num_neg_edges = 20) - @test gneg.num_nodes == g.num_nodes - @test gneg.num_edges == 20 - @test is_bidirected(gneg) - @test intersect(g, gneg).num_edges == 0 -end end - -@testset "rand_edge_split" begin if GRAPH_T == :coo - n, m = 100, 300 - - g = rand_graph(n, m, bidirected = true, graph_type = GRAPH_T) - # check bidirected=is_bidirected(g) default - g1, g2 = rand_edge_split(g, 0.9) - @test is_bidirected(g1) - @test is_bidirected(g2) - @test intersect(g1, g2).num_edges == 0 - @test g1.num_edges + g2.num_edges == g.num_edges - @test g2.num_edges < 50 - - g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T) - # check bidirected=is_bidirected(g) default - g1, g2 = rand_edge_split(g, 0.9) - @test !is_bidirected(g1) - @test !is_bidirected(g2) - @test intersect(g1, g2).num_edges == 0 - @test g1.num_edges + g2.num_edges == g.num_edges - @test g2.num_edges < 50 - - g1, g2 = rand_edge_split(g, 0.9, bidirected = false) - @test !is_bidirected(g1) - @test !is_bidirected(g2) - @test intersect(g1, g2).num_edges == 0 - @test g1.num_edges + g2.num_edges == g.num_edges - @test g2.num_edges < 50 -end end - -@testset "set_edge_weight" begin - g = rand_graph(10, 20, graph_type = GRAPH_T) - w = rand(20) - - gw = set_edge_weight(g, w) - @test get_edge_weight(gw) == w - - # now from weighted graph - s, t = edge_index(g) - g2 = GNNGraph(s, t, rand(20), graph_type = GRAPH_T) - gw2 = set_edge_weight(g2, w) - @test get_edge_weight(gw2) == w -end - -@testset "to_bidirected" begin if GRAPH_T == :coo - s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4] - w = [1.0, 2.0, 3.0, 4.0, 5.0] - e = [10.0, 20.0, 30.0, 40.0, 50.0] - g = GNNGraph(s, t, w, edata = e) - - g2 = to_bidirected(g) - @test g2.num_nodes == g.num_nodes - @test g2.num_edges == 7 - @test is_bidirected(g2) - @test !has_multi_edges(g2) - - s2, t2 = edge_index(g2) - w2 = get_edge_weight(g2) - @test s2 == [1, 2, 2, 3, 3, 4, 4] - @test t2 == [2, 1, 3, 2, 4, 3, 4] - @test w2 == [1, 1, 2, 2, 3.5, 3.5, 5] - @test g2.edata.e == [10.0, 10.0, 20.0, 20.0, 35.0, 35.0, 50.0] -end end - -@testset "to_unidirected" begin if GRAPH_T == :coo - s = [1, 2, 3, 4, 4] - t = [2, 3, 4, 3, 4] - w = [1.0, 2.0, 3.0, 4.0, 5.0] - e = [10.0, 20.0, 30.0, 40.0, 50.0] - g = GNNGraph(s, t, w, edata = e) - - g2 = to_unidirected(g) - @test g2.num_nodes == g.num_nodes - @test g2.num_edges == 4 - @test !has_multi_edges(g2) - - s2, t2 = edge_index(g2) - w2 = get_edge_weight(g2) - @test s2 == [1, 2, 3, 4] - @test t2 == [2, 3, 4, 4] - @test w2 == [1, 2, 3.5, 5] - @test g2.edata.e == [10.0, 20.0, 35.0, 50.0] -end end - -@testset "Graphs.Graph from GNNGraph" begin - g = rand_graph(10, 20, graph_type = GRAPH_T) - - G = Graphs.Graph(g) - @test nv(G) == g.num_nodes - @test ne(G) == g.num_edges ÷ 2 - - DG = Graphs.DiGraph(g) - @test nv(DG) == g.num_nodes - @test ne(DG) == g.num_edges -end - -@testset "random_walk_pe" begin - s = [1, 2, 2, 3] - t = [2, 1, 3, 2] - ndata = [-1, 0, 1] - g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata) - output = random_walk_pe(g, 3) - @test output == [0.0 0.0 0.0 - 0.5 1.0 0.5 - 0.0 0.0 0.0] -end - -@testset "HeteroGraphs" begin - @testset "batch" begin - gs = [rand_bipartite_heterograph((10, 15), 20) for _ in 1:5] - g = MLUtils.batch(gs) - @test g.num_nodes[:A] == 50 - @test g.num_nodes[:B] == 75 - @test g.num_edges[(:A,:to,:B)] == 100 - @test g.num_edges[(:B,:to,:A)] == 100 - @test g.num_graphs == 5 - @test g.graph_indicator == Dict(:A => vcat([fill(i, 10) for i in 1:5]...), - :B => vcat([fill(i, 15) for i in 1:5]...)) - - for gi in gs - gi.ndata[:A].x = ones(2, 10) - gi.ndata[:A].y = zeros(10) - gi.edata[(:A,:to,:B)].e = fill(2, 20) - gi.gdata.u = 7 - end - g = MLUtils.batch(gs) - @test g.ndata[:A].x == ones(2, 50) - @test g.ndata[:A].y == zeros(50) - @test g.edata[(:A,:to,:B)].e == fill(2, 100) - @test g.gdata.u == fill(7, 5) - - # Allow for wider eltype - g = MLUtils.batch(GNNHeteroGraph[g for g in gs]) - @test g.ndata[:A].x == ones(2, 50) - @test g.ndata[:A].y == zeros(50) - @test g.edata[(:A,:to,:B)].e == fill(2, 100) - @test g.gdata.u == fill(7, 5) - end - - @testset "batch non-similar edge types" begin - gs = [rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)), - rand_heterograph((:A => 10, :B => 15), ((:A, :to1, :B) => 5, (:B, :to2, :B) => 16)), - rand_heterograph((:B => 15, :C => 5), ((:C, :to1, :B) => 5, (:B, :to2, :C) => 21)), - rand_heterograph((:A => 10, :B => 10, :C => 10), ((:A, :to1, :C) => 5, (:A, :to1, :B) => 5)), - rand_heterograph((:C => 20), ((:C, :to3, :C) => 10)) - ] - g = MLUtils.batch(gs) - - @test g.num_nodes[:A] == 10 + 10 + 10 - @test g.num_nodes[:B] == 14 + 15 + 15 + 10 - @test g.num_nodes[:C] == 5 + 10 + 20 - @test g.num_edges[(:A,:to1,:A)] == 5 - @test g.num_edges[(:A,:to1,:B)] == 20 + 5 + 5 - @test g.num_edges[(:A,:to1,:C)] == 5 - - @test g.num_edges[(:B,:to2,:B)] == 16 - @test g.num_edges[(:B,:to2,:C)] == 21 - - @test g.num_edges[(:C,:to1,:B)] == 5 - @test g.num_edges[(:C,:to3,:C)] == 10 - @test length(keys(g.num_edges)) == 7 - @test g.num_graphs == 5 - - function ndata_if_key(g, key, subkey, value) - if haskey(g.ndata, key) - g.ndata[key][subkey] = reduce(hcat, fill(value, g.num_nodes[key])) - end - end - - function edata_if_key(g, key, subkey, value) - if haskey(g.edata, key) - g.edata[key][subkey] = reduce(hcat, fill(value, g.num_edges[key])) - end - end - - for gi in gs - ndata_if_key(gi, :A, :x, [0]) - ndata_if_key(gi, :A, :y, ones(2)) - ndata_if_key(gi, :B, :x, ones(3)) - ndata_if_key(gi, :C, :y, zeros(4)) - edata_if_key(gi, (:A,:to1,:B), :x, [0]) - gi.gdata.u = 7 - end - - g = MLUtils.batch(gs) - - @test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10)) - @test g.ndata[:A].y == ones(2, 10 + 10 + 10) - @test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10) - @test g.ndata[:C].y == zeros(4, 5 + 10 + 20) - - @test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5)) - - @test g.gdata.u == fill(7, 5) - - # Allow for wider eltype - g = MLUtils.batch(GNNHeteroGraph[g for g in gs]) - @test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10)) - @test g.ndata[:A].y == ones(2, 10 + 10 + 10) - @test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10) - @test g.ndata[:C].y == zeros(4, 5 + 10 + 20) - - @test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5)) - - @test g.gdata.u == fill(7, 5) - end - - @testset "add_edges" begin - hg = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) - hg = add_edges(hg, (:B,:to,:A), [1, 1], [1,2]) - @test hg.num_edges == Dict((:A,:to,:B) => 4, (:B,:to,:A) => 2) - @test has_edge(hg, (:B,:to,:A), 1, 1) - @test has_edge(hg, (:B,:to,:A), 1, 2) - @test !has_edge(hg, (:B,:to,:A), 2, 1) - @test !has_edge(hg, (:B,:to,:A), 2, 2) - - @testset "new nodes" begin - hg = rand_bipartite_heterograph((2, 2), 3) - hg = add_edges(hg, (:C,:rel,:B) => ([1, 3], [1,2])) - @test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3) - @test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2) - s, t = edge_index(hg, (:C,:rel,:B)) - @test s == [1, 3] - @test t == [1, 2] - - hg = add_edges(hg, (:D,:rel,:F) => ([1, 3], [1,2])) - @test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3, :D => 3, :F => 2) - @test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2, (:D,:rel,:F) => 2) - s, t = edge_index(hg, (:D,:rel,:F)) - @test s == [1, 3] - @test t == [1, 2] - end - - @testset "also add weights" begin - hg = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7], [0.1, 0.2, 0.3, 0.4])) - hgnew = add_edges(hg, (:user, :like, :actor) => ([1, 2], [3, 4], [0.5, 0.6])) - @test hgnew.num_nodes[:user] == 3 - @test hgnew.num_nodes[:movie] == 13 - @test hgnew.num_nodes[:actor] == 4 - @test hgnew.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 2) - @test get_edge_weight(hgnew, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4] - @test get_edge_weight(hgnew, (:user, :like, :actor)) == [0.5, 0.6] - - hgnew2 = add_edges(hgnew, (:user, :like, :actor) => ([6, 7], [8, 10], [0.7, 0.8])) - @test hgnew2.num_nodes[:user] == 7 - @test hgnew2.num_nodes[:movie] == 13 - @test hgnew2.num_nodes[:actor] == 10 - @test hgnew2.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 4) - @test get_edge_weight(hgnew2, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4] - @test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8] - end - end - - @testset "add self-loops heterographs" begin - g = rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)) - # Case in which haskey(g.graph, edge_t) passes - g = add_self_loops(g, (:A, :to1, :A)) - - @test g.num_edges[(:A, :to1, :A)] == 5 + 10 - @test g.num_edges[(:A, :to1, :B)] == 20 - # This test should not use length(keys(g.num_edges)) since that may be undefined behavior - @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 2 - - # Case in which haskey(g.graph, edge_t) fails - g = add_self_loops(g, (:A, :to3, :A)) - - @test g.num_edges[(:A, :to1, :A)] == 5 + 10 - @test g.num_edges[(:A, :to1, :B)] == 20 - @test g.num_edges[(:A, :to3, :A)] == 10 - @test sum(1 for k in keys(g.num_edges) if g.num_edges[k] != 0) == 3 - - # Case with edge weights - g = GNNHeteroGraph(Dict((:A, :to1, :A) => ([1, 2, 3], [3, 2, 1], [2, 2, 2]), (:A, :to2, :B) => ([1, 4, 5], [1, 2, 3]))) - n = g.num_nodes[:A] - g = add_self_loops(g, (:A, :to1, :A)) - - @test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n)) - end -end - -@testset "ppr_diffusion" begin - if GRAPH_T == :coo - s = [1, 1, 2, 3] - t = [2, 3, 4, 5] - eweights = [0.1, 0.2, 0.3, 0.4] - - g = GNNGraph(s, t, eweights) - - g_new = ppr_diffusion(g) - w_new = get_edge_weight(g_new) - - check_ew = Float32[0.012749999 - 0.025499998 - 0.038249996 - 0.050999995] - - @test w_new ≈ check_ew - end -end -[.\GNNGraphs\test\utils.jl] -@testset "edge encoding/decoding" begin - # not is_bidirected - n = 5 - s = [1, 1, 2, 3, 3, 4, 5] - t = [1, 3, 1, 1, 2, 5, 5] - - # directed=true - idx, maxid = GNNGraphs.edge_encoding(s, t, n) - @test maxid == n^2 - @test idx == [1, 3, 6, 11, 12, 20, 25] - - sdec, tdec = GNNGraphs.edge_decoding(idx, n) - @test sdec == s - @test tdec == t - - n1, m1 = 10, 30 - g = rand_graph(n1, m1) - s1, t1 = edge_index(g) - idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1) - sdec, tdec = GNNGraphs.edge_decoding(idx, n1) - @test sdec == s1 - @test tdec == t1 - - # directed=false - idx, maxid = GNNGraphs.edge_encoding(s, t, n, directed = false) - @test maxid == n * (n + 1) ÷ 2 - @test idx == [1, 3, 2, 3, 7, 14, 15] - - mask = s .> t - snew = copy(s) - tnew = copy(t) - snew[mask] .= t[mask] - tnew[mask] .= s[mask] - sdec, tdec = GNNGraphs.edge_decoding(idx, n, directed = false) - @test sdec == snew - @test tdec == tnew - - n1, m1 = 6, 8 - g = rand_graph(n1, m1) - s1, t1 = edge_index(g) - idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1, directed = false) - sdec, tdec = GNNGraphs.edge_decoding(idx, n1, directed = false) - mask = s1 .> t1 - snew = copy(s1) - tnew = copy(t1) - snew[mask] .= t1[mask] - tnew[mask] .= s1[mask] - @test sdec == snew - @test tdec == tnew - - @testset "directed=false, self_loops=false" begin - n = 5 - edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] - s = [e[1] for e in edges] - t = [e[2] for e in edges] - g = GNNGraph(s, t) - idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) - @test idxmax == n * (n - 1) ÷ 2 - @test idx == 1:idxmax - - snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) - @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] - @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] - end - - @testset "directed=false, self_loops=false" begin - n = 5 - edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] - s = [e[1] for e in edges] - t = [e[2] for e in edges] - - idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) - @test idxmax == n * (n - 1) ÷ 2 - @test idx == 1:idxmax - - snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) - @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] - @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] - end - - @testset "directed=true, self_loops=false" begin - n = 5 - edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] - s = [e[1] for e in edges] - t = [e[2] for e in edges] - - idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=true, self_loops=false) - @test idxmax == n^2 - n - @test idx == [1, 9, 3, 4, 6, 7, 8, 11, 12, 16] - snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=true, self_loops=false) - @test snew == s - @test tnew == t - end -end - -@testset "color_refinement" begin - rng = MersenneTwister(17) - g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) - x0 = ones(Int, 10) - x, ncolors, niters = color_refinement(g, x0) - @test ncolors == 8 - @test niters == 2 - @test x == [4, 5, 6, 7, 8, 5, 8, 9, 10, 11] - - x2, _, _ = color_refinement(g) - @test x2 == x -end -[.\GNNGraphs\test\ext\SimpleWeightedGraphs.jl] -@testset "simple_weighted_graph" begin - srcs = [1, 2, 1] - dsts = [2, 3, 3] - wts = [0.5, 0.8, 2.0] - g = SimpleWeightedGraph(srcs, dsts, wts) - gd = SimpleWeightedDiGraph(srcs, dsts, wts) - gnn_g = GNNGraph(g) - gnn_gd = GNNGraph(gd) - @test get_edge_weight(gnn_g) == [0.5, 2, 0.5, 0.8, 2.0, 0.8] - @test get_edge_weight(gnn_gd) == [0.5, 2, 0.8] -end - -[.\GNNlib\ext\GNNlibCUDAExt.jl] -module GNNlibCUDAExt - -using CUDA -using Random, Statistics, LinearAlgebra -using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj -using GNNGraphs: GNNGraph, COO_T, SPARSE_T - -###### PROPAGATE SPECIALIZATIONS #################### - -## COPY_XJ - -## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e) - propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) -end - -## E_MUL_XJ - -## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e::AbstractVector) - propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) -end - -## W_MUL_XJ - -## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e::Nothing) - propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) -end - -# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) -# A = adjacency_matrix(g, weighted=false) -# D = compute_degree(A) -# return xj * A * D -# end - -# # Zygote bug. Error with sparse matrix without nograd -# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) - -# Flux.Zygote.@nograd compute_degree - -end #module - -[.\GNNlib\src\GNNlib.jl] -module GNNlib - -using Statistics: mean -using LinearAlgebra, Random -using MLUtils: zeros_like -using NNlib -using NNlib: scatter, gather -using DataStructures: nlargest -using ChainRulesCore: @non_differentiable -using GNNGraphs -using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, - check_num_nodes, check_num_edges, - EType, NType # for heteroconvs - -include("utils.jl") -export reduce_nodes, - reduce_edges, - softmax_nodes, - softmax_edges, - broadcast_nodes, - broadcast_edges, - softmax_edge_neighbors - -include("msgpass.jl") -export apply_edges, - aggregate_neighbors, - propagate, - copy_xj, - copy_xi, - xi_dot_xj, - xi_sub_xj, - xj_sub_xi, - e_mul_xj, - w_mul_xj - -## The following methods are defined but not exported - -include("layers/basic.jl") -export dot_decoder - -include("layers/conv.jl") -export agnn_conv, - cg_conv, - cheb_conv, - d_conv, - edge_conv, - egnn_conv, - gat_conv, - gatv2_conv, - gated_graph_conv, - gcn_conv, - gin_conv, - gmm_conv, - graph_conv, - megnet_conv, - nn_conv, - res_gated_graph_conv, - sage_conv, - sg_conv, - tag_conv, - transformer_conv - -include("layers/temporalconv.jl") -export a3tgcn_conv - -include("layers/pool.jl") -export global_pool, - global_attention_pool, - set2set_pool, - topk_pool, - topk_index - -# include("layers/heteroconv.jl") # no functional part at the moment - -end #module - -[.\GNNlib\src\msgpass.jl] -""" - propagate(fmsg, g, aggr; [xi, xj, e]) - propagate(fmsg, g, aggr xi, xj, e=nothing) - -Performs message passing on graph `g`. Takes care of materializing the node features on each edge, -applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` -(depending on the return value of `fmsg`, an array or a named tuple of -arrays with last dimension's size `g.num_nodes`). - -It can be decomposed in two steps: - -```julia -m = apply_edges(fmsg, g, xi, xj, e) -m̄ = aggregate_neighbors(g, aggr, m) -``` - -GNN layers typically call `propagate` in their forward pass, -providing as input `f` a closure. - -# Arguments - -- `g`: A `GNNGraph`. -- `xi`: An array or a named tuple containing arrays whose last dimension's size - is `g.num_nodes`. It will be appropriately materialized on the - target node of each edge (see also [`edge_index`](@ref)). -- `xj`: As `xj`, but to be materialized on edges' sources. -- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `fmsg`: A generic function that will be passed over to [`apply_edges`](@ref). - Has to take as inputs the edge-materialized `xi`, `xj`, and `e` - (arrays or named tuples of arrays whose last dimension' size is the size of - a batch of edges). Its output has to be an array or a named tuple of arrays - with the same batch size. If also `layer` is passed to propagate, - the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` - instead of `fmsg(xi, xj, e)`. -- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. - -# Examples - -```julia -using GraphNeuralNetworks, Flux - -struct GNNConv <: GNNLayer - W - b - σ -end - -Flux.@layer GNNConv - -function GNNConv(ch::Pair{Int,Int}, σ=identity) - in, out = ch - W = Flux.glorot_uniform(out, in) - b = zeros(Float32, out) - GNNConv(W, b, σ) -end - -function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix) - message(xi, xj, e) = l.W * xj - m̄ = propagate(message, g, +, xj=x) - return l.σ.(m̄ .+ l.bias) -end - -l = GNNConv(10 => 20) -l(g, x) -``` - -See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). -""" -function propagate end - -function propagate(f, g::AbstractGNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) - propagate(f, g, aggr, xi, xj, e) -end - -function propagate(f, g::AbstractGNNGraph, aggr, xi, xj, e = nothing) - m = apply_edges(f, g, xi, xj, e) - m̄ = aggregate_neighbors(g, aggr, m) - return m̄ -end - -## APPLY EDGES - -""" - apply_edges(fmsg, g; [xi, xj, e]) - apply_edges(fmsg, g, xi, xj, e=nothing) - -Returns the message from node `j` to node `i` applying -the message function `fmsg` on the edges in graph `g`. -In the message-passing scheme, the incoming messages -from the neighborhood of `i` will later be aggregated -in order to update the features of node `i` (see [`aggregate_neighbors`](@ref)). - -The function `fmsg` operates on batches of edges, therefore -`xi`, `xj`, and `e` are tensors whose last dimension -is the batch size, or can be named tuples of -such tensors. - -# Arguments - -- `g`: An `AbstractGNNGraph`. -- `xi`: An array or a named tuple containing arrays whose last dimension's size - is `g.num_nodes`. It will be appropriately materialized on the - target node of each edge (see also [`edge_index`](@ref)). -- `xj`: As `xi`, but now to be materialized on each edge's source node. -- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `fmsg`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. - These are arrays (or named tuples of arrays) whose last dimension' size is the size of - a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) - with the same batch size. If also `layer` is passed to propagate, - the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` - instead of `fmsg(xi, xj, e)`. - -See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). -""" -function apply_edges end - -function apply_edges(f, g::AbstractGNNGraph; xi = nothing, xj = nothing, e = nothing) - apply_edges(f, g, xi, xj, e) -end - -function apply_edges(f, g::AbstractGNNGraph, xi, xj, e = nothing) - check_num_nodes(g, (xj, xi)) - check_num_edges(g, e) - s, t = edge_index(g) # for heterographs, errors if more than one edge type - xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges) - xj = GNNGraphs._gather(xj, s) - m = f(xi, xj, e) - return m -end - -## AGGREGATE NEIGHBORS -@doc raw""" - aggregate_neighbors(g, aggr, m) - -Given a graph `g`, edge features `m`, and an aggregation -operator `aggr` (e.g `+, min, max, mean`), returns the new node -features -```math -\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i} -``` - -Neighborhood aggregation is the second step of [`propagate`](@ref), -where it comes after [`apply_edges`](@ref). -""" -function aggregate_neighbors(g::GNNGraph, aggr, m) - check_num_edges(g, m) - s, t = edge_index(g) - return GNNGraphs._scatter(aggr, m, t, g.num_nodes) -end - -function aggregate_neighbors(g::GNNHeteroGraph, aggr, m) - check_num_edges(g, m) - s, t = edge_index(g) - dest_node_t = only(g.etypes)[3] - return GNNGraphs._scatter(aggr, m, t, g.num_nodes[dest_node_t]) -end - -### MESSAGE FUNCTIONS ### -""" - copy_xj(xi, xj, e) = xj -""" -copy_xj(xi, xj, e) = xj - -""" - copy_xi(xi, xj, e) = xi -""" -copy_xi(xi, xj, e) = xi - -""" - xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1) -""" -xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims = 1) - -""" - xi_sub_xj(xi, xj, e) = xi .- xj -""" -xi_sub_xj(xi, xj, e) = xi .- xj - -""" - xj_sub_xi(xi, xj, e) = xj .- xi -""" -xj_sub_xi(xi, xj, e) = xj .- xi - -""" - e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj - -Reshape `e` into broadcast compatible shape with `xj` -(by prepending singleton dimensions) then perform -broadcasted multiplication. -""" -function e_mul_xj(xi, xj::AbstractArray{Tj, Nj}, - e::AbstractArray{Te, Ne}) where {Tj, Te, Nj, Ne} - @assert Ne <= Nj - e = reshape(e, ntuple(_ -> 1, Nj - Ne)..., size(e)...) - return e .* xj -end - -""" - w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj - -Similar to [`e_mul_xj`](@ref) but specialized on scalar edge features (weights). -""" -w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights - -function w_mul_xj(xi, xj::AbstractArray{Tj, Nj}, w::AbstractVector) where {Tj, Nj} - w = reshape(w, ntuple(_ -> 1, Nj - 1)..., length(w)) - return w .* xj -end - -###### PROPAGATE SPECIALIZATIONS #################### -## See also the methods defined in the package extensions. - -## COPY_XJ - -function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e) - A = adjacency_matrix(g, weighted = false) - return xj * A -end - -## E_MUL_XJ - -# for weighted convolution -function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, - e::AbstractVector) - g = set_edge_weight(g, e) - A = adjacency_matrix(g, weighted = true) - return xj * A -end - - -## W_MUL_XJ - -# for weighted convolution -function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, - e::Nothing) - A = adjacency_matrix(g, weighted = true) - return xj * A -end - - -# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) -# A = adjacency_matrix(g, weighted=false) -# D = compute_degree(A) -# return xj * A * D -# end - -# # Zygote bug. Error with sparse matrix without nograd -# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) - -# Flux.Zygote.@nograd compute_degree - -[.\GNNlib\src\utils.jl] -ofeltype(x, y) = convert(float(eltype(x)), y) - -""" - reduce_nodes(aggr, g, x) - -For a batched graph `g`, return the graph-wise aggregation of the node -features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. -The returned array will have last dimension `g.num_graphs`. - -See also: [`reduce_edges`](@ref). -""" -function reduce_nodes(aggr, g::GNNGraph, x) - @assert size(x)[end] == g.num_nodes - indexes = graph_indicator(g) - return NNlib.scatter(aggr, x, indexes) -end - -""" - reduce_nodes(aggr, indicator::AbstractVector, x) - -Return the graph-wise aggregation of the node features `x` given the -graph indicator `indicator`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. - -See also [`graph_indicator`](@ref). -""" -function reduce_nodes(aggr, indicator::AbstractVector, x) - return NNlib.scatter(aggr, x, indicator) -end - -""" - reduce_edges(aggr, g, e) - -For a batched graph `g`, return the graph-wise aggregation of the edge -features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. -The returned array will have last dimension `g.num_graphs`. -""" -function reduce_edges(aggr, g::GNNGraph, e) - @assert size(e)[end] == g.num_edges - s, t = edge_index(g) - indexes = graph_indicator(g)[s] - return NNlib.scatter(aggr, e, indexes) -end - -""" - softmax_nodes(g, x) - -Graph-wise softmax of the node features `x`. -""" -function softmax_nodes(g::GNNGraph, x) - @assert size(x)[end] == g.num_nodes - gi = graph_indicator(g) - max_ = gather(scatter(max, x, gi), gi) - num = exp.(x .- max_) - den = reduce_nodes(+, g, num) - den = gather(den, gi) - return num ./ den -end - -""" - softmax_edges(g, e) - -Graph-wise softmax of the edge features `e`. -""" -function softmax_edges(g::GNNGraph, e) - @assert size(e)[end] == g.num_edges - gi = graph_indicator(g, edges = true) - max_ = gather(scatter(max, e, gi), gi) - num = exp.(e .- max_) - den = reduce_edges(+, g, num) - den = gather(den, gi) - return num ./ (den .+ eps(eltype(e))) -end - -@doc raw""" - softmax_edge_neighbors(g, e) - -Softmax over each node's neighborhood of the edge features `e`. - -```math -\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} - {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. -``` -""" -function softmax_edge_neighbors(g::AbstractGNNGraph, e) - if g isa GNNHeteroGraph - for (key, value) in g.num_edges - @assert size(e)[end] == value - end - else - @assert size(e)[end] == g.num_edges - end - s, t = edge_index(g) - max_ = gather(scatter(max, e, t), t) - num = exp.(e .- max_) - den = gather(scatter(+, num, t), t) - return num ./ den -end - -""" - broadcast_nodes(g, x) - -Graph-wise broadcast array `x` of size `(*, g.num_graphs)` -to size `(*, g.num_nodes)`. -""" -function broadcast_nodes(g::GNNGraph, x) - @assert size(x)[end] == g.num_graphs - gi = graph_indicator(g) - return gather(x, gi) -end - -""" - broadcast_edges(g, x) - -Graph-wise broadcast array `x` of size `(*, g.num_graphs)` -to size `(*, g.num_edges)`. -""" -function broadcast_edges(g::GNNGraph, x) - @assert size(x)[end] == g.num_graphs - gi = graph_indicator(g, edges = true) - return gather(x, gi) -end - -expand_srcdst(g::AbstractGNNGraph, x) = throw(ArgumentError("Invalid input type, expected matrix or tuple of matrices.")) -expand_srcdst(g::AbstractGNNGraph, x::AbstractMatrix) = (x, x) -expand_srcdst(g::AbstractGNNGraph, x::Tuple{<:AbstractMatrix, <:AbstractMatrix}) = x - -# Replacement for Base.Fix1 to allow for multiple arguments -struct Fix1{F,X} - f::F - x::X -end - -(f::Fix1)(y...) = f.f(f.x, y...) - -[.\GNNlib\src\layers\basic.jl] -function dot_decoder(g, x) - return apply_edges(xi_dot_xj, g, xi = x, xj = x) -end - -[.\GNNlib\src\layers\conv.jl] -####################### GCNConv ###################################### - -check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = - throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) - -function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector) - if length(edge_weight) !== g.num_edges - throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) - end -end - -check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing - -function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where - {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - check_gcnconv_input(g, edge_weight) - if conv_weight === nothing - weight = l.weight - else - weight = conv_weight - if size(weight) != size(l.weight) - throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) - end - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - # Pad weights with ones - # TODO for ADJMAT_T the new edges are not generally at the end - edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(weight) - if Dout < Din && !(g isa GNNHeteroGraph) - # multiply before convolution if it is more convenient, otherwise multiply after - # (this works only for homogenous graph) - x = weight * x - end - - xj, xi = expand_srcdst(g, x) # expand only after potential multiplication - T = eltype(xi) - - if g isa GNNHeteroGraph - din = degree(g, g.etypes[1], T; dir = :in) - dout = degree(g, g.etypes[1], T; dir = :out) - - cout = norm_fn(dout) - cin = norm_fn(din) - else - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) - end - cin = cout = norm_fn(d) - end - xj = xj .* cout' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = xj) - else - x = propagate(copy_xj, g, +, xj = xj) - end - x = x .* cin' - if Dout >= Din || g isa GNNHeteroGraph - x = weight * x - end - return l.σ.(x .+ l.bias) -end - -# when we also have edge_weight we need to convert the graph to COO -function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where - {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - g = GNNGraph(g, graph_type = :coo) - return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) -end - -####################### ChebConv ###################################### - -function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T} - check_num_nodes(g, X) - @assert size(X, 1) == size(l.weight, 2) "Input feature size must match input channel size." - - L̃ = scaled_laplacian(g, eltype(X)) - - Z_prev = X - Z = X * L̃ - Y = view(l.weight, :, :, 1) * Z_prev - Y = Y .+ view(l.weight, :, :, 2) * Z - for k in 3:(l.k) - Z, Z_prev = 2 * Z * L̃ - Z_prev, Z - Y = Y .+ view(l.weight, :, :, k) * Z - end - return Y .+ l.bias -end - -####################### GraphConv ###################################### - -function graph_conv(l, g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.weight1 * xi .+ l.weight2 * m - return l.σ.(x .+ l.bias) -end - -####################### GATConv ###################################### - -function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" - @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" - - xj, xi = expand_srcdst(g, x) - - if l.add_self_loops - @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = add_self_loops(g) - end - - _, chout = l.channel - heads = l.heads - - Wxi = Wxj = l.dense_x(xj) - Wxi = Wxj = reshape(Wxj, chout, heads, :) - - if xi !== xj - Wxi = l.dense_x(xi) - Wxi = reshape(Wxi, chout, heads, :) - end - - # a hand-written message passing - message = Fix1(gat_message, l) - m = apply_edges(message, g, Wxi, Wxj, e) - α = softmax_edge_neighbors(g, m.logα) - α = dropout(α, l.dropout) - β = α .* m.Wxj - x = aggregate_neighbors(g, +, β) - - if !l.concat - x = mean(x, dims = 2) - end - x = reshape(x, :, size(x, 3)) # return a matrix - x = l.σ.(x .+ l.bias) - - return x -end - -function gat_message(l, Wxi, Wxj, e) - _, chout = l.channel - heads = l.heads - - if e === nothing - Wxx = vcat(Wxi, Wxj) - else - We = l.dense_e(e) - We = reshape(We, chout, heads, :) # chout × nheads × nnodes - Wxx = vcat(Wxi, Wxj, We) - end - aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges - slope = convert(eltype(aWW), l.negative_slope) - logα = leakyrelu.(aWW, slope) - return (; logα, Wxj) -end - -####################### GATv2Conv ###################################### - -function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" - @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" - - xj, xi = expand_srcdst(g, x) - - if l.add_self_loops - @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = add_self_loops(g) - end - _, out = l.channel - heads = l.heads - - Wxi = reshape(l.dense_i(xi), out, heads, :) # out × heads × nnodes - Wxj = reshape(l.dense_j(xj), out, heads, :) # out × heads × nnodes - - message = Fix1(gatv2_message, l) - m = apply_edges(message, g, Wxi, Wxj, e) - α = softmax_edge_neighbors(g, m.logα) - α = dropout(α, l.dropout) - β = α .* m.Wxj - x = aggregate_neighbors(g, +, β) - - if !l.concat - x = mean(x, dims = 2) - end - x = reshape(x, :, size(x, 3)) - x = l.σ.(x .+ l.bias) - return x -end - -function gatv2_message(l, Wxi, Wxj, e) - _, out = l.channel - heads = l.heads - - Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" - if e !== nothing - Wx += reshape(l.dense_e(e), out, heads, :) - end - slope = convert(eltype(Wx), l.negative_slope) - logα = sum(l.a .* leakyrelu.(Wx, slope), dims = 1) # 1 × heads × nedges - return (; logα, Wxj) -end - -####################### GatedGraphConv ###################################### - -function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) - check_num_nodes(g, x) - m, n = size(x) - @assert m <= l.dims "number of input features must be less or equal to output features." - if m < l.dims - xpad = zeros_like(x, (l.dims - m, n)) - x = vcat(x, xpad) - end - h = x - for i in 1:(l.num_layers) - m = view(l.weight, :, :, i) * h - m = propagate(copy_xj, g, l.aggr; xj = m) - # in gru forward, hidden state is first argument, input is second - h, _ = l.gru(h, m) - end - return h -end - -####################### EdgeConv ###################################### - -function edge_conv(l, g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - message = Fix1(edge_conv_message, l) - x = propagate(message, g, l.aggr; xi, xj, e = nothing) - return x -end - -edge_conv_message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) - -####################### GINConv ###################################### - -function gin_conv(l, g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - m = propagate(copy_xj, g, l.aggr, xj = xj) - - return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) -end - -####################### NNConv ###################################### - -function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) - check_num_nodes(g, x) - message = Fix1(nn_conv_message, l) - m = propagate(message, g, l.aggr, xj = x, e = e) - return l.σ.(l.weight * x .+ m .+ l.bias) -end - -function nn_conv_message(l, xi, xj, e) - nin, nedges = size(xj) - W = reshape(l.nn(e), (:, nin, nedges)) - xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul - m = NNlib.batched_mul(W, xj) - return reshape(m, :, nedges) -end - -####################### SAGEConv ###################################### - -function sage_conv(l, g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) - return x -end - -####################### ResGatedConv ###################################### - -function res_gated_graph_conv(l, g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx - - Ax = l.A * xi - Bx = l.B * xj - Vx = l.V * xj - - m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx)) - - return l.σ.(l.U * xi .+ m .+ l.bias) -end - -####################### CGConv ###################################### - -function cg_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - if e !== nothing - check_num_edges(g, e) - end - - message = Fix1(cg_message, l) - m = propagate(message, g, +, xi = xi, xj = xj, e = e) - - if l.residual - if size(x, 1) == size(m, 1) - m += x - else - @warn "number of output features different from number of input features, residual not applied." - end - end - - return m -end - -function cg_message(l, xi, xj, e) - if e !== nothing - z = vcat(xi, xj, e) - else - z = vcat(xi, xj) - end - return l.dense_f(z) .* l.dense_s(z) -end - -####################### AGNNConv ###################################### - -function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) - check_num_nodes(g, x) - if l.add_self_loops - g = add_self_loops(g) - end - - xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) - cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) - α = softmax_edge_neighbors(g, l.β .* cos_dist) - - x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj - end - - return x -end - -####################### MegNetConv ###################################### - -function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) - check_num_nodes(g, x) - - ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e - l.ϕe(vcat(xi, xj, e)) - end - - xᵉ = aggregate_neighbors(g, l.aggr, ē) - - x̄ = l.ϕv(vcat(x, xᵉ)) - - return x̄, ē -end - -####################### GMMConv ###################################### - -function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) - (nin, ein), out = l.ch #Notational Simplicity - - @assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)" - - num_edges = g.num_edges - w = reshape(e, (ein, 1, num_edges)) - mu = reshape(l.mu, (ein, l.K, 1)) - - w = @. ((w - mu)^2) / 2 - w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1)) - w = exp.(sum(w, dims = 1)) # (1, K, num_edge) - - xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes) - - m = propagate(e_mul_xj, g, mean, xj = xj, e = w) - m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) - - m = l.σ(m .+ l.bias) - - if l.residual - if size(x, 1) == size(m, 1) - m += x - else - @warn "Residual not applied : output feature is not equal to input_feature" - end - end - - return m -end - -####################### SGCConv ###################################### - -# this layer is not stable enough to be supported by GNNHeteroGraph type -# due to it's looping mechanism -function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} - @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" - - if edge_weight !== nothing - @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if Dout < Din - x = l.weight * x - end - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) - end - c = 1 ./ sqrt.(d) - for iter in 1:(l.k) - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) - else - x = propagate(copy_xj, g, +, xj = x) - end - x = x .* c' - end - if Dout >= Din - x = l.weight * x - end - return (x .+ l.bias) -end - -# when we also have edge_weight we need to convert the graph to COO -function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) - g = GNNGraph(g; graph_type=:coo) - return sgc_conv(l, g, x, edge_weight) -end - -####################### EGNNGConv ###################################### - -function egnn_conv(l, g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) - if l.num_features.edge > 0 - @assert e!==nothing "Edge features must be provided." - end - @assert size(h, 1)==l.num_features.in "Input features must match layer input size." - - x_diff = apply_edges(xi_sub_xj, g, x, x) - sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) - x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) - - message = Fix1(egnn_message, l) - msg = apply_edges(message, g, - xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) - h_aggr = aggregate_neighbors(g, +, msg.h) - x_aggr = aggregate_neighbors(g, mean, msg.x) - - hnew = l.ϕh(vcat(h, h_aggr)) - if l.residual - h = h .+ hnew - else - h = hnew - end - x = x .+ x_aggr - return h, x -end - -function egnn_message(l, xi, xj, e) - if l.num_features.edge > 0 - f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) - else - f = vcat(xi.h, xj.h, e.sqnorm_xdiff) - end - - msg_h = l.ϕe(f) - msg_x = l.ϕx(msg_h) .* e.x_diff - return (; x = msg_x, h = msg_h) -end - -######################## SGConv ###################################### - -# this layer is not stable enough to be supported by GNNHeteroGraph type -# due to it's looping mechanism -function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} - @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" - - if edge_weight !== nothing - @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if Dout < Din - x = l.weight * x - end - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) - end - c = 1 ./ sqrt.(d) - for iter in 1:(l.k) - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) - else - x = propagate(copy_xj, g, +, xj = x) - end - x = x .* c' - end - if Dout >= Din - x = l.weight * x - end - return (x .+ l.bias) -end - -# when we also have edge_weight we need to convert the graph to COO -function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) - g = GNNGraph(g; graph_type=:coo) - return sg_conv(l, g, x, edge_weight) -end - -######################## TransformerConv ###################################### - -function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing} = nothing) - check_num_nodes(g, x) - - if l.add_self_loops - g = add_self_loops(g) - end - - out = l.channels[2] - heads = l.heads - W1x = !isnothing(l.W1) ? l.W1(x) : nothing - W2x = reshape(l.W2(x), out, heads, :) - W3x = reshape(l.W3(x), out, heads, :) - W4x = reshape(l.W4(x), out, heads, :) - W6e = !isnothing(l.W6) ? reshape(l.W6(e), out, heads, :) : nothing - - message_uij = Fix1(transformer_message_uij, l) - m = apply_edges(message_uij, g; xi = (; W3x), xj = (; W4x), e = (; W6e)) - α = softmax_edge_neighbors(g, m) - α_val = propagate(transformer_message_main, g, +; - xi = (; W3x), xj = (; W2x), e = (; W6e, α)) - - h = α_val - if l.concat - h = reshape(h, out * heads, :) # concatenate heads - else - h = mean(h, dims = 2) # average heads - h = reshape(h, out, :) - end - - if !isnothing(W1x) # root_weight - if !isnothing(l.W5) # gating - β = l.W5(vcat(h, W1x, h .- W1x)) - h = β .* W1x + (1.0f0 .- β) .* h - else - h += W1x - end - end - - if l.skip_connection - @assert size(h, 1)==size(x, 1) "In-channels must correspond to out-channels * heads if skip_connection is used" - h += x - end - if !isnothing(l.BN1) - h = l.BN1(h) - end - - if !isnothing(l.FF) - h1 = h - h = l.FF(h) - if l.skip_connection - h += h1 - end - if !isnothing(l.BN2) - h = l.BN2(h) - end - end - - return h -end - -# TODO remove l dependence -function transformer_message_uij(l, xi, xj, e) - key = xj.W4x - if !isnothing(e.W6e) - key += e.W6e - end - uij = sum(xi.W3x .* key, dims = 1) ./ l.sqrt_out - return uij -end - -function transformer_message_main(xi, xj, e) - val = xj.W2x - if !isnothing(e.W6e) - val += e.W6e - end - return e.α .* val -end - - -######################## TAGConv ###################################### - -function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} - @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" - - if edge_weight !== nothing - @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) - end - c = 1 ./ sqrt.(d) - - sum_pow = 0 - sum_total = 0 - for iter in 1:(l.k) - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) - else - x = propagate(copy_xj, g, +, xj = x) - end - x = x .* c' - - # On the first iteration, initialize sum_pow with the first propagated features - # On subsequent iterations, accumulate propagated features - if iter == 1 - sum_pow = x - sum_total = l.weight * sum_pow - else - sum_pow += x - # Weighted sum of features for each power of adjacency matrix - # This applies the weight matrix to the accumulated sum of propagated features - sum_total += l.weight * sum_pow - end - end - - return (sum_total .+ l.bias) -end - -# when we also have edge_weight we need to convert the graph to COO -function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) - g = GNNGraph(g; graph_type = :coo) - return l(g, x, edge_weight) -end - -######################## DConv ###################################### - -function d_conv(l, g::GNNGraph, x::AbstractMatrix) - #A = adjacency_matrix(g, weighted = true) - s, t = edge_index(g) - gt = GNNGraph(t, s, get_edge_weight(g)) - deg_out = degree(g; dir = :out) - deg_in = degree(g; dir = :in) - deg_out = Diagonal(deg_out) - deg_in = Diagonal(deg_in) - - h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x - - T0 = x - if l.k > 1 - # T1_in = T0 * deg_in * A' - #T1_out = T0 * deg_out' * A - T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') - T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) - h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out - end - for i in 2:l.k - T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) - T2_in = 2 * T2_in - T0 - T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') - T2_out = 2 * T2_out - T0 - h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out - T1_in = T2_in - T1_out = T2_out - end - return h .+ l.bias -end -[.\GNNlib\src\layers\pool.jl] - - -function global_pool(l, g::GNNGraph, x::AbstractArray) - return reduce_nodes(l.aggr, g, x) -end - -function global_attention_pool(l, g::GNNGraph, x::AbstractArray) - α = softmax_nodes(g, l.fgate(x)) - feats = α .* l.ffeat(x) - u = reduce_nodes(+, g, feats) - return u -end - -function topk_pool(t, X::AbstractArray) - y = t.p' * X / norm(t.p) - idx = topk_index(y, t.k) - t.Ã .= view(t.A, idx, idx) - X_ = view(X, :, idx) .* σ.(view(y, idx)') - return X_ -end - -function topk_index(y::AbstractVector, k::Int) - v = nlargest(k, y) - return collect(1:length(y))[y .>= v[end]] -end - -topk_index(y::Adjoint, k::Int) = topk_index(y', k) - -function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) - n_in = size(x, 1) - qstar = zeros_like(x, (2*n_in, g.num_graphs)) - for t in 1:l.num_iters - q = l.lstm(qstar) # [n_in, n_graphs] - qn = broadcast_nodes(g, q) # [n_in, n_nodes] - α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] - r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] - qstar = vcat(q, r) # [2*n_in, n_graphs] - end - return qstar -end - -[.\GNNlib\src\layers\temporalconv.jl] -function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) - h = a3tgcn.tgcn(g, x) - e = a3tgcn.dense1(h) - e = a3tgcn.dense2(e) - a = softmax(e, dims = 3) - c = sum(a .* h , dims = 3) - if length(size(c)) == 3 - c = dropdims(c, dims = 3) - end - return c -end - - -[.\GNNlib\test\msgpass_tests.jl] -@testitem "msgpass" setup=[SharedTestSetup] begin - #TODO test all graph types - GRAPH_T = :coo - in_channel = 10 - out_channel = 5 - num_V = 6 - num_E = 14 - T = Float32 - - adj = [0 1 0 0 0 0 - 1 0 0 1 1 1 - 0 0 0 0 0 1 - 0 1 0 0 1 0 - 0 1 0 1 0 1 - 0 1 1 0 1 0] - - X = rand(T, in_channel, num_V) - E = rand(T, in_channel, num_E) - - g = GNNGraph(adj, graph_type = GRAPH_T) - - @testset "propagate" begin - function message(xi, xj, e) - @test xi === nothing - @test e === nothing - ones(T, out_channel, size(xj, 2)) - end - - m = propagate(message, g, +, xj = X) - - @test size(m) == (out_channel, num_V) - - @testset "isolated nodes" begin - x1 = rand(1, 6) - g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) - y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) - @test size(y1) == (1, 6) - end - end - - @testset "apply_edges" begin - m = apply_edges(g, e = E) do xi, xj, e - @test xi === nothing - @test xj === nothing - ones(out_channel, size(e, 2)) - end - - @test m == ones(out_channel, num_E) - - # With NamedTuple input - m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e - @test xi === nothing - @test xj.b == 2 * xj.a - @test size(xj.a, 2) == size(xj.b, 2) == size(e, 2) - ones(out_channel, size(e, 2)) - end - - # NamedTuple output - m = apply_edges(g, e = E) do xi, xj, e - @test xi === nothing - @test xj === nothing - (; a = ones(out_channel, size(e, 2))) - end - - @test m.a == ones(out_channel, num_E) - - @testset "sizecheck" begin - x = rand(3, g.num_nodes - 1) - @test_throws AssertionError apply_edges(copy_xj, g, xj = x) - @test_throws AssertionError apply_edges(copy_xj, g, xi = x) - - x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1)) - @test_throws AssertionError apply_edges(copy_xj, g, xj = x) - @test_throws AssertionError apply_edges(copy_xj, g, xi = x) - - e = rand(3, g.num_edges - 1) - @test_throws AssertionError apply_edges(copy_xj, g, e = e) - end - end - - @testset "copy_xj" begin - n = 128 - A = sprand(n, n, 0.1) - Adj = map(x -> x > 0 ? 1 : 0, A) - X = rand(10, n) - - g = GNNGraph(A, ndata = X, graph_type = GRAPH_T) - - function spmm_copyxj_fused(g) - propagate(copy_xj, - g, +; xj = g.ndata.x) - end - - function spmm_copyxj_unfused(g) - propagate((xi, xj, e) -> xj, - g, +; xj = g.ndata.x) - end - - @test spmm_copyxj_unfused(g) ≈ X * Adj - @test spmm_copyxj_fused(g) ≈ X * Adj - end - - @testset "e_mul_xj and w_mul_xj for weighted conv" begin - n = 128 - A = sprand(n, n, 0.1) - Adj = map(x -> x > 0 ? 1 : 0, A) - X = rand(10, n) - - g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T) - - function spmm_unfused(g) - propagate((xi, xj, e) -> reshape(e, 1, :) .* xj, - g, +; xj = g.ndata.x, e = g.edata.e) - end - function spmm_fused(g) - propagate(e_mul_xj, - g, +; xj = g.ndata.x, e = g.edata.e) - end - - function spmm_fused2(g) - propagate(w_mul_xj, - g, +; xj = g.ndata.x) - end - - @test spmm_unfused(g) ≈ X * A - @test spmm_fused(g) ≈ X * A - @test spmm_fused2(g) ≈ X * A - end - - @testset "aggregate_neighbors" begin - @testset "sizecheck" begin - m = rand(2, g.num_edges - 1) - @test_throws AssertionError aggregate_neighbors(g, +, m) - - m = (a = rand(2, g.num_edges + 1), b = nothing) - @test_throws AssertionError aggregate_neighbors(g, +, m) - end - end - -end -[.\GNNlib\test\runtests.jl] -using GNNlib -using Test -using ReTestItems -using Random, Statistics - -runtests(GNNlib) - -[.\GNNlib\test\shared_testsetup.jl] -@testsetup module SharedTestSetup - -import Reexport: @reexport - -@reexport using GNNlib -@reexport using GNNGraphs -@reexport using NNlib -@reexport using MLUtils -@reexport using SparseArrays -@reexport using Test, Random, Statistics - -end -[.\GNNlib\test\utils_tests.jl] -@testitem "utils" setup=[SharedTestSetup] begin - # TODO test all graph types - GRAPH_T = :coo - De, Dx = 3, 2 - g = MLUtils.batch([rand_graph(10, 60, bidirected=true, - ndata = rand(Dx, 10), - edata = rand(De, 30), - graph_type = GRAPH_T) for i in 1:5]) - x = g.ndata.x - e = g.edata.e - - @testset "reduce_nodes" begin - r = reduce_nodes(mean, g, x) - @test size(r) == (Dx, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) - - r2 = reduce_nodes(mean, graph_indicator(g), x) - @test r2 == r - end - - @testset "reduce_edges" begin - r = reduce_edges(mean, g, e) - @test size(r) == (De, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) - end - - @testset "softmax_nodes" begin - r = softmax_nodes(g, x) - @test size(r) == size(x) - @test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) - end - - @testset "softmax_edges" begin - r = softmax_edges(g, e) - @test size(r) == size(e) - @test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) - end - - @testset "broadcast_nodes" begin - z = rand(4, g.num_graphs) - r = broadcast_nodes(g, z) - @test size(r) == (4, g.num_nodes) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 10] ≈ z[:, 1] - @test r[:, 11] ≈ z[:, 2] - end - - @testset "broadcast_edges" begin - z = rand(4, g.num_graphs) - r = broadcast_edges(g, z) - @test size(r) == (4, g.num_edges) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 60] ≈ z[:, 1] - @test r[:, 61] ≈ z[:, 2] - end - - @testset "softmax_edge_neighbors" begin - s = [1, 2, 3, 4] - t = [5, 5, 6, 6] - g2 = GNNGraph(s, t) - e2 = randn(Float32, 3, g2.num_edges) - z = softmax_edge_neighbors(g2, e2) - @test size(z) == size(e2) - @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) - @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) - end -end - - -[.\GNNLux\src\GNNLux.jl] -module GNNLux -using ConcreteStructs: @concrete -using NNlib: NNlib, sigmoid, relu, swish -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, - initialparameters, initialstates, parameterlength, statelength -using Lux: Lux, Chain, Dense, GRUCell, - glorot_uniform, zeros32, - StatefulLuxLayer -using Reexport: @reexport -using Random: AbstractRNG -using GNNlib: GNNlib -@reexport using GNNGraphs - -include("layers/basic.jl") -export GNNLayer, - GNNContainerLayer, - GNNChain - -include("layers/conv.jl") -export AGNNConv, - CGConv, - ChebConv, - EdgeConv, - EGNNConv, - DConv, - GATConv, - GATv2Conv, - GatedGraphConv, - GCNConv, - GINConv, - # GMMConv, - GraphConv, - # MEGNetConv, - NNConv, - # ResGatedGraphConv, - # SAGEConv, - SGConv - # TAGConv, - # TransformerConv - - -end #module - -[.\GNNLux\src\layers\basic.jl] -""" - abstract type GNNLayer <: AbstractExplicitLayer end - -An abstract type from which graph neural network layers are derived. -It is Derived from Lux's `AbstractExplicitLayer` type. - -See also [`GNNChain`](@ref GNNLux.GNNChain). -""" -abstract type GNNLayer <: AbstractExplicitLayer end - -abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end - -@concrete struct GNNChain <: GNNContainerLayer{(:layers,)} - layers <: NamedTuple -end - -GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...) - -function GNNChain(; kw...) - :layers in Base.keys(kw) && - throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) - nt = NamedTuple{keys(kw)}(values(kw)) - nt = map(_wrapforchain, nt) - return GNNChain(nt) -end - -_wrapforchain(l::AbstractExplicitLayer) = l -_wrapforchain(l) = Lux.WrappedFunction(l) - -Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers)) -Base.getindex(c::GNNChain, i::Int) = c.layers[i] -Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) - -function Base.getproperty(c::GNNChain, name::Symbol) - hasfield(typeof(c), name) && return getfield(c, name) - layers = getfield(c, :layers) - hasfield(typeof(layers), name) && return getfield(layers, name) - throw(ArgumentError("$(typeof(c)) has no field or layer $name")) -end - -Base.length(c::GNNChain) = length(c.layers) -Base.lastindex(c::GNNChain) = lastindex(c.layers) -Base.firstindex(c::GNNChain) = firstindex(c.layers) - -LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end]) - -(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st) - -function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times - newst = (;) - for (name, l) in pairs(layers) - x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name)) - newst = merge(newst, (; name => s′)) - end - return x, newst -end - -_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) -_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) -_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) -_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) - -[.\GNNLux\src\layers\conv.jl] -_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false -_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple() -_getstate(s::StatefulLuxLayer{true}) = s.st -_getstate(s::StatefulLuxLayer{false}) = s.st_any - - -@concrete struct GCNConv <: GNNLayer - in_dims::Int - out_dims::Int - use_bias::Bool - add_self_loops::Bool - use_edge_weight::Bool - init_weight - init_bias - σ -end - -function GCNConv(ch::Pair{Int, Int}, σ = identity; - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true, - add_self_loops::Bool = true, - use_edge_weight::Bool = false, - allow_fast_activation::Bool = true) - in_dims, out_dims = ch - σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv) - weight = l.init_weight(rng, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight, bias) - else - return (; weight) - end -end - -LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims -LuxCore.outputsize(d::GCNConv) = (d.out_dims,) - -function Base.show(io::IO, l::GCNConv) - print(io, "GCNConv(", l.in_dims, " => ", l.out_dims) - l.σ == identity || print(io, ", ", l.σ) - l.use_bias || print(io, ", use_bias=false") - l.add_self_loops || print(io, ", add_self_loops=false") - !l.use_edge_weight || print(io, ", use_edge_weight=true") - print(io, ")") -end - -(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) = - l(g, x, edge_weight, ps, st; conv_weight, norm_fn) - -function (l::GCNConv)(g, x, edge_weight, ps, st; - norm_fn = d -> 1 ./ sqrt.(d), - conv_weight=nothing, ) - - m = (; ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.σ) - y = GNNlib.gcn_conv(m, g, x, edge_weight, norm_fn, conv_weight) - return y, st -end - -@concrete struct ChebConv <: GNNLayer - in_dims::Int - out_dims::Int - use_bias::Bool - k::Int - init_weight - init_bias -end - -function ChebConv(ch::Pair{Int, Int}, k::Int; - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true) - in_dims, out_dims = ch - return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv) - weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight, bias) - else - return (; weight) - end -end - -LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims : - l.in_dims * l.out_dims * l.k -LuxCore.statelength(d::ChebConv) = 0 -LuxCore.outputsize(d::ChebConv) = (d.out_dims,) - -function Base.show(io::IO, l::ChebConv) - print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", k=", l.k) - l.use_bias || print(io, ", use_bias=false") - print(io, ")") -end - -function (l::ChebConv)(g, x, ps, st) - m = (; ps.weight, bias = _getbias(ps), l.k) - y = GNNlib.cheb_conv(m, g, x) - return y, st - -end - -@concrete struct GraphConv <: GNNLayer - in_dims::Int - out_dims::Int - use_bias::Bool - init_weight - init_bias - σ - aggr -end - -function GraphConv(ch::Pair{Int, Int}, σ = identity; - aggr = +, - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true, - allow_fast_activation::Bool = true) - in_dims, out_dims = ch - σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv) - weight1 = l.init_weight(rng, l.out_dims, l.in_dims) - weight2 = l.init_weight(rng, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight1, weight2, bias) - else - return (; weight1, weight2) - end -end - -function LuxCore.parameterlength(l::GraphConv) - if l.use_bias - return 2 * l.in_dims * l.out_dims + l.out_dims - else - return 2 * l.in_dims * l.out_dims - end -end - -LuxCore.statelength(d::GraphConv) = 0 -LuxCore.outputsize(d::GraphConv) = (d.out_dims,) - -function Base.show(io::IO, l::GraphConv) - print(io, "GraphConv(", l.in_dims, " => ", l.out_dims) - (l.σ == identity) || print(io, ", ", l.σ) - (l.aggr == +) || print(io, ", aggr=", l.aggr) - l.use_bias || print(io, ", use_bias=false") - print(io, ")") -end - -function (l::GraphConv)(g, x, ps, st) - m = (; ps.weight1, ps.weight2, bias = _getbias(ps), - l.σ, l.aggr) - return GNNlib.graph_conv(m, g, x), st -end - - -@concrete struct AGNNConv <: GNNLayer - init_beta <: AbstractVector - add_self_loops::Bool - trainable::Bool -end - -function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) - return AGNNConv([init_beta], add_self_loops, trainable) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::AGNNConv) - if l.trainable - return (; β = l.init_beta) - else - return (;) - end -end - -LuxCore.parameterlength(l::AGNNConv) = l.trainable ? 1 : 0 -LuxCore.statelength(d::AGNNConv) = 0 - -function Base.show(io::IO, l::AGNNConv) - print(io, "AGNNConv(", l.init_beta) - l.add_self_loops || print(io, ", add_self_loops=false") - l.trainable || print(io, ", trainable=false") - print(io, ")") -end - -function (l::AGNNConv)(g, x::AbstractMatrix, ps, st) - β = l.trainable ? ps.β : l.init_beta - m = (; β, l.add_self_loops) - return GNNlib.agnn_conv(m, g, x), st -end - -@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)} - in_dims::NTuple{2, Int} - out_dims::Int - dense_f - dense_s - residual::Bool - init_weight - init_bias -end - -CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) - -function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, - use_bias = true, init_weight = glorot_uniform, init_bias = zeros32, - allow_fast_activation = true) - (nin, ein), out = ch - dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation) - dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation) - return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias) -end - -LuxCore.outputsize(l::CGConv) = (l.out_dims,) - -(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st) - -function (l::CGConv)(g, x, e, ps, st) - dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f)) - dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s)) - m = (; dense_f, dense_s, l.residual) - return GNNlib.cg_conv(m, g, x, e), st -end - -@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)} - nn <: AbstractExplicitLayer - aggr -end - -EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) - -function Base.show(io::IO, l::EdgeConv) - print(io, "EdgeConv(", l.nn) - print(io, ", aggr=", l.aggr) - print(io, ")") -end - - -function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) - m = (; nn, l.aggr) - y = GNNlib.edge_conv(m, g, x) - stnew = _getstate(nn) - return y, stnew -end - - -@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)} - ϕe - ϕx - ϕh - num_features - residual::Bool -end - -function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) - return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) -end - -#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py -function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], - residual = false) - (in_size, edge_feat_size), out_size = ch - act_fn = swish - - # +1 for the radial feature: ||x_i - x_j||^2 - ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), - Dense(hidden_size => hidden_size, act_fn)) - - ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish), - Dense(hidden_size => out_size)) - - ϕx = Chain(Dense(hidden_size => hidden_size, swish), - Dense(hidden_size => 1, use_bias = false)) - - num_features = (in = in_size, edge = edge_feat_size, out = out_size, - hidden = hidden_size) - if residual - @assert in_size==out_size "Residual connection only possible if in_size == out_size" - end - return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) -end - -LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,) - -(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st) - -function (l::EGNNConv)(g, h, x, e, ps, st) - ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) - ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx)) - ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh)) - m = (; ϕe, ϕx, ϕh, l.residual, l.num_features) - return GNNlib.egnn_conv(m, g, h, x, e), st -end - -function Base.show(io::IO, l::EGNNConv) - ne = l.num_features.edge - nin = l.num_features.in - nout = l.num_features.out - nh = l.num_features.hidden - print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") - if l.residual - print(io, ", residual=true") - end - print(io, ")") -end - -@concrete struct DConv <: GNNLayer - in_dims::Int - out_dims::Int - k::Int - init_weight - init_bias - use_bias::Bool -end - -function DConv(ch::Pair{Int, Int}, k::Int; - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias = true) - in, out = ch - return DConv(in, out, k, init_weight, init_bias, use_bias) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::DConv) - weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weights, bias) - else - return (; weights) - end -end - -LuxCore.outputsize(l::DConv) = (l.out_dims,) -LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims : - 2 * l.in_dims * l.out_dims * l.k - -function (l::DConv)(g, x, ps, st) - m = (; ps.weights, bias = _getbias(ps), l.k) - return GNNlib.d_conv(m, g, x), st -end - -function Base.show(io::IO, l::DConv) - print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))") -end - -@concrete struct GATConv <: GNNLayer - dense_x - dense_e - init_weight - init_bias - use_bias::Bool - σ - negative_slope - channel::Pair{NTuple{2, Int}, Int} - heads::Int - concat::Bool - add_self_loops::Bool - dropout -end - - -GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) - -function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; - heads::Int = 1, concat::Bool = true, negative_slope = 0.2, - init_weight = glorot_uniform, init_bias = zeros32, - use_bias::Bool = true, - add_self_loops = true, dropout=0.0) - (in, ein), out = ch - if add_self_loops - @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." - end - - dense_x = Dense(in => out * heads, use_bias = false) - dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing - negative_slope = convert(Float32, negative_slope) - return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias, - σ, negative_slope, ch, heads, concat, add_self_loops, dropout) -end - -LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) -##TODO: parameterlength - -function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv) - (in, ein), out = l.channel - dense_x = LuxCore.initialparameters(rng, l.dense_x) - a = l.init_weight(ein > 0 ? 3out : 2out, l.heads) - ps = (; dense_x, a) - if ein > 0 - ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) - end - if l.use_bias - ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) - end - return ps -end - -(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st) - -function (l::GATConv)(g, x, e, ps, st) - dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x)) - dense_e = l.dense_e === nothing ? nothing : - StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) - - m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, - ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope) - return GNNlib.gat_conv(m, g, x, e), st -end - -function Base.show(io::IO, l::GATConv) - (in, ein), out = l.channel - print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", negative_slope=", l.negative_slope) - print(io, ")") -end - -@concrete struct GATv2Conv <: GNNLayer - dense_i - dense_j - dense_e - init_weight - init_bias - use_bias::Bool - σ - negative_slope - channel::Pair{NTuple{2, Int}, Int} - heads::Int - concat::Bool - add_self_loops::Bool - dropout -end - -function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) - GATv2Conv((ch[1], 0) => ch[2], args...; kws...) -end - -function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, - σ = identity; - heads::Int = 1, - concat::Bool = true, - negative_slope = 0.2, - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true, - add_self_loops = true, - dropout=0.0) - - (in, ein), out = ch - - if add_self_loops - @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." - end - - dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias) - dense_j = Dense(in => out * heads; use_bias = false, init_weight) - if ein > 0 - dense_e = Dense(ein => out * heads; use_bias = false, init_weight) - else - dense_e = nothing - end - return GATv2Conv(dense_i, dense_j, dense_e, - init_weight, init_bias, use_bias, - σ, negative_slope, - ch, heads, concat, add_self_loops, dropout) -end - - -LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) -##TODO: parameterlength - -function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv) - (in, ein), out = l.channel - dense_i = LuxCore.initialparameters(rng, l.dense_i) - dense_j = LuxCore.initialparameters(rng, l.dense_j) - a = l.init_weight(out, l.heads) - ps = (; dense_i, dense_j, a) - if ein > 0 - ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) - end - if l.use_bias - ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) - end - return ps -end - -(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st) - -function (l::GATv2Conv)(g, x, e, ps, st) - dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i)) - dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j)) - dense_e = l.dense_e === nothing ? nothing : - StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) - - m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, - ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope) - return GNNlib.gatv2_conv(m, g, x, e), st -end - -function Base.show(io::IO, l::GATv2Conv) - (in, ein), out = l.channel - print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", negative_slope=", l.negative_slope) - print(io, ")") -end - -@concrete struct SGConv <: GNNLayer - in_dims::Int - out_dims::Int - k::Int - use_bias::Bool - add_self_loops::Bool - use_edge_weight::Bool - init_weight - init_bias -end - -function SGConv(ch::Pair{Int, Int}, k = 1; - init_weight = glorot_uniform, - init_bias = zeros32, - use_bias::Bool = true, - add_self_loops::Bool = true, - use_edge_weight::Bool = false) - in_dims, out_dims = ch - return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias) -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) - weight = l.init_weight(rng, l.out_dims, l.in_dims) - if l.use_bias - bias = l.init_bias(rng, l.out_dims) - return (; weight, bias) - else - return (; weight) - end -end - -LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims -LuxCore.outputsize(d::SGConv) = (d.out_dims,) - -function Base.show(io::IO, l::SGConv) - print(io, "SGConv(", l.in_dims, " => ", l.out_dims) - l.k || print(io, ", ", l.k) - l.use_bias || print(io, ", use_bias=false") - l.add_self_loops || print(io, ", add_self_loops=false") - !l.use_edge_weight || print(io, ", use_edge_weight=true") - print(io, ")") -end - -(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st) - -function (l::SGConv)(g, x, edge_weight, ps, st) - m = (; ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.k) - y = GNNlib.sg_conv(m, g, x, edge_weight) - return y, st -end - -@concrete struct GatedGraphConv <: GNNLayer - gru - init_weight - dims::Int - num_layers::Int - aggr -end - - -function GatedGraphConv(dims::Int, num_layers::Int; - aggr = +, init_weight = glorot_uniform) - gru = GRUCell(dims => dims) - return GatedGraphConv(gru, init_weight, dims, num_layers, aggr) -end - -LuxCore.outputsize(l::GatedGraphConv) = (l.dims,) - -function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv) - gru = LuxCore.initialparameters(rng, l.gru) - weight = l.init_weight(rng, l.dims, l.dims, l.num_layers) - return (; gru, weight) -end - -LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers - - -function (l::GatedGraphConv)(g, x, ps, st) - gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style - m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) - return GNNlib.gated_graph_conv(m, g, x), st -end - -function Base.show(io::IO, l::GatedGraphConv) - print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@concrete struct GINConv <: GNNContainerLayer{(:nn,)} - nn <: AbstractExplicitLayer - ϵ <: Real - aggr -end - -GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) - -function (l::GINConv)(g, x, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) - m = (; nn, l.ϵ, l.aggr) - y = GNNlib.gin_conv(m, g, x) - stnew = _getstate(nn) - return y, stnew -end - -function Base.show(io::IO, l::GINConv) - print(io, "GINConv($(l.nn)") - print(io, ", $(l.ϵ)") - print(io, ")") -end - -@concrete struct NNConv <: GNNContainerLayer{(:nn,)} - nn <: AbstractExplicitLayer - aggr - in_dims::Int - out_dims::Int - use_bias::Bool - add_self_loops::Bool - use_edge_weight::Bool - init_weight - init_bias - σ -end - - -function NNConv(ch::Pair{Int, Int}, nn, σ = identity; - aggr = +, - init_bias = zeros32, - use_bias::Bool = true, - init_weight = glorot_uniform, - add_self_loops::Bool = true, - use_edge_weight::Bool = false, - allow_fast_activation::Bool = true) - in_dims, out_dims = ch - σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) -end - -function (l::NNConv)(g, x, edge_weight, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) - - m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.σ) - y = GNNlib.nn_conv(m, g, x, edge_weight) - stnew = _getstate(nn) - return y, stnew -end - -LuxCore.outputsize(d::NNConv) = (d.out_dims,) - -function Base.show(io::IO, l::NNConv) - print(io, "NNConv($(l.nn)") - print(io, ", $(l.ϵ)") - l.σ == identity || print(io, ", ", l.σ) - l.use_bias || print(io, ", use_bias=false") - l.add_self_loops || print(io, ", add_self_loops=false") - !l.use_edge_weight || print(io, ", use_edge_weight=true") - print(io, ")") -end - -[.\GNNLux\test\runtests.jl] -using Test -using Lux -using GNNLux -using Random, Statistics - -using ReTestItems -# using Pkg, Preferences, Test -# using InteractiveUtils, Hwloc - -runtests(GNNLux) - -[.\GNNLux\test\shared_testsetup.jl] -@testsetup module SharedTestSetup - -import Reexport: @reexport - -@reexport using Test -@reexport using GNNLux -@reexport using Lux -@reexport using StableRNGs -@reexport using Random, Statistics - -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme - -export test_lux_layer - -function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; - outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2) - - if container - @test l isa GNNContainerLayer - else - @test l isa GNNLayer - end - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) - @test LuxCore.statelength(l) == LuxCore.statelength(st) - - y, st′ = l(g, x, ps, st) - @test eltype(y) == eltype(x) - if outputsize !== nothing - @test LuxCore.outputsize(l) == outputsize - end - if sizey !== nothing - @test size(y) == sizey - elseif outputsize !== nothing - @test size(y) == (outputsize..., g.num_nodes) - end - - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) -end - -end -[.\GNNLux\test\layers\basic_tests.jl] -@testitem "layers/basic" setup=[SharedTestSetup] begin - rng = StableRNG(17) - g = rand_graph(10, 40) - x = randn(rng, Float32, 3, 10) - - @testset "GNNLayer" begin - @test GNNLayer <: LuxCore.AbstractExplicitLayer - end - - @testset "GNNContainerLayer" begin - @test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer - end - - @testset "GNNChain" begin - @test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} - c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) - test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) - end -end - -[.\GNNLux\test\layers\conv_tests.jl] -@testitem "layers/conv" setup=[SharedTestSetup] begin - rng = StableRNG(1234) - g = rand_graph(10, 40) - in_dims = 3 - out_dims = 5 - x = randn(rng, Float32, in_dims, 10) - - @testset "GCNConv" begin - l = GCNConv(in_dims => out_dims, tanh) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "ChebConv" begin - l = ChebConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "GraphConv" begin - l = GraphConv(in_dims => out_dims, tanh) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "AGNNConv" begin - l = AGNNConv(init_beta=1.0f0) - test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) - end - - @testset "EdgeConv" begin - nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims)) - l = EdgeConv(nn, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) - end - - @testset "CGConv" begin - l = CGConv(in_dims => in_dims, residual = true) - test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) - end - - @testset "DConv" begin - l = DConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(5,)) - end - - @testset "EGNNConv" begin - hin = 6 - hout = 7 - hidden = 8 - l = EGNNConv(hin => hout, hidden) - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - h = randn(rng, Float32, hin, g.num_nodes) - (hnew, xnew), stnew = l(g, h, x, ps, st) - @test size(hnew) == (hout, g.num_nodes) - @test size(xnew) == (in_dims, g.num_nodes) - end - - @testset "GATConv" begin - x = randn(rng, Float32, 6, 10) - - l = GATConv(6 => 8, heads=2) - test_lux_layer(rng, l, g, x, outputsize=(16,)) - - l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5) - test_lux_layer(rng, l, g, x, outputsize=(8,)) - - #TODO test edge - end - - @testset "GATv2Conv" begin - x = randn(rng, Float32, 6, 10) - - l = GATv2Conv(6 => 8, heads=2) - test_lux_layer(rng, l, g, x, outputsize=(16,)) - - l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5) - test_lux_layer(rng, l, g, x, outputsize=(8,)) - - #TODO test edge - end - - @testset "SGConv" begin - l = SGConv(in_dims => out_dims, 2) - test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) - end - - @testset "GatedGraphConv" begin - l = GatedGraphConv(in_dims, 3) - test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) - end - - @testset "GINConv" begin - nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) - l = GINConv(nn, 0.5) - test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) - end - - @testset "NNConv" begin - edim = 10 - nn = Dense(edim, out_dims * in_dims) - l = NNConv(in_dims => out_dims, nn, tanh, bias = true, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) - end -end - -[.\perf\bench_gnn.jl] -using SparseArrays -using GraphNeuralNetworks -using BenchmarkTools -import Random: seed! -using LinearAlgebra - -n = 1024 -seed!(0) -A = sprand(n, n, 0.01) -b = rand(1, n) -B = rand(100, n) - -g = GNNGraph(A, - ndata = (; b = b, B = B), - edata = (; A = reshape(A.nzval, 1, :)), - graph_type = :coo) - -function spmv(g) - propagate((xi, xj, e) -> e .* xj, # same as e_mul_xj - g, +; xj = g.ndata.b, e = g.edata.A) -end - -function spmm1(g) - propagate((xi, xj, e) -> e .* xj, # same as e_mul_xj - g, +; xj = g.ndata.B, e = g.edata.A) -end -function spmm2(g) - propagate(e_mul_xj, - g, +; xj = g.ndata.B, e = vec(g.edata.A)) -end - -# @assert isequal(spmv(g), b * A) # true -# @btime spmv(g) # ~5 ms -# @btime b * A # ~32 us - -@assert isequal(spmm1(g), B * A) # true -@assert isequal(spmm2(g), B * A) # true -@btime spmm1(g) # ~9 ms -@btime spmm2(g) # ~9 ms -@btime B * A # ~400 us - -function spmm_copyxj_fused(g) - propagate(copy_xj, - g, +; xj = g.ndata.B) -end - -function spmm_copyxj_unfused(g) - propagate((xi, xj, e) -> xj, - g, +; xj = g.ndata.B) -end - -Adj = map(x -> x > 0 ? 1 : 0, A) -@assert spmm_copyxj_unfused(g) ≈ B * Adj -@assert spmm_copyxj_fused(g) ≈ B * Adj # bug fixed in #107 - -@btime spmm_copyxj_fused(g) # 268.614 μs (22 allocations: 1.13 MiB) -@btime spmm_copyxj_unfused(g) # 4.263 ms (52855 allocations: 12.23 MiB) -@btime B * Adj # 196.135 μs (2 allocations: 800.05 KiB) - -println() - -[.\perf\neural_ode_mnist.jl] -# Load the packages -using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations -using Flux: onehotbatch, onecold -using Flux.Losses: logitcrossentropy -using Flux -using Statistics: mean -using MLDatasets -using CUDA -# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux - -# device = cpu # `gpu` not working yet -device = CUDA.functional() ? gpu : cpu - -# LOAD DATA -X, y = MNIST(:train)[:] -y = onehotbatch(y, 0:9) - -# Define the Neural GDE -diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) - -nin, nhidden, nout = 28 * 28, 100, 10 -epochs = 10 - -node_chain = Chain(Dense(nhidden => nhidden, tanh), - Dense(nhidden => nhidden)) |> device - -node = NeuralODE(node_chain, - (0.0f0, 1.0f0), Tsit5(), save_everystep = false, - reltol = 1e-3, abstol = 1e-3, save_start = false) |> device - -model = Chain(Flux.flatten, - Dense(nin => nhidden, relu), - node, - diffeqsol_to_array, - Dense(nhidden, nout)) |> device - -# # Training - -# ## Optimizer -opt = Flux.setup(Adam(0.01), model) - -function eval_loss_accuracy(X, y) - ŷ = model(X) - l = logitcrossentropy(ŷ, y) - acc = mean(onecold(ŷ) .== onecold(y)) - return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) -end - -# ## Training Loop -for epoch in 1:epochs - grad = gradient(model) do model - ŷ = model(X) - logitcrossentropy(ŷ, y) - end - Flux.update!(opt, model, grad[1]) - @show eval_loss_accuracy(X, y) -end - -[.\perf\node_classification_cora_geometricflux.jl] -# An example of semi-supervised node classification - -using Flux -using Flux: onecold, onehotbatch -using Flux.Losses: logitcrossentropy -using GeometricFlux, GraphSignals -using MLDatasets: Cora -using Statistics, Random -using CUDA -CUDA.allowscalar(false) - -function eval_loss_accuracy(X, y, ids, model) - ŷ = model(X) - l = logitcrossentropy(ŷ[:, ids], y[:, ids]) - acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) - return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) -end - -# arguments for the `train` function -Base.@kwdef mutable struct Args - η = 1.0f-3 # learning rate - epochs = 100 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = true # if true use cuda (if available) - nhidden = 128 # dimension of hidden features - infotime = 10 # report every `infotime` epochs -end - -function train(; kws...) - args = Args(; kws...) - - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = gpu - args.seed > 0 && CUDA.seed!(args.seed) - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end - - # LOAD DATA - data = Cora.dataset() - g = FeaturedGraph(data.adjacency_list) |> device - X = data.node_features |> device - y = onehotbatch(data.node_labels, 1:(data.num_classes)) |> device - train_ids = data.train_indices |> device - val_ids = data.val_indices |> device - test_ids = data.test_indices |> device - ytrain = y[:, train_ids] - - nin, nhidden, nout = size(X, 1), args.nhidden, data.num_classes - - ## DEFINE MODEL - model = Chain(GCNConv(g, nin => nhidden, relu), - Dropout(0.5), - GCNConv(g, nhidden => nhidden, relu), - Dense(nhidden, nout)) |> device - - opt = Flux.setup(Adam(args.η), model) - - @info g - - ## LOGGING FUNCTION - function report(epoch) - train = eval_loss_accuracy(X, y, train_ids, model) - test = eval_loss_accuracy(X, y, test_ids, model) - println("Epoch: $epoch Train: $(train) Test: $(test)") - end - - ## TRAINING - report(0) - for epoch in 1:(args.epochs) - grad = Flux.gradient(model) do model - ŷ = model(X) - logitcrossentropy(ŷ[:, train_ids], ytrain) - end - - Flux.update!(opt, model, grad[1]) - - epoch % args.infotime == 0 && report(epoch) - end -end - -train(usecuda = false) - -[.\perf\perf.jl] -using Flux, GraphNeuralNetworks, Graphs, BenchmarkTools, CUDA -using DataFrames, Statistics, JLD2, SparseArrays -CUDA.device!(2) -CUDA.allowscalar(false) - -BenchmarkTools.ratio(::Missing, x) = Inf -BenchmarkTools.ratio(x, ::Missing) = 0.0 -BenchmarkTools.ratio(::Missing, ::Missing) = missing - -function run_single_benchmark(N, c, D, CONV; gtype = :lg) - data = erdos_renyi(N, c / (N - 1), seed = 17) - X = randn(Float32, D, N) - - g = GNNGraph(data; ndata = X, graph_type = gtype) - g_gpu = g |> gpu - - m = CONV(D => D) - ps = Flux.params(m) - - m_gpu = m |> gpu - ps_gpu = Flux.params(m_gpu) - - res = Dict() - - res["CPU_FWD"] = @benchmark $m($g) - res["CPU_GRAD"] = @benchmark gradient(() -> sum($m($g).ndata.x), $ps) - - try - res["GPU_FWD"] = @benchmark CUDA.@sync($m_gpu($g_gpu)) teardown=(GC.gc(); CUDA.reclaim()) - catch - res["GPU_FWD"] = missing - end - - try - res["GPU_GRAD"] = @benchmark CUDA.@sync(gradient(() -> sum($m_gpu($g_gpu).ndata.x), - $ps_gpu)) teardown=(GC.gc(); CUDA.reclaim()) - catch - res["GPU_GRAD"] = missing - end - - return res -end - -""" - run_benchmarks(; - Ns = [10, 100, 1000, 10000], - c = 6, - D = 100, - layers = [GCNConv, GraphConv, GATConv] - ) - -Benchmark GNN layers on Erdos-Renyi random graphs -with average degree `c`. Benchmarks are performed for each graph size in the list `Ns`. -`D` is the number of node features. -""" -function run_benchmarks(; - Ns = [10, 100, 1000, 10000], - c = 6, - D = 100, - layers = [GCNConv, GATConv], - gtypes = [:coo, :sparse, :dense]) - df = DataFrame(N = Int[], c = Float64[], layer = String[], gtype = Symbol[], - time_cpu = Any[], time_gpu = Any[]) |> allowmissing - - for gtype in gtypes - for N in Ns - println("## GRAPH_TYPE = $gtype N = $N") - for CONV in layers - res = run_single_benchmark(N, c, D, CONV; gtype) - row = (; layer = "$CONV", - N = N, - c = c, - gtype = gtype, - time_cpu = ismissing(res["CPU"]) ? missing : median(res["CPU"]), - time_gpu = ismissing(res["GPU"]) ? missing : median(res["GPU"])) - push!(df, row) - end - end - end - - df.gpu_to_cpu = ratio.(df.time_gpu, df.time_cpu) - sort!(df, [:layer, :N, :c, :gtype]) - return df -end - -# df = run_benchmarks() -# for g in groupby(df, :layer); println(g, "\n"); end - -# @save "perf/perf_master_20210803_carlo.jld2" dfmaster=df -## or -# @save "perf/perf_pr.jld2" dfpr=df - -function compare(dfpr, dfmaster; on = [:N, :c, :gtype, :layer]) - df = outerjoin(dfpr, dfmaster; on = on, makeunique = true, - renamecols = :_pr => :_master) - df.pr_to_master_cpu = ratio.(df.time_cpu_pr, df.time_cpu_master) - df.pr_to_master_gpu = ratio.(df.time_gpu_pr, df.time_gpu_master) - return df[:, [:N, :c, :gtype, :layer, :pr_to_master_cpu, :pr_to_master_gpu]] -end - -# @load "perf/perf_pr.jld2" dfpr -# @load "perf/perf_master.jld2" dfmaster -# compare(dfpr, dfmaster) - -[.\src\deprecations.jl] - -# V1.0 deprecations -# TODO doe some reason this is not working -# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight) -# @deprecate (l::GNNLayer)(gs::AbstractVector{<:GNNGraph}, args...; kws...) l(MLUtils.batch(gs), args...; kws...) -[.\src\GraphNeuralNetworks.jl] -module GraphNeuralNetworks - -using Statistics: mean -using LinearAlgebra, Random -using Flux -using Flux: glorot_uniform, leakyrelu, GRUCell, batch -using MacroTools: @forward -using NNlib -using NNlib: scatter, gather -using ChainRulesCore -using Reexport -using MLUtils: zeros_like - -using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, - check_num_nodes, check_num_edges, - EType, NType # for heteroconvs - -@reexport using GNNGraphs -@reexport using GNNlib - -include("layers/basic.jl") -export GNNLayer, - GNNChain, - WithGraph, - DotDecoder - -include("layers/conv.jl") -export AGNNConv, - CGConv, - ChebConv, - DConv, - EdgeConv, - EGNNConv, - GATConv, - GATv2Conv, - GatedGraphConv, - GCNConv, - GINConv, - GMMConv, - GraphConv, - MEGNetConv, - NNConv, - ResGatedGraphConv, - SAGEConv, - SGConv, - TAGConv, - TransformerConv - -include("layers/heteroconv.jl") -export HeteroGraphConv - -include("layers/temporalconv.jl") -export TGCN, - A3TGCN, - GConvLSTM, - GConvGRU, - DCGRU - -include("layers/pool.jl") -export GlobalPool, - GlobalAttentionPool, - Set2Set, - TopKPool, - topk_index - -include("deprecations.jl") - -end - -[.\src\layers\basic.jl] -""" - abstract type GNNLayer end - -An abstract type from which graph neural network layers are derived. - -See also [`GNNChain`](@ref). -""" -abstract type GNNLayer end - -# Forward pass with graph-only input. -# To be specialized by layers also needing edge features as input (e.g. NNConv). -(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) - -""" - WithGraph(model, g::GNNGraph; traingraph=false) - -A type wrapping the `model` and tying it to the graph `g`. -In the forward pass, can only take feature arrays as inputs, -returning `model(g, x...; kws...)`. - -If `traingraph=false`, the graph's parameters won't be part of -the `trainable` parameters in the gradient updates. - -# Examples - -```julia -g = GNNGraph([1,2,3], [2,3,1]) -x = rand(Float32, 2, 3) -model = SAGEConv(2 => 3) -wg = WithGraph(model, g) -# No need to feed the graph to `wg` -@assert wg(x) == model(g, x) - -g2 = GNNGraph([1,1,2,3], [2,4,1,1]) -x2 = rand(Float32, 2, 4) -# WithGraph will ignore the internal graph if fed with a new one. -@assert wg(g2, x2) == model(g2, x2) -``` -""" -struct WithGraph{M, G <: GNNGraph} - model::M - g::G - traingraph::Bool -end - -WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph) - -Flux.@layer :expand WithGraph -Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model) - -(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...) -(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...) - -""" - GNNChain(layers...) - GNNChain(name = layer, ...) - -Collects multiple layers / functions to be called in sequence -on given input graph and input node features. - -It allows to compose layers in a sequential fashion as `Flux.Chain` -does, propagating the output of each layer to the next one. -In addition, `GNNChain` handles the input graph as well, providing it -as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type. - -`GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`, -and if names are given, `m[:name] == m[1]` etc. - -# Examples - -```jldoctest -julia> using Flux, GraphNeuralNetworks - -julia> m = GNNChain(GCNConv(2=>5), - BatchNorm(5), - x -> relu.(x), - Dense(5, 4)) -GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)) - -julia> x = randn(Float32, 2, 3); - -julia> g = rand_graph(3, 6) -GNNGraph: - num_nodes = 3 - num_edges = 6 - -julia> m(g, x) -4×3 Matrix{Float32}: - -0.795592 -0.795592 -0.795592 - -0.736409 -0.736409 -0.736409 - 0.994925 0.994925 0.994925 - 0.857549 0.857549 0.857549 - -julia> m2 = GNNChain(enc = m, - dec = DotDecoder()) -GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder()) - -julia> m2(g, x) -1×6 Matrix{Float32}: - 2.90053 2.90053 2.90053 2.90053 2.90053 2.90053 - -julia> m2[:enc](g, x) == m(g, x) -true -``` -""" -struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer - layers::T -end - -Flux.@layer :expand GNNChain - -GNNChain(xs...) = GNNChain(xs) - -function GNNChain(; kw...) - :layers in Base.keys(kw) && - throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) - isempty(kw) && return GNNChain(()) - GNNChain(values(kw)) -end - -@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last, - Base.iterate, Base.lastindex, Base.keys, Base.firstindex - -(c::GNNChain)(g::GNNGraph, x) = _applychain(c.layers, g, x) -(c::GNNChain)(g::GNNGraph) = _applychain(c.layers, g) - -## TODO see if this is faster for small chains -## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180 -# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} -# symbols = vcat(:x, [gensym() for _ in 1:N]) -# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N] -# Expr(:block, calls...) -# end -# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x) - -function _applychain(layers, g::GNNGraph, x) # type-unstable path, helps compile times - for l in layers - x = _applylayer(l, g, x) - end - return x -end - -function _applychain(layers, g::GNNGraph) # type-unstable path, helps compile times - for l in layers - g = _applylayer(l, g) - end - return g -end - -# # explicit input -_applylayer(l, g::GNNGraph, x) = l(x) -_applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) - -# input from graph -_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(node_features(g))) -_applylayer(l::GNNLayer, g::GNNGraph) = l(g) - -# # Handle Flux.Parallel -function _applylayer(l::Parallel, g::GNNGraph) - GNNGraph(g, ndata = _applylayer(l, g, node_features(g))) -end - -function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) - closures = map(f -> (x -> _applylayer(f, g, x)), l.layers) - return Parallel(l.connection, closures)(x) -end - -Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]) -function Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) - GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) -end - -function Base.show(io::IO, c::GNNChain) - print(io, "GNNChain(") - _show_layers(io, c.layers) - print(io, ")") -end - -_show_layers(io, layers::Tuple) = join(io, layers, ", ") -function _show_layers(io, layers::NamedTuple) - join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") -end -function _show_layers(io, layers::AbstractVector) - (print(io, "["); join(io, layers, ", "); print(io, "]")) -end - -""" - DotDecoder() - -A graph neural network layer that -for given input graph `g` and node features `x`, -returns the dot product `x_i ⋅ xj` on each edge. - -# Examples - -```jldoctest -julia> g = rand_graph(5, 6) -GNNGraph: - num_nodes = 5 - num_edges = 6 - -julia> dotdec = DotDecoder() -DotDecoder() - -julia> dotdec(g, rand(2, 5)) -1×6 Matrix{Float64}: - 0.345098 0.458305 0.106353 0.345098 0.458305 0.106353 -``` -""" -struct DotDecoder <: GNNLayer end - -(::DotDecoder)(g, x) = GNNlib.dot_decoder(g, x) - -[.\src\layers\conv.jl] -@doc raw""" - GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight]) - -Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907). - -Performs the operation -```math -\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j -``` -where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees. - -If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as -```math -a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}} -``` - -The input to the layer is a node feature array `X` of size `(num_features, num_nodes)` -and optionally an edge weight vector. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `σ`: Activation function. Default `identity`. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. -- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). - If `add_self_loops=true` the new weights will be set to 1. - This option is ignored if the `edge_weight` is explicitly provided in the forward pass. - Default `false`. - -# Forward - - (::GCNConv)(g::GNNGraph, x, edge_weight = nothing; norm_fn = d -> 1 ./ sqrt.(d), conv_weight = nothing) -> AbstractMatrix - -Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, -and optionally an edge weight vector. Returns a node feature matrix of size -`[out, num_nodes]`. - -The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument. -By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph. -If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix instead of the weights stored in the model. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = GCNConv(3 => 5) - -# forward pass -y = l(g, x) # size: 5 × num_nodes - -# convolution with edge weights and custom normalization function -w = [1.1, 0.1, 2.3, 0.5] -custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function -y = l(g, x, w; norm_fn = custom_norm_fn) - -# Edge weights can also be embedded in the graph. -g = GNNGraph(s, t, w) -l = GCNConv(3 => 5, use_edge_weight=true) -y = l(g, x) # same as l(g, x, w) -``` -""" -struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer - weight::W - bias::B - σ::F - add_self_loops::Bool - use_edge_weight::Bool -end - -Flux.@layer GCNConv - -function GCNConv(ch::Pair{Int, Int}, σ = identity; - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - use_edge_weight = false) - in, out = ch - W = init(out, in) - b = bias ? Flux.create_bias(W, true, out) : false - GCNConv(W, b, σ, add_self_loops, use_edge_weight) -end - - -function (l::GCNConv)(g, x, edge_weight = nothing; - norm_fn = d -> 1 ./ sqrt.(d), - conv_weight = nothing) - - return GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) -end - - -function Base.show(io::IO, l::GCNConv) - out, in = size(l.weight) - print(io, "GCNConv($in => $out") - l.σ == identity || print(io, ", ", l.σ) - print(io, ")") -end - -@doc raw""" - ChebConv(in => out, k; bias=true, init=glorot_uniform) - -Chebyshev spectral graph convolutional layer from -paper [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375). - -Implements - -```math -X' = \sum^{K-1}_{k=0} W^{(k)} Z^{(k)} -``` - -where ``Z^{(k)}`` is the ``k``-th term of Chebyshev polynomials, and can be calculated by the following recursive form: - -```math -\begin{aligned} -Z^{(0)} &= X \\ -Z^{(1)} &= \hat{L} X \\ -Z^{(k)} &= 2 \hat{L} Z^{(k-1)} - Z^{(k-2)} -\end{aligned} -``` - -with ``\hat{L}`` the [`scaled_laplacian`](@ref). - -# Arguments - -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `k`: The order of Chebyshev polynomial. -- `bias`: Add learnable bias. -- `init`: Weights' initializer. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = ChebConv(3 => 5, 5) - -# forward pass -y = l(g, x) # size: 5 × num_nodes -``` -""" -struct ChebConv{W <: AbstractArray{<:Number, 3}, B} <: GNNLayer - weight::W - bias::B - k::Int -end - -function ChebConv(ch::Pair{Int, Int}, k::Int; - init = glorot_uniform, bias::Bool = true) - in, out = ch - W = init(out, in, k) - b = bias ? Flux.create_bias(W, true, out) : false - ChebConv(W, b, k) -end - -Flux.@layer ChebConv - -(l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x) - -function Base.show(io::IO, l::ChebConv) - out, in, k = size(l.weight) - print(io, "ChebConv(", in, " => ", out) - print(io, ", k=", k) - print(io, ")") -end - -@doc raw""" - GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform) - -Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244). - -Performs: -```math -\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j -``` - -where the aggregation type is selected by `aggr`. - -# Arguments - -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `σ`: Activation function. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -- `bias`: Add learnable bias. -- `init`: Weights' initializer. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) - -# forward pass -y = l(g, x) -``` -""" -struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer - weight1::W - weight2::W - bias::B - σ::F - aggr::A -end - -Flux.@layer GraphConv - -function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, - init = glorot_uniform, bias::Bool = true) - in, out = ch - W1 = init(out, in) - W2 = init(out, in) - b = bias ? Flux.create_bias(W1, true, out) : false - GraphConv(W1, W2, b, σ, aggr) -end - -(l::GraphConv)(g, x) = GNNlib.graph_conv(l, g, x) - -function Base.show(io::IO, l::GraphConv) - in_channel = size(l.weight1, ndims(l.weight1)) - out_channel = size(l.weight1, ndims(l.weight1) - 1) - print(io, "GraphConv(", in_channel, " => ", out_channel) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@doc raw""" - GATConv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops]) - GATConv((in, ein) => out, ...) - -Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.org/abs/1710.10903). - -Implements the operation -```math -\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j -``` -where the attention coefficients ``\alpha_{ij}`` are given by -```math -\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i; W \mathbf{x}_j])) -``` -with ``z_i`` a normalization factor. - -In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass -and the attention coefficients will be calculated as -```math -\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W_e \mathbf{e}_{j\to i}; W \mathbf{x}_i; W \mathbf{x}_j])) -``` - -# Arguments - -- `in`: The dimension of input node features. -- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward). -- `out`: The dimension of output node features. -- `σ`: Activation function. Default `identity`. -- `bias`: Learn the additive bias if true. Default `true`. -- `heads`: Number attention heads. Default `1`. -- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`. -- `negative_slope`: The parameter of LeakyReLU.Default `0.2`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. -- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; heads=2, concat=true) - -# forward pass -y = l(g, x) -``` -""" -struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer - dense_x::DX - dense_e::DE - bias::B - a::A - σ::F - negative_slope::T - channel::Pair{NTuple{2, Int}, Int} - heads::Int - concat::Bool - add_self_loops::Bool - dropout::DV -end - -Flux.@layer GATConv -Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a) - -GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) - -function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; - heads::Int = 1, concat::Bool = true, negative_slope = 0.2, - init = glorot_uniform, bias::Bool = true, add_self_loops = true, dropout=0.0) - (in, ein), out = ch - if add_self_loops - @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." - end - - dense_x = Dense(in, out * heads, bias = false) - dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing - b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false - a = init(ein > 0 ? 3out : 2out, heads) - negative_slope = convert(Float32, negative_slope) - GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout) -end - -(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -(l::GATConv)(g, x, e = nothing) = GNNlib.gat_conv(l, g, x, e) - -function Base.show(io::IO, l::GATConv) - (in, ein), out = l.channel - print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", negative_slope=", l.negative_slope) - print(io, ")") -end - -@doc raw""" - GATv2Conv(in => out, [σ; heads, concat, init, bias, negative_slope, add_self_loops]) - GATv2Conv((in, ein) => out, ...) - - -GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491). - -Implements the operation -```math -\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j -``` -where the attention coefficients ``\alpha_{ij}`` are given by -```math -\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_2 \mathbf{x}_i + W_1 \mathbf{x}_j)) -``` -with ``z_i`` a normalization factor. - -In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass -and the attention coefficients will be calculated as -```math -\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU(W_3 \mathbf{e}_{j\to i} + W_2 \mathbf{x}_i + W_1 \mathbf{x}_j)). -``` - -# Arguments - -- `in`: The dimension of input node features. -- `ein`: The dimension of input edge features. Default 0 (i.e. no edge features passed in the forward). -- `out`: The dimension of output node features. -- `σ`: Activation function. Default `identity`. -- `bias`: Learn the additive bias if true. Default `true`. -- `heads`: Number attention heads. Default `1`. -- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`. -- `negative_slope`: The parameter of LeakyReLU.Default `0.2`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. -- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`. - -# Examples -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -ein = 3 -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false) - -# edge features -e = randn(Float32, ein, length(s)) - -# forward pass -y = l(g, x, e) -``` -""" -struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer - dense_i::A1 - dense_j::A2 - dense_e::A3 - bias::B - a::C - σ::F - negative_slope::T - channel::Pair{NTuple{2, Int}, Int} - heads::Int - concat::Bool - add_self_loops::Bool - dropout::DV -end - -Flux.@layer GATv2Conv -Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) - -function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) - GATv2Conv((ch[1], 0) => ch[2], args...; kws...) -end - -function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, - σ = identity; - heads::Int = 1, - concat::Bool = true, - negative_slope = 0.2, - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - dropout=0.0) - (in, ein), out = ch - - if add_self_loops - @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." - end - - dense_i = Dense(in, out * heads; bias = bias, init = init) - dense_j = Dense(in, out * heads; bias = false, init = init) - if ein > 0 - dense_e = Dense(ein, out * heads; bias = false, init = init) - else - dense_e = nothing - end - b = bias ? Flux.create_bias(dense_i.weight, true, concat ? out * heads : out) : false - a = init(out, heads) - return GATv2Conv(dense_i, dense_j, dense_e, - b, a, σ, negative_slope, ch, heads, concat, - add_self_loops, dropout) -end - -(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -(l::GATv2Conv)(g, x, e=nothing) = GNNlib.gatv2_conv(l, g, x, e) - -function Base.show(io::IO, l::GATv2Conv) - (in, ein), out = l.channel - print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", negative_slope=", l.negative_slope) - print(io, ")") -end - -@doc raw""" - GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform) - -Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493). - -Implements the recursion -```math -\begin{aligned} -\mathbf{h}^{(0)}_i &= [\mathbf{x}_i; \mathbf{0}] \\ -\mathbf{h}^{(l)}_i &= GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j) -\end{aligned} -``` - -where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing through GRU. The dimension of input ``\mathbf{x}_i`` needs to be less or equal to `out`. - -# Arguments - -- `out`: The dimension of output features. -- `num_layers`: The number of recursion steps. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -- `init`: Weight initialization function. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -out_channel = 5 -num_layers = 3 -g = GNNGraph(s, t) - -# create layer -l = GatedGraphConv(out_channel, num_layers) - -# forward pass -y = l(g, x) -``` -""" -struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer - weight::W - gru::R - dims::Int - num_layers::Int - aggr::A -end - -Flux.@layer GatedGraphConv - -function GatedGraphConv(dims::Int, num_layers::Int; - aggr = +, init = glorot_uniform) - w = init(dims, dims, num_layers) - gru = GRUCell(dims => dims) - GatedGraphConv(w, gru, dims, num_layers, aggr) -end - - -(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H) - -function Base.show(io::IO, l::GatedGraphConv) - print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@doc raw""" - EdgeConv(nn; aggr=max) - -Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829). - -Performs the operation -```math -\mathbf{x}_i' = \square_{j \in N(i)}\, nn([\mathbf{x}_i; \mathbf{x}_j - \mathbf{x}_i]) -``` - -where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron. - -# Arguments - -- `nn`: A (possibly learnable) function. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) - -# create layer -l = EdgeConv(Dense(2 * in_channel, out_channel), aggr = +) - -# forward pass -y = l(g, x) -``` -""" -struct EdgeConv{NN, A} <: GNNLayer - nn::NN - aggr::A -end - -Flux.@layer :expand EdgeConv - -EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) - -(l::EdgeConv)(g, x) = GNNlib.edge_conv(l, g, x) - -function Base.show(io::IO, l::EdgeConv) - print(io, "EdgeConv(", l.nn) - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@doc raw""" - GINConv(f, ϵ; aggr=+) - -Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf). - -Implements the graph convolution -```math -\mathbf{x}_i' = f_\Theta\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right) -``` -where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron. - -# Arguments - -- `f`: A (possibly learnable) function acting on node features. -- `ϵ`: Weighting factor. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) - -# create dense layer -nn = Dense(in_channel, out_channel) - -# create layer -l = GINConv(nn, 0.01f0, aggr = mean) - -# forward pass -y = l(g, x) -``` -""" -struct GINConv{R <: Real, NN, A} <: GNNLayer - nn::NN - ϵ::R - aggr::A -end - -Flux.@layer :expand GINConv -Flux.trainable(l::GINConv) = (nn = l.nn,) - -GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) - -(l::GINConv)(g, x) = GNNlib.gin_conv(l, g, x) - -function Base.show(io::IO, l::GINConv) - print(io, "GINConv($(l.nn)") - print(io, ", $(l.ϵ)") - print(io, ")") -end - -@doc raw""" - NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform) - -The continuous kernel-based convolutional operator from the -[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper. -This convolution is also known as the edge-conditioned convolution from the -[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper. - -Performs the operation - -```math -\mathbf{x}_i' = W \mathbf{x}_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j -``` - -where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron). -Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`, -the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`. -For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed. - -# Arguments - -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `f`: A (possibly learnable) function acting on edge features. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -- `σ`: Activation function. -- `bias`: Add learnable bias. -- `init`: Weights' initializer. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -edim = 10 -g = GNNGraph(s, t) - -# create dense layer -nn = Dense(edim => out_channel * in_channel) - -# create layer -l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) - -# forward pass -y = l(g, x) -``` -""" -struct NNConv{W, B, NN, F, A} <: GNNLayer - weight::W - bias::B - nn::NN - σ::F - aggr::A -end - -Flux.@layer :expand NNConv - -function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, - init = glorot_uniform) - in, out = ch - W = init(out, in) - b = bias ? Flux.create_bias(W, true, out) : false - return NNConv(W, b, nn, σ, aggr) -end - -(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) - -(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -function Base.show(io::IO, l::NNConv) - out, in = size(l.weight) - print(io, "NNConv($in => $out") - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@doc raw""" - SAGEConv(in => out, σ=identity; aggr=mean, bias=true, init=glorot_uniform) - -GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf). - -Performs: -```math -\mathbf{x}_i' = W \cdot [\mathbf{x}_i; \square_{j \in \mathcal{N}(i)} \mathbf{x}_j] -``` - -where the aggregation type is selected by `aggr`. - -# Arguments - -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `σ`: Activation function. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -- `bias`: Add learnable bias. -- `init`: Weights' initializer. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) - -# create layer -l = SAGEConv(in_channel => out_channel, tanh, bias = false, aggr = +) - -# forward pass -y = l(g, x) -``` -""" -struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer - weight::W - bias::B - σ::F - aggr::A -end - -Flux.@layer SAGEConv - -function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, - init = glorot_uniform, bias::Bool = true) - in, out = ch - W = init(out, 2 * in) - b = bias ? Flux.create_bias(W, true, out) : false - SAGEConv(W, b, σ, aggr) -end - -(l::SAGEConv)(g, x) = GNNlib.sage_conv(l, g, x) - -function Base.show(io::IO, l::SAGEConv) - out_channel, in_channel = size(l.weight) - print(io, "SAGEConv(", in_channel ÷ 2, " => ", out_channel) - l.σ == identity || print(io, ", ", l.σ) - print(io, ", aggr=", l.aggr) - print(io, ")") -end - -@doc raw""" - ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true) - -The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) paper. - -The layer's forward pass is given by - -```math -\mathbf{x}_i' = act\big(U\mathbf{x}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big), -``` -where the edge gates ``\eta_{ij}`` are given by - -```math -\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j). -``` - -# Arguments - -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `act`: Activation function. -- `init`: Weight matrices' initializing function. -- `bias`: Learn an additive bias if true. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) - -# create layer -l = ResGatedGraphConv(in_channel => out_channel, tanh, bias = true) - -# forward pass -y = l(g, x) -``` -""" -struct ResGatedGraphConv{W, B, F} <: GNNLayer - A::W - B::W - U::W - V::W - bias::B - σ::F -end - -Flux.@layer ResGatedGraphConv - -function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; - init = glorot_uniform, bias::Bool = true) - in, out = ch - A = init(out, in) - B = init(out, in) - U = init(out, in) - V = init(out, in) - b = bias ? Flux.create_bias(A, true, out) : false - return ResGatedGraphConv(A, B, U, V, b, σ) -end - -(l::ResGatedGraphConv)(g, x) = GNNlib.res_gated_graph_conv(l, g, x) - -function Base.show(io::IO, l::ResGatedGraphConv) - out_channel, in_channel = size(l.A) - print(io, "ResGatedGraphConv(", in_channel, " => ", out_channel) - l.σ == identity || print(io, ", ", l.σ) - print(io, ")") -end - -@doc raw""" - CGConv((in, ein) => out, act=identity; bias=true, init=glorot_uniform, residual=false) - CGConv(in => out, ...) - -The crystal graph convolutional layer from the paper -[Crystal Graph Convolutional Neural Networks for an Accurate and -Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf). -Performs the operation - -```math -\mathbf{x}_i' = \mathbf{x}_i + \sum_{j\in N(i)}\sigma(W_f \mathbf{z}_{ij} + \mathbf{b}_f)\, act(W_s \mathbf{z}_{ij} + \mathbf{b}_s) -``` - -where ``\mathbf{z}_{ij}`` is the node and edge features concatenation -``[\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{j\to i}]`` -and ``\sigma`` is the sigmoid function. -The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same -as the input size. - -# Arguments - -- `in`: The dimension of input node features. -- `ein`: The dimension of input edge features. -If `ein` is not given, assumes that no edge features are passed as input in the forward pass. -- `out`: The dimension of output node features. -- `act`: Activation function. -- `bias`: Add learnable bias. -- `init`: Weights' initializer. -- `residual`: Add a residual connection. - -# Examples - -```julia -g = rand_graph(5, 6) -x = rand(Float32, 2, g.num_nodes) -e = rand(Float32, 3, g.num_edges) - -l = CGConv((2, 3) => 4, tanh) -y = l(g, x, e) # size: (4, num_nodes) - -# No edge features -l = CGConv(2 => 4, tanh) -y = l(g, x) # size: (4, num_nodes) -``` -""" -struct CGConv{D1, D2} <: GNNLayer - ch::Pair{NTuple{2, Int}, Int} - dense_f::D1 - dense_s::D2 - residual::Bool -end - -Flux.@layer CGConv - -CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) - -function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, - bias = true, init = glorot_uniform) - (nin, ein), out = ch - dense_f = Dense(2nin + ein, out, sigmoid; bias, init) - dense_s = Dense(2nin + ein, out, act; bias, init) - return CGConv(ch, dense_f, dense_s, residual) -end - -(l::CGConv)(g, x, e = nothing) = GNNlib.cg_conv(l, g, x, e) - - -(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -function Base.show(io::IO, l::CGConv) - print(io, "CGConv($(l.ch)") - l.dense_s.σ == identity || print(io, ", ", l.dense_s.σ) - print(io, ", residual=$(l.residual)") - print(io, ")") -end - -@doc raw""" - AGNNConv(; init_beta=1.0f0, trainable=true, add_self_loops=true) - -Attention-based Graph Neural Network layer from paper [Attention-based -Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735). - -The forward pass is given by -```math -\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} \mathbf{x}_j -``` -where the attention coefficients ``\alpha_{ij}`` are given by -```math -\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}} - {\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_{j'})}} -``` -with the cosine distance defined by -```math -\cos(\mathbf{x}_i, \mathbf{x}_j) = - \frac{\mathbf{x}_i \cdot \mathbf{x}_j}{\lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert} -``` -and ``\beta`` a trainable parameter if `trainable=true`. - -# Arguments - -- `init_beta`: The initial value of ``\beta``. Default 1.0f0. -- `trainable`: If true, ``\beta`` is trainable. Default `true`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. - -# Examples: - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -g = GNNGraph(s, t) - -# create layer -l = AGNNConv(init_beta=2.0f0) - -# forward pass -y = l(g, x) -``` -""" -struct AGNNConv{A <: AbstractVector} <: GNNLayer - β::A - add_self_loops::Bool - trainable::Bool -end - -Flux.@layer AGNNConv - -Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;) - -function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) - AGNNConv([init_beta], add_self_loops, trainable) -end - -(l::AGNNConv)(g, x) = GNNlib.agnn_conv(l, g, x) - -@doc raw""" - MEGNetConv(ϕe, ϕv; aggr=mean) - MEGNetConv(in => out; aggr=mean) - -Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf) -paper. In the forward pass, takes as inputs node features `x` and edge features `e` and returns -updated features `x'` and `e'` according to - -```math -\begin{aligned} -\mathbf{e}_{i\to j}' = \phi_e([\mathbf{x}_i;\, \mathbf{x}_j;\, \mathbf{e}_{i\to j}]),\\ -\mathbf{x}_{i}' = \phi_v([\mathbf{x}_i;\, \square_{j\in \mathcal{N}(i)}\,\mathbf{e}_{j\to i}']). -\end{aligned} -``` - -`aggr` defines the aggregation to be performed. - -If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from -the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu` -activations. - -# Examples - -```julia -g = rand_graph(10, 30) -x = randn(Float32, 3, 10) -e = randn(Float32, 3, 30) -m = MEGNetConv(3 => 3) -x′, e′ = m(g, x, e) -``` -""" -struct MEGNetConv{TE, TV, A} <: GNNLayer - ϕe::TE - ϕv::TV - aggr::A -end - -Flux.@layer :expand MEGNetConv - -MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) - -function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) - nin, nout = ch - ϕe = Chain(Dense(3nin, nout, relu), - Dense(nout, nout)) - - ϕv = Chain(Dense(nin + nout, nout, relu), - Dense(nout, nout)) - - return MEGNetConv(ϕe, ϕv; aggr) -end - -function (l::MEGNetConv)(g::GNNGraph) - x, e = l(g, node_features(g), edge_features(g)) - return GNNGraph(g, ndata = x, edata = e) -end - -(l::MEGNetConv)(g, x, e) = GNNlib.megnet_conv(l, g, x, e) - -@doc raw""" - GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false) - -Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402) -Performs the operation -```math -\mathbf{x}_i' = \mathbf{x}_i + \frac{1}{|N(i)|} \sum_{j\in N(i)}\frac{1}{K}\sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{j\to i}) \odot \Theta_k \mathbf{x}_j -``` -where ``w^a_{k}(e^a)`` for feature `a` and kernel `k` is given by -```math -w^a_{k}(e^a) = \exp(-\frac{1}{2}(e^a - \mu^a_k)^T (\Sigma^{-1})^a_k(e^a - \mu^a_k)) -``` -``\Theta_k, \mu^a_k, (\Sigma^{-1})^a_k`` are learnable parameters. - -The input to the layer is a node feature array `x` of size `(num_features, num_nodes)` and -edge pseudo-coordinate array `e` of size `(num_features, num_edges)` -The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same -as the input size. - -# Arguments - -- `in`: Number of input node features. -- `ein`: Number of input edge features. -- `out`: Number of output features. -- `σ`: Activation function. Default `identity`. -- `K`: Number of kernels. Default `1`. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `residual`: Residual conncetion. Default `false`. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -g = GNNGraph(s,t) -nin, ein, out, K = 4, 10, 7, 8 -x = randn(Float32, nin, g.num_nodes) -e = randn(Float32, ein, g.num_edges) - -# create layer -l = GMMConv((nin, ein) => out, K=K) - -# forward pass -l(g, x, e) -``` -""" -struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer - mu::A - sigma_inv::A - bias::B - σ::F - ch::Pair{NTuple{2, Int}, Int} - K::Int - dense_x::Dense - residual::Bool -end - -Flux.@layer GMMConv - -function GMMConv(ch::Pair{NTuple{2, Int}, Int}, - σ = identity; - K::Int = 1, - bias::Bool = true, - init = Flux.glorot_uniform, - residual = false) - (nin, ein), out = ch - mu = init(ein, K) - sigma_inv = init(ein, K) - b = bias ? Flux.create_bias(mu, true, out) : false - dense_x = Dense(nin, out * K, bias = false) - GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual) -end - -(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e) - -(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -function Base.show(io::IO, l::GMMConv) - (nin, ein), out = l.ch - print(io, "GMMConv((", nin, ",", ein, ")=>", out) - l.σ == identity || print(io, ", σ=", l.dense_s.σ) - print(io, ", K=", l.K) - l.residual == true || print(io, ", residual=", l.residual) - print(io, ")") -end - -@doc raw""" - SGConv(int => out, k=1; [bias, init, add_self_loops, use_edge_weight]) - -SGC layer from [Simplifying Graph Convolutional Networks](https://arxiv.org/pdf/1902.07153.pdf) -Performs operation -```math -H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta -``` -where ``\tilde{A}`` is ``A + I``. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `k` : Number of hops k. Default `1`. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. -- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). - If `add_self_loops=true` the new weights will be set to 1. Default `false`. - -# Examples - -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) - -# create layer -l = SGConv(3 => 5; add_self_loops = true) - -# forward pass -y = l(g, x) # size: 5 × num_nodes - -# convolution with edge weights -w = [1.1, 0.1, 2.3, 0.5] -y = l(g, x, w) - -# Edge weights can also be embedded in the graph. -g = GNNGraph(s, t, w) -l = SGConv(3 => 5, add_self_loops = true, use_edge_weight=true) -y = l(g, x) # same as l(g, x, w) -``` -""" -struct SGConv{A <: AbstractMatrix, B} <: GNNLayer - weight::A - bias::B - k::Int - add_self_loops::Bool - use_edge_weight::Bool -end - -Flux.@layer SGConv - -function SGConv(ch::Pair{Int, Int}, k = 1; - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - use_edge_weight = false) - in, out = ch - W = init(out, in) - b = bias ? Flux.create_bias(W, true, out) : false - return SGConv(W, b, k, add_self_loops, use_edge_weight) -end - -(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight) - -function Base.show(io::IO, l::SGConv) - out, in = size(l.weight) - print(io, "SGConv($in => $out") - l.k == 1 || print(io, ", ", l.k) - print(io, ")") -end - -@doc raw""" - TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false) - -TAGConv layer from [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/pdf/1710.10370.pdf). -This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data. -It performs the operation: - -```math -H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k} -``` - -where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`. - -# Arguments -- `in`: Number of input features. -- `out`: Number of output features. -- `k`: Maximum number of hops to consider. Default is `3`. -- `bias`: Whether to include a learnable bias term. Default is `true`. -- `init`: Initialization function for the weights. Default is `glorot_uniform`. -- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`. -- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`. - -# Examples - -```julia -# Example graph data -s = [1, 1, 2, 3] -t = [2, 3, 1, 1] -g = GNNGraph(s, t) # Create a graph -x = randn(Float32, 3, g.num_nodes) # Random features for each node - -# Create a TAGConv layer -l = TAGConv(3 => 5, k=3; add_self_loops=true) - -# Apply the TAGConv layer -y = l(g, x) # Output size: 5 × num_nodes -``` -""" -struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer - weight::A - bias::B - k::Int - add_self_loops::Bool - use_edge_weight::Bool -end - -Flux.@layer TAGConv - -function TAGConv(ch::Pair{Int, Int}, k = 3; - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - use_edge_weight = false) - in, out = ch - W = init(out, in) - b = bias ? Flux.create_bias(W, true, out) : false - return TAGConv(W, b, k, add_self_loops, use_edge_weight) -end - -(l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight) - -function Base.show(io::IO, l::TAGConv) - out, in = size(l.weight) - print(io, "TAGConv($in => $out") - l.k == 1 || print(io, ", ", l.k) - print(io, ")") -end - -@doc raw""" - EGNNConv((in, ein) => out; hidden_size=2in, residual=false) - EGNNConv(in => out; hidden_size=2in, residual=false) - -Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph -Neural Networks](https://arxiv.org/abs/2102.09844). - -The layer performs the following operation: - -```math -\begin{aligned} -\mathbf{m}_{j\to i} &=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\ -\mathbf{x}_i' &= \mathbf{x}_i + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\ -\mathbf{m}_i &= C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\ -\mathbf{h}_i' &= \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i) -\end{aligned} -``` -where ``\mathbf{h}_i``, ``\mathbf{x}_i``, ``\mathbf{e}_{j\to i}`` are invariant node features, equivariant node -features, and edge features respectively. ``\phi_e``, ``\phi_h``, and -``\phi_x`` are two-layer MLPs. `C` is a constant for normalization, -computed as ``1/|\mathcal{N}(i)|``. - - -# Constructor Arguments - -- `in`: Number of input features for `h`. -- `out`: Number of output features for `h`. -- `ein`: Number of input edge features. -- `hidden_size`: Hidden representation size. -- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`. - -# Forward Pass - - l(g, x, h, e=nothing) - -## Forward Pass Arguments: - -- `g` : The graph. -- `x` : Matrix of equivariant node coordinates. -- `h` : Matrix of invariant node features. -- `e` : Matrix of invariant edge features. Default `nothing`. - -Returns updated `h` and `x`. - -# Examples - -```julia -g = rand_graph(10, 10) -h = randn(Float32, 5, g.num_nodes) -x = randn(Float32, 3, g.num_nodes) -egnn = EGNNConv(5 => 6, 10) -hnew, xnew = egnn(g, h, x) -``` -""" -struct EGNNConv{TE, TX, TH, NF} <: GNNLayer - ϕe::TE - ϕx::TX - ϕh::TH - num_features::NF - residual::Bool -end - -Flux.@layer EGNNConv - -function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) - return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) -end - -#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py -function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], - residual = false) - (in_size, edge_feat_size), out_size = ch - act_fn = swish - - # +1 for the radial feature: ||x_i - x_j||^2 - ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), - Dense(hidden_size => hidden_size, act_fn)) - - ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish), - Dense(hidden_size, out_size)) - - ϕx = Chain(Dense(hidden_size, hidden_size, swish), - Dense(hidden_size, 1, bias = false)) - - num_features = (in = in_size, edge = edge_feat_size, out = out_size, - hidden = hidden_size) - if residual - @assert in_size==out_size "Residual connection only possible if in_size == out_size" - end - return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) -end - -(l::EGNNConv)(g, h, x, e = nothing) = GNNlib.egnn_conv(l, g, h, x, e) - -function Base.show(io::IO, l::EGNNConv) - ne = l.num_features.edge - nin = l.num_features.in - nout = l.num_features.out - nh = l.num_features.hidden - print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") - if l.residual - print(io, ", residual=true") - end - print(io, ")") -end - -@doc raw""" - TransformerConv((in, ein) => out; [heads, concat, init, add_self_loops, bias_qkv, - bias_root, root_weight, gating, skip_connection, batch_norm, ff_channels])) - -The transformer-like multi head attention convolutional operator from the -[Masked Label Prediction: Unified Message Passing Model for Semi-Supervised -Classification](https://arxiv.org/abs/2009.03509) paper, which also considers -edge features. -It further contains options to also be configured as the transformer-like convolutional operator from the -[Attention, Learn to Solve Routing Problems!](https://arxiv.org/abs/1706.03762) paper, -including a successive feed-forward network as well as skip layers and batch normalization. - -The layer's basic forward pass is given by -```math -x_i' = W_1x_i + \sum_{j\in N(i)} \alpha_{ij} (W_2 x_j + W_6e_{ij}) -``` -where the attention scores are -```math -\alpha_{ij} = \mathrm{softmax}\left(\frac{(W_3x_i)^T(W_4x_j+ -W_6e_{ij})}{\sqrt{d}}\right). -``` - -Optionally, a combination of the aggregated value with transformed root node features -by a gating mechanism via -```math -x'_i = \beta_i W_1 x_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} -\alpha_{i,j} W_2 x_j \right)}_{=m_i} -``` -with -```math -\beta_i = \textrm{sigmoid}(W_5^{\top} [ W_1 x_i, m_i, W_1 x_i - m_i ]). -``` -can be performed. - -# Arguments - -- `in`: Dimension of input features, which also corresponds to the dimension of - the output features. -- `ein`: Dimension of the edge features; if 0, no edge features will be used. -- `out`: Dimension of the output. -- `heads`: Number of heads in output. Default `1`. -- `concat`: Concatenate layer output or not. If not, layer output is averaged - over the heads. Default `true`. -- `init`: Weight matrices' initializing function. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the input graph. Default `false`. -- `bias_qkv`: If set, bias is used in the key, query and value transformations for nodes. - Default `true`. -- `bias_root`: If set, the layer will also learn an additive bias for the root when root - weight is used. Default `true`. -- `root_weight`: If set, the layer will add the transformed root node features - to the output. Default `true`. -- `gating`: If set, will combine aggregation and transformed root node features by a - gating mechanism. Default `false`. -- `skip_connection`: If set, a skip connection will be made from the input and - added to the output. Default `false`. -- `batch_norm`: If set, a batch normalization will be applied to the output. Default `false`. -- `ff_channels`: If positive, a feed-forward NN is appended, with the first having the given - number of hidden nodes; this NN also gets a skip connection and batch normalization - if the respective parameters are set. Default: `0`. - -# Examples - -```julia -N, in_channel, out_channel = 4, 3, 5 -ein, heads = 2, 3 -g = GNNGraph([1,1,2,4], [2,3,1,1]) -l = TransformerConv((in_channel, ein) => in_channel; heads, gating = true, bias_qkv = true) -x = rand(Float32, in_channel, N) -e = rand(Float32, ein, g.num_edges) -l(g, x, e) -``` -""" -struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLayer - W1::TW1 - W2::TW2 - W3::TW3 - W4::TW4 - W5::TW5 - W6::TW6 - FF::TFF - BN1::TBN1 - BN2::TBN2 - channels::Pair{NTuple{2, Int}, Int} - heads::Int - add_self_loops::Bool - concat::Bool - skip_connection::Bool - sqrt_out::Float32 -end - -Flux.@layer TransformerConv - -function Flux.trainable(l::TransformerConv) - (; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) -end - -function TransformerConv(ch::Pair{Int, Int}, args...; kws...) - TransformerConv((ch[1], 0) => ch[2], args...; kws...) -end - -function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; - heads::Int = 1, - concat::Bool = true, - init = glorot_uniform, - add_self_loops::Bool = false, - bias_qkv = true, - bias_root::Bool = true, - root_weight::Bool = true, - gating::Bool = false, - skip_connection::Bool = false, - batch_norm::Bool = false, - ff_channels::Int = 0) - (in, ein), out = ch - - if add_self_loops - @assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported." - end - - W1 = root_weight ? - Dense(in, out * (concat ? heads : 1); bias = bias_root, init = init) : nothing - W2 = Dense(in => out * heads; bias = bias_qkv, init = init) - W3 = Dense(in => out * heads; bias = bias_qkv, init = init) - W4 = Dense(in => out * heads; bias = bias_qkv, init = init) - out_mha = out * (concat ? heads : 1) - W5 = gating ? Dense(3 * out_mha => 1, sigmoid; bias = false, init = init) : nothing - W6 = ein > 0 ? Dense(ein => out * heads; bias = bias_qkv, init = init) : nothing - FF = ff_channels > 0 ? - Chain(Dense(out_mha => ff_channels, relu), - Dense(ff_channels => out_mha)) : nothing - BN1 = batch_norm ? BatchNorm(out_mha) : nothing - BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing - - return TransformerConv(W1, W2, W3, W4, W5, W6, FF, BN1, BN2, - ch, heads, add_self_loops, concat, skip_connection, - Float32(√out)) -end - -(l::TransformerConv)(g, x, e = nothing) = GNNlib.transformer_conv(l, g, x, e) - -function (l::TransformerConv)(g::GNNGraph) - GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) -end - -function Base.show(io::IO, l::TransformerConv) - (in, ein), out = l.channels - print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") -end - -""" - DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) - -Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). - -# Arguments - -- `ch`: Pair of input and output dimensions. -- `k`: Number of diffusion steps. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `bias`: Add learnable bias. Default `true`. - -# Examples -``` -julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10)); - -julia> dconv = DConv(2 => 4, 4) -DConv(2 => 4, 4) - -julia> y = dconv(g, g.ndata.x); - -julia> size(y) -(4, 10) -``` -""" -struct DConv <: GNNLayer - in::Int - out::Int - weights::AbstractArray - bias::AbstractArray - k::Int -end - -Flux.@layer DConv - -function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) - in, out = ch - weights = init(2, k, out, in) - b = bias ? Flux.create_bias(weights, true, out) : false - return DConv(in, out, weights, b, k) -end - -(l::DConv)(g, x) = GNNlib.d_conv(l, g, x) - -function Base.show(io::IO, l::DConv) - print(io, "DConv($(l.in) => $(l.out), $(l.k))") -end - -[.\src\layers\heteroconv.jl] -@doc raw""" - HeteroGraphConv(itr; aggr = +) - HeteroGraphConv(pairs...; aggr = +) - -A convolutional layer for heterogeneous graphs. - -The `itr` argument is an iterator of `pairs` of the form `edge_t => layer`, where `edge_t` is a -3-tuple of the form `(src_node_type, edge_type, dst_node_type)`, and `layer` is a -convolutional layers for homogeneous graphs. - -Each convolution is applied to the corresponding relation. -Since a node type can be involved in multiple relations, the single convolution outputs -have to be aggregated using the `aggr` function. The default is to sum the outputs. - -# Forward Arguments - -* `g::GNNHeteroGraph`: The input graph. -* `x::Union{NamedTuple,Dict}`: The input node features. The keys are node types and the - values are node feature tensors. - -# Examples - -```jldoctest -julia> g = rand_bipartite_heterograph((10, 15), 20) -GNNHeteroGraph: - num_nodes: Dict(:A => 10, :B => 15) - num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) - -julia> x = (A = rand(Float32, 64, 10), B = rand(Float32, 64, 15)); - -julia> layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu), - (:B, :to, :A) => GraphConv(64 => 32, relu)); - -julia> y = layer(g, x); # output is a named tuple - -julia> size(y.A) == (32, 10) && size(y.B) == (32, 15) -true -``` -""" -struct HeteroGraphConv - etypes::Vector{EType} - layers::Vector{<:GNNLayer} - aggr::Function -end - -Flux.@layer HeteroGraphConv - -HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr) -HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr) - -function HeteroGraphConv(itr; aggr = +) - etypes = [k[1] for k in itr] - layers = [k[2] for k in itr] - return HeteroGraphConv(etypes, layers, aggr) -end - -function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict}) - function forw(l, et) - sg = edge_type_subgraph(g, et) - node1_t, _, node2_t = et - return l(sg, (x[node1_t], x[node2_t])) - end - outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)] - dst_ntypes = [et[3] for et in hgc.etypes] - return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) -end - -function _reduceby_node_t(aggr, outs, ntypes) - function _reduce(node_t) - idxs = findall(x -> x == node_t, ntypes) - if length(idxs) == 0 - return nothing - elseif length(idxs) == 1 - return outs[idxs[1]] - else - return foldl(aggr, outs[i] for i in idxs) - end - end - # workaround to provide the aggregation once per unique node type, - # gradient is not needed - unique_ntypes = ChainRulesCore.ignore_derivatives() do - unique(ntypes) - end - vals = [_reduce(node_t) for node_t in unique_ntypes] - return NamedTuple{tuple(unique_ntypes...)}(vals) -end - -function Base.show(io::IO, hgc::HeteroGraphConv) - if get(io, :compact, false) - print(io, "HeteroGraphConv(aggr=$(hgc.aggr))") - else - println(io, "HeteroGraphConv(aggr=$(hgc.aggr)):") - for (i, (et,layer)) in enumerate(zip(hgc.etypes, hgc.layers)) - print(io, " $(et => layer)") - if i < length(hgc.etypes) - print(io, "\n") - end - end - end -end - -[.\src\layers\pool.jl] -@doc raw""" - GlobalPool(aggr) - -Global pooling layer for graph neural networks. -Takes a graph and feature nodes as inputs -and performs the operation - -```math -\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i -``` - -where ``V`` is the set of nodes of the input graph and -the type of aggregation represented by ``\square`` is selected by the `aggr` argument. -Commonly used aggregations are `mean`, `max`, and `+`. - -See also [`reduce_nodes`](@ref). - -# Examples - -```julia -using Flux, GraphNeuralNetworks, Graphs - -pool = GlobalPool(mean) - -g = GNNGraph(erdos_renyi(10, 4)) -X = rand(32, 10) -pool(g, X) # => 32x1 matrix - - -g = Flux.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5]) -X = rand(32, 50) -pool(g, X) # => 32x5 matrix -``` -""" -struct GlobalPool{F} <: GNNLayer - aggr::F -end - -(l::GlobalPool)(g::GNNGraph, x::AbstractArray) = GNNlib.global_pool(l, g, x) - -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) - -@doc raw""" - GlobalAttentionPool(fgate, ffeat=identity) - -Global soft attention layer from the [Gated Graph Sequence Neural -Networks](https://arxiv.org/abs/1511.05493) paper - -```math -\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) -``` - -where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref) -operation: - -```math -\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} - {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}. -``` - -# Arguments - -- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. - It is tipically expressed by a neural network. - -- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. - It is tipically expressed by a neural network. - -# Examples - -```julia -chin = 6 -chout = 5 - -fgate = Dense(chin, 1) -ffeat = Dense(chin, chout) -pool = GlobalAttentionPool(fgate, ffeat) - -g = Flux.batch([GNNGraph(random_regular_graph(10, 4), - ndata=rand(Float32, chin, 10)) - for i=1:3]) - -u = pool(g, g.ndata.x) - -@assert size(u) == (chout, g.num_graphs) -``` -""" -struct GlobalAttentionPool{G, F} - fgate::G - ffeat::F -end - -Flux.@layer GlobalAttentionPool - -GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) - -(l::GlobalAttentionPool)(g, x) = GNNlib.global_attention_pool(l, g, x) - -(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) - -""" - TopKPool(adj, k, in_channel) - -Top-k pooling layer. - -# Arguments - -- `adj`: Adjacency matrix of a graph. -- `k`: Top-k nodes are selected to pool together. -- `in_channel`: The dimension of input channel. -""" -struct TopKPool{T, S} - A::AbstractMatrix{T} - k::Int - p::AbstractVector{S} - Ã::AbstractMatrix{T} -end - -function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform) - TopKPool(adj, k, init(in_channel), similar(adj, k, k)) -end - -(t::TopKPool)(x::AbstractArray) = topk_pool(t, x) - - -@doc raw""" - Set2Set(n_in, n_iters, n_layers = 1) - -Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391). - -For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times: -```math -\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*) -\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)} -\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i -\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}] -``` -where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, -input size `2*n_in` and output size `n_in`. - -Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`. -``` -""" -struct Set2Set{L} <: GNNLayer - lstm::L - num_iters::Int -end - -Flux.@layer Set2Set - -function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) - @assert n_layers >= 1 - n_out = 2 * n_in - - if n_layers == 1 - lstm = LSTM(n_out => n_in) - else - layers = [LSTM(n_out => n_in)] - for _ in 2:n_layers - push!(layers, LSTM(n_in => n_in)) - end - lstm = Chain(layers...) - end - - return Set2Set(lstm, n_iters) -end - -function (l::Set2Set)(g, x) - Flux.reset!(l.lstm) - return GNNlib.set2set_pool(l, g, x) -end - -(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) - -[.\src\layers\temporalconv.jl] -# Adapting Flux.Recur to work with GNNGraphs -function (m::Flux.Recur)(g::GNNGraph, x) - m.state, y = m.cell(m.state, g, x) - return y -end - -function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T - h = [m(g, x_t) for x_t in Flux.eachlastdim(x)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], sze[2], length(h)) -end - -struct TGCNCell <: GNNLayer - conv::GCNConv - gru::Flux.GRUv3Cell - state0 - in::Int - out::Int -end - -Flux.@layer TGCNCell - -function TGCNCell(ch::Pair{Int, Int}; - bias::Bool = true, - init = Flux.glorot_uniform, - init_state = Flux.zeros32, - add_self_loops = false, - use_edge_weight = true) - in, out = ch - conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops, - use_edge_weight) - gru = Flux.GRUv3Cell(out, out) - state0 = init_state(out,1) - return TGCNCell(conv, gru, state0, in,out) -end - -function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) - x̃ = tgcn.conv(g, x) - h, x̃ = tgcn.gru(h, x̃) - return h, x̃ -end - -function Base.show(io::IO, tgcn::TGCNCell) - print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))") -end - -""" - TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight]) - -Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). - -Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. -- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). - If `add_self_loops=true` the new weights will be set to 1. - This option is ignored if the `edge_weight` is explicitly provided in the forward pass. - Default `false`. -# Examples - -```jldoctest -julia> tgcn = TGCN(2 => 6) -Recur( - TGCNCell( - GCNConv(2 => 6, σ), # 18 parameters - GRUv3Cell(6 => 6), # 240 parameters - Float32[0.0; 0.0; … ; 0.0; 0.0;;], # 6 parameters (all zero) - 2, - 6, - ), -) # Total: 8 trainable arrays, 264 parameters, - # plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB. - -julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> y = tgcn(g, x); - -julia> size(y) -(6, 5) - -julia> Flux.reset!(tgcn); - -julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20 -(6, 5, 20) -``` - -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. -""" -TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...)) - -Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0) - -# make TGCN compatible with GNNChain -(l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x) -_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g) - - -""" - A3TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight]) - -Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph -Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). - -Performs a TGCN layer, followed by a soft attention layer. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. -- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). - If `add_self_loops=true` the new weights will be set to 1. - This option is ignored if the `edge_weight` is explicitly provided in the forward pass. - Default `false`. -# Examples - -```jldoctest -julia> a3tgcn = A3TGCN(2 => 6) -A3TGCN(2 => 6) - -julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> y = a3tgcn(g,x); - -julia> size(y) -(6, 5) - -julia> Flux.reset!(a3tgcn); - -julia> y = a3tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)); - -julia> size(y) -(6, 5) -``` - -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. -""" -struct A3TGCN <: GNNLayer - tgcn::Flux.Recur{TGCNCell} - dense1::Dense - dense2::Dense - in::Int - out::Int -end - -Flux.@layer A3TGCN - -function A3TGCN(ch::Pair{Int, Int}, - bias::Bool = true, - init = Flux.glorot_uniform, - init_state = Flux.zeros32, - add_self_loops = false, - use_edge_weight = true) - in, out = ch - tgcn = TGCN(in => out; bias, init, init_state, add_self_loops, use_edge_weight) - dense1 = Dense(out, out) - dense2 = Dense(out, out) - return A3TGCN(tgcn, dense1, dense2, in, out) -end - -function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray) - h = a3tgcn.tgcn(g, x) - e = a3tgcn.dense1(h) - e = a3tgcn.dense2(e) - a = softmax(e, dims = 3) - c = sum(a .* h , dims = 3) - if length(size(c)) == 3 - c = dropdims(c, dims = 3) - end - return c -end - -function Base.show(io::IO, a3tgcn::A3TGCN) - print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") -end - -struct GConvGRUCell <: GNNLayer - conv_x_r::ChebConv - conv_h_r::ChebConv - conv_x_z::ChebConv - conv_h_z::ChebConv - conv_x_h::ChebConv - conv_h_h::ChebConv - k::Int - state0 - in::Int - out::Int -end - -Flux.@layer GConvGRUCell - -function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; - bias::Bool = true, - init = Flux.glorot_uniform, - init_state = Flux.zeros32) - in, out = ch - # reset gate - conv_x_r = ChebConv(in => out, k; bias, init) - conv_h_r = ChebConv(out => out, k; bias, init) - # update gate - conv_x_z = ChebConv(in => out, k; bias, init) - conv_h_z = ChebConv(out => out, k; bias, init) - # new gate - conv_x_h = ChebConv(in => out, k; bias, init) - conv_h_h = ChebConv(out => out, k; bias, init) - state0 = init_state(out, n) - return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out) -end - -function (ggru::GConvGRUCell)(h, g::GNNGraph, x) - r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) - r = Flux.sigmoid_fast(r) - z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) - z = Flux.sigmoid_fast(z) - h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) - h̃ = Flux.tanh_fast(h̃) - h = (1 .- z) .* h̃ .+ z .* h - return h, h -end - -function Base.show(io::IO, ggru::GConvGRUCell) - print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") -end - -""" - GConvGRU(in => out, k, n; [bias, init, init_state]) - -Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `k`: Chebyshev polynomial order. -- `n`: Number of nodes in the graph. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. - -# Examples - -```jldoctest -julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); - -julia> y = ggru(g1, x1); - -julia> size(y) -(5, 5) - -julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -julia> z = ggru(g2, x2); - -julia> size(z) -(5, 5, 30) -``` -""" -GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...)) -Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) - -(l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) -_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) - -struct GConvLSTMCell <: GNNLayer - conv_x_i::ChebConv - conv_h_i::ChebConv - w_i - b_i - conv_x_f::ChebConv - conv_h_f::ChebConv - w_f - b_f - conv_x_c::ChebConv - conv_h_c::ChebConv - w_c - b_c - conv_x_o::ChebConv - conv_h_o::ChebConv - w_o - b_o - k::Int - state0 - in::Int - out::Int -end - -Flux.@layer GConvLSTMCell - -function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; - bias::Bool = true, - init = Flux.glorot_uniform, - init_state = Flux.zeros32) - in, out = ch - # input gate - conv_x_i = ChebConv(in => out, k; bias, init) - conv_h_i = ChebConv(out => out, k; bias, init) - w_i = init(out, 1) - b_i = bias ? Flux.create_bias(w_i, true, out) : false - # forget gate - conv_x_f = ChebConv(in => out, k; bias, init) - conv_h_f = ChebConv(out => out, k; bias, init) - w_f = init(out, 1) - b_f = bias ? Flux.create_bias(w_f, true, out) : false - # cell state - conv_x_c = ChebConv(in => out, k; bias, init) - conv_h_c = ChebConv(out => out, k; bias, init) - w_c = init(out, 1) - b_c = bias ? Flux.create_bias(w_c, true, out) : false - # output gate - conv_x_o = ChebConv(in => out, k; bias, init) - conv_h_o = ChebConv(out => out, k; bias, init) - w_o = init(out, 1) - b_o = bias ? Flux.create_bias(w_o, true, out) : false - state0 = (init_state(out, n), init_state(out, n)) - return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, - conv_x_f, conv_h_f, w_f, b_f, - conv_x_c, conv_h_c, w_c, b_c, - conv_x_o, conv_h_o, w_o, b_o, - k, state0, in, out) -end - -function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) - # input gate - i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i - i = Flux.sigmoid_fast(i) - # forget gate - f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f - f = Flux.sigmoid_fast(f) - # cell state - c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) - # output gate - o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o - o = Flux.sigmoid_fast(o) - h = o .* Flux.tanh_fast(c) - return (h,c), h -end - -function Base.show(io::IO, gclstm::GConvLSTMCell) - print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") -end - -""" - GConvLSTM(in => out, k, n; [bias, init, init_state]) - -Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `k`: Chebyshev polynomial order. -- `n`: Number of nodes in the graph. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# Examples - -```jldoctest -julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); - -julia> y = gclstm(g1, x1); - -julia> size(y) -(5, 5) - -julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -julia> z = gclstm(g2, x2); - -julia> size(z) -(5, 5, 30) -``` -""" -GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) -Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) - -(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) -_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) - -struct DCGRUCell - in::Int - out::Int - state0 - k::Int - dconv_u::DConv - dconv_r::DConv - dconv_c::DConv -end - -Flux.@layer DCGRUCell - -function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - in, out = ch - dconv_u = DConv((in + out) => out, k; bias=bias, init=init) - dconv_r = DConv((in + out) => out, k; bias=bias, init=init) - dconv_c = DConv((in + out) => out, k; bias=bias, init=init) - state0 = init_state(out, n) - return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) -end - -function (dcgru::DCGRUCell)(h, g::GNNGraph, x) - h̃ = vcat(x, h) - z = dcgru.dconv_u(g, h̃) - z = NNlib.sigmoid_fast.(z) - r = dcgru.dconv_r(g, h̃) - r = NNlib.sigmoid_fast.(r) - ĥ = vcat(x, h .* r) - c = dcgru.dconv_c(g, ĥ) - c = tanh.(c) - h = z.* h + (1 .- z) .* c - return h, h -end - -function Base.show(io::IO, dcgru::DCGRUCell) - print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") -end - -""" - DCGRU(in => out, k, n; [bias, init, init_state]) - -Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural -Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). - -Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `k`: Diffusion step. -- `n`: Number of nodes in the graph. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# Examples - -```jldoctest -julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); - -julia> y = dcgru(g1, x1); - -julia> size(y) -(5, 5) - -julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -julia> z = dcgru(g2, x2); - -julia> size(z) -(5, 5, 30) -``` -""" -DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) -Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) - -(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) -_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) - -function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::ChebConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::GATConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::GATv2Conv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::GatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::CGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::SGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::TransformerConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::GCNConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::ResGatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::SAGEConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -function (l::GraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) - return l.(tg.snapshots, x) -end - -[.\test\runtests.jl] -using CUDA -using GraphNeuralNetworks -using GNNGraphs: sort_edge_index -using GNNGraphs: getn, getdata -using Functors -using Flux -using Flux: gpu -using LinearAlgebra, Statistics, Random -using NNlib -import MLUtils -using SparseArrays -using Graphs -using Zygote -using Test -using MLDatasets -using InlineStrings # not used but with the import we test #98 and #104 - -CUDA.allowscalar(false) - -const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}} - -ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets - -include("test_utils.jl") - -tests = [ - "layers/basic", - "layers/conv", - "layers/heteroconv", - "layers/temporalconv", - "layers/pool", - "examples/node_classification_cora", -] - -!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") - -# @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) -for graph_type in (:coo, :dense, :sparse) - - @info "Testing graph format :$graph_type" - global GRAPH_T = graph_type - global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) - - @testset "$t" for t in tests - startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI - include("$t.jl") - end -end - -[.\test\test_utils.jl] -using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA -CUDA.allowscalar(false) - -function ngradient(f, x...) - fdm = central_fdm(5, 1) - return FiniteDifferences.grad(fdm, f, x...) -end - -const rule_config = Zygote.ZygoteRuleConfig() - -# Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed -function FiniteDifferences.to_vec(x::Integer) - Integer_from_vec(v) = x - return Int[x], Integer_from_vec -end - -# Test that forward pass on cpu and gpu are the same. -# Tests also gradient on cpu and gpu comparing with -# finite difference methods. -# Test gradients with respects to layer weights and to input. -# If `g` has edge features, it is assumed that the layer can -# use them in the forward pass as `l(g, x, e)`. -# Test also gradient with respect to `e`. -function test_layer(l, g::GNNGraph; atol = 1e-5, rtol = 1e-5, - exclude_grad_fields = [], - verbose = false, - test_gpu = TEST_GPU, - outsize = nothing, - outtype = :node) - - # TODO these give errors, probably some bugs in ChainRulesTestUtils - # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false) - # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false) - - isnothing(node_features(g)) && error("Plese add node data to the input graph") - fdm = central_fdm(5, 1) - - x = node_features(g) - e = edge_features(g) - use_edge_feat = !isnothing(e) - - x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad - xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g]) - - f(l, g::GNNGraph) = l(g) - f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x) - - loss(l, g::GNNGraph) = - if outtype == :node - sum(node_features(f(l, g))) - elseif outtype == :edge - sum(edge_features(f(l, g))) - elseif outtype == :graph - sum(graph_features(f(l, g))) - elseif outtype == :node_edge - gnew = f(l, g) - sum(node_features(gnew)) + sum(edge_features(gnew)) - end - - function loss(l, g::GNNGraph, x, e) - y = f(l, g, x, e) - if outtype == :node_edge - return sum(y[1]) + sum(y[2]) - else - return sum(y) - end - end - - # TEST OUTPUT - y = f(l, g, x, e) - if outtype == :node_edge - @assert y isa Tuple - @test eltype(y[1]) == eltype(x) - @test eltype(y[2]) == eltype(e) - @test all(isfinite, y[1]) - @test all(isfinite, y[2]) - if !isnothing(outsize) - @test size(y[1]) == outsize[1] - @test size(y[2]) == outsize[2] - end - else - @test eltype(y) == eltype(x) - @test all(isfinite, y) - if !isnothing(outsize) - @test size(y) == outsize - end - end - - # test same output on different graph formats - gcoo = GNNGraph(g, graph_type = :coo) - ycoo = f(l, gcoo, x, e) - if outtype == :node_edge - @test ycoo[1] ≈ y[1] - @test ycoo[2] ≈ y[2] - else - @test ycoo ≈ y - end - - g′ = f(l, g) - if outtype == :node - @test g′.ndata.x ≈ y - elseif outtype == :edge - @test g′.edata.e ≈ y - elseif outtype == :graph - @test g′.gdata.u ≈ y - elseif outtype == :node_edge - @test g′.ndata.x ≈ y[1] - @test g′.edata.e ≈ y[2] - else - @error "wrong outtype $outtype" - end - if test_gpu - ygpu = f(lgpu, ggpu, xgpu, egpu) - if outtype == :node_edge - @test ygpu[1] isa CuArray - @test eltype(ygpu[1]) == eltype(xgpu) - @test Array(ygpu[1]) ≈ y[1] - @test ygpu[2] isa CuArray - @test eltype(ygpu[2]) == eltype(xgpu) - @test Array(ygpu[2]) ≈ y[2] - else - @test ygpu isa CuArray - @test eltype(ygpu) == eltype(xgpu) - @test Array(ygpu) ≈ y - end - end - - # TEST x INPUT GRADIENT - x̄ = gradient(x -> loss(l, g, x, e), x)[1] - x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1] - @test eltype(x̄) == eltype(x) - @test x̄≈x̄_fd atol=atol rtol=rtol - - if test_gpu - x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1] - @test x̄gpu isa CuArray - @test eltype(x̄gpu) == eltype(x) - @test Array(x̄gpu)≈x̄ atol=atol rtol=rtol - end - - # TEST e INPUT GRADIENT - if e !== nothing - verbose && println("Test e gradient cpu") - ē = gradient(e -> loss(l, g, x, e), e)[1] - ē_fd = FiniteDifferences.grad(fdm, e64 -> loss(l64, g64, x64, e64), e64)[1] - @test eltype(ē) == eltype(e) - @test ē≈ē_fd atol=atol rtol=rtol - - if test_gpu - verbose && println("Test e gradient gpu") - ēgpu = gradient(egpu -> loss(lgpu, ggpu, xgpu, egpu), egpu)[1] - @test ēgpu isa CuArray - @test eltype(ēgpu) == eltype(ē) - @test Array(ēgpu)≈ē atol=atol rtol=rtol - end - end - - # TEST LAYER GRADIENT - l(g, x, e) - l̄ = gradient(l -> loss(l, g, x, e), l)[1] - l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1] - test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) - - if test_gpu - l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1] - test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose) - end - - # TEST LAYER GRADIENT - l(g) - l̄ = gradient(l -> loss(l, g), l)[1] - test_approx_structs(l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose) - - return true -end - -function test_approx_structs(l, l̄, l̄fd; atol = 1e-5, rtol = 1e-5, - exclude_grad_fields = [], - verbose = false) - l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue - l̄fd = l̄fd isa Base.RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue - - for f in fieldnames(typeof(l)) - f ∈ exclude_grad_fields && continue - verbose && println("Test gradient of field $f...") - x, g, gfd = getfield(l, f), getfield(l̄, f), getfield(l̄fd, f) - test_approx_structs(x, g, gfd; atol, rtol, exclude_grad_fields, verbose) - verbose && println("... field $f done!") - end - return true -end - -function test_approx_structs(x, g::Nothing, gfd; atol, rtol, kws...) - # finite diff gradients has to be zero if present - @test !(gfd isa AbstractArray) || isapprox(gfd, fill!(similar(gfd), 0); atol, rtol) -end - -function test_approx_structs(x::Union{AbstractArray, Number}, - g::Union{AbstractArray, Number}, gfd; atol, rtol, kws...) - @test eltype(g) == eltype(x) - if x isa CuArray - @test g isa CuArray - g = Array(g) - end - @test g≈gfd atol=atol rtol=rtol -end - -""" - to32(m) - -Convert the `eltype` of model's float parameters to `Float32`. -Preserves integer arrays. -""" -to32(m) = _paramtype(Float32, m) - -""" - to64(m) - -Convert the `eltype` of model's float parameters to `Float64`. -Preserves integer arrays. -""" -to64(m) = _paramtype(Float64, m) - -struct GNNEltypeAdaptor{T} end - -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where T = convert(AbstractArray{T}, x) -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Integer}) where T = x -Adapt.adapt_storage(::GNNEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x) - -_paramtype(::Type{T}, m) where T = fmap(adapt(GNNEltypeAdaptor{T}()), m) - -[.\test\examples\node_classification_cora.jl] -using Flux -using Flux: onecold, onehotbatch -using Flux.Losses: logitcrossentropy -using GraphNeuralNetworks -using MLDatasets: Cora -using Statistics, Random -using CUDA -CUDA.allowscalar(false) - -function eval_loss_accuracy(X, y, ids, model, g) - ŷ = model(g, X) - l = logitcrossentropy(ŷ[:, ids], y[:, ids]) - acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) - return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) -end - -# arguments for the `train` function -Base.@kwdef mutable struct Args - η = 5.0f-3 # learning rate - epochs = 10 # number of epochs - seed = 17 # set seed > 0 for reproducibility - usecuda = false # if true use cuda (if available) - nhidden = 64 # dimension of hidden features -end - -function train(Layer; verbose = false, kws...) - args = Args(; kws...) - args.seed > 0 && Random.seed!(args.seed) - - if args.usecuda && CUDA.functional() - device = Flux.gpu - args.seed > 0 && CUDA.seed!(args.seed) - else - device = Flux.cpu - end - - # LOAD DATA - dataset = Cora() - classes = dataset.metadata["classes"] - g = mldataset2gnngraph(dataset) |> device - X = g.ndata.features - y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged - train_mask = g.ndata.train_mask - test_mask = g.ndata.test_mask - ytrain = y[:, train_mask] - - nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) - - ## DEFINE MODEL - model = GNNChain(Layer(nin, nhidden), - # Dropout(0.5), - Layer(nhidden, nhidden), - Dense(nhidden, nout)) |> device - - opt = Flux.setup(Adam(args.η), model) - - ## TRAINING - function report(epoch) - train = eval_loss_accuracy(X, y, train_mask, model, g) - test = eval_loss_accuracy(X, y, test_mask, model, g) - println("Epoch: $epoch Train: $(train) Test: $(test)") - end - - verbose && report(0) - @time for epoch in 1:(args.epochs) - grad = Flux.gradient(model) do model - ŷ = model(g, X) - logitcrossentropy(ŷ[:, train_mask], ytrain) - end - Flux.update!(opt, model, grad[1]) - verbose && report(epoch) - end - - train_res = eval_loss_accuracy(X, y, train_mask, model, g) - test_res = eval_loss_accuracy(X, y, test_mask, model, g) - return train_res, test_res -end - -function train_many(; usecuda = false) - for (layer, Layer) in [ - ("GCNConv", (nin, nout) -> GCNConv(nin => nout, relu)), - ("ResGatedGraphConv", (nin, nout) -> ResGatedGraphConv(nin => nout, relu)), - ("GraphConv", (nin, nout) -> GraphConv(nin => nout, relu, aggr = mean)), - ("SAGEConv", (nin, nout) -> SAGEConv(nin => nout, relu)), - ("GATConv", (nin, nout) -> GATConv(nin => nout, relu)), - ("GINConv", (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr = mean)), - ("TransformerConv", - (nin, nout) -> TransformerConv(nin => nout, concat = false, - add_self_loops = true, root_weight = false, - heads = 2)), - ## ("ChebConv", (nin, nout) -> ChebConv(nin => nout, 2)), # not working on gpu - ## ("NNConv", (nin, nout) -> NNConv(nin => nout)), # needs edge features - ## ("GatedGraphConv", (nin, nout) -> GatedGraphConv(nout, 2)), # needs nin = nout - ## ("EdgeConv",(nin, nout) -> EdgeConv(Dense(2nin, nout, relu))), # Fits the training set but does not generalize well - ] - @show layer - @time train_res, test_res = train(Layer; usecuda, verbose = false) - # @show train_res, test_res - @test train_res.acc > 94 - @test test_res.acc > 69 - end -end - -train_many(usecuda = false) -if TEST_GPU - train_many(usecuda = true) -end - -[.\test\layers\basic.jl] -@testset "GNNChain" begin - n, din, d, dout = 10, 3, 4, 2 - deg = 4 - - g = GNNGraph(random_regular_graph(n, deg), - graph_type = GRAPH_T, - ndata = randn(Float32, din, n)) - x = g.ndata.x - - gnn = GNNChain(GCNConv(din => d), - LayerNorm(d), - x -> tanh.(x), - GraphConv(d => d, tanh), - Dropout(0.5), - Dense(d, dout)) - - testmode!(gnn) - - test_layer(gnn, g, rtol = 1e-5, exclude_grad_fields = [:μ, :σ²]) - - @testset "constructor with names" begin - m = GNNChain(GCNConv(din => d), - LayerNorm(d), - x -> tanh.(x), - Dense(d, dout)) - - m2 = GNNChain(enc = m, - dec = DotDecoder()) - - @test m2[:enc] === m - @test m2(g, x) == m2[:dec](g, m2[:enc](g, x)) - end - - @testset "constructor with vector" begin - m = GNNChain(GCNConv(din => d), - LayerNorm(d), - x -> tanh.(x), - Dense(d, dout)) - m2 = GNNChain([m.layers...]) - @test m2(g, x) == m(g, x) - end - - @testset "Parallel" begin - AddResidual(l) = Parallel(+, identity, l) - - gnn = GNNChain(GraphConv(din => d, tanh), - LayerNorm(d), - AddResidual(GraphConv(d => d, tanh)), - BatchNorm(d), - Dense(d, dout)) - - trainmode!(gnn) - - test_layer(gnn, g, rtol = 1e-4, atol=1e-4, exclude_grad_fields = [:μ, :σ²]) - end - - @testset "Only graph input" begin - nin, nout = 2, 4 - ndata = rand(Float32, nin, 3) - edata = rand(Float32, nin, 3) - g = GNNGraph([1, 1, 2], [2, 3, 3], ndata = ndata, edata = edata) - m = NNConv(nin => nout, Dense(2, nin * nout, tanh)) - chain = GNNChain(m) - y = m(g, g.ndata.x, g.edata.e) - @test m(g).ndata.x == y - @test chain(g).ndata.x == y - end -end - -@testset "WithGraph" begin - x = rand(Float32, 2, 3) - g = GNNGraph([1, 2, 3], [2, 3, 1], ndata = x) - model = SAGEConv(2 => 3) - wg = WithGraph(model, g) - # No need to feed the graph to `wg` - @test wg(x) == model(g, x) - @test Flux.params(wg) == Flux.params(model) - g2 = GNNGraph([1, 1, 2, 3], [2, 4, 1, 1]) - x2 = rand(Float32, 2, 4) - # WithGraph will ignore the internal graph if fed with a new one. - @test wg(g2, x2) == model(g2, x2) - - wg = WithGraph(model, g, traingraph = false) - @test length(Flux.params(wg)) == length(Flux.params(model)) - - wg = WithGraph(model, g, traingraph = true) - @test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g)) -end - -@testset "Flux restructure" begin - chain = GNNChain(GraphConv(2 => 2)) - params, restructure = Flux.destructure(chain) - @test restructure(params) isa GNNChain -end - -[.\test\layers\conv.jl] -RTOL_LOW = 1e-2 -RTOL_HIGH = 1e-5 -ATOL_LOW = 1e-3 - -in_channel = 3 -out_channel = 5 -N = 4 -T = Float32 - -adj1 = [0 1 0 1 - 1 0 1 0 - 0 1 0 1 - 1 0 1 0] - -g1 = GNNGraph(adj1, - ndata = rand(T, in_channel, N), - graph_type = GRAPH_T) - -adj_single_vertex = [0 0 0 1 - 0 0 0 0 - 0 0 0 1 - 1 0 1 0] - -g_single_vertex = GNNGraph(adj_single_vertex, - ndata = rand(T, in_channel, N), - graph_type = GRAPH_T) - -test_graphs = [g1, g_single_vertex] - -@testset "GCNConv" begin - l = GCNConv(in_channel => out_channel) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - l = GCNConv(in_channel => out_channel, tanh, bias = false) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - l = GCNConv(in_channel => out_channel, add_self_loops = false) - test_layer(l, g1, rtol = RTOL_HIGH, outsize = (out_channel, g1.num_nodes)) - - @testset "edge weights & custom normalization" begin - s = [2, 3, 1, 3, 1, 2] - t = [1, 1, 2, 2, 3, 3] - w = T[1, 2, 3, 4, 5, 6] - g = GNNGraph((s, t, w), ndata = ones(T, 1, 3), graph_type = GRAPH_T) - x = g.ndata.x - custom_norm_fn(d) = 1 ./ sqrt.(d) - l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true) - l.weight .= 1 - d = degree(g, dir = :in, edge_weight = true) - y = l(g, x) - @test y[1, 1] ≈ w[1] / √(d[1] * d[2]) + w[2] / √(d[1] * d[3]) - @test y[1, 2] ≈ w[3] / √(d[2] * d[1]) + w[4] / √(d[2] * d[3]) - @test y ≈ l(g, x, w; norm_fn = custom_norm_fn) # checking without custom - - # test gradient with respect to edge weights - w = rand(T, 6) - x = rand(T, 1, 3) - g = GNNGraph((s, t, w), ndata = x, graph_type = GRAPH_T, edata = w) - l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true) - @test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{T} # redundant test but more explicit - test_layer(l, g, rtol = RTOL_HIGH, outsize = (1, g.num_nodes), test_gpu = false) - end - - @testset "conv_weight" begin - l = GraphNeuralNetworks.GCNConv(in_channel => out_channel) - w = zeros(T, out_channel, in_channel) - g1 = GNNGraph(adj1, ndata = ones(T, in_channel, N)) - @test l(g1, g1.ndata.x, conv_weight = w) == zeros(T, out_channel, N) - a = rand(T, in_channel, N) - g2 = GNNGraph(adj1, ndata = a) - @test l(g2, g2.ndata.x, conv_weight = w) == w * a - end -end - -@testset "ChebConv" begin - k = 2 - l = ChebConv(in_channel => out_channel, k) - @test size(l.weight) == (out_channel, in_channel, k) - @test size(l.bias) == (out_channel,) - @test l.k == k - for g in test_graphs - g = add_self_loops(g) - test_layer(l, g, rtol = RTOL_HIGH, test_gpu = TEST_GPU, - outsize = (out_channel, g.num_nodes)) - end - - @testset "bias=false" begin - @test length(Flux.params(ChebConv(2 => 3, 3))) == 2 - @test length(Flux.params(ChebConv(2 => 3, 3, bias = false))) == 1 - end -end - -@testset "GraphConv" begin - l = GraphConv(in_channel => out_channel) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - @testset "bias=false" begin - @test length(Flux.params(GraphConv(2 => 3))) == 3 - @test length(Flux.params(GraphConv(2 => 3, bias = false))) == 2 - end -end - -@testset "GATConv" begin - for heads in (1, 2), concat in (true, false) - l = GATConv(in_channel => out_channel; heads, concat, dropout=0) - for g in test_graphs - test_layer(l, g, rtol = RTOL_LOW, - exclude_grad_fields = [:negative_slope, :dropout], - outsize = (concat ? heads * out_channel : out_channel, - g.num_nodes)) - end - end - - @testset "edge features" begin - ein = 3 - l = GATConv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0) - g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) - test_layer(l, g, rtol = RTOL_LOW, - exclude_grad_fields = [:negative_slope, :dropout], - outsize = (out_channel, g.num_nodes)) - end - - @testset "num params" begin - l = GATConv(2 => 3, add_self_loops = false) - @test length(Flux.params(l)) == 3 - l = GATConv((2, 4) => 3, add_self_loops = false) - @test length(Flux.params(l)) == 4 - l = GATConv((2, 4) => 3, add_self_loops = false, bias = false) - @test length(Flux.params(l)) == 3 - end -end - -@testset "GATv2Conv" begin - for heads in (1, 2), concat in (true, false) - l = GATv2Conv(in_channel => out_channel, tanh; heads, concat, dropout=0) - for g in test_graphs - test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW, - exclude_grad_fields = [:negative_slope, :dropout], - outsize = (concat ? heads * out_channel : out_channel, - g.num_nodes)) - end - end - - @testset "edge features" begin - ein = 3 - l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0) - g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) - test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW, - exclude_grad_fields = [:negative_slope, :dropout], - outsize = (out_channel, g.num_nodes)) - end - - @testset "num params" begin - l = GATv2Conv(2 => 3, add_self_loops = false) - @test length(Flux.params(l)) == 5 - l = GATv2Conv((2, 4) => 3, add_self_loops = false) - @test length(Flux.params(l)) == 6 - l = GATv2Conv((2, 4) => 3, add_self_loops = false, bias = false) - @test length(Flux.params(l)) == 4 - end -end - -@testset "GatedGraphConv" begin - num_layers = 3 - l = GatedGraphConv(out_channel, num_layers) - @test size(l.weight) == (out_channel, out_channel, num_layers) - - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "EdgeConv" begin - l = EdgeConv(Dense(2 * in_channel, out_channel), aggr = +) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "GINConv" begin - nn = Dense(in_channel, out_channel) - - l = GINConv(nn, 0.01f0, aggr = mean) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - @test !in(:eps, Flux.trainable(l)) -end - -@testset "NNConv" begin - edim = 10 - nn = Dense(edim, out_channel * in_channel) - - l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) - for g in test_graphs - g = GNNGraph(g, edata = rand(T, edim, g.num_edges)) - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "SAGEConv" begin - l = SAGEConv(in_channel => out_channel) - @test l.aggr == mean - - l = SAGEConv(in_channel => out_channel, tanh, bias = false, aggr = +) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "ResGatedGraphConv" begin - l = ResGatedGraphConv(in_channel => out_channel, tanh, bias = true) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "CGConv" begin - edim = 10 - l = CGConv((in_channel, edim) => out_channel, tanh, residual = false, bias = true) - for g in test_graphs - g = GNNGraph(g, edata = rand(T, edim, g.num_edges)) - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - # no edge features - l1 = CGConv(in_channel => out_channel, tanh, residual = false, bias = true) - @test l1(g1, g1.ndata.x) == l1(g1).ndata.x - @test l1(g1, g1.ndata.x, nothing) == l1(g1).ndata.x -end - -@testset "AGNNConv" begin - l = AGNNConv(trainable=false, add_self_loops=false) - @test l.β == [1.0f0] - @test l.add_self_loops == false - @test l.trainable == false - Flux.trainable(l) == (;) - - l = AGNNConv(init_beta=2.0f0) - @test l.β == [2.0f0] - @test l.add_self_loops == true - @test l.trainable == true - Flux.trainable(l) == (; β = [1f0]) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (in_channel, g.num_nodes)) - end -end - -@testset "MEGNetConv" begin - l = MEGNetConv(in_channel => out_channel, aggr = +) - for g in test_graphs - g = GNNGraph(g, edata = rand(T, in_channel, g.num_edges)) - test_layer(l, g, rtol = RTOL_LOW, - outtype = :node_edge, - outsize = ((out_channel, g.num_nodes), (out_channel, g.num_edges))) - end -end - -@testset "GMMConv" begin - ein_channel = 10 - K = 5 - l = GMMConv((in_channel, ein_channel) => out_channel, K = K) - for g in test_graphs - g = GNNGraph(g, edata = rand(Float32, ein_channel, g.num_edges)) - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end -end - -@testset "SGConv" begin - K = [1, 2, 3] # for different number of hops - for k in K - l = SGConv(in_channel => out_channel, k, add_self_loops = true) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - l = SGConv(in_channel => out_channel, k, add_self_loops = true) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - end -end - -@testset "TAGConv" begin - K = [1, 2, 3] - for k in K - l = TAGConv(in_channel => out_channel, k, add_self_loops = true) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - - l = TAGConv(in_channel => out_channel, k, add_self_loops = true) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - end -end - -@testset "EGNNConv" begin - hin = 5 - hout = 5 - hidden = 5 - l = EGNNConv(hin => hout, hidden) - g = rand_graph(10, 20, graph_type = GRAPH_T) - x = rand(T, in_channel, g.num_nodes) - h = randn(T, hin, g.num_nodes) - hnew, xnew = l(g, h, x) - @test size(hnew) == (hout, g.num_nodes) - @test size(xnew) == (in_channel, g.num_nodes) -end - -@testset "TransformerConv" begin - ein = 2 - heads = 3 - # used like in Kool et al., 2019 - l = TransformerConv(in_channel * heads => in_channel; heads, add_self_loops = true, - root_weight = false, ff_channels = 10, skip_connection = true, - batch_norm = false) - # batch_norm=false here for tests to pass; true in paper - for adj in [adj1, adj_single_vertex] - g = GNNGraph(adj, ndata = rand(T, in_channel * heads, size(adj, 1)), - graph_type = GRAPH_T) - test_layer(l, g, rtol = RTOL_LOW, - exclude_grad_fields = [:negative_slope], - outsize = (in_channel * heads, g.num_nodes)) - end - # used like in Shi et al., 2021 - l = TransformerConv((in_channel, ein) => in_channel; heads, gating = true, - bias_qkv = true) - for g in test_graphs - g = GNNGraph(g, edata = rand(T, ein, g.num_edges)) - test_layer(l, g, rtol = RTOL_LOW, - exclude_grad_fields = [:negative_slope], - outsize = (in_channel * heads, g.num_nodes)) - end - # test averaging heads - l = TransformerConv(in_channel => in_channel; heads, concat = false, - bias_root = false, - root_weight = false) - for g in test_graphs - test_layer(l, g, rtol = RTOL_LOW, - exclude_grad_fields = [:negative_slope], - outsize = (in_channel, g.num_nodes)) - end -end - -@testset "DConv" begin - K = [1, 2, 3] # for different number of hops - for k in K - l = DConv(in_channel => out_channel, k) - for g in test_graphs - test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) - end - end -end -[.\test\layers\heteroconv.jl] -@testset "HeteroGraphConv" begin - d, n = 3, 5 - g = rand_bipartite_heterograph((n, 2*n), 15) - hg = rand_bipartite_heterograph((2,3), 6) - - model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), - (:B,:to,:A) => GraphConv(d => d)]) - - for x in [ - (A = rand(Float32, d, n), B = rand(Float32, d, 2n)), - Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n)) - ] - # x = (A = rand(Float32, d, n), B = rand(Float32, d, 2n)) - x = Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n)) - - y = model(g, x) - - grad, dx = gradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x) - ngrad, ndx = ngradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x) - - @test grad.layers[1].weight1 ≈ ngrad.layers[1].weight1 rtol=1e-4 - @test grad.layers[1].weight2 ≈ ngrad.layers[1].weight2 rtol=1e-4 - @test grad.layers[1].bias ≈ ngrad.layers[1].bias rtol=1e-4 - @test grad.layers[2].weight1 ≈ ngrad.layers[2].weight1 rtol=1e-4 - @test grad.layers[2].weight2 ≈ ngrad.layers[2].weight2 rtol=1e-4 - @test grad.layers[2].bias ≈ ngrad.layers[2].bias rtol=1e-4 - - @test dx[:A] ≈ ndx[:A] rtol=1e-4 - @test dx[:B] ≈ ndx[:B] rtol=1e-4 - end - - @testset "Constructor from pairs" begin - layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh), - (:B, :to, :A) => GraphConv(64 => 32, tanh)); - @test length(layer.etypes) == 2 - end - - @testset "Destination node aggregation" begin - # deterministic setup to validate the aggregation - d, n = 3, 5 - g = GNNHeteroGraph(((:A, :to, :B) => ([1, 1, 2, 3], [1, 2, 2, 3]), - (:B, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3]), - (:C, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3])); num_nodes = Dict(:A => n, :B => n, :C => n)) - model = HeteroGraphConv([ - (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), - (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), - (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = +) - x = (A = rand(Float32, d, n), B = rand(Float32, d, n), C = rand(Float32, d, n)) - y = model(g, x) - weights = ones(Float32, d, d) - - ### Test default summation aggregation - # B2 has 2 edges from A and itself (sense check) - expected = sum(weights * x.A[:, [1, 2]]; dims = 2) .+ weights * x.B[:, [2]] - output = y.B[:, [2]] - @test expected ≈ output - - # B5 has only itself - @test weights * x.B[:, [5]] ≈ y.B[:, [5]] - - # A1 has 1 edge from B, 1 from C and twice itself - expected = sum(weights * x.B[:, [1]] + weights * x.C[:, [1]]; dims = 2) .+ - 2 * weights * x.A[:, [1]] - output = y.A[:, [1]] - @test expected ≈ output - - # A2 has 2 edges from B, 2 from C and twice itself - expected = sum(weights * x.B[:, [1, 2]] + weights * x.C[:, [1, 2]]; dims = 2) .+ - 2 * weights * x.A[:, [2]] - output = y.A[:, [2]] - @test expected ≈ output - - # A5 has only itself but twice - @test 2 * weights * x.A[:, [5]] ≈ y.A[:, [5]] - - #### Test different aggregation function - model2 = HeteroGraphConv([ - (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), - (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), - (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = -) - y2 = model2(g, x) - # B no change - @test y.B ≈ y2.B - - # A1 has 1 edge from B, 1 from C, itself cancels out - expected = sum(weights * x.B[:, [1]] - weights * x.C[:, [1]]; dims = 2) - output = y2.A[:, [1]] - @test expected ≈ output - - # A2 has 2 edges from B, 2 from C, itself cancels out - expected = sum(weights * x.B[:, [1, 2]] - weights * x.C[:, [1, 2]]; dims = 2) - output = y2.A[:, [2]] - @test expected ≈ output - end - - @testset "CGConv" begin - x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh), - (:B, :to, :A) => CGConv(4 => 2, tanh)); - y = layers(hg, x); - @test size(y.A) == (2,2) && size(y.B) == (2,3) - end - - @testset "EdgeConv" begin - x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), - (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); - y = layers(hg, x); - @test size(y.A) == (2,2) && size(y.B) == (2,3) - end - - @testset "SAGEConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +), - (:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - - @testset "GATConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => GATConv(4 => 2), - (:B, :to, :A) => GATConv(4 => 2)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - - @testset "GINConv" begin - x = (A = rand(4, 2), B = rand(4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4), - (:B, :to, :A) => GINConv(Dense(4, 2), 0.4)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - - @testset "ResGatedGraphConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2), - (:B, :to, :A) => ResGatedGraphConv(4 => 2)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - - @testset "GATv2Conv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => GATv2Conv(4 => 2), - (:B, :to, :A) => GATv2Conv(4 => 2)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - - @testset "GCNConv" begin - g = rand_bipartite_heterograph((2,3), 6) - x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), - (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); - @test size(y.A) == (2,2) && size(y.B) == (2,3) - end -end - -[.\test\layers\pool.jl] -@testset "GlobalPool" begin - p = GlobalPool(+) - n = 10 - chin = 6 - X = rand(Float32, 6, n) - g = GNNGraph(random_regular_graph(n, 4), ndata = X, graph_type = GRAPH_T) - u = p(g, X) - @test u ≈ sum(X, dims = 2) - - ng = 3 - g = Flux.batch([GNNGraph(random_regular_graph(n, 4), - ndata = rand(Float32, chin, n), - graph_type = GRAPH_T) - for i in 1:ng]) - u = p(g, g.ndata.x) - @test size(u) == (chin, ng) - @test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2) - @test p(g).gdata.u == u - - test_layer(p, g, rtol = 1e-5, exclude_grad_fields = [:aggr], outtype = :graph) -end - -@testset "GlobalAttentionPool" begin - n = 10 - chin = 6 - chout = 5 - ng = 3 - - fgate = Dense(chin, 1) - ffeat = Dense(chin, chout) - p = GlobalAttentionPool(fgate, ffeat) - @test length(Flux.params(p)) == 4 - - g = Flux.batch([GNNGraph(random_regular_graph(n, 4), - ndata = rand(Float32, chin, n), - graph_type = GRAPH_T) - for i in 1:ng]) - - test_layer(p, g, rtol = 1e-5, outtype = :graph, outsize = (chout, ng)) -end - -@testset "TopKPool" begin - N = 10 - k, in_channel = 4, 7 - X = rand(in_channel, N) - for T in [Bool, Float64] - adj = rand(T, N, N) - p = TopKPool(adj, k, in_channel) - @test eltype(p.p) === Float32 - @test size(p.p) == (in_channel,) - @test eltype(p.Ã) === T - @test size(p.Ã) == (k, k) - y = p(X) - @test size(y) == (in_channel, k) - end -end - -@testset "topk_index" begin - X = [8, 7, 6, 5, 4, 3, 2, 1] - @test topk_index(X, 4) == [1, 2, 3, 4] - @test topk_index(X', 4) == [1, 2, 3, 4] -end - -@testset "Set2Set" begin - n_in = 3 - n_iters = 2 - n_layers = 1 - g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5]) - g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes)) - l = Set2Set(n_in, n_iters, n_layers) - y = l(g, node_features(g)) - @test size(y) == (2 * n_in, g.num_graphs) - - ## TODO the numerical gradient seems to be 3 times smaller than zygote one - # test_layer(l, g, rtol = 1e-4, atol=1e-4, outtype = :graph, outsize = (2 * n_in, g.num_graphs), - # verbose=true, exclude_grad_fields = [:state0, :state]) -end -[.\test\layers\temporalconv.jl] -in_channel = 3 -out_channel = 5 -N = 4 -S = 5 -T = Float32 - -g1 = GNNGraph(rand_graph(N,8), - ndata = rand(T, in_channel, N), - graph_type = :sparse) - -tg = TemporalSnapshotsGNNGraph([g1 for _ in 1:S]) - -@testset "TGCNCell" begin - tgcn = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) - h, x̃ = tgcn(tgcn.state0, g1, g1.ndata.x) - @test size(h) == (out_channel, N) - @test size(x̃) == (out_channel, N) - @test h == x̃ -end - -@testset "TGCN" begin - tgcn = TGCN(in_channel => out_channel) - @test size(Flux.gradient(x -> sum(tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) - model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) - @test size(model(g1, g1.ndata.x)) == (1, N) - @test model(g1) isa GNNGraph -end - -@testset "A3TGCN" begin - a3tgcn = A3TGCN(in_channel => out_channel) - @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) - model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1)) - @test size(model(g1, g1.ndata.x)) == (1, N) - @test model(g1) isa GNNGraph -end - -@testset "GConvLSTMCell" begin - gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) - (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) - @test size(h) == (out_channel, N) - @test size(c) == (out_channel, N) -end - -@testset "GConvLSTM" begin - gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) - @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) - model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -end - -@testset "GConvGRUCell" begin - gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) - h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) - @test size(h) == (out_channel, N) -end - -@testset "GConvGRU" begin - gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) - @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) - model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) - @test size(model(g1, g1.ndata.x)) == (1, N) - @test model(g1) isa GNNGraph -end - -@testset "DCGRU" begin - dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) - @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) - model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) - @test size(model(g1, g1.ndata.x)) == (1, N) - @test model(g1) isa GNNGraph -end - -@testset "GINConv" begin - ginconv = GINConv(Dense(in_channel => out_channel),0.3) - @test length(ginconv(tg, tg.ndata.x)) == S - @test size(ginconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "ChebConv" begin - chebconv = ChebConv(in_channel => out_channel, 5) - @test length(chebconv(tg, tg.ndata.x)) == S - @test size(chebconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(chebconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "GATConv" begin - gatconv = GATConv(in_channel => out_channel) - @test length(gatconv(tg, tg.ndata.x)) == S - @test size(gatconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(gatconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "GATv2Conv" begin - gatv2conv = GATv2Conv(in_channel => out_channel) - @test length(gatv2conv(tg, tg.ndata.x)) == S - @test size(gatv2conv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(gatv2conv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "GatedGraphConv" begin - gatedgraphconv = GatedGraphConv(5, 5) - @test length(gatedgraphconv(tg, tg.ndata.x)) == S - @test size(gatedgraphconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(gatedgraphconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "CGConv" begin - cgconv = CGConv(in_channel => out_channel) - @test length(cgconv(tg, tg.ndata.x)) == S - @test size(cgconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(cgconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "SGConv" begin - sgconv = SGConv(in_channel => out_channel) - @test length(sgconv(tg, tg.ndata.x)) == S - @test size(sgconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(sgconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "TransformerConv" begin - transformerconv = TransformerConv(in_channel => out_channel) - @test length(transformerconv(tg, tg.ndata.x)) == S - @test size(transformerconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(transformerconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "GCNConv" begin - gcnconv = GCNConv(in_channel => out_channel) - @test length(gcnconv(tg, tg.ndata.x)) == S - @test size(gcnconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(gcnconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "ResGatedGraphConv" begin - resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh) - @test length(resgatedconv(tg, tg.ndata.x)) == S - @test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "SAGEConv" begin - sageconv = SAGEConv(in_channel => out_channel) - @test length(sageconv(tg, tg.ndata.x)) == S - @test size(sageconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(sageconv(tg, x))), tg.ndata.x)[1]) == S -end - -@testset "GraphConv" begin - graphconv = GraphConv(in_channel => out_channel, tanh) - @test length(graphconv(tg, tg.ndata.x)) == S - @test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N) - @test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S -end - - From 01ec78b454f7e4cdc25cbbf54c447b97737c4635 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:25:37 +0530 Subject: [PATCH 14/41] Delete redundant file --- sccript.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 sccript.py diff --git a/sccript.py b/sccript.py deleted file mode 100644 index 9250023d2..000000000 --- a/sccript.py +++ /dev/null @@ -1,16 +0,0 @@ -import os - -def main(): - with open("data.txt", "w", encoding="utf-8") as outfile: - for root, _, files in os.walk("."): - for filename in files: - if filename.endswith(".jl"): - filepath = os.path.join(root, filename) - print(f"Processing: {filepath}") # Add print statement here - outfile.write(f"[{filepath}]\n") - with open(filepath, "r", encoding="utf-8") as infile: - outfile.write(infile.read()) - outfile.write("\n") - -if __name__ == "__main__": - main() \ No newline at end of file From ff012bb3f10fc3a689e8beb070bd19073c3ec866 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 15:27:54 +0530 Subject: [PATCH 15/41] trying test fix --- GNNLux/test/layers/conv_tests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 3af1c0430..380b9bc7a 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - g = rand_graph(10, 40) + g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) @@ -97,7 +97,6 @@ @testset "NNConv" begin edim = 10 nn = Dense(edim, out_dims * in_dims) - g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) end From cf7d30a3f02b7c297a3194a6daf07c0f86f67bbc Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 15:37:53 +0530 Subject: [PATCH 16/41] trying test fix --- GNNLux/test/layers/conv_tests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 380b9bc7a..60cac0fe4 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,5 +1,6 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) + edims = 10 g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) in_dims = 3 out_dims = 5 From 1c60d1ce2d6cca4a5c2144a03a48d2dee8fb262e Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:53:33 +0530 Subject: [PATCH 17/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 60cac0fe4..e6db4e130 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - edims = 10 + edim = 10 g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) in_dims = 3 out_dims = 5 From 39b9c7482ed195ffffd80a1bca89c6dd47144386 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:08:37 +0530 Subject: [PATCH 18/41] Update basic_tests.jl --- GNNLux/test/layers/basic_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index cac2a45fa..ac937d128 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin rng = StableRNG(17) - g = rand_graph(10, 40) + g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) @testset "GNNLayer" begin From 894bdb31293ad35a0003c41b645819345d93cecf Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:09:10 +0530 Subject: [PATCH 19/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index e6db4e130..0c743a048 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,11 +1,13 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) edim = 10 - g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + g = rand_graph(10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) + g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) From 6e610c1e3a8aafcfc70f2950b12ce36192af4cea Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:09:41 +0530 Subject: [PATCH 20/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 0c743a048..87c40924b 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,7 +1,7 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) edim = 10 - g = rand_graph(10, 40) + g = rand_graph(rng, 10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) From caf355c01ded16961119379f7440283cbe0ace0c Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:24:38 +0530 Subject: [PATCH 21/41] Update conv_tests.jl: edata issues --- GNNLux/test/layers/conv_tests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 87c40924b..9b4f7abb0 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -6,8 +6,6 @@ out_dims = 5 x = randn(rng, Float32, in_dims, 10) - g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) - @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) @@ -100,6 +98,7 @@ @testset "NNConv" begin edim = 10 nn = Dense(edim, out_dims * in_dims) + g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) end From 3e322613ce6743e9e9f0a9ee5a9cd17805d89b04 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:58:47 +0530 Subject: [PATCH 22/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9b4f7abb0..cdd5f6642 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -6,6 +6,8 @@ out_dims = 5 x = randn(rng, Float32, in_dims, 10) + g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) @@ -97,9 +99,8 @@ @testset "NNConv" begin edim = 10 - nn = Dense(edim, out_dims * in_dims) - g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + nn = Dense(edim, out_dims * in_dims) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true) + test_lux_layer(rng, l, g2, x, sizey=(out_dims, g.num_nodes), container=true) end end From 24da4c4c76356950397784ddd4ad8dd660739908 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:18:29 +0530 Subject: [PATCH 23/41] Update conv_tests.jl: edata --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index cdd5f6642..4fbebb863 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -101,6 +101,6 @@ edim = 10 nn = Dense(edim, out_dims * in_dims) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g2, x, sizey=(out_dims, g.num_nodes), container=true) + test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true, edge_weight=g.edata.e) end end From f2ff073a70a85a85aaa6f042ef0ccbb4f91a77e4 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:19:51 +0530 Subject: [PATCH 24/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 4fbebb863..9323446cc 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -101,6 +101,6 @@ edim = 10 nn = Dense(edim, out_dims * in_dims) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true, edge_weight=g.edata.e) + test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) end end From 93affd23016dba2a4272e88349cd26027bb3ca74 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 19 Aug 2024 17:45:01 +0530 Subject: [PATCH 25/41] change lux testing --- GNNLux/test/shared_testsetup.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index b6b80df49..332b578c6 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -14,7 +14,7 @@ export test_lux_layer function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing) if container @test l isa GNNContainerLayer @@ -27,7 +27,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) @test LuxCore.statelength(l) == LuxCore.statelength(st) - y, st′ = l(g, x, ps, st) + y, st′ = l(g, x, edge_weight, ps, st) @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize From f0481b4a700c9004d37a0c41576ff587242da59f Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Aug 2024 12:43:49 +0530 Subject: [PATCH 26/41] Update conv_tests.jl: Trying to fix tests --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9323446cc..bff70eb3a 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -99,7 +99,7 @@ @testset "NNConv" begin edim = 10 - nn = Dense(edim, out_dims * in_dims) + nn = Dense(edim, in_dims * out_dims) l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) end From b3e26495fb23f7cd7e89b4cf8772e56cab841a1b Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 23 Aug 2024 19:42:54 +0530 Subject: [PATCH 27/41] Update conv.jl: trying to fix --- GNNLux/src/layers/conv.jl | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index bca5eef7d..64403e092 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -635,45 +635,42 @@ end in_dims::Int out_dims::Int use_bias::Bool - add_self_loops::Bool - use_edge_weight::Bool init_weight init_bias σ end - function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, init_bias = zeros32, use_bias::Bool = true, init_weight = glorot_uniform, - add_self_loops::Bool = true, - use_edge_weight::Bool = false, allow_fast_activation::Bool = true) in_dims, out_dims = ch σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) + return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) end function (l::NNConv)(g, x, edge_weight, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) - m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.σ) + m = (; nn, l.aggr, ps.weight, bias = ps.bias, l.σ) y = GNNlib.nn_conv(m, g, x, edge_weight) - stnew = _getstate(nn) + stnew = (; nn = _getstate(nn)) return y, stnew end + +function LuxCore.initialstates(rng::AbstractRNG, l::NNConv) + return (; nn = LuxCore.initialstates(rng, l.nn)) +end + +LuxCore.statelength(l::NNConv) = statelength(l.nn) LuxCore.outputsize(d::NNConv) = (d.out_dims,) function Base.show(io::IO, l::NNConv) - print(io, "NNConv($(l.nn)") - print(io, ", $(l.ϵ)") - l.σ == identity || print(io, ", ", l.σ) - l.use_bias || print(io, ", use_bias=false") - l.add_self_loops || print(io, ", add_self_loops=false") - !l.use_edge_weight || print(io, ", use_edge_weight=true") + out, in = size(l.weight) + print(io, "NNConv($in => $out") + print(io, ", aggr=", l.aggr) print(io, ")") end From 23b89c26dcfcd9c4aae566ab3f36ad87a24f79f6 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:18:48 +0530 Subject: [PATCH 28/41] Update conv.jl: reverted --- GNNLux/src/layers/conv.jl | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 64403e092..bca5eef7d 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -635,42 +635,45 @@ end in_dims::Int out_dims::Int use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool init_weight init_bias σ end + function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, init_bias = zeros32, use_bias::Bool = true, init_weight = glorot_uniform, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, allow_fast_activation::Bool = true) in_dims, out_dims = ch σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) + return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) end function (l::NNConv)(g, x, edge_weight, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) + nn = StatefulLuxLayer{true}(l.nn, ps, st) - m = (; nn, l.aggr, ps.weight, bias = ps.bias, l.σ) + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) y = GNNlib.nn_conv(m, g, x, edge_weight) - stnew = (; nn = _getstate(nn)) + stnew = _getstate(nn) return y, stnew end - -function LuxCore.initialstates(rng::AbstractRNG, l::NNConv) - return (; nn = LuxCore.initialstates(rng, l.nn)) -end - -LuxCore.statelength(l::NNConv) = statelength(l.nn) LuxCore.outputsize(d::NNConv) = (d.out_dims,) function Base.show(io::IO, l::NNConv) - out, in = size(l.weight) - print(io, "NNConv($in => $out") - print(io, ", aggr=", l.aggr) + print(io, "NNConv($(l.nn)") + print(io, ", $(l.ϵ)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") print(io, ")") end From 4b32e2fe3797b745b156b92635e6e05cfe7b3724 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 26 Aug 2024 03:17:31 +0530 Subject: [PATCH 29/41] fixing --- GNNLux/src/layers/conv.jl | 3 - temp.jl | 129 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 temp.jl diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index ca6daf258..c48bcd9ef 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -666,11 +666,8 @@ function (l::NNConv)(g, x, edge_weight, ps, st) return y, stnew end -LuxCore.outputsize(d::NNConv) = (d.out_dims,) - function Base.show(io::IO, l::NNConv) print(io, "NNConv($(l.nn)") - print(io, ", $(l.ϵ)") l.σ == identity || print(io, ", ", l.σ) l.use_bias || print(io, ", use_bias=false") l.add_self_loops || print(io, ", add_self_loops=false") diff --git a/temp.jl b/temp.jl new file mode 100644 index 000000000..a8e6f5c13 --- /dev/null +++ b/temp.jl @@ -0,0 +1,129 @@ + +using StableRNGs +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + +using Lux: Lux, Chain, Dense, GRUCell, + glorot_uniform, zeros32 , + StatefulLuxLayer + +import Reexport: @reexport + +@reexport using Test +@reexport using GNNLux +@reexport using Lux +@reexport using StableRNGs +@reexport using Random, Statistics + +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + + +rng = StableRNG(1234) + edim = 10 + g = rand_graph(10, 40) + in_dims = 3 + out_dims = 5 + x = randn(Float32, in_dims, 10) + + g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + +@testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) +end + + + +function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; + outputsize=nothing, sizey=nothing, container=false, + atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing) + + if container + @test l isa GNNContainerLayer + else + @test l isa GNNLayer + end + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(l) == LuxCore.statelength(st) + + if edge_weight !== nothing + y, st′ = l(g, x, edge_weight, ps, st) + else + y, st′ = l(g, x, ps, st) + end + @test eltype(y) == eltype(x) + if outputsize !== nothing + @test LuxCore.outputsize(l) == outputsize + end + if sizey !== nothing + @test size(y) == sizey + elseif outputsize !== nothing + @test size(y) == (outputsize..., g.num_nodes) + end + + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) +end + + +""" +MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean) +""" + +g = rand_graph(10, 40, seed=1234) + in_dims = 3 + out_dims = 5 + x = randn(Float32, in_dims, 10) + rng = StableRNG(1234) + l = MEGNetConv(in_dims => out_dims) + l + l isa GNNContainerLayer + test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) + + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + edata = rand(T, in_channel, g.num_edges) + + (x_new, e_new), st_new = l(g, x, ps, st) + + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) + + + + + edim = 10 + in_dims = 3 # Example + out_dims = 5 # Example +using Flux + g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + nn = Dense(edim, out_dims * in_dims) + l = NNConv(in_dims => out_dims, nn, tanh) + test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) + + + + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + + + l = MEGNetConv(in_dims => out_dims) + l + l isa GNNContainerLayer + test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) + + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) \ No newline at end of file From 6227cd368f02500790b7907846584d5b6b76fc31 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 26 Aug 2024 03:39:37 +0530 Subject: [PATCH 30/41] Update shared_testsetup.jl: dont make other tests fail --- GNNLux/test/shared_testsetup.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index 332b578c6..bcd243df3 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -26,8 +26,13 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; st = LuxCore.initialstates(rng, l) @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) @test LuxCore.statelength(l) == LuxCore.statelength(st) - - y, st′ = l(g, x, edge_weight, ps, st) + + if edge_weight !== nothing + y, st′ = l(g, x, ps, st) + else + y, st′ = l(g, x, edge_weight, ps, st) + end + @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize @@ -42,4 +47,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end -end \ No newline at end of file +end From b1d185fdd1f8510996077b9f9acee5cb075b4f0c Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Mon, 26 Aug 2024 03:51:50 +0530 Subject: [PATCH 31/41] Update shared_testsetup.jl: fixing other tests --- GNNLux/test/shared_testsetup.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index bcd243df3..797e1577d 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -28,9 +28,9 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test LuxCore.statelength(l) == LuxCore.statelength(st) if edge_weight !== nothing - y, st′ = l(g, x, ps, st) - else y, st′ = l(g, x, edge_weight, ps, st) + else + y, st′ = l(g, x, ps, st) end @test eltype(y) == eltype(x) From ef68f79d78c0e022afc4a6ea5edd1c20b7a1b9f2 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 26 Aug 2024 03:52:52 +0530 Subject: [PATCH 32/41] gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 91820619c..fc157c7be 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store docs/src/democards/gridtheme.css -test.jl \ No newline at end of file +temp.jl \ No newline at end of file From 4f0d60f7430ac650e273a6685cba33fa8faea1a4 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 26 Aug 2024 03:54:37 +0530 Subject: [PATCH 33/41] ignore --- .gitignore | 2 +- temp.jl | 129 ----------------------------------------------------- 2 files changed, 1 insertion(+), 130 deletions(-) delete mode 100644 temp.jl diff --git a/.gitignore b/.gitignore index fc157c7be..91820619c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store docs/src/democards/gridtheme.css -temp.jl \ No newline at end of file +test.jl \ No newline at end of file diff --git a/temp.jl b/temp.jl deleted file mode 100644 index a8e6f5c13..000000000 --- a/temp.jl +++ /dev/null @@ -1,129 +0,0 @@ - -using StableRNGs -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme - -using Lux: Lux, Chain, Dense, GRUCell, - glorot_uniform, zeros32 , - StatefulLuxLayer - -import Reexport: @reexport - -@reexport using Test -@reexport using GNNLux -@reexport using Lux -@reexport using StableRNGs -@reexport using Random, Statistics - -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme - - -rng = StableRNG(1234) - edim = 10 - g = rand_graph(10, 40) - in_dims = 3 - out_dims = 5 - x = randn(Float32, in_dims, 10) - - g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) - -@testset "GINConv" begin - nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) - l = GINConv(nn, 0.5) - test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) -end - - - -function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; - outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing) - - if container - @test l isa GNNContainerLayer - else - @test l isa GNNLayer - end - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) - @test LuxCore.statelength(l) == LuxCore.statelength(st) - - if edge_weight !== nothing - y, st′ = l(g, x, edge_weight, ps, st) - else - y, st′ = l(g, x, ps, st) - end - @test eltype(y) == eltype(x) - if outputsize !== nothing - @test LuxCore.outputsize(l) == outputsize - end - if sizey !== nothing - @test size(y) == sizey - elseif outputsize !== nothing - @test size(y) == (outputsize..., g.num_nodes) - end - - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) -end - - -""" -MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean) -""" - -g = rand_graph(10, 40, seed=1234) - in_dims = 3 - out_dims = 5 - x = randn(Float32, in_dims, 10) - rng = StableRNG(1234) - l = MEGNetConv(in_dims => out_dims) - l - l isa GNNContainerLayer - test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) - - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - edata = rand(T, in_channel, g.num_edges) - - (x_new, e_new), st_new = l(g, x, ps, st) - - @test size(x_new) == (out_dims, g.num_nodes) - @test size(e_new) == (out_dims, g.num_edges) - - - - - edim = 10 - in_dims = 3 # Example - out_dims = 5 # Example -using Flux - g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) - nn = Dense(edim, out_dims * in_dims) - l = NNConv(in_dims => out_dims, nn, tanh) - test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) - - - - hin = 6 - hout = 7 - hidden = 8 - l = EGNNConv(hin => hout, hidden) - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - h = randn(rng, Float32, hin, g.num_nodes) - (hnew, xnew), stnew = l(g, h, x, ps, st) - @test size(hnew) == (hout, g.num_nodes) - @test size(xnew) == (in_dims, g.num_nodes) - - - l = MEGNetConv(in_dims => out_dims) - l - l isa GNNContainerLayer - test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true) - - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) \ No newline at end of file From e7661f2aef1b2dfca5aeb3a720c93b51a7db9a6d Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 26 Aug 2024 03:56:22 +0530 Subject: [PATCH 34/41] remove useless params --- GNNLux/src/layers/conv.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index c48bcd9ef..cfe8157df 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -635,32 +635,26 @@ end in_dims::Int out_dims::Int use_bias::Bool - add_self_loops::Bool - use_edge_weight::Bool init_weight init_bias σ end - function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, init_bias = zeros32, use_bias::Bool = true, init_weight = glorot_uniform, - add_self_loops::Bool = true, - use_edge_weight::Bool = false, allow_fast_activation::Bool = true) in_dims, out_dims = ch σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) + return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) end function (l::NNConv)(g, x, edge_weight, ps, st) nn = StatefulLuxLayer{true}(l.nn, ps, st) - m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.σ) + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ) y = GNNlib.nn_conv(m, g, x, edge_weight) stnew = _getstate(nn) return y, stnew @@ -670,8 +664,6 @@ function Base.show(io::IO, l::NNConv) print(io, "NNConv($(l.nn)") l.σ == identity || print(io, ", ", l.σ) l.use_bias || print(io, ", use_bias=false") - l.add_self_loops || print(io, ", add_self_loops=false") - !l.use_edge_weight || print(io, ", use_edge_weight=true") print(io, ")") end From 67bc8fd62baae578d7a10bd51de15cc627cafb68 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:08:07 +0530 Subject: [PATCH 35/41] Update GNNLux.jl: ordering --- GNNLux/src/GNNLux.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 7b3598b2f..e0a199276 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -31,8 +31,8 @@ export AGNNConv, GINConv, # GMMConv, GraphConv, - NNConv, MEGNetConv, + NNConv, # ResGatedGraphConv, # SAGEConv, SGConv @@ -44,4 +44,4 @@ export TGCN export A3TGCN end #module - \ No newline at end of file + From b94b1f6e01a9c59723182b0ea13edc928a1bc59e Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:00:53 +0530 Subject: [PATCH 36/41] Update Project.toml: fixed --- GNNLux/Project.toml | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/GNNLux/Project.toml b/GNNLux/Project.toml index d98416ece..9f27ee3d1 100644 --- a/GNNLux/Project.toml +++ b/GNNLux/Project.toml @@ -4,29 +4,15 @@ authors = ["Carlo Lucibello and contributors"] version = "0.1.0" [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ConcreteStructs = "0.2.3" From 91fed90cfa565f496059fde033078e9ecfbdd239 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:20:35 +0530 Subject: [PATCH 37/41] Update conv_tests.jl: checking test --- GNNLux/test/layers/conv_tests.jl | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index f7c19523d..7aa2bf256 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -97,11 +97,26 @@ test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + + @testset "NNConv" begin - edim = 10 - nn = Dense(edim, in_dims * out_dims) - l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) + n_in = 3 + n_in_edge = 10 + n_out = 5 + + s = [1,1,2,3] + t = [2,3,1,1] + g2 = GNNGraph(s, t) + + nn = Dense(n_in_edge => n_out * n_in) + l = NNConv(n_in => n_out, nn, tanh, aggr = +) + x = randn(Float32, n_in, g2.num_nodes) + e = randn(Float32, n_in_edge, g2.num_edges) + y = l(g, x, e) # just to see if it runs without an error + #edim = 10 + #nn = Dense(edim, in_dims * out_dims) + #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) + #test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) end @testset "MEGNetConv" begin From a5875531fc6f3b7b3b81c9f6089732a6162932f7 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Wed, 4 Sep 2024 23:09:26 +0530 Subject: [PATCH 38/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 7aa2bf256..906151a23 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -112,11 +112,11 @@ l = NNConv(n_in => n_out, nn, tanh, aggr = +) x = randn(Float32, n_in, g2.num_nodes) e = randn(Float32, n_in_edge, g2.num_edges) - y = l(g, x, e) # just to see if it runs without an error + #y = l(g, x, e) # just to see if it runs without an error #edim = 10 #nn = Dense(edim, in_dims * out_dims) #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - #test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e) + test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=g2.edata.e) end @testset "MEGNetConv" begin From 232a1b401063f52690e8e77e80482202b4c28074 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Wed, 4 Sep 2024 23:20:04 +0530 Subject: [PATCH 39/41] Update conv_tests.jl --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 906151a23..06578e05b 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -116,7 +116,7 @@ #edim = 10 #nn = Dense(edim, in_dims * out_dims) #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=g2.edata.e) + test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e) end @testset "MEGNetConv" begin From e2de74c47129257e67ceaa2a89efd1690e980267 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sat, 7 Sep 2024 22:39:00 +0530 Subject: [PATCH 40/41] checking tests --- GNNLux/test/layers/conv_tests.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 06578e05b..2ee18b984 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -112,11 +112,15 @@ l = NNConv(n_in => n_out, nn, tanh, aggr = +) x = randn(Float32, n_in, g2.num_nodes) e = randn(Float32, n_in_edge, g2.num_edges) - #y = l(g, x, e) # just to see if it runs without an error + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + y = l(g, x, e, ps, st) # just to see if it runs without an error #edim = 10 #nn = Dense(edim, in_dims * out_dims) #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) - test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e) + #test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e) end @testset "MEGNetConv" begin From faa4df3c83b432b5b023f7eff6dbf78fbde18e72 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:03:20 +0530 Subject: [PATCH 41/41] Update conv_tests.jl: typo in test --- GNNLux/test/layers/conv_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 2ee18b984..20908cb3a 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -116,7 +116,7 @@ ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) - y = l(g, x, e, ps, st) # just to see if it runs without an error + y = l(g2, x, e, ps, st) # just to see if it runs without an error #edim = 10 #nn = Dense(edim, in_dims * out_dims) #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +)