Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
YichengDWu committed May 29, 2022
1 parent a4c139a commit 7e0556f
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
if bidirected
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
end
m2 = bidirected ? m÷2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
m2 = bidirected ? m ÷ 2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
end


Expand Down Expand Up @@ -92,11 +92,11 @@ GNNGraph:
```
"""
function knn_graph(points::AbstractMatrix, k::Int;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)
function knn_graph(points::AbstractMatrix, k::Int;
graph_indicator=nothing,
self_loops=false,
dir=:in,
kws...)

if graph_indicator !== nothing
d, n = size(points)
Expand All @@ -105,22 +105,22 @@ function knn_graph(points::AbstractMatrix, k::Int;
# All graphs in the batch must have at least k nodes.
cm = StatsBase.countmap(graph_indicator)
@assert all(values(cm) .>= k)

# Make sure that the distance between points in different graphs
# is always larger than any distance within the same graph.
points = points .- minimum(points)
points = points ./ maximum(points)
dummy_feature = 2d .* reshape(graph_indicator, 1, n)
points = vcat(points, dummy_feature)
end

kdtree = NearestNeighbors.KDTree(points)
if !self_loops
k += 1
end
sortres = false
idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres)

g = GNNGraph(idxs; dir, graph_indicator, kws...)
if !self_loops
g = remove_self_loops(g)
Expand Down Expand Up @@ -174,17 +174,17 @@ GNNGraph:
```
"""
function radius_graph(points::AbstractMatrix, r::AbstractFloat;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)
function radius_graph(points::AbstractMatrix, r::AbstractFloat;
graph_indicator=nothing,
self_loops=false,
dir=:in,
kws...)

if graph_indicator !== nothing
d, n = size(points)
@assert graph_indicator isa AbstractVector{<:Integer}
@assert length(graph_indicator) == n

# Make sure that the distance between points in different graphs
# is always larger than r.
dummy_feature = 2r .* reshape(graph_indicator, 1, n)
Expand All @@ -195,7 +195,7 @@ function radius_graph(points::AbstractMatrix, r::AbstractFloat;

sortres = false
idxs = NearestNeighbors.inrange(balltree, points, r, sortres)

g = GNNGraph(idxs; dir, graph_indicator, kws...)
if !self_loops
g = remove_self_loops(g)
Expand Down

0 comments on commit 7e0556f

Please sign in to comment.