Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add radius_graph api #180

Merged
merged 5 commits into from
May 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ export add_nodes,

include("generate.jl")
export rand_graph,
knn_graph
knn_graph,
radius_graph

include("sampling.jl")
export sample_neighbors
Expand Down
75 changes: 75 additions & 0 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,78 @@ function knn_graph(points::AbstractMatrix, k::Int;
end
return g
end

"""
radius_graph(points::AbstractMatrix,
r::AbstractFloat;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)
Create a graph where each node is linked
to its neighbors within a given distance `r`.
# Arguments
- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes.
- `r`: The radius.
- `graph_indicator`: Either nothing or a vector containing the graph assigment of each node,
in which case the returned graph will be a batch of graphs.
- `self_loops`: If `true`, consider the node itself among its neighbors, in which
case the graph will contain self-loops.
- `dir`: The direction of the edges. If `dir=:in` edges go from the
neighbors to the central node. If `dir=:out` we have the opposite
direction.
- `kws`: Further keyword arguments will be passed to the [`GNNGraph ](@ref) constructor.
# Examples
```juliarepl
julia> n, r = 10, 0.75;
julia> x = rand(3, n);
julia> g = radius_graph(x, r)
GNNGraph:
num_nodes = 10
num_edges = 46
julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2];
julia> g = radius_graph(x, r; graph_indicator)
GNNGraph:
num_nodes = 10
num_edges = 20
num_graphs = 2
```
"""
function radius_graph(points::AbstractMatrix, r::AbstractFloat;
graph_indicator = nothing,
self_loops = false,
dir = :in,
kws...)

if graph_indicator !== nothing
d, n = size(points)
@assert graph_indicator isa AbstractVector{<:Integer}
@assert length(graph_indicator) == n

# Make sure that the distance between points in different graphs
# is always larger than r.
dummy_feature = 2r .* reshape(graph_indicator, 1, n)
points = vcat(points, dummy_feature)
end

balltree = NearestNeighbors.BallTree(points)

sortres = false
idxs = NearestNeighbors.inrange(balltree, points, r, sortres)

g = GNNGraph(idxs; dir, graph_indicator, kws...)
if !self_loops
g = remove_self_loops(g)
end
return g
end
18 changes: 18 additions & 0 deletions test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,22 @@
@test all(6 .<= s[ne+1:end] .<= 10)
@test all(6 .<= t[ne+1:end] .<= 10)
end

@testset "radius_graph" begin
n, r = 10, 0.5
x = rand(3, n)
g = radius_graph(x, r; graph_type=GRAPH_T)
@test g.num_nodes == 10
@test has_self_loops(g) == false

g = radius_graph(x, r; dir=:out, self_loops=true, graph_type=GRAPH_T)
@test g.num_nodes == 10
@test has_self_loops(g) == true

graph_indicator = [1,1,1,1,1,2,2,2,2,2]
g = radius_graph(x, r; graph_indicator, graph_type=GRAPH_T)
@test g.num_graphs == 2
s, t = edge_index(g)
@test (s.>5) == (t.>5)
end
end