diff --git a/.github/workflows/test_GNNlib.yml b/.github/workflows/test_GNNlib.yml new file mode 100644 index 000000000..176d92920 --- /dev/null +++ b/.github/workflows/test_GNNlib.yml @@ -0,0 +1,48 @@ +name: GNNlib +on: + pull_request: + branches: + - master + push: + branches: + - master +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.10' # Replace this with the minimum Julia version that your package supports. + # - '1' # '1' will automatically expand to the latest stable 1.x release of Julia. + # - 'pre' + os: + - ubuntu-latest + arch: + - x64 + + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-buildpkg@v1 + - name: Install Julia dependencies and run tests + shell: julia --project=monorepo {0} + run: | + using Pkg + # dev mono repo versions + pkg"registry up" + Pkg.update() + pkg"dev ./GNNGraphs ./GNNlib" + Pkg.test("GNNlib"; coverage=true) + - uses: julia-actions/julia-processcoverage@v1 + with: + # directories: ./GNNlib/src, ./GNNlib/ext + directories: ./GNNlib/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info diff --git a/GNNlib/Project.toml b/GNNlib/Project.toml index 17a2755d3..0e680560e 100644 --- a/GNNlib/Project.toml +++ b/GNNlib/Project.toml @@ -20,8 +20,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GNNlibCUDAExt = "CUDA" [compat] -ChainRulesCore = "1.24" CUDA = "4, 5" +ChainRulesCore = "1.24" DataStructures = "0.18" GNNGraphs = "1.0" LinearAlgebra = "1" @@ -32,7 +32,10 @@ Statistics = "1" julia = "1.10" [extras] +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "ReTestItems", "Reexport", "SparseArrays"] diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 6c7f95a6d..4a6735a06 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -1,12 +1,12 @@ -function global_pool(aggr, g::GNNGraph, x::AbstractArray) - return reduce_nodes(aggr, g, x) +function global_pool(l, g::GNNGraph, x::AbstractArray) + return reduce_nodes(l.aggr, g, x) end -function global_attention_pool(fgate, ffeat, g::GNNGraph, x::AbstractArray) - α = softmax_nodes(g, fgate(x)) - feats = α .* ffeat(x) +function global_attention_pool(l, g::GNNGraph, x::AbstractArray) + α = softmax_nodes(g, l.fgate(x)) + feats = α .* l.ffeat(x) u = reduce_nodes(+, g, feats) return u end @@ -26,11 +26,11 @@ end topk_index(y::Adjoint, k::Int) = topk_index(y', k) -function set2set_pool(lstm, num_iters, g::GNNGraph, x::AbstractMatrix) +function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) n_in = size(x, 1) qstar = zeros_like(x, (2*n_in, g.num_graphs)) - for t in 1:num_iters - q = lstm(qstar) # [n_in, n_graphs] + for t in 1:l.num_iters + q = l.lstm(qstar) # [n_in, n_graphs] qn = broadcast_nodes(g, q) # [n_in, n_nodes] α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] diff --git a/GNNlib/test/msgpass_tests.jl b/GNNlib/test/msgpass_tests.jl new file mode 100644 index 000000000..60d13fcb0 --- /dev/null +++ b/GNNlib/test/msgpass_tests.jl @@ -0,0 +1,140 @@ +@testitem "msgpass" setup=[SharedTestSetup] begin + #TODO test all graph types + GRAPH_T = :coo + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 14 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + X = rand(T, in_channel, num_V) + E = rand(T, in_channel, num_E) + + g = GNNGraph(adj, graph_type = GRAPH_T) + + @testset "propagate" begin + function message(xi, xj, e) + @test xi === nothing + @test e === nothing + ones(T, out_channel, size(xj, 2)) + end + + m = propagate(message, g, +, xj = X) + + @test size(m) == (out_channel, num_V) + + @testset "isolated nodes" begin + x1 = rand(1, 6) + g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) + y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) + @test size(y1) == (1, 6) + end + end + + @testset "apply_edges" begin + m = apply_edges(g, e = E) do xi, xj, e + @test xi === nothing + @test xj === nothing + ones(out_channel, size(e, 2)) + end + + @test m == ones(out_channel, num_E) + + # With NamedTuple input + m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e + @test xi === nothing + @test xj.b == 2 * xj.a + @test size(xj.a, 2) == size(xj.b, 2) == size(e, 2) + ones(out_channel, size(e, 2)) + end + + # NamedTuple output + m = apply_edges(g, e = E) do xi, xj, e + @test xi === nothing + @test xj === nothing + (; a = ones(out_channel, size(e, 2))) + end + + @test m.a == ones(out_channel, num_E) + + @testset "sizecheck" begin + x = rand(3, g.num_nodes - 1) + @test_throws AssertionError apply_edges(copy_xj, g, xj = x) + @test_throws AssertionError apply_edges(copy_xj, g, xi = x) + + x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1)) + @test_throws AssertionError apply_edges(copy_xj, g, xj = x) + @test_throws AssertionError apply_edges(copy_xj, g, xi = x) + + e = rand(3, g.num_edges - 1) + @test_throws AssertionError apply_edges(copy_xj, g, e = e) + end + end + + @testset "copy_xj" begin + n = 128 + A = sprand(n, n, 0.1) + Adj = map(x -> x > 0 ? 1 : 0, A) + X = rand(10, n) + + g = GNNGraph(A, ndata = X, graph_type = GRAPH_T) + + function spmm_copyxj_fused(g) + propagate(copy_xj, + g, +; xj = g.ndata.x) + end + + function spmm_copyxj_unfused(g) + propagate((xi, xj, e) -> xj, + g, +; xj = g.ndata.x) + end + + @test spmm_copyxj_unfused(g) ≈ X * Adj + @test spmm_copyxj_fused(g) ≈ X * Adj + end + + @testset "e_mul_xj and w_mul_xj for weighted conv" begin + n = 128 + A = sprand(n, n, 0.1) + Adj = map(x -> x > 0 ? 1 : 0, A) + X = rand(10, n) + + g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T) + + function spmm_unfused(g) + propagate((xi, xj, e) -> reshape(e, 1, :) .* xj, + g, +; xj = g.ndata.x, e = g.edata.e) + end + function spmm_fused(g) + propagate(e_mul_xj, + g, +; xj = g.ndata.x, e = g.edata.e) + end + + function spmm_fused2(g) + propagate(w_mul_xj, + g, +; xj = g.ndata.x) + end + + @test spmm_unfused(g) ≈ X * A + @test spmm_fused(g) ≈ X * A + @test spmm_fused2(g) ≈ X * A + end + + @testset "aggregate_neighbors" begin + @testset "sizecheck" begin + m = rand(2, g.num_edges - 1) + @test_throws AssertionError aggregate_neighbors(g, +, m) + + m = (a = rand(2, g.num_edges + 1), b = nothing) + @test_throws AssertionError aggregate_neighbors(g, +, m) + end + end + +end \ No newline at end of file diff --git a/GNNlib/test/runtests.jl b/GNNlib/test/runtests.jl new file mode 100644 index 000000000..e4c4512b4 --- /dev/null +++ b/GNNlib/test/runtests.jl @@ -0,0 +1,6 @@ +using GNNlib +using Test +using ReTestItems +using Random, Statistics + +runtests(GNNlib) diff --git a/GNNlib/test/shared_testsetup.jl b/GNNlib/test/shared_testsetup.jl new file mode 100644 index 000000000..106db5159 --- /dev/null +++ b/GNNlib/test/shared_testsetup.jl @@ -0,0 +1,12 @@ +@testsetup module SharedTestSetup + +import Reexport: @reexport + +@reexport using GNNlib +@reexport using GNNGraphs +@reexport using NNlib +@reexport using MLUtils +@reexport using SparseArrays +@reexport using Test, Random, Statistics + +end \ No newline at end of file diff --git a/GNNlib/test/utils_tests.jl b/GNNlib/test/utils_tests.jl new file mode 100644 index 000000000..762ba58b9 --- /dev/null +++ b/GNNlib/test/utils_tests.jl @@ -0,0 +1,68 @@ +@testitem "utils" setup=[SharedTestSetup] begin + # TODO test all graph types + GRAPH_T = :coo + De, Dx = 3, 2 + g = MLUtils.batch([rand_graph(10, 60, bidirected=true, + ndata = rand(Dx, 10), + edata = rand(De, 30), + graph_type = GRAPH_T) for i in 1:5]) + x = g.ndata.x + e = g.edata.e + + @testset "reduce_nodes" begin + r = reduce_nodes(mean, g, x) + @test size(r) == (Dx, g.num_graphs) + @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) + + r2 = reduce_nodes(mean, graph_indicator(g), x) + @test r2 == r + end + + @testset "reduce_edges" begin + r = reduce_edges(mean, g, e) + @test size(r) == (De, g.num_graphs) + @test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) + end + + @testset "softmax_nodes" begin + r = softmax_nodes(g, x) + @test size(r) == size(x) + @test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) + end + + @testset "softmax_edges" begin + r = softmax_edges(g, e) + @test size(r) == size(e) + @test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) + end + + @testset "broadcast_nodes" begin + z = rand(4, g.num_graphs) + r = broadcast_nodes(g, z) + @test size(r) == (4, g.num_nodes) + @test r[:, 1] ≈ z[:, 1] + @test r[:, 10] ≈ z[:, 1] + @test r[:, 11] ≈ z[:, 2] + end + + @testset "broadcast_edges" begin + z = rand(4, g.num_graphs) + r = broadcast_edges(g, z) + @test size(r) == (4, g.num_edges) + @test r[:, 1] ≈ z[:, 1] + @test r[:, 60] ≈ z[:, 1] + @test r[:, 61] ≈ z[:, 2] + end + + @testset "softmax_edge_neighbors" begin + s = [1, 2, 3, 4] + t = [5, 5, 6, 6] + g2 = GNNGraph(s, t) + e2 = randn(Float32, 3, g2.num_edges) + z = softmax_edge_neighbors(g2, e2) + @test size(z) == size(e2) + @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) + @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) + end +end + diff --git a/Project.toml b/Project.toml index 7a429a998..d0f9e76e7 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.6.20" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" @@ -27,7 +26,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [compat] CUDA = "4, 5" ChainRulesCore = "1" -DataStructures = "0.18" Flux = "0.14" Functors = "0.4.1" GNNGraphs = "1.0" diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 0f5c43a18..021d4d8b2 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -9,7 +9,6 @@ using NNlib using NNlib: scatter, gather using ChainRulesCore using Reexport -using DataStructures: nlargest using MLUtils: zeros_like using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, diff --git a/src/deprecations.jl b/src/deprecations.jl index 28f6532cc..a8dd15652 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -2,3 +2,4 @@ # V1.0 deprecations # TODO doe some reason this is not working # @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight) +# @deprecate (l::GNNLayer)(gs::AbstractVector{<:GNNGraph}, args...; kws...) l(MLUtils.batch(gs), args...; kws...) \ No newline at end of file diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c9322f0ec..22fd029f9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -11,12 +11,6 @@ abstract type GNNLayer end # To be specialized by layers also needing edge features as input (e.g. NNConv). (l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -function (l::GNNLayer)(g::AbstractVector{<:GNNGraph}, args...; kws...) - @warn "Passing an array of graphs to a `GNNLayer` is discouraged. - Explicitely call `Flux.batch(graphs)` first instead." maxlog=1 - return l(batch(g), args...; kws...) -end - """ WithGraph(model, g::GNNGraph; traingraph=false) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 4201a8d15..ed2f7eca6 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -36,9 +36,7 @@ struct GlobalPool{F} <: GNNLayer aggr::F end -function (l::GlobalPool)(g::GNNGraph, x::AbstractArray) - return reduce_nodes(l.aggr, g, x) -end +(l::GlobalPool)(g::GNNGraph, x::AbstractArray) = GNNlib.global_pool(l, g, x) (l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) @@ -96,12 +94,7 @@ end GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) -function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray) - α = softmax_nodes(g, l.fgate(x)) - feats = α .* l.ffeat(x) - u = reduce_nodes(+, g, feats) - return u -end +(l::GlobalAttentionPool)(g, x) = GNNlib.global_attention_pool(l, g, x) (l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) @@ -127,20 +120,7 @@ function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_un TopKPool(adj, k, init(in_channel), similar(adj, k, k)) end -function (t::TopKPool)(X::AbstractArray) - y = t.p' * X / norm(t.p) - idx = topk_index(y, t.k) - t.Ã .= view(t.A, idx, idx) - X_ = view(X, :, idx) .* σ.(view(y, idx)') - return X_ -end - -function topk_index(y::AbstractVector, k::Int) - v = nlargest(k, y) - return collect(1:length(y))[y .>= v[end]] -end - -topk_index(y::Adjoint, k::Int) = topk_index(y', k) +(t::TopKPool)(x::AbstractArray) = topk_pool(t, x) @doc raw""" @@ -185,18 +165,9 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) return Set2Set(lstm, n_iters) end -function (l::Set2Set)(g::GNNGraph, x::AbstractMatrix) - n_in = size(x, 1) +function (l::Set2Set)(g, x) Flux.reset!(l.lstm) - qstar = zeros_like(x, (2*n_in, g.num_graphs)) - for t in 1:l.num_iters - q = l.lstm(qstar) # [n_in, n_graphs] - qn = broadcast_nodes(g, q) # [n_in, n_nodes] - α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] - r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] - qstar = vcat(q, r) # [2*n_in, n_graphs] - end - return qstar + return GNNlib.set2set_pool(l, g, x) end (l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) diff --git a/test/msgpass.jl b/test/msgpass.jl deleted file mode 100644 index a88556806..000000000 --- a/test/msgpass.jl +++ /dev/null @@ -1,135 +0,0 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 14 -T = Float32 - -adj = [0 1 0 0 0 0 - 1 0 0 1 1 1 - 0 0 0 0 0 1 - 0 1 0 0 1 0 - 0 1 0 1 0 1 - 0 1 1 0 1 0] - -X = rand(T, in_channel, num_V) -E = rand(T, in_channel, num_E) - -g = GNNGraph(adj, graph_type = GRAPH_T) - -@testset "propagate" begin - function message(xi, xj, e) - @test xi === nothing - @test e === nothing - ones(T, out_channel, size(xj, 2)) - end - - m = propagate(message, g, +, xj = X) - - @test size(m) == (out_channel, num_V) - - @testset "isolated nodes" begin - x1 = rand(1, 6) - g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) - y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) - @test size(y1) == (1, 6) - end -end - -@testset "apply_edges" begin - m = apply_edges(g, e = E) do xi, xj, e - @test xi === nothing - @test xj === nothing - ones(out_channel, size(e, 2)) - end - - @test m == ones(out_channel, num_E) - - # With NamedTuple input - m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e - @test xi === nothing - @test xj.b == 2 * xj.a - @test size(xj.a, 2) == size(xj.b, 2) == size(e, 2) - ones(out_channel, size(e, 2)) - end - - # NamedTuple output - m = apply_edges(g, e = E) do xi, xj, e - @test xi === nothing - @test xj === nothing - (; a = ones(out_channel, size(e, 2))) - end - - @test m.a == ones(out_channel, num_E) - - @testset "sizecheck" begin - x = rand(3, g.num_nodes - 1) - @test_throws AssertionError apply_edges(copy_xj, g, xj = x) - @test_throws AssertionError apply_edges(copy_xj, g, xi = x) - - x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1)) - @test_throws AssertionError apply_edges(copy_xj, g, xj = x) - @test_throws AssertionError apply_edges(copy_xj, g, xi = x) - - e = rand(3, g.num_edges - 1) - @test_throws AssertionError apply_edges(copy_xj, g, e = e) - end -end - -@testset "copy_xj" begin - n = 128 - A = sprand(n, n, 0.1) - Adj = map(x -> x > 0 ? 1 : 0, A) - X = rand(10, n) - - g = GNNGraph(A, ndata = X, graph_type = GRAPH_T) - - function spmm_copyxj_fused(g) - propagate(copy_xj, - g, +; xj = g.ndata.x) - end - - function spmm_copyxj_unfused(g) - propagate((xi, xj, e) -> xj, - g, +; xj = g.ndata.x) - end - - @test spmm_copyxj_unfused(g) ≈ X * Adj - @test spmm_copyxj_fused(g) ≈ X * Adj -end - -@testset "e_mul_xj and w_mul_xj for weighted conv" begin - n = 128 - A = sprand(n, n, 0.1) - Adj = map(x -> x > 0 ? 1 : 0, A) - X = rand(10, n) - - g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T) - - function spmm_unfused(g) - propagate((xi, xj, e) -> reshape(e, 1, :) .* xj, - g, +; xj = g.ndata.x, e = g.edata.e) - end - function spmm_fused(g) - propagate(e_mul_xj, - g, +; xj = g.ndata.x, e = g.edata.e) - end - - function spmm_fused2(g) - propagate(w_mul_xj, - g, +; xj = g.ndata.x) - end - - @test spmm_unfused(g) ≈ X * A - @test spmm_fused(g) ≈ X * A - @test spmm_fused2(g) ≈ X * A -end - -@testset "aggregate_neighbors" begin - @testset "sizecheck" begin - m = rand(2, g.num_edges - 1) - @test_throws AssertionError aggregate_neighbors(g, +, m) - - m = (a = rand(2, g.num_edges + 1), b = nothing) - @test_throws AssertionError aggregate_neighbors(g, +, m) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 1e56c7110..e41c7c1ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,8 +24,6 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - "utils", - "msgpass", "layers/basic", "layers/conv", "layers/heteroconv", diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 7fb423c18..000000000 --- a/test/utils.jl +++ /dev/null @@ -1,64 +0,0 @@ -De, Dx = 3, 2 -g = Flux.batch([GNNGraph(erdos_renyi(10, 30), - ndata = rand(Dx, 10), - edata = rand(De, 30), - graph_type = GRAPH_T) for i in 1:5]) -x = g.ndata.x -e = g.edata.e - -@testset "reduce_nodes" begin - r = reduce_nodes(mean, g, x) - @test size(r) == (Dx, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) - - r2 = reduce_nodes(mean, graph_indicator(g), x) - @test r2 == r -end - -@testset "reduce_edges" begin - r = reduce_edges(mean, g, e) - @test size(r) == (De, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) -end - -@testset "softmax_nodes" begin - r = softmax_nodes(g, x) - @test size(r) == size(x) - @test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) -end - -@testset "softmax_edges" begin - r = softmax_edges(g, e) - @test size(r) == size(e) - @test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) -end - -@testset "broadcast_nodes" begin - z = rand(4, g.num_graphs) - r = broadcast_nodes(g, z) - @test size(r) == (4, g.num_nodes) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 10] ≈ z[:, 1] - @test r[:, 11] ≈ z[:, 2] -end - -@testset "broadcast_edges" begin - z = rand(4, g.num_graphs) - r = broadcast_edges(g, z) - @test size(r) == (4, g.num_edges) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 60] ≈ z[:, 1] - @test r[:, 61] ≈ z[:, 2] -end - -@testset "softmax_edge_neighbors" begin - s = [1, 2, 3, 4] - t = [5, 5, 6, 6] - g2 = GNNGraph(s, t) - e2 = randn(Float32, 3, g2.num_edges) - z = softmax_edge_neighbors(g2, e2) - @test size(z) == size(e2) - @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) - @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) -end -