Skip to content

Commit

Permalink
edges iterates over Graphs.Edge (#171)
Browse files Browse the repository at this point in the history
* edges iterates over Graphs.Edge

* add test for zero

* fix test

* fix test

* fix io
  • Loading branch information
CarloLucibello authored May 22, 2022
1 parent 1c261d4 commit 97c75be
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
4 changes: 3 additions & 1 deletion src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T<:Integer}
return GNNGraph(s, t; num_nodes, kws...)
end

Base.zero(::Type{G}) where G<:GNNGraph = G(0)

# COO convenience constructors
GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...)
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
Expand Down Expand Up @@ -209,7 +211,7 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph)
print(io, "GNNGraph:
num_nodes = $(g.num_nodes)
num_edges = $(g.num_edges)")
g.num_graphs > 1 && print("\n num_graphs = $(g.num_graphs)")
g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)")
if !isempty(g.ndata)
print(io, "\n ndata:")
for k in keys(g.ndata)
Expand Down
10 changes: 5 additions & 5 deletions src/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]

get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3]

Graphs.edges(g::GNNGraph) = zip(edge_index(g)...)
Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...)

Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int}
Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)}

# """
# eltype(g::GNNGraph)
Expand All @@ -42,9 +42,9 @@ end

Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0

graph_type_symbol(g::GNNGraph{<:COO_T}) = :coo
graph_type_symbol(g::GNNGraph{<:SPARSE_T}) = :sparse
graph_type_symbol(g::GNNGraph{<:ADJMAT_T}) = :dense
graph_type_symbol(::GNNGraph{<:COO_T}) = :coo
graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse
graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense

Graphs.nv(g::GNNGraph) = g.num_nodes
Graphs.ne(g::GNNGraph) = g.num_edges
Expand Down
10 changes: 8 additions & 2 deletions test/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
@test g.num_nodes == 4
@test nv(g) == g.num_nodes
@test ne(g) == g.num_edges
@test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort
@test Tuple.(collect(edges(g))) |> sort == collect(zip(s, t)) |> sort
@test sort(outneighbors(g, 1)) == [2, 4]
@test sort(inneighbors(g, 1)) == [2, 4]
@test is_directed(g) == true
Expand Down Expand Up @@ -150,7 +150,7 @@

@test g.num_edges == 4
@test g.num_nodes == 4
@test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort
@test length(edges(g)) == 4
@test sort(outneighbors(g, 1)) == [2]
@test sort(inneighbors(g, 1)) == [4]
@test is_directed(g) == true
Expand All @@ -168,6 +168,12 @@
@test adjacency_list(g, dir=:in) == adj_list_in
end

@testset "zero" begin
g = rand_graph(4, 6, graph_type=GRAPH_T)
G = typeof(g)
@test zero(G) == G(0)
end

@testset "Graphs.jl constructor" begin
lg = random_regular_graph(10, 4)
@test !Graphs.is_directed(lg)
Expand Down
8 changes: 8 additions & 0 deletions test/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
end
end

@testset "edges" begin
g = rand_graph(4, 10, graph_type=GRAPH_T)
@test edgetype(g) <: Graphs.Edge
for e in edges(g)
@test e isa Graphs.Edge
end
end

@testset "has_self_loops" begin
s = [1, 1, 2, 3]
t = [2, 2, 2, 4]
Expand Down

0 comments on commit 97c75be

Please sign in to comment.