From 977b2c608a998fc6e950b1ba244bfa35fdfdced4 Mon Sep 17 00:00:00 2001 From: ShawnShawnYou <58975154+ShawnShawnYou@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:56:19 +0800 Subject: [PATCH 1/6] update several minor in conjugate graph (#147) Signed-off-by: zhongxiaoyao.zxy --- src/impl/conjugate_graph.cpp | 14 ++++- src/impl/conjugate_graph.h | 4 ++ src/index/hnsw.cpp | 60 +++++++++++++++---- src/index/hnsw.h | 3 + tests/fixtures/fixtures.cpp | 88 +++++++++++++++++++++++++++- tests/fixtures/fixtures.h | 10 ++++ tests/test_index_old.cpp | 109 +++++++++++++++++++++++++++++++++++ tests/test_multi_thread.cpp | 84 +++++++++++++++++++++++++++ 8 files changed, 360 insertions(+), 12 deletions(-) diff --git a/src/impl/conjugate_graph.cpp b/src/impl/conjugate_graph.cpp index daff8c3f..67351ba3 100644 --- a/src/impl/conjugate_graph.cpp +++ b/src/impl/conjugate_graph.cpp @@ -27,6 +27,9 @@ ConjugateGraph::AddNeighbor(int64_t from_tag_id, int64_t to_tag_id) { return false; } auto& neighbor_set = conjugate_graph_[from_tag_id]; + if (neighbor_set.size() >= MAXIMUM_DEGREE) { + return false; + } auto insert_result = neighbor_set.insert(to_tag_id); if (!insert_result.second) { return false; @@ -54,6 +57,10 @@ ConjugateGraph::get_neighbors(int64_t from_tag_id) const { tl::expected ConjugateGraph::EnhanceResult(std::priority_queue>& results, const std::function& distance_of_tag) const { + if (this->is_empty()) { + return 0; + } + int64_t k = results.size(); int64_t look_at_k = std::min(LOOK_AT_K, k); std::priority_queue> old_results(results); @@ -205,4 +212,9 @@ ConjugateGraph::Deserialize(StreamReader& in_stream) { } } -} // namespace vsag \ No newline at end of file +bool +ConjugateGraph::is_empty() const { + return (this->memory_usage_ == sizeof(this->memory_usage_) + FOOTER_SIZE); +} + +} // namespace vsag diff --git a/src/impl/conjugate_graph.h b/src/impl/conjugate_graph.h index 04e39b17..3c1c7b8b 100644 --- a/src/impl/conjugate_graph.h +++ b/src/impl/conjugate_graph.h @@ -28,6 +28,7 @@ namespace vsag { static const int64_t LOOK_AT_K = 20; +static const int64_t MAXIMUM_DEGREE = 128; class ConjugateGraph { public: @@ -63,6 +64,9 @@ class ConjugateGraph { void clear(); + bool + is_empty() const; + private: uint32_t memory_usage_; diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 2fb090d4..6dbadbac 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -229,10 +229,14 @@ HNSW::knn_search(const DatasetPtr& query, auto params = HnswSearchParameters::FromJson(parameters); // perform search + int64_t original_k = k; std::priority_queue> results; double time_cost; try { Timer t(time_cost); + if (use_conjugate_graph_ and params.use_conjugate_graph_search) { + k = std::max(k, LOOK_AT_K); + } results = alg_hnsw_->searchKnn( (const void*)(vector), k, std::max(params.ef_search, k), filter_ptr); } catch (const std::runtime_error& e) { @@ -264,9 +268,13 @@ HNSW::knn_search(const DatasetPtr& query, return this->alg_hnsw_->getDistanceByLabel(label, vector); }; conjugate_graph_->EnhanceResult(results, func); + k = original_k; } // return result + while (results.size() > k) { + results.pop(); + } result->Dim(results.size())->NumElements(1)->Owner(true, allocator_->GetRawAllocator()); @@ -720,10 +728,13 @@ HNSW::brute_force(const DatasetPtr& query, int64_t k) { float* dists = (float*)allocator_->Allocate(sizeof(float) * k); result->Distances(dists); - auto vector = query->GetFloat32Vectors(); + void* vector = nullptr; + size_t data_size = 0; + get_vectors(query, &vector, &data_size); + std::shared_lock lock(rw_mutex_); std::priority_queue> bf_result = - alg_hnsw_->bruteForce((const void*)vector, k); + alg_hnsw_->bruteForce(vector, k); result->Dim(std::min(k, (int64_t)bf_result.size())); for (int i = result->GetDim() - 1; i >= 0; i--) { @@ -758,18 +769,26 @@ HNSW::pretrain(const std::vector& base_tag_ids, return 0; } + uint32_t data_size = 0; uint32_t add_edges = 0; int64_t topk_neighbor_tag_id; - const float* topk_data; - std::shared_ptr generated_data(new float[dim_]); + const void* topk_data; + const void* base_data; auto base = Dataset::Make(); auto generated_query = Dataset::Make(); - base->Dim(dim_)->NumElements(1)->Owner(false); - generated_query->Dim(dim_)->NumElements(1)->Float32Vectors(generated_data.get())->Owner(false); + if (type_ == DataTypes::DATA_TYPE_INT8) { + data_size = dim_; + } else { + data_size = dim_ * 4; + } + + std::shared_ptr generated_data(new int8_t[data_size]); + set_dataset(generated_query, generated_data.get(), 1); for (const int64_t& base_tag_id : base_tag_ids) { + base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id); try { - base->Float32Vectors(this->alg_hnsw_->getDataByLabel(base_tag_id)); + set_dataset(base, base_data, 1); } catch (const std::runtime_error& e) { LOG_ERROR_AND_RETURNS( ErrorType::INVALID_ARGUMENT, @@ -795,11 +814,18 @@ HNSW::pretrain(const std::vector& base_tag_ids, if (topk_neighbor_tag_id == base_tag_id) { continue; } - topk_data = this->alg_hnsw_->getDataByLabel(topk_neighbor_tag_id); + topk_data = (const void*)this->alg_hnsw_->getDataByLabel(topk_neighbor_tag_id); for (int d = 0; d < dim_; d++) { - generated_data.get()[d] = vsag::GENERATE_OMEGA * base->GetFloat32Vectors()[d] + - (1 - vsag::GENERATE_OMEGA) * topk_data[d]; + if (type_ == DataTypes::DATA_TYPE_INT8) { + generated_data.get()[d] = + vsag::GENERATE_OMEGA * (float)(((int8_t*)base_data)[d]) + + (1 - vsag::GENERATE_OMEGA) * (float)(((int8_t*)topk_data)[d]); + } else { + ((float*)generated_data.get())[d] = + vsag::GENERATE_OMEGA * ((float*)base_data)[d] + + (1 - vsag::GENERATE_OMEGA) * ((float*)topk_data)[d]; + } } auto feedback_result = this->Feedback(generated_query, k, parameters, base_tag_id); @@ -843,4 +869,18 @@ HNSW::get_vectors(const vsag::DatasetPtr& base, void** vectors_ptr, size_t* data } } +void +HNSW::set_dataset(const DatasetPtr& base, const void* vectors_ptr, uint32_t num_element) const { + if (type_ == DataTypes::DATA_TYPE_FLOAT) { + base->Float32Vectors((float*)vectors_ptr) + ->Dim(dim_) + ->Owner(false) + ->NumElements(num_element); + } else if (type_ == DataTypes::DATA_TYPE_INT8) { + base->Int8Vectors((int8_t*)vectors_ptr)->Dim(dim_)->Owner(false)->NumElements(num_element); + } else { + throw std::invalid_argument(fmt::format("no support for this type: {}", (int)type_)); + } +} + } // namespace vsag diff --git a/src/index/hnsw.h b/src/index/hnsw.h index f1b5993f..89d3252c 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -256,6 +256,9 @@ class HNSW : public Index { void get_vectors(const DatasetPtr& base, void** vectors_ptr, size_t* data_size_ptr) const; + void + set_dataset(const DatasetPtr& base, const void* vectors_ptr, uint32_t num_element) const; + BinarySet empty_binaryset() const; diff --git a/tests/fixtures/fixtures.cpp b/tests/fixtures/fixtures.cpp index fcdd923c..e2051c51 100644 --- a/tests/fixtures/fixtures.cpp +++ b/tests/fixtures/fixtures.cpp @@ -29,7 +29,12 @@ namespace vsag { extern float L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr); -} +extern float +InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); + +extern float +INT8InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); +} // namespace vsag namespace fixtures { @@ -86,6 +91,27 @@ generate_vectors(int64_t num_vectors, int64_t dim, bool need_normalize, int seed return vectors; } +std::vector +generate_int8_codes(uint64_t count, uint32_t dim, int seed) { + auto code_size = dim; + std::vector codes(count * code_size, 0); + auto vec = fixtures::generate_vectors(count, dim, true, seed); + + for (int i = 0; i < count; i++) { + for (int d = 0; d < dim; d++) { + float delta = vec[d + i * dim]; + if (delta < 0) { + delta = 0; + } else if (delta > 0.999) { + delta = 1; + } + // Note that here we use overflowed uint8 value to obtain int8 value + codes[code_size * i + d] = 255.0 * delta; + } + } + return codes; +} + std::vector generate_int4_codes(uint64_t count, uint32_t dim, int seed) { auto code_size = (dim + 1) / 2; @@ -237,6 +263,66 @@ generate_hnsw_build_parameters_string(const std::string& metric_type, int64_t di return build_parameters; } +vsag::DatasetPtr +brute_force(const vsag::DatasetPtr& query, + const vsag::DatasetPtr& base, + int64_t k, + const std::string& metric_type, + const std::string& data_type) { + assert(query->GetDim() == base->GetDim()); + assert(query->GetNumElements() == 1); + + auto result = vsag::Dataset::Make(); + int64_t* ids = new int64_t[k]; + float* dists = new float[k]; + result->Ids(ids)->Distances(dists)->NumElements(k); + + std::priority_queue> bf_result; + + size_t dim = query->GetDim(); + const void* query_vec = nullptr; + const void* base_vec = nullptr; + float dist = 0; + for (uint32_t i = 0; i < base->GetNumElements(); i++) { + if (data_type == "float32") { + query_vec = query->GetFloat32Vectors(); + base_vec = base->GetFloat32Vectors() + i * base->GetDim(); + } else if (data_type == "int8") { + query_vec = query->GetInt8Vectors(); + base_vec = base->GetInt8Vectors() + i * base->GetDim(); + } else { + throw std::runtime_error("un-support data type"); + } + + if (metric_type == "l2") { + dist = vsag::L2Sqr(query_vec, base_vec, &dim); + } else if (metric_type == "ip") { + if (data_type == "float32") { + dist = vsag::InnerProductDistance(query_vec, base_vec, &dim); + } else { + dist = vsag::INT8InnerProductDistance(query_vec, base_vec, &dim); + } + } + + if (bf_result.size() < k) { + bf_result.push({dist, base->GetIds()[i]}); + } else { + if (dist < bf_result.top().first) { + bf_result.pop(); + bf_result.push({dist, base->GetIds()[i]}); + } + } + } + + for (int i = k - 1; i >= 0; i--) { + ids[i] = bf_result.top().second; + dists[i] = bf_result.top().first; + bf_result.pop(); + } + + return std::move(result); +} + vsag::DatasetPtr brute_force(const vsag::DatasetPtr& query, const vsag::DatasetPtr& base, diff --git a/tests/fixtures/fixtures.h b/tests/fixtures/fixtures.h index 28568f0d..cd6d9f84 100644 --- a/tests/fixtures/fixtures.h +++ b/tests/fixtures/fixtures.h @@ -34,6 +34,9 @@ generate_vectors(int64_t num_vectors, int64_t dim, bool need_normalize = true, i std::vector generate_int4_codes(uint64_t count, uint32_t dim, int seed = 47); +std::vector +generate_int8_codes(uint64_t count, uint32_t dim, int seed = 47); + std::vector generate_uint8_codes(uint64_t count, uint32_t dim, int seed = 47); @@ -77,6 +80,13 @@ brute_force(const vsag::DatasetPtr& query, int64_t k, const std::string& metric_type); +vsag::DatasetPtr +brute_force(const vsag::DatasetPtr& query, + const vsag::DatasetPtr& base, + int64_t k, + const std::string& metric_type, + const std::string& data_type); + template typename std::enable_if::value, T>::type RandomValue(const T& min, const T& max) { diff --git a/tests/test_index_old.cpp b/tests/test_index_old.cpp index 90302d5c..2b6328d6 100644 --- a/tests/test_index_old.cpp +++ b/tests/test_index_old.cpp @@ -1060,6 +1060,115 @@ TEST_CASE("build index with generated_build_parameters", "[ft][index]") { REQUIRE(recall > 0.95); } +TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") { + auto logger = vsag::Options::Instance().logger(); + logger->SetLevel(vsag::Logger::Level::kDEBUG); + + // parameters + int dim = 256; + int num_base = 10000; + int num_query = 1000; + int64_t k = 3; + auto metric_type = GENERATE("ip"); + constexpr auto build_parameter_json = R"( + {{ + "dtype": "int8", + "metric_type": "{}", + "dim": {}, + "hnsw": {{ + "max_degree": 16, + "ef_construction": 200, + "use_conjugate_graph": true + }} + }} + )"; + auto build_parameter = fmt::format(build_parameter_json, metric_type, dim); + + // create index + auto createindex = vsag::Factory::CreateIndex("fresh_hnsw", build_parameter); + REQUIRE(createindex.has_value()); + auto index = createindex.value(); + + // generate dataset + std::vector base_ids(num_base); + for (int64_t i = 0; i < num_base; ++i) { + base_ids[i] = i; + } + auto base_vectors = fixtures::generate_int8_codes(num_base, dim); + auto base = vsag::Dataset::Make(); + auto queries = vsag::Dataset::Make(); + base->NumElements(num_base) + ->Dim(dim) + ->Ids(base_ids.data()) + ->Int8Vectors(base_vectors.data()) + ->Owner(false); + + auto query_vectors = fixtures::generate_int8_codes(num_query, dim); + queries->NumElements(num_query)->Dim(dim)->Int8Vectors(query_vectors.data())->Owner(false); + + // build index + auto buildindex = index->Build(base); + REQUIRE(buildindex.has_value()); + + // train and search + float recall[2]; + int correct; + uint32_t error_fix = 0; + bool use_conjugate_graph_search = false; + for (int round = 0; round < 2; round++) { + correct = 0; + + if (round == 0) { + logger->Debug("====train stage===="); + } else { + logger->Debug("====test stage===="); + } + + logger->Debug(fmt::format(R"(Memory Usage: {:.3f} KB)", index->GetMemoryUsage() / 1024.0)); + + use_conjugate_graph_search = (round != 0); + constexpr auto search_parameters_json = R"( + {{ + "hnsw": {{ + "ef_search": 100, + "use_conjugate_graph_search": {} + }} + }} + )"; + auto search_parameters = fmt::format(search_parameters_json, use_conjugate_graph_search); + + for (int i = 0; i < num_query; i++) { + auto query = vsag::Dataset::Make(); + query->Dim(dim) + ->Int8Vectors(queries->GetInt8Vectors() + i * dim) + ->NumElements(1) + ->Owner(false); + + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + auto bf_result = fixtures::brute_force(query, base, 1, metric_type, "int8"); + int64_t global_optimum = bf_result->GetIds()[0]; + int64_t local_optimum = result.value()->GetIds()[0]; + + if (local_optimum != global_optimum and round == 0) { + error_fix += *index->Feedback(query, k, search_parameters, global_optimum); + REQUIRE(*index->Feedback(query, k, search_parameters) == 0); + } + + if (local_optimum == global_optimum) { + correct++; + } + } + recall[round] = correct / (1.0 * num_query); + logger->Debug(fmt::format(R"(Recall: {:.4f})", recall[round])); + } + + logger->Debug("====summary===="); + logger->Debug(fmt::format(R"(Error fix: {})", error_fix)); + + REQUIRE(fixtures::time_t(recall[1]) == fixtures::time_t(1.0f)); +} + TEST_CASE("hnsw + feedback with global optimum id", "[ft][index][hnsw]") { auto logger = vsag::Options::Instance().logger(); logger->SetLevel(vsag::Logger::Level::kDEBUG); diff --git a/tests/test_multi_thread.cpp b/tests/test_multi_thread.cpp index e4941b7b..67f7bdd2 100644 --- a/tests/test_multi_thread.cpp +++ b/tests/test_multi_thread.cpp @@ -310,3 +310,87 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { REQUIRE(search_results[i].get()); } } + +TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hnsw]") { + // avoid too much slow task logs + vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); + + int thread_num = 32; + int dim = 256; + int max_elements = 10000; + int max_degree = 32; + int ef_construction = 200; + int ef_search = 100; + int k = 10; + nlohmann::json hnsw_parameters{{"max_degree", max_degree}, + {"ef_construction", ef_construction}, + {"ef_search", ef_search}, + {"use_conjugate_graph", true}}; + nlohmann::json index_parameters{ + {"dtype", "int8"}, {"metric_type", "ip"}, {"dim", dim}, {"hnsw", hnsw_parameters}}; + auto index = vsag::Factory::CreateIndex("hnsw", index_parameters.dump()).value(); + std::shared_ptr ids(new int64_t[max_elements]); + std::shared_ptr data(new int8_t[dim * max_elements]); + + ThreadPool pool(thread_num); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real(-128, 127); + for (int i = 0; i < max_elements; i++) ids[i] = i; + for (int i = 0; i < dim * max_elements; i++) data[i] = (int8_t)distrib_real(rng); + nlohmann::json parameters{ + {"hnsw", {{"ef_search", ef_search}, {"use_conjugate_graph_search", true}}}, + }; + std::string str_parameters = parameters.dump(); + + std::vector> insert_results; + std::vector> feedback_results; + std::vector> search_results; + + for (int64_t i = 0; i < max_elements; ++i) { + // insert + insert_results.push_back(pool.enqueue([&ids, &data, &index, dim, i]() -> int64_t { + auto dataset = vsag::Dataset::Make(); + dataset->Dim(dim) + ->NumElements(1) + ->Ids(ids.get() + i) + ->Int8Vectors(data.get() + i * dim) + ->Owner(false); + auto add_res = index->Add(dataset); + return add_res.value().size(); + })); + } + + for (int64_t i = 0; i < max_elements; ++i) { + // feedback + feedback_results.push_back( + pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> uint64_t { + auto query = vsag::Dataset::Make(); + query->Dim(dim)->NumElements(1)->Int8Vectors(data.get() + i * dim)->Owner(false); + auto feedback_res = index->Feedback(query, k, str_parameters); + return feedback_res.value(); + })); + + // search + search_results.push_back(pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool { + auto query = vsag::Dataset::Make(); + query->Dim(dim)->NumElements(1)->Int8Vectors(data.get() + i * dim)->Owner(false); + auto result = index->KnnSearch(query, k, str_parameters); + return result.has_value(); + })); + } + + for (auto& res : insert_results) { + REQUIRE(res.get() == 0); + } + + for (auto& res : feedback_results) { + REQUIRE(res.get() >= 0); + } + + for (auto& res : search_results) { + REQUIRE(res.get()); + } +} From 339cab39930930a66b1afa344cf381c97dd1533c Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 3 Dec 2024 11:03:46 +0800 Subject: [PATCH 2/6] add resize for hgraph (#140) - visitlist for hgraph is fixed, now make it flexible - fix typo error on graph interface Signed-off-by: LHT129 --- src/data_cell/graph_interface.h | 2 +- src/index/hgraph_index.cpp | 14 +++++++++++++- src/index/hgraph_index.h | 3 +++ src/index/hgraph_zparameters.cpp | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/data_cell/graph_interface.h b/src/data_cell/graph_interface.h index a23f5a50..1ce08b22 100644 --- a/src/data_cell/graph_interface.h +++ b/src/data_cell/graph_interface.h @@ -102,7 +102,7 @@ class GraphInterface { virtual void SetMaximumDegree(uint32_t maximum_degree) { - this->max_capacity_ = maximum_degree; + this->maximum_degree_ = maximum_degree; } virtual void diff --git a/src/index/hgraph_index.cpp b/src/index/hgraph_index.cpp index 76317368..6119619d 100644 --- a/src/index/hgraph_index.cpp +++ b/src/index/hgraph_index.cpp @@ -17,6 +17,8 @@ #include +#include + #include "common.h" #include "data_cell/sparse_graph_datacell.h" #include "hgraph_zparameters.h" @@ -233,7 +235,7 @@ HGraphIndex::hnsw_add(const DatasetPtr& data) { auto* ids = data->GetIds(); auto* datas = data->GetFloat32Vectors(); auto cur_count = this->bottom_graph_->TotalCount(); - vsag::Vector(total + cur_count, allocator_).swap(this->neighbors_mutex_); + this->resize(total + cur_count); std::mutex add_mutex; @@ -690,4 +692,14 @@ HGraphIndex::add_one_point(const float* data, int level, InnerIdType inner_id) { bottom_graph_->IncreaseTotalCount(1); } +void +HGraphIndex::resize(uint64_t new_size) { + auto cur_size = this->bottom_graph_->MaxCapacity(); + if (new_size > cur_size) { + vsag::Vector(new_size, allocator_).swap(this->neighbors_mutex_); + pool_ = std::make_shared(new_size, allocator_); + this->bottom_graph_->SetMaxCapacity(new_size); + } +} + } // namespace vsag diff --git a/src/index/hgraph_index.h b/src/index/hgraph_index.h index e2d7cdc8..c13f37d2 100644 --- a/src/index/hgraph_index.h +++ b/src/index/hgraph_index.h @@ -240,6 +240,9 @@ class HGraphIndex : public Index { void hnsw_add(const DatasetPtr& data); + void + resize(uint64_t new_size); + GraphInterfacePtr generate_one_route_graph(); diff --git a/src/index/hgraph_zparameters.cpp b/src/index/hgraph_zparameters.cpp index e5ced184..c6119337 100644 --- a/src/index/hgraph_zparameters.cpp +++ b/src/index/hgraph_zparameters.cpp @@ -95,7 +95,7 @@ const std::string HGraphParameters::DEFAULT_HGRAPH_PARAMS = format_map( "type": "nsw", "{GRAPH_PARAMS_KEY}": { "{GRAPH_PARAM_MAX_DEGREE}": 64, - "{GRAPH_PARAM_INIT_MAX_CAPACITY}": 2000000 + "{GRAPH_PARAM_INIT_MAX_CAPACITY}": 1000 } }, "{HGRAPH_BASE_CODES_KEY}": { From 28ad5734407bf343043094a789c4601816423ba7 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 3 Dec 2024 11:14:42 +0800 Subject: [PATCH 3/6] introduce some constants config for hgraph (#168) - ef_construction & max_degree Signed-off-by: LHT129 --- examples/cpp/simple_hgraph_sq8.cpp | 4 +++- include/vsag/constants.h | 1 + src/constants.cpp | 1 + src/index/hgraph_index.cpp | 10 ++++++++++ src/index/hgraph_zparameters.cpp | 7 ++++--- src/inner_string_params.h | 4 ++++ 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/cpp/simple_hgraph_sq8.cpp b/examples/cpp/simple_hgraph_sq8.cpp index 8567e378..09722fac 100644 --- a/examples/cpp/simple_hgraph_sq8.cpp +++ b/examples/cpp/simple_hgraph_sq8.cpp @@ -44,7 +44,9 @@ main(int argc, char** argv) { "metric_type": "l2", "dim": 128, "index_param": { - "base_quantization_type": "sq8" + "base_quantization_type": "sq8", + "max_degree": 26, + "ef_construction": 100 } } )"; diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 3f3d09fd..bf2d2d31 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -99,5 +99,6 @@ extern const char* const SERIALIZE_VERSION; extern const char* const HGRAPH_USE_REORDER; extern const char* const HGRAPH_BASE_QUANTIZATION_TYPE; extern const char* const HGRAPH_GRAPH_MAX_DEGREE; +extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION; } // namespace vsag diff --git a/src/constants.cpp b/src/constants.cpp index 9f6f4a05..17c06b2c 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -100,5 +100,6 @@ const char* const SERIALIZE_VERSION = "VERSION"; const char* const HGRAPH_USE_REORDER = HGRAPH_USE_REORDER_KEY; const char* const HGRAPH_BASE_QUANTIZATION_TYPE = "base_quantization_type"; const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree"; +const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction"; }; // namespace vsag diff --git a/src/index/hgraph_index.cpp b/src/index/hgraph_index.cpp index 6119619d..d0c92886 100644 --- a/src/index/hgraph_index.cpp +++ b/src/index/hgraph_index.cpp @@ -82,6 +82,16 @@ HGraphIndex::Init() { this->pool_ = std::make_shared(this->bottom_graph_->MaxCapacity(), allocator_); + if (this->index_param_.contains(BUILD_PARAMS_KEY)) { + auto& build_params = this->index_param_[BUILD_PARAMS_KEY]; + if (build_params.contains(BUILD_EF_CONSTRUCTION)) { + this->ef_construct_ = build_params[BUILD_EF_CONSTRUCTION]; + } + if (build_params.contains(BUILD_THREAD_COUNT)) { + this->build_thread_count_ = build_params[BUILD_THREAD_COUNT]; + } + } + if (this->build_thread_count_ > 1) { this->build_pool_ = std::make_unique(this->build_thread_count_); } diff --git a/src/index/hgraph_zparameters.cpp b/src/index/hgraph_zparameters.cpp index c6119337..48a1426d 100644 --- a/src/index/hgraph_zparameters.cpp +++ b/src/index/hgraph_zparameters.cpp @@ -26,7 +26,8 @@ namespace vsag { const std::unordered_map> HGraphParameters::EXTERNAL_MAPPING = {{HGRAPH_USE_REORDER, {HGRAPH_USE_REORDER_KEY}}, {HGRAPH_BASE_QUANTIZATION_TYPE, {HGRAPH_BASE_CODES_KEY, QUANTIZATION_TYPE_KEY}}, - {HGRAPH_GRAPH_MAX_DEGREE, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_MAX_DEGREE}}}; + {HGRAPH_GRAPH_MAX_DEGREE, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_MAX_DEGREE}}, + {HGRAPH_BUILD_EF_CONSTRUCTION, {BUILD_PARAMS_KEY, BUILD_EF_CONSTRUCTION}}}; HGraphParameters::HGraphParameters(JsonType& hgraph_param, const IndexCommonParam& common_param) : common_param_(common_param) { @@ -119,8 +120,8 @@ const std::string HGraphParameters::DEFAULT_HGRAPH_PARAMS = format_map( "{QUANTIZATION_TYPE_KEY}": "{QUANTIZATION_TYPE_VALUE_SQ8}", "{QUANTIZATION_PARAMS_KEY}": {} }, - "build_params": { - "ef_construction": 400, + "{BUILD_PARAMS_KEY}": { + "{BUILD_EF_CONSTRUCTION}": 400, "{BUILD_THREAD_COUNT}": 5 } })", diff --git a/src/inner_string_params.h b/src/inner_string_params.h index 10fc1284..f92e907d 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -55,7 +55,9 @@ const char* const GRAPH_PARAMS_KEY = "graph_params"; const char* const GRAPH_PARAM_MAX_DEGREE = "max_degree"; const char* const GRAPH_PARAM_INIT_MAX_CAPACITY = "init_capacity"; +const char* const BUILD_PARAMS_KEY = "build_params"; const char* const BUILD_THREAD_COUNT = "build_thread_count"; +const char* const BUILD_EF_CONSTRUCTION = "ef_construction"; const std::unordered_map DEFAULT_MAP = { {"HGRAPH_USE_REORDER_KEY", HGRAPH_USE_REORDER_KEY}, @@ -75,7 +77,9 @@ const std::unordered_map DEFAULT_MAP = { {"GRAPH_PARAMS_KEY", GRAPH_PARAMS_KEY}, {"GRAPH_PARAM_MAX_DEGREE", GRAPH_PARAM_MAX_DEGREE}, {"GRAPH_PARAM_INIT_MAX_CAPACITY", GRAPH_PARAM_INIT_MAX_CAPACITY}, + {"BUILD_PARAMS_KEY", BUILD_PARAMS_KEY}, {"BUILD_THREAD_COUNT", BUILD_THREAD_COUNT}, + {"BUILD_EF_CONSTRUCTION", BUILD_EF_CONSTRUCTION}, }; } // namespace vsag From f2ab0469f6f0d31870541a07ad18c8762ddcf10b Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 3 Dec 2024 15:01:50 +0800 Subject: [PATCH 4/6] fix ci script bug (#170) - remove asan test for the asan bug https://stackoverflow.com/questions/77894856/possible-bug-in-gcc-sanitizers Signed-off-by: LHT129 --- .circleci/config.yml | 8 ++++---- Makefile | 9 +++++++-- scripts/{test_asan_bg.sh => test_parallel_bg.sh} | 11 +++++++---- src/index/hnsw.cpp | 2 +- 4 files changed, 19 insertions(+), 11 deletions(-) rename scripts/{test_asan_bg.sh => test_parallel_bg.sh} (82%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 85ccd5c3..7a384174 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -27,12 +27,12 @@ jobs: - restore_cache: keys: - fork-cache-{{ checksum "CMakeLists.txt" }}-{{ checksum ".circleci/fresh_ci_cache.commit" }} - - run: make asan + - run: make debug - save_cache: key: fork-cache-{{ checksum "CMakeLists.txt" }}-{{ checksum ".circleci/fresh_ci_cache.commit" }} paths: - ./build - - run: make test_asan_parallel + - run: make test_parallel main-branch-check: docker: @@ -44,9 +44,9 @@ jobs: - restore_cache: keys: - main-ccache-{{ checksum "CMakeLists.txt" }}-{{ checksum ".circleci/fresh_ci_cache.commit" }} - - run: make asan + - run: make debug - save_cache: key: main-ccache-{{ checksum "CMakeLists.txt" }}-{{ checksum ".circleci/fresh_ci_cache.commit" }} paths: - ./build - - run: make test_asan_parallel + - run: make test_parallel diff --git a/Makefile b/Makefile index ae665793..0eea123f 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ help: ## Show the help. .PHONY: debug debug: ## Build vsag with debug options. - cmake ${VSAG_CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Debug -DENABLE_CCACHE=ON + cmake ${VSAG_CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Debug -DENABLE_CCACHE=ON -DENABLE_ASAN=OFF cmake --build build --parallel ${COMPILE_JOBS} .PHONY: release @@ -57,6 +57,11 @@ test: ## Build and run unit tests. ./build/tests/functests -d yes ${UT_FILTER} --allow-running-no-tests ${UT_SHARD} ./build/mockimpl/tests_mockimpl -d yes ${UT_FILTER} --allow-running-no-tests ${UT_SHARD} +.PHONY: test_parallel +test_parallel: debug + @./scripts/test_parallel_bg.sh + ./build/mockimpl/tests_mockimpl -d yes ${UT_FILTER} --allow-running-no-tests ${UT_SHARD} + .PHONY: asan asan: ## Build with AddressSanitizer option. cmake ${VSAG_CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Debug -DENABLE_ASAN=ON -DENABLE_CCACHE=ON @@ -64,7 +69,7 @@ asan: ## Build with AddressSanitizer option. .PHONY: test_asan_parallel test_asan_parallel: asan ## Run unit tests parallel with AddressSanitizer option. - @./scripts/test_asan_bg.sh + @./scripts/test_parallel_bg.sh ./build/mockimpl/tests_mockimpl -d yes ${UT_FILTER} --allow-running-no-tests ${UT_SHARD} .PHONY: test_asan diff --git a/scripts/test_asan_bg.sh b/scripts/test_parallel_bg.sh similarity index 82% rename from scripts/test_asan_bg.sh rename to scripts/test_parallel_bg.sh index b8955adc..3a02d945 100755 --- a/scripts/test_asan_bg.sh +++ b/scripts/test_parallel_bg.sh @@ -8,22 +8,25 @@ othertag="" mkdir ./log -./build/tests/unittests -d yes ${UT_FILTER} --allow-running-no-tests 2>&1 | tee ./log/unittest.log -exit_codes+=($?) +./build/tests/unittests -d yes ${UT_FILTER} --allow-running-no-tests &> "./log/unittest.log" & +pids+=($!) +tail -f "./log/unittest.log" & logger_files+=("./log/unittest.log") for tag in ${parallel_tags} do othertag="~"${tag}${othertag} - ./build/tests/functests -d yes ${UT_FILTER} --allow-running-no-tests ${tag} 2>&1 | tee ./log/${tag}.log & + ./build/tests/functests -d yes ${UT_FILTER} --allow-running-no-tests &> ./log/${tag}.log & pids+=($!) logname="./log/"${tag}".log" logger_files+=($logname) + tail -f ./log/${tag}.log & done -./build/tests/functests -d yes ${UT_FILTER} --allow-running-no-tests ${othertag} 2>&1 | tee ./log/other.log & +./build/tests/functests -d yes ${UT_FILTER} --allow-running-no-tests ${othertag} &> ./log/other.log & pids+=($!) logger_files+=("./log/other.log") +tail -f "./log/other.log" & for pid in "${pids[@]}" do diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 6dbadbac..4b24831b 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -786,8 +786,8 @@ HNSW::pretrain(const std::vector& base_tag_ids, set_dataset(generated_query, generated_data.get(), 1); for (const int64_t& base_tag_id : base_tag_ids) { - base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id); try { + base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id); set_dataset(base, base_data, 1); } catch (const std::runtime_error& e) { LOG_ERROR_AND_RETURNS( From 6cd6f5862deb2f02a88691170b9fef91206487f6 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 3 Dec 2024 16:55:09 +0800 Subject: [PATCH 5/6] [skip ci] assign new reviewers (#173) Signed-off-by: LHT129 --- .github/CODEOWNERS | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 75e63e57..70b84212 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,8 +10,8 @@ *.md @wxyucs @jiaweizone # GitHub routines -/.circleci/ @wxyucs -/.github/ @wxyucs +/.circleci/ @wxyucs @LHT129 +/.github/ @wxyucs @LHT129 /docker/ @wxyucs /.* @wxyucs /LICENSE @wxyucs @@ -22,5 +22,29 @@ # Core codes, maybe assign reviewers later /src/ @jiaweizone +# index codes +/src/index @inabao @wxyucs @LHT129 @ShawnShawnYou + +# datacell codes +/src/data_cell/ @LHT129 @ShawnShawnYou + +# io codes +/src/io/ @LHT129 + +# quantization codes +/src/quantization/ @LHT129 @ShawnShawnYou + +# impl codes (conjugate graph) +/src/impl/ @ShawnShawnYou + +# hnswlib codes +/src/algorithm/ @inabao @wxyucs + # All the SIMD related codes /src/simd/ @LHT129 @ShawnShawnYou + +# pybinds +/python_bindings/ @wxyucs @LHT129 + +# scripts +/scripts/ @LHT129 @wxyucs @inabao From 63af8766a074dcffd6b2442dfd3bb777b983d8de Mon Sep 17 00:00:00 2001 From: LHT129 Date: Mon, 2 Dec 2024 09:07:26 +0000 Subject: [PATCH 6/6] fix illegal instruction on platform has avx only Signed-off-by: LHT129 --- .github/workflows/DCO.yml | 15 ++++++++ .github/workflows/build_test.yml | 39 +++++++++++++++++++ src/simd/CMakeLists.txt | 15 +++++--- src/simd/avx.cpp | 14 +++---- src/simd/avx512_test.cpp | 3 +- src/simd/avx_test.cpp | 15 ++++---- src/simd/fp32_simd.h | 6 --- src/simd/fp32_simd_test.cpp | 26 ++++++++----- src/simd/normalize.cpp | 60 ++++++++++++++++++++++++++++++ src/simd/normalize.h | 37 ++---------------- src/simd/normalize_test.cpp | 36 +++++++++++++----- src/simd/simd.cpp | 12 +++--- src/simd/sq4_simd.h | 6 --- src/simd/sq4_uniform_simd.h | 6 --- src/simd/sq4_uniform_simd_test.cpp | 32 ++++++++++------ src/simd/sq8_simd.h | 4 -- src/simd/sq8_simd_test.cpp | 33 +++++++++------- src/simd/sq8_uniform_simd.h | 6 --- src/simd/sq8_uniform_simd_test.cpp | 31 +++++++++------ 19 files changed, 254 insertions(+), 142 deletions(-) create mode 100644 .github/workflows/DCO.yml create mode 100644 .github/workflows/build_test.yml create mode 100644 src/simd/normalize.cpp diff --git a/.github/workflows/DCO.yml b/.github/workflows/DCO.yml new file mode 100644 index 00000000..b546a2be --- /dev/null +++ b/.github/workflows/DCO.yml @@ -0,0 +1,15 @@ +name: DCO Check + +on: + pull_request: + branches: [ "main" ] + +jobs: + dco-check: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Validate DCO + uses: tisonkun/actions-dco@v1.1 diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml new file mode 100644 index 00000000..a12d74a0 --- /dev/null +++ b/.github/workflows/build_test.yml @@ -0,0 +1,39 @@ +name: build & test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + clang-format-check: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Clang format + run: sudo apt-get install clang-format + - name: Run Clang format check + run: make fmt + + build: + runs-on: ubuntu-latest + container: + image: vsaglib/vsag:ubuntu + steps: + - uses: actions/checkout@v4 + - name: Load Cache + uses: actions/cache@v4.1.2 + with: + path: ./build/ + key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Make Asan + run: make asan + - name: Save Cache + uses: actions/cache@v4.1.2 + with: + path: ./build/ + key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Test + run: make test_asan \ No newline at end of file diff --git a/src/simd/CMakeLists.txt b/src/simd/CMakeLists.txt index 4501d97a..68501f06 100644 --- a/src/simd/CMakeLists.txt +++ b/src/simd/CMakeLists.txt @@ -6,6 +6,9 @@ set (SIMD_SRCS sq4_simd.cpp sq4_uniform_simd.cpp sq8_uniform_simd.cpp + sse.cpp + avx.cpp + avx512.cpp ) if (DIST_CONTAINS_SSE) set (SIMD_SRCS ${SIMD_SRCS} sse.cpp) @@ -18,11 +21,13 @@ if (DIST_CONTAINS_AVX) set (SIMD_SRCS ${SIMD_SRCS} avx.cpp) set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx") endif () -if (DIST_CONTAINS_AVX2) - set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") -endif () +# FIXME(LHT): cause illegal instruction on platform which has avx only +#if (DIST_CONTAINS_AVX2) +# set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +#endif () if (DIST_CONTAINS_AVX512) - set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp) + set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp + normalize.cpp) set_source_files_properties ( avx512.cpp PROPERTIES @@ -50,7 +55,7 @@ endmacro () simd_add_definitions (DIST_CONTAINS_SSE -DENABLE_SSE=1) simd_add_definitions (DIST_CONTAINS_AVX -DENABLE_AVX=1) -simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1) +#simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1) simd_add_definitions (DIST_CONTAINS_AVX512 -DENABLE_AVX512=1) target_link_libraries (simd PRIVATE cpuinfo) diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index ed7b3fca..486d0e1a 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -209,7 +209,7 @@ FP32ComputeIP(const float* query, const float* codes, uint64_t dim) { ip += sse::FP32ComputeIP(query + n * 8, codes + n * 8, dim - n * 8); return ip; #else - return vsag::Generic::FP32ComputeIP(query, codes, dim); + return vsag::generic::FP32ComputeIP(query, codes, dim); #endif } @@ -235,7 +235,7 @@ FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim) { l2 += sse::FP32ComputeL2Sqr(query + n * 8, codes + n * 8, dim - n * 8); return l2; #else - return vsag::Generic::FP32ComputeL2Sqr(query, codes, dim); + return vsag::generic::FP32ComputeL2Sqr(query, codes, dim); #endif } @@ -275,7 +275,7 @@ SQ8ComputeIP(const float* query, finalResult += sse::SQ8ComputeIP(query + i, codes + i, lowerBound + i, diff + i, dim - i); return finalResult; #else - return Generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim); + return generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim); #endif } @@ -320,7 +320,7 @@ SQ8ComputeL2Sqr(const float* query, result += sse::SQ8ComputeL2Sqr(query + i, codes + i, lowerBound + i, diff + i, dim - i); return result; #else - return vsag::Generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO + return vsag::generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO #endif } @@ -364,7 +364,7 @@ SQ8ComputeCodesIP(const uint8_t* codes1, result += sse::SQ8ComputeCodesIP(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i); return result; #else - return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); + return generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); #endif } @@ -407,7 +407,7 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, result += sse::SQ8ComputeCodesL2Sqr(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i); return result; #else - return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); + return generic::SQ8ComputeCodesL2Sqr(codes1, codes2, lowerBound, diff, dim); #endif } @@ -511,7 +511,7 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t result += static_cast(sse::SQ8UniformComputeCodesIP(codes1 + d, codes2 + d, dim - d)); return static_cast(result); #else - return sse::S8UniformComputeCodesIP(codes1, codes2, dim); + return sse::SQ8UniformComputeCodesIP(codes1, codes2, dim); #endif } diff --git a/src/simd/avx512_test.cpp b/src/simd/avx512_test.cpp index 155ea12b..39f27559 100644 --- a/src/simd/avx512_test.cpp +++ b/src/simd/avx512_test.cpp @@ -22,10 +22,11 @@ #include "cpuinfo.h" #include "fixtures.h" #include "simd.h" +#include "simd_status.h" TEST_CASE("avx512 int8", "[ut][simd][avx]") { #if defined(ENABLE_AVX512) - if (cpuinfo_has_x86_sse()) { + if (vsag::SimdStatus::SupportAVX512()) { auto common_dims = fixtures::get_common_used_dims(); for (size_t dim : common_dims) { auto vectors = fixtures::generate_vectors(2, dim); diff --git a/src/simd/avx_test.cpp b/src/simd/avx_test.cpp index ea9ee7ee..dc787821 100644 --- a/src/simd/avx_test.cpp +++ b/src/simd/avx_test.cpp @@ -14,16 +14,15 @@ // limitations under the License. #include -#include #include "./simd.h" #include "catch2/catch_approx.hpp" -#include "cpuinfo.h" #include "fixtures.h" +#include "simd_status.h" TEST_CASE("avx l2 simd16", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_sse()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 16; auto vectors = fixtures::generate_vectors(2, dim); @@ -37,8 +36,8 @@ TEST_CASE("avx l2 simd16", "[ut][simd][avx]") { } TEST_CASE("avx ip simd16", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_sse()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 16; auto vectors = fixtures::generate_vectors(2, dim); @@ -52,8 +51,8 @@ TEST_CASE("avx ip simd16", "[ut][simd][avx]") { } TEST_CASE("avx pq calculation", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_avx2()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 256; float single_dim_value = 0.571; float results_expected[256]{0.0f}; diff --git a/src/simd/fp32_simd.h b/src/simd/fp32_simd.h index 024cfce4..47027f3d 100644 --- a/src/simd/fp32_simd.h +++ b/src/simd/fp32_simd.h @@ -27,32 +27,26 @@ float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace avx512 -#endif using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim); extern FP32ComputeType FP32ComputeIP; diff --git a/src/simd/fp32_simd_test.cpp b/src/simd/fp32_simd_test.cpp index a20e744f..36a41da4 100644 --- a/src/simd/fp32_simd_test.cpp +++ b/src/simd/fp32_simd_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -33,15 +34,22 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + float gt, sse, avx2, avx512; \ + gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + if (SimdStatus::SupportSSE()) { \ + sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ }; TEST_CASE("FP32 SIMD Compute", "[FP32SIMD]") { diff --git a/src/simd/normalize.cpp b/src/simd/normalize.cpp new file mode 100644 index 00000000..74ccca03 --- /dev/null +++ b/src/simd/normalize.cpp @@ -0,0 +1,60 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "normalize.h" + +#include "simd_status.h" + +namespace vsag { + +static NormalizeType +SetNormalize() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::Normalize; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::Normalize; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::Normalize; +#endif + } + return generic::Normalize; +} +NormalizeType Normalize = SetNormalize(); + +static DivScalarType +SetDivScalar() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::DivScalar; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::DivScalar; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::DivScalar; +#endif + } + return generic::DivScalar; +} +DivScalarType DivScalar = SetDivScalar(); + +} // namespace vsag \ No newline at end of file diff --git a/src/simd/normalize.h b/src/simd/normalize.h index ab59c934..30268ba1 100644 --- a/src/simd/normalize.h +++ b/src/simd/normalize.h @@ -26,7 +26,6 @@ float Normalize(const float* from, float* to, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -34,9 +33,7 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -44,9 +41,7 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -54,34 +49,10 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace avx512 -#endif -inline void -DivScalar(const float* from, float* to, uint64_t dim, float scalar) { -#if defined(ENABLE_AVX512) - avx512::DivScalar(from, to, dim, scalar); -#endif -#if defined(ENABLE_AVX2) - avx2::DivScalar(from, to, dim, scalar); -#endif -#if defined(ENABLE_SSE) - sse::DivScalar(from, to, dim, scalar); -#endif - generic::DivScalar(from, to, dim, scalar); -} - -inline float -Normalize(const float* from, float* to, uint64_t dim) { -#if defined(ENABLE_AVX512) - return avx512::Normalize(from, to, dim); -#endif -#if defined(ENABLE_AVX2) - return avx2::Normalize(from, to, dim); -#endif -#if defined(ENABLE_SSE) - return sse::Normalize(from, to, dim); -#endif - return generic::Normalize(from, to, dim); -} +using NormalizeType = float (*)(const float* from, float* to, uint64_t dim); +extern NormalizeType Normalize; +using DivScalarType = void (*)(const float* from, float* to, uint64_t dim, float scalar); +extern DivScalarType DivScalar; } // namespace vsag diff --git a/src/simd/normalize_test.cpp b/src/simd/normalize_test.cpp index 4f3c6ccd..6b174016 100644 --- a/src/simd/normalize_test.cpp +++ b/src/simd/normalize_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -41,21 +42,36 @@ TEST_CASE("Normalize SIMD Compute", "[simd]") { std::vector tmp_value(dim * 4); for (uint64_t i = 0; i < count; ++i) { auto gt = generic::Normalize(vec1.data() + i * dim, tmp_value.data(), dim); - auto sse = sse::Normalize(vec1.data() + i * dim, tmp_value.data() + dim, dim); - auto avx2 = avx2::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 2, dim); - auto avx512 = avx512::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 3, dim); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); - for (int j = 0; j < dim; ++j) { - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim])); - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim * 2])); - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim * 3])); + if (SimdStatus::SupportSSE()) { + auto sse = sse::Normalize(vec1.data() + i * dim, tmp_value.data() + dim, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 1])); + } + } + if (SimdStatus::SupportAVX2()) { + auto avx2 = avx2::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 2, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 2])); + } + } + if (SimdStatus::SupportAVX512()) { + auto avx512 = + avx512::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 3, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 3])); + } } } } } + #define BENCHMARK_SIMD_COMPUTE(Simd, Comp) \ BENCHMARK_ADVANCED(#Simd #Comp) { \ for (int i = 0; i < count; ++i) { \ diff --git a/src/simd/simd.cpp b/src/simd/simd.cpp index 5c5f6f07..ac3cfae8 100644 --- a/src/simd/simd.cpp +++ b/src/simd/simd.cpp @@ -126,13 +126,15 @@ GetInnerProductDistanceFunc(size_t dim) { DistanceFunc GetINT8InnerProductDistanceFunc(size_t dim) { + if (SimdStatus::SupportAVX512()) { #ifdef ENABLE_AVX512 - if (dim > 32) { - return vsag::INT8InnerProduct512ResidualsAVX512Distance; - } else if (dim > 16) { - return vsag::INT8InnerProduct256ResidualsAVX512Distance; - } + if (dim > 32) { + return vsag::INT8InnerProduct512ResidualsAVX512Distance; + } else if (dim > 16) { + return vsag::INT8InnerProduct256ResidualsAVX512Distance; + } #endif + } return vsag::INT8InnerProductDistance; } diff --git a/src/simd/sq4_simd.h b/src/simd/sq4_simd.h index 9d390f90..9e0715c0 100644 --- a/src/simd/sq4_simd.h +++ b/src/simd/sq4_simd.h @@ -45,7 +45,6 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ4ComputeIP(const float* query, @@ -72,9 +71,7 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ4ComputeIP(const float* query, @@ -101,9 +98,7 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ4ComputeIP(const float* query, @@ -130,7 +125,6 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx512 -#endif using SQ4ComputeType = float (*)(const float* query, const uint8_t* codes, diff --git a/src/simd/sq4_uniform_simd.h b/src/simd/sq4_uniform_simd.h index c240776b..0a6f352c 100644 --- a/src/simd/sq4_uniform_simd.h +++ b/src/simd/sq4_uniform_simd.h @@ -23,26 +23,20 @@ float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx512 -#endif using SQ4UniformComputeCodesType = float (*)(const uint8_t* codes1, const uint8_t* codes2, diff --git a/src/simd/sq4_uniform_simd_test.cpp b/src/simd/sq4_uniform_simd_test.cpp index b074c6a3..73e37660 100644 --- a/src/simd/sq4_uniform_simd_test.cpp +++ b/src/simd/sq4_uniform_simd_test.cpp @@ -17,9 +17,9 @@ #include -#include "../logger.h" #include "catch2/benchmark/catch_benchmark.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -35,17 +35,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = \ - generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto sse = sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx2 = avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx512 = \ - avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = \ + generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = \ + sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = \ + avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = \ + avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ4 Uniform SIMD Compute Codes", "[SQ4 Uniform SIMD]") { diff --git a/src/simd/sq8_simd.h b/src/simd/sq8_simd.h index 1bba09fc..e6c58858 100644 --- a/src/simd/sq8_simd.h +++ b/src/simd/sq8_simd.h @@ -74,7 +74,6 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, } // namespace sse #endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ8ComputeIP(const float* query, @@ -101,9 +100,7 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ8ComputeIP(const float* query, @@ -130,7 +127,6 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx512 -#endif using SQ8ComputeType = float (*)(const float* query, const uint8_t* codes, diff --git a/src/simd/sq8_simd_test.cpp b/src/simd/sq8_simd_test.cpp index 77cdb9f2..e399714c 100644 --- a/src/simd/sq8_simd_test.cpp +++ b/src/simd/sq8_simd_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -33,19 +34,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = generic::Func( \ - vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto sse = \ - sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto avx2 = \ - avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto avx512 = avx512::Func( \ - vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = generic::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = sse::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = avx2::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = avx512::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ8 SIMD Compute Codes", "[SQ8 SIMD]") { diff --git a/src/simd/sq8_uniform_simd.h b/src/simd/sq8_uniform_simd.h index 4f6dca78..ae8050c3 100644 --- a/src/simd/sq8_uniform_simd.h +++ b/src/simd/sq8_uniform_simd.h @@ -23,26 +23,20 @@ float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx512 -#endif using SQ8UniformComputeCodesType = float (*)(const uint8_t* codes1, const uint8_t* codes2, diff --git a/src/simd/sq8_uniform_simd_test.cpp b/src/simd/sq8_uniform_simd_test.cpp index ff06275a..13db3c19 100644 --- a/src/simd/sq8_uniform_simd_test.cpp +++ b/src/simd/sq8_uniform_simd_test.cpp @@ -19,6 +19,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -34,17 +35,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = \ - generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto sse = sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx2 = avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx512 = \ - avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = \ + generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = \ + sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = \ + avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = \ + avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ8 Uniform SIMD Compute Codes", "[SQ8 Uniform SIMD]") {