-
Notifications
You must be signed in to change notification settings - Fork 0
/
neighbors.cpp
76 lines (64 loc) · 2.8 KB
/
neighbors.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
//
// Copyright (c) 2017 – Technicolor R&D France
//
// The source code form of this open source project is subject to the terms of the
// Clear BSD license.
//
// You can redistribute it and/or modify it under the terms of the Clear BSD
// License (See LICENSE file).
//
#include <memory>
#include "binheap.hpp"
#include "distances.hpp"
const int BLOCK_VECS = 256;
const int BLOCK_NEIGHS = 256;
static inline void add_candidates_heaps(float* dist_block, kv_binheap<int, float>* heaps,
int block_count_vec, int block_count_neigh, int base_neigh) {
float* dist_line = dist_block;
for(int vec_i = 0; vec_i < block_count_vec; ++vec_i) {
kv_binheap<int, float>& h = heaps[vec_i];
for(int neigh_i = 0; neigh_i < block_count_neigh; ++neigh_i) {
h.push(base_neigh + neigh_i, dist_line[neigh_i]);
}
dist_line += block_count_neigh;
}
}
void find_k_neighbors(
const int vector_count, const int neighbor_count, const int dim, const int k,
const float* vectors, const float* neighbors, int* assignements) {
std::unique_ptr<kv_binheap<int, float>[]> heaps = std::unique_ptr<kv_binheap<int, float>[]>(
new kv_binheap<int, float>[BLOCK_VECS]);
for(int h_i = 0; h_i < BLOCK_VECS; ++h_i) {
heaps[h_i].reset_capacity(k);
}
std::unique_ptr<float[]> dists_block(new float[BLOCK_VECS * BLOCK_NEIGHS]);
std::unique_ptr<float[]> sorted_distances(new float[k]);
cross_dists_func dists_func = get_cross_dists_func(dim);
// Shifted data structures
int* shifted_assignements = assignements;
const float* shifted_vectors = vectors;
for(int vec_i = 0; vec_i < vector_count; vec_i += BLOCK_VECS) {
int block_count_vec = std::min(BLOCK_VECS, vector_count - vec_i);
// Reset heaps
for(int v_i = 0; v_i < block_count_vec; ++v_i) {
heaps[v_i].reset();
}
// Compute distances and insert into heaps
const float* shifted_neighbors = neighbors;
for(int neigh_i = 0; neigh_i < neighbor_count; neigh_i += BLOCK_NEIGHS) {
int block_count_neigh = std::min(BLOCK_NEIGHS, neighbor_count - neigh_i);
dists_func(dists_block.get(), shifted_neighbors, block_count_neigh, shifted_vectors,
block_count_vec, block_count_neigh);
add_candidates_heaps(dists_block.get(), heaps.get(), block_count_vec, block_count_neigh,
neigh_i);
shifted_neighbors += block_count_neigh;
}
// Sort heaps
for(int v_i = 0; v_i < block_count_vec; ++v_i) {
heaps[v_i].sort(shifted_assignements, sorted_distances.get());
shifted_assignements += k;
}
// Shift data structures
shifted_vectors += block_count_vec * dim;
}
}