Skip to content

Commit

Permalink
Merge pull request #180 from MilkshakeForReal/master
Browse files Browse the repository at this point in the history
add radius_graph api
  • Loading branch information
CarloLucibello authored May 29, 2022
2 parents 78792f0 + 20739e4 commit 9398a53
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ export add_nodes,

include("generate.jl")
export rand_graph,
knn_graph
knn_graph,
radius_graph

include("sampling.jl")
export sample_neighbors
Expand Down
75 changes: 75 additions & 0 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,78 @@ function knn_graph(points::AbstractMatrix, k::Int;
end
return g
end

"""
radius_graph(points::AbstractMatrix,
r::AbstractFloat;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)
Create a graph where each node is linked
to its neighbors within a given distance `r`.
# Arguments
- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes.
- `r`: The radius.
- `graph_indicator`: Either nothing or a vector containing the graph assigment of each node,
in which case the returned graph will be a batch of graphs.
- `self_loops`: If `true`, consider the node itself among its neighbors, in which
case the graph will contain self-loops.
- `dir`: The direction of the edges. If `dir=:in` edges go from the
neighbors to the central node. If `dir=:out` we have the opposite
direction.
- `kws`: Further keyword arguments will be passed to the [`GNNGraph ](@ref) constructor.
# Examples
```juliarepl
julia> n, r = 10, 0.75;
julia> x = rand(3, n);
julia> g = radius_graph(x, r)
GNNGraph:
num_nodes = 10
num_edges = 46
julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2];
julia> g = radius_graph(x, r; graph_indicator)
GNNGraph:
num_nodes = 10
num_edges = 20
num_graphs = 2
```
"""
function radius_graph(points::AbstractMatrix, r::AbstractFloat;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)

if graph_indicator !== nothing
d, n = size(points)
@assert graph_indicator isa AbstractVector{<:Integer}
@assert length(graph_indicator) == n

# Make sure that the distance between points in different graphs
# is always larger than r.
dummy_feature = 2r .* reshape(graph_indicator, 1, n)
points = vcat(points, dummy_feature)
end

balltree = NearestNeighbors.BallTree(points)

sortres = false
idxs = NearestNeighbors.inrange(balltree, points, r, sortres)

g = GNNGraph(idxs; dir, graph_indicator, kws...)
if !self_loops
g = remove_self_loops(g)
end
return g
end
18 changes: 18 additions & 0 deletions test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,22 @@
@test all(6 .<= s[ne+1:end] .<= 10)
@test all(6 .<= t[ne+1:end] .<= 10)
end

@testset "radius_graph" begin
n, r = 10, 0.5
x = rand(3, n)
g = radius_graph(x, r; graph_type=GRAPH_T)
@test g.num_nodes == 10
@test has_self_loops(g) == false

g = radius_graph(x, r; dir=:out, self_loops=true, graph_type=GRAPH_T)
@test g.num_nodes == 10
@test has_self_loops(g) == true

graph_indicator = [1,1,1,1,1,2,2,2,2,2]
g = radius_graph(x, r; graph_indicator, graph_type=GRAPH_T)
@test g.num_graphs == 2
s, t = edge_index(g)
@test (s.>5) == (t.>5)
end
end

0 comments on commit 9398a53

Please sign in to comment.