-
Notifications
You must be signed in to change notification settings - Fork 0
/
indexdb_create1.cpp
134 lines (109 loc) · 4.13 KB
/
indexdb_create1.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//
// Copyright (c) 2017 – Technicolor R&D France
//
// The source code form of this open source project is subject to the terms of the
// Clear BSD license.
//
// You can redistribute it and/or modify it under the terms of the Clear BSD
// License (See LICENSE file).
//
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <memory>
#include <cereal/archives/binary.hpp>
#include "vector_io.hpp"
#include "databases.hpp"
struct cmdargs {
int centroid_count;
const char* learn_filename;
const char* db_filename;
const char* residuals_filename;
};
static void usage() {
std::cerr << "Usage: indexdb_create1 [centroid_count] [learn_file] "
<< "[db_file] [residuals_file]" << std::endl;
std::exit(1);
}
static void parse_args(cmdargs& args, int argc, char* argv[]) {
if(argc < 5) {
usage();
}
args.centroid_count = std::atoi(argv[1]);
args.learn_filename = argv[2];
args.db_filename = argv[3];
args.residuals_filename = argv[4];
}
void check_assignements(int centroid_count, int* assignements, int count) {
std::unique_ptr<int[]> hist = std::make_unique<int[]>(centroid_count);
std::fill(hist.get(), hist.get() + centroid_count, 0);
for(int vec_i = 0; vec_i < count; ++vec_i) {
hist[assignements[vec_i]]++;
}
for(int cent_i = 0; cent_i < centroid_count; ++cent_i) {
std::cout << hist[cent_i] << std::endl;
}
}
void check_residuals(vectors_owner<float>& residuals, const float* backup,
const int* assignments, const float* centroids) {
int dim = residuals.dimension;
for(int vec_i = 0; vec_i < residuals.count; ++vec_i) {
const float* vec = backup + vec_i * dim;
const float* residual = residuals.get(vec_i);
const float* cent = centroids + assignments[vec_i] * dim;
for(int i = 0; i < dim; ++i) {
if (std::abs(vec[i] - (cent[i] + residual[i])) > 1e-5) {
std::cerr << "Residual error: " << vec_i << " " << i << std::endl;
std::exit(1);
}
}
}
}
/*
void check_residuals()
*/
static std::unique_ptr<base_db> create_database(
vectors_owner<float>& learn_vectors, int centroid_count) {
// Backup
const unsigned full_dim = learn_vectors.count * learn_vectors.dimension;
std::unique_ptr<float[]> backup = std::make_unique<float[]>(full_dim);
std::copy(learn_vectors.get(0), learn_vectors.get(0) + full_dim,
backup.get());
// Learn coarse quantizer
std::unique_ptr<float[]> centroids = learn_coarse_quantizer(learn_vectors,
centroid_count);
std::unique_ptr<base_pq> pq(new base_pq(8,8,learn_vectors.dimension));
index_db* idb = new index_db(std::move(pq), centroid_count, std::move(centroids));
std::unique_ptr<base_db> db(idb);
std::cerr << "Done K-Means" << std::endl;
// Compute residuals
const int thread_count = optimal_thread_count(learn_vectors.count);
std::unique_ptr<int[]> assignements = std::make_unique<int[]>(learn_vectors.count);
idb->assign_single_compute_residuals(learn_vectors.get(0),
learn_vectors.count, assignements.get(), thread_count);
// Check
//check_assignements(centroid_count, assignements.get(), learn_vectors.count);
check_residuals(learn_vectors, backup.get(), assignements.get(),
idb->centroids.get());
return db;
}
static void save_database_residuals(cmdargs& args, vectors_owner<float>& learn_vectors,
const std::unique_ptr<base_db>& db) {
// Save database
std::ofstream out_file(args.db_filename);
cereal::BinaryOutputArchive out_archive(out_file);
out_archive(db);
// Save residuals
save_vectors(learn_vectors, args.residuals_filename);
}
int main(int argc, char* argv[]) {
// Parse command line arguments
cmdargs args;
parse_args(args, argc, argv);
// Load learn vectors
vectors_owner<float> learn_vectors = load_vectors_by_extension(args.learn_filename);
// Create database
std::unique_ptr<base_db> db = create_database(learn_vectors, args.centroid_count);
// Save database and residuals
save_database_residuals(args, learn_vectors, db);
}