Skip to content

Commit

Permalink
Typecast GNNGraph.num_nodes to Int (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
abieler authored Jan 20, 2022
1 parent ce0396b commit df5b508
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/GNNGraphs/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

function to_coo(coo::COO_T; dir=:out, num_nodes=nothing, weighted=true)
s, t, val = coo
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
@assert isnothing(val) || length(val) == length(s)
@assert length(s) == length(t)
if !isempty(s)
Expand Down Expand Up @@ -114,7 +114,7 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
# `dir` will be ignored since the input `coo` is always in source -> target format.
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
s, t, val = coo
n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
val = isnothing(val) ? eltype(s)(1) : val
T = T === nothing ? eltype(val) : T
if !weighted
Expand Down Expand Up @@ -164,9 +164,9 @@ function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=
eweight = fill!(similar(s, T), 1)
end

num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
A = sparse(s, t, eweight, num_nodes, num_nodes)
num_edges = nnz(A)
num_edges::Int = nnz(A)
if eltype(A) != T
A = T.(A)
end
Expand Down
3 changes: 2 additions & 1 deletion src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ function GNNGraph(g::AbstractGraph; kws...)
# add reverse edges since GNNGraph is directed
s, t = [s; t], [t; s]
end
GNNGraph((s, t); num_nodes=Graphs.nv(g), kws...)
num_nodes::Int = Graphs.nv(g)
GNNGraph((s, t); num_nodes=num_nodes, kws...)
end


Expand Down
7 changes: 6 additions & 1 deletion test/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@
for e in Graphs.edges(lg)
i, j = src(e), dst(e)
@test has_edge(g, i, j)
@test has_edge(g, j, i)
@test has_edge(g, j, i)
end

@testset "SimpleGraph{Int32}" begin
g = GNNGraph(SimpleGraph{Int32}(6), graph_type=GRAPH_T)
@test g.num_nodes == 6
end
end

Expand Down

0 comments on commit df5b508

Please sign in to comment.