Skip to content

Commit

Permalink
implement hash for GNNGraph (#121)
Browse files Browse the repository at this point in the history
* error message

* == for graphs

* implement hash

* relax gatconv tests

* fix test util
  • Loading branch information
CarloLucibello authored Jan 29, 2022
1 parent 2493acf commit ff74af4
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 7 deletions.
11 changes: 10 additions & 1 deletion src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,13 @@ Flux.Data._nobs(g::GNNGraph) = g.num_graphs
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)

#########################
Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1)))

function Base.:(==)(g1::GNNGraph, g2::GNNGraph)
g1 === g2 && return true
all(k -> getfield(g1, k) == getfield(g2, k), fieldnames(typeof(g1)))
end

function Base.hash(g::T, h::UInt) where T<:GNNGraph
fs = (getfield(g, k) for k in fieldnames(typeof(g)))
return foldl((h, f) -> hash(f, h), fs, init=hash(T, h))
end
4 changes: 2 additions & 2 deletions src/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
if duplicate_if_needed
# Used to copy edge features on reverse edges
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
@assert all(s -> s == 0 || s == n || s == n÷2, sz) "Wrong size in last dimension for feature array."

function duplicate(v)
if v isa AbstractArray && size(v)[end] == n÷2
Expand All @@ -65,7 +65,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
end
data = NamedTuple{keys(data)}(duplicate.(values(data)))
else
@assert all(s -> s == 0 || s == n, sz)
@assert all(s -> s == 0 || s == n, sz) "Wrong size in last dimension for feature array."
end
return data
end
Expand Down
31 changes: 30 additions & 1 deletion test/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,38 @@
end

@testset "Graphs.jl integration" begin
g = GNNGraph(erdos_renyi(10, 20))
g = GNNGraph(erdos_renyi(10, 20), graph_type=GRAPH_T)
@test g isa Graphs.AbstractGraph
end

@testset "==" begin
g1 = rand_graph(5, 6, ndata=rand(5), edata=rand(6), graph_type=GRAPH_T)
@test g1 == g1
@test g1 == deepcopy(g1)
@test g1 !== deepcopy(g1)

g2 = GNNGraph(g1, graph_type=GRAPH_T)
@test g1 == g2
@test g1 === g2 # this is true since GNNGraph is immutable

g2 = GNNGraph(g1, ndata=rand(5), graph_type=GRAPH_T)
@test g1 != g2
@test g1 !== g2

g2 = GNNGraph(g1, edata=rand(6), graph_type=GRAPH_T)
@test g1 != g2
@test g1 !== g2
end

@testset "hash" begin
g1 = rand_graph(5, 6, ndata=rand(5), edata=rand(6), graph_type=GRAPH_T)
@test hash(g1) == hash(g1)
@test hash(g1) == hash(deepcopy(g1))
@test hash(g1) == hash(GNNGraph(g1, ndata=g1.ndata, graph_type=GRAPH_T))
@test hash(g1) == hash(GNNGraph(g1, ndata=g1.ndata, graph_type=GRAPH_T))
@test hash(g1) != hash(GNNGraph(g1, ndata=rand(5), graph_type=GRAPH_T))
@test hash(g1) != hash(GNNGraph(g1, edata=rand(6), graph_type=GRAPH_T))
end
end


4 changes: 2 additions & 2 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
for heads in (1, 2), concat in (true, false)
l = GATConv(in_channel => out_channel; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=1e-3, atol=1e-3,
test_layer(l, g, rtol=1e-3,
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end
Expand All @@ -114,7 +114,7 @@
for heads in (1, 2), concat in (true, false)
l = GATv2Conv(in_channel => out_channel; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=1e-3, atol=1e-3,
test_layer(l, g, rtol=1e-3,
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
end
else
verbose && println("C")
test_approx_structs(x, f̄, f̄2; exclude_grad_fields, broken_grad_fields, verbose)
test_approx_structs(x, f̄, f̄2; atol, rtol, exclude_grad_fields, broken_grad_fields, verbose)
end
end
return true
Expand Down

0 comments on commit ff74af4

Please sign in to comment.