diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 4af7ac81d..dd8911088 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -192,7 +192,7 @@ function add_edges(g::GNNHeteroGraph{<:COO_T}, if node_t ∉ ntypes push!(ntypes, node_t) if haskey(num_nodes, node_t) - _num_nodes[node_t] == num_nodes[node_t] + _num_nodes[node_t] = num_nodes[node_t] else _num_nodes[node_t] = maximum(st) end diff --git a/test/GNNGraphs/gnnheterograph.jl b/test/GNNGraphs/gnnheterograph.jl index 84f0ba3d0..e79b74772 100644 --- a/test/GNNGraphs/gnnheterograph.jl +++ b/test/GNNGraphs/gnnheterograph.jl @@ -116,6 +116,52 @@ end @test size(g[:B].y) == (d, 2*n) end +@testset "add_edges" begin + d, n = 3, 5 + g = rand_bipartite_heterograph(n, 2 * n, 15) + s, t = [1, 2, 3], [3, 2, 1] + ## Keep the same ntypes - construct with args + g1 = add_edges(g, (:A, :rel1, :B), s, t) + @test num_node_types(g1) == 2 + @test num_edge_types(g1) == 3 + for i in eachindex(sources, targets) + @test has_edge(g1, (:A, :rel1, :B), s[i], t[i]) + end + # no change to num_nodes + @test g1.num_nodes[:A] == n + @test g1.num_nodes[:B] == 2n + + ## Keep the same ntypes - construct with a pair + g2 = add_edges(g, (:A, :rel1, :B) => (s, t)) + @test num_node_types(g2) == 2 + @test num_edge_types(g2) == 3 + for i in eachindex(sources, targets) + @test has_edge(g2, (:A, :rel1, :B), s[i], t[i]) + end + # no change to num_nodes + @test g2.num_nodes[:A] == n + @test g2.num_nodes[:B] == 2n + + ## New ntype with num_nodes (applies only to the new ntype) and edata + edata = rand(Float32, d, length(s)) + g3 = add_edges(g, + (:A, :rel1, :C) => (s, t); + num_nodes = Dict(:A => 1, :B => 1, :C => 10), + edata) + @test num_node_types(g3) == 3 + @test num_edge_types(g3) == 3 + for i in eachindex(sources, targets) + @test has_edge(g3, (:A, :rel1, :C), s[i], t[i]) + end + # added edata + @test g3.edata[(:A, :rel1, :C)].e == edata + # no change to existing num_nodes + @test g3.num_nodes[:A] == n + @test g3.num_nodes[:B] == 2n + # new num_nodes added as per kwarg + @test g3.num_nodes[:C] == 10 +end + ## Cannot test this because DataStore is not an ordered collection ## Uncomment when/if it will be based on OrderedDict # @testset "show" begin