Skip to content

Commit

Permalink
fix add_edges
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Sep 3, 2023
1 parent 9a4f4b7 commit e6480a4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
if node_t ntypes
push!(ntypes, node_t)
if haskey(num_nodes, node_t)
_num_nodes[node_t] == num_nodes[node_t]
_num_nodes[node_t] = num_nodes[node_t]
else
_num_nodes[node_t] = maximum(st)
end
Expand Down
46 changes: 46 additions & 0 deletions test/GNNGraphs/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,52 @@ end
@test size(g[:B].y) == (d, 2*n)
end

@testset "add_edges" begin
d, n = 3, 5
g = rand_bipartite_heterograph(n, 2 * n, 15)
s, t = [1, 2, 3], [3, 2, 1]
## Keep the same ntypes - construct with args
g1 = add_edges(g, (:A, :rel1, :B), s, t)
@test num_node_types(g1) == 2
@test num_edge_types(g1) == 3
for i in eachindex(sources, targets)
@test has_edge(g1, (:A, :rel1, :B), s[i], t[i])
end
# no change to num_nodes
@test g1.num_nodes[:A] == n
@test g1.num_nodes[:B] == 2n

## Keep the same ntypes - construct with a pair
g2 = add_edges(g, (:A, :rel1, :B) => (s, t))
@test num_node_types(g2) == 2
@test num_edge_types(g2) == 3
for i in eachindex(sources, targets)
@test has_edge(g2, (:A, :rel1, :B), s[i], t[i])
end
# no change to num_nodes
@test g2.num_nodes[:A] == n
@test g2.num_nodes[:B] == 2n

## New ntype with num_nodes (applies only to the new ntype) and edata
edata = rand(Float32, d, length(s))
g3 = add_edges(g,
(:A, :rel1, :C) => (s, t);
num_nodes = Dict(:A => 1, :B => 1, :C => 10),
edata)
@test num_node_types(g3) == 3
@test num_edge_types(g3) == 3
for i in eachindex(sources, targets)
@test has_edge(g3, (:A, :rel1, :C), s[i], t[i])
end
# added edata
@test g3.edata[(:A, :rel1, :C)].e == edata
# no change to existing num_nodes
@test g3.num_nodes[:A] == n
@test g3.num_nodes[:B] == 2n
# new num_nodes added as per kwarg
@test g3.num_nodes[:C] == 10
end

## Cannot test this because DataStore is not an ordered collection
## Uncomment when/if it will be based on OrderedDict
# @testset "show" begin
Expand Down

0 comments on commit e6480a4

Please sign in to comment.