Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi threaded #36

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[targets]
test = ["Test", "NearestNeighbors", "StaticArrays"]
test = ["Test", "NearestNeighbors", "StaticArrays", "Statistics"]
42 changes: 31 additions & 11 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,42 @@
Insert index `query` referring to data point `data[q]` into the graph.
"""
function insert_point!(hnsw, query, l = get_random_level(hnsw.lgraph))
add_vertex!(hnsw.lgraph, query, l)

# Get enterpoint and highest level in a threadsafe way
lock(hnsw.ep_lock)
enter_point = get_enter_point(hnsw)
L = get_top_layer(hnsw)
add_vertex!(hnsw.lgraph, query, l)
L = get_entry_level(hnsw)

# Special Case for the very first entry
if enter_point == 0
set_enter_point!(hnsw, query)
unlock(hnsw.ep_lock)
return nothing
end
unlock(hnsw.ep_lock)

ep = Neighbor(enter_point, distance(hnsw, enter_point, query))
for level ∈ L:-1:l+1 #Find nearest point within each layer and traverse down

# Traverse through levels to l (assuming l < L)
for level ∈ L:-1:l+1
W = search_layer(hnsw, query, ep, 1,level)
ep = nearest(W) #nearest element from q in W
end

# Insert query on all levels < min(L,l)
for level ∈ min(L,l):-1:1
W = search_layer(hnsw, query, ep, hnsw.efConstruction, level)
add_connections!(hnsw, level, query, W)
ep = nearest(W)
end
l > L && set_enter_point!(hnsw, query) #another lock here

# Update enter point if inserted point has highest layer
if l > L
lock(hnsw.ep_lock)
set_enter_point!(hnsw, query)
unlock(hnsw.ep_lock)
end
return nothing
end

Expand All @@ -39,7 +54,7 @@ function search_layer(hnsw, query, enter_point, num_points, level)
while length(C) > 0
c = pop_nearest!(C) # from query in C
c.dist > furthest(W).dist && break #Stopping condition
#lock(lg.locklist[c.idx])
lock(hnsw.lgraph.locklist[c.idx])
for e ∈ neighbors(hnsw.lgraph, level, c)
if !isvisited(vl, e)
visit!(vl, e)
Expand All @@ -51,7 +66,7 @@ function search_layer(hnsw, query, enter_point, num_points, level)
end
end
end
#unlock(lg.locklist[c.idx])
unlock(hnsw.lgraph.locklist[c.idx])
end
release_list(hnsw.vlp, vl)
return W #num_points closest neighbors
Expand Down Expand Up @@ -87,10 +102,9 @@ function knn_search(hnsw, q, K)
@assert length(q)==length(hnsw.data[1])
ep = get_enter_point(hnsw)
epN = Neighbor(ep, distance(hnsw, q, ep))
L = get_top_layer(hnsw) #layer of ep , top layer of hnsw
L = get_entry_level(hnsw) #layer of ep , top layer of hnsw
for level ∈ L:-1:2 # Iterate from top to second lowest
epN = search_layer(hnsw, q, epN, 1, level)[1]
#TODO: better upper layer implementation here as well
end
W = search_layer(hnsw, q, epN, ef, 1)
list = nearest(W, K)
Expand All @@ -101,11 +115,17 @@ end

function knn_search(hnsw::HierarchicalNSW{T,F},
q::AbstractVector{<:AbstractVector}, # query
K) where {T,F}
K; multithreading=false) where {T,F}
idxs = Vector{Vector{T}}(undef,length(q))
dists = Vector{Vector{F}}(undef,length(q))
for n = 1:length(q)
idxs[n], dists[n] = knn_search(hnsw, q[n], K)
if multithreading
Threads.@threads for n = 1:length(q)
idxs[n], dists[n] = knn_search(hnsw, q[n], K)
end
else
for n = 1:length(q)
idxs[n], dists[n] = knn_search(hnsw, q[n], K)
end
end
idxs, dists
end
44 changes: 25 additions & 19 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
###########################################################################################
mutable struct HierarchicalNSW{T, F, V, M}
lgraph::LayeredGraph{T}
added::Vector{Bool}
data::V
ep::T
entry_level::Int
ep_lock::Mutex
vlp::VisitedListPool
metric::M
Expand All @@ -29,32 +31,33 @@ function HierarchicalNSW(data;
F = eltype(data[1])
vlp = VisitedListPool(1,max_elements)
return HierarchicalNSW{T,F,typeof(data),typeof(metric)}(
lg, data, ep, Mutex(), vlp, metric, efConstruction, ef)
lg, fill(false, max_elements), data, ep, 0, Mutex(), vlp, metric, efConstruction, ef)
end

"""
add_to_graph!(hnsw, indices, multithreading=false)
Add `i ∈ indices` referring to `data[i]` into the graph.

ATM does not check if already added.
Adding index twice leads to segfault.
Indices already added previously will be ignored.
"""
function add_to_graph!(hnsw::HierarchicalNSW{T}, indices, multithreading=false) where {T}
#Does not check if index has already been added
function add_to_graph!(hnsw::HierarchicalNSW{T}, indices; multithreading=false) where {T}
any(hnsw.added[indices]) && @warn "Some of the points have already been added!"

if multithreading == false
for i ∈ indices
insert_point!(hnsw, T(i))
hnsw.added[i] || insert_point!(hnsw, T(i))
hnsw.added[i] = true
end
else
#levels = [get_random_level(hnsw) for i ∈ 1:maximum(indices)]
println("multithreading does not work yet")
#Threads.@threads for i ∈ 1:maximum(indices)#indices
# insert_point!(hnsw, i, levels[i])
#end
levels = [get_random_level(hnsw.lgraph) for _ ∈ 1:maximum(indices)]
Threads.@threads for i ∈ indices
hnsw.added[i] || insert_point!(hnsw, T(i), levels[i])
hnsw.added[i] = true
end
end
return nothing
end
add_to_graph!(hnsw::HierarchicalNSW) = add_to_graph!(hnsw, eachindex(hnsw.data))
add_to_graph!(hnsw::HierarchicalNSW; kwargs...) = add_to_graph!(hnsw, eachindex(hnsw.data); kwargs...)


set_ef!(hnsw::HierarchicalNSW, ef) = hnsw.ef = ef
Expand All @@ -63,17 +66,20 @@ set_ef!(hnsw::HierarchicalNSW, ef) = hnsw.ef = ef
# Utility Functions #
###########################################################################################
get_enter_point(hnsw::HierarchicalNSW) = hnsw.ep
set_enter_point!(hnsw::HierarchicalNSW, ep) = hnsw.ep = ep
get_top_layer(hnsw::HierarchicalNSW) = hnsw.lgraph.numlayers
function set_enter_point!(hnsw::HierarchicalNSW, ep)
hnsw.ep = ep
hnsw.entry_level = levelof(hnsw.lgraph, ep)
end
get_entry_level(hnsw::HierarchicalNSW) = hnsw.entry_level

distance(hnsw, i, j) = @inbounds evaluate(hnsw.metric, hnsw.data[i], hnsw.data[j])
distance(hnsw, i, q::AbstractVector) = @inbounds evaluate(hnsw.metric, hnsw.data[i], q)
distance(hnsw, q::AbstractVector, j) = @inbounds evaluate(hnsw.metric, hnsw.data[j], q)
@inline distance(hnsw, i, j) = @inbounds evaluate(hnsw.metric, hnsw.data[i], hnsw.data[j])
@inline distance(hnsw, i, q::AbstractVector) = @inbounds evaluate(hnsw.metric, hnsw.data[i], q)
@inline distance(hnsw, q::AbstractVector, j) = @inbounds evaluate(hnsw.metric, hnsw.data[j], q)

function Base.show(io::IO, hnsw::HierarchicalNSW)
lg = hnsw.lgraph
println(io, "Hierarchical Navigable Small World with $(get_top_layer(lg)) layers")
for i = get_top_layer(lg):-1:1
println(io, "Hierarchical Navigable Small World with $(get_entry_level(hnsw)) layers")
for i = get_entry_level(hnsw):-1:1
nodes = count(x->length(x)>=i, lg.linklist)
λ = x -> length(x)>=i ? length(x[i]) : 0
edges = sum(map(λ, lg.linklist))
Expand Down
46 changes: 19 additions & 27 deletions src/layered_graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ function LinkList{T}(num_elements::Int) where {T}
Vector{Vector{T}}(undef, num_elements)
end

mutable struct LayeredGraph{T}
struct LayeredGraph{T}
linklist::LinkList{T} #linklist[index][level][link]
locklist::Vector{Mutex}
numlayers::Int
M0::Int
M::Int
m_L::Float64
Expand All @@ -31,7 +30,6 @@ function LayeredGraph{T}(num_elements::Int, M, M0, m_L) where {T}
LayeredGraph{T}(
LinkList{T}(num_elements),
[Mutex() for i=1:num_elements],
0,
M,
M0,
m_L)
Expand All @@ -40,7 +38,6 @@ end

function add_vertex!(lg::LayeredGraph{T}, i, level) where {T}
lg.linklist[i] = fill(zero(T), lg.M0 + (level-1)*lg.M)
lg.numlayers > level || (lg.numlayers = level)
return nothing
end

Expand All @@ -58,24 +55,24 @@ add_edge!(lg, level, s::Neighbor, t) = add_edge!(lg, level, s.idx, t)
add_edge!(lg, level, s::Integer, t::Neighbor) = add_edge!(lg, level, s, t.idx)


function replace_edge!(lg, level, source, target, newtarget)
offset = index_offset(lg,level)
for m ∈ 1:max_connections(lg, level)
if lg.linklist[source][offset + m] == target
lg.linklist[source][offset + m] = newtarget
return true
end
function set_edges!(lg, level, source, targets)
offset = index_offset(lg, level)
M = max_connections(lg, level)
T = length(targets)
for m ∈ 1:min(M,T)
lg.linklist[source][offset + m] = targets[m].idx
end
for m ∈ T+1:M
lg.linklist[source][offset + m] = 0 #type ?
end
@warn "target link to be replaced was not found"
return false
end
set_edges!(lg, level, source::Neighbor, targets) = set_edges!(lg, level, source.idx, targets)

############################################################################################
# Utility Functions #
############################################################################################

Base.length(lg::LayeredGraph) = lg.numlayers
get_top_layer(lg::LayeredGraph) = lg.numlayers
get_random_level(lg) = floor(Int, -log(rand())* lg.m_L) + 1
max_connections(lg::LayeredGraph, level) = level==1 ? lg.M0 : lg.M
index_offset(lg, level) = level > 1 ? lg.M0 + lg.M*(level-2) : 0
Expand Down Expand Up @@ -123,27 +120,22 @@ function add_connections!(hnsw, level, query, candidates)
lg = hnsw.lgraph
M = max_connections(lg, level)
W = neighbor_heuristic(hnsw, level, candidates)
#set neighbors
for n in W
add_edge!(lg, level, query, n)
end
#Set links from query
set_edges!(lg, level, query, W)
#set links to query
for n in W
q = Neighbor(query, n.dist)
lock(lg.locklist[n.idx]) #lock() linklist of n here
if add_edge!(lg, level, n, q)
else
#remove weakest link and replace it
#TODO: likely needs neighbor_heuristic here
weakest_link = q # dist to query
#conditionally remove weakest link and replace it
C = NeighborSet(q)
for c in neighbors(lg, level, n)
dist = distance(hnsw, n.idx, c)
if weakest_link.dist < dist
weakest_link = Neighbor(c, dist)
end
end
if weakest_link.dist > q.dist
replace_edge!(lg, level, n.idx, weakest_link.idx, q.idx)
insert!(C, Neighbor(c, dist))
end
C = neighbor_heuristic(hnsw, level, C)
q ∈ C && set_edges!(lg, level, n, C)
end
unlock(lg.locklist[n.idx]) #unlock here
end
Expand Down
6 changes: 3 additions & 3 deletions test/lowlevel_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using HNSW
import HNSW: LayeredGraph, add_vertex!, add_edge!, get_top_layer, levelof, neighbors,max_connections
import HNSW: LayeredGraph, add_vertex!, add_edge!, get_entry_level, levelof, neighbors,max_connections
using Test
using LinearAlgebra
using NearestNeighbors
Expand Down Expand Up @@ -57,10 +57,10 @@ end
lg = HNSW.LayeredGraph{UInt32}(num_elements, M0, M, m_L)
@test max_connections(lg, 1) == M0
@test max_connections(lg, 2) == M
@testset "add_vertex! & get_top_layer" for i=1:10
@testset "add_vertex! & get_entry_level" for i=1:10
level = rand(1:10)
add_vertex!(lg, i, level)
@test get_top_layer(lg) >= level
#@test get_entry(lg) >= level
end
@testset "add_edge!" begin
for i = 1:10, j=1:10
Expand Down