diff --git a/Project.toml b/Project.toml index 7031607f6..81cf23015 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/GNNGraphs/gnngraph.jl b/src/GNNGraphs/gnngraph.jl index 1e2d3d957..7bb9c9812 100644 --- a/src/GNNGraphs/gnngraph.jl +++ b/src/GNNGraphs/gnngraph.jl @@ -155,6 +155,8 @@ function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T<:Integer} return GNNGraph(s, t; num_nodes, kws...) end +Base.zero(::Type{G}) where G<:GNNGraph = G(0) + # COO convenience constructors GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...) GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) @@ -209,7 +211,7 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph) print(io, "GNNGraph: num_nodes = $(g.num_nodes) num_edges = $(g.num_edges)") - g.num_graphs > 1 && print("\n num_graphs = $(g.num_graphs)") + g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)") if !isempty(g.ndata) print(io, "\n ndata:") for k in keys(g.ndata) diff --git a/src/GNNGraphs/query.jl b/src/GNNGraphs/query.jl index f4efcf842..c892c7339 100644 --- a/src/GNNGraphs/query.jl +++ b/src/GNNGraphs/query.jl @@ -17,9 +17,9 @@ get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3] -Graphs.edges(g::GNNGraph) = zip(edge_index(g)...) +Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...) -Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int} +Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)} # """ # eltype(g::GNNGraph) @@ -42,9 +42,9 @@ end Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0 -graph_type_symbol(g::GNNGraph{<:COO_T}) = :coo -graph_type_symbol(g::GNNGraph{<:SPARSE_T}) = :sparse -graph_type_symbol(g::GNNGraph{<:ADJMAT_T}) = :dense +graph_type_symbol(::GNNGraph{<:COO_T}) = :coo +graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse +graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense Graphs.nv(g::GNNGraph) = g.num_nodes Graphs.ne(g::GNNGraph) = g.num_edges diff --git a/test/GNNGraphs/gnngraph.jl b/test/GNNGraphs/gnngraph.jl index e9654b7a8..05feeda34 100644 --- a/test/GNNGraphs/gnngraph.jl +++ b/test/GNNGraphs/gnngraph.jl @@ -63,7 +63,7 @@ @test g.num_nodes == 4 @test nv(g) == g.num_nodes @test ne(g) == g.num_edges - @test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort + @test Tuple.(collect(edges(g))) |> sort == collect(zip(s, t)) |> sort @test sort(outneighbors(g, 1)) == [2, 4] @test sort(inneighbors(g, 1)) == [2, 4] @test is_directed(g) == true @@ -150,7 +150,7 @@ @test g.num_edges == 4 @test g.num_nodes == 4 - @test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort + @test length(edges(g)) == 4 @test sort(outneighbors(g, 1)) == [2] @test sort(inneighbors(g, 1)) == [4] @test is_directed(g) == true @@ -168,6 +168,12 @@ @test adjacency_list(g, dir=:in) == adj_list_in end + @testset "zero" begin + g = rand_graph(4, 6, graph_type=GRAPH_T) + G = typeof(g) + @test zero(G) == G(0) + end + @testset "Graphs.jl constructor" begin lg = random_regular_graph(10, 4) @test !Graphs.is_directed(lg) diff --git a/test/GNNGraphs/query.jl b/test/GNNGraphs/query.jl index 4d4c88a14..cae095a6b 100644 --- a/test/GNNGraphs/query.jl +++ b/test/GNNGraphs/query.jl @@ -21,6 +21,14 @@ end end + @testset "edges" begin + g = rand_graph(4, 10, graph_type=GRAPH_T) + @test edgetype(g) <: Graphs.Edge + for e in edges(g) + @test e isa Graphs.Edge + end + end + @testset "has_self_loops" begin s = [1, 1, 2, 3] t = [2, 2, 2, 4]