From a6bd30c782fcbc60451ab46de3b8b39d61c9b2dc Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Fri, 22 Nov 2024 17:52:35 +0800 Subject: [PATCH] enhance: speed up search iter stage 1 Signed-off-by: Patrick Weizhi Xu --- go.mod | 2 +- internal/core/src/common/QueryInfo.h | 7 + internal/core/src/index/Utils.cpp | 35 + internal/core/src/index/Utils.h | 8 + internal/core/src/index/VectorDiskIndex.cpp | 27 +- internal/core/src/index/VectorMemIndex.cpp | 12 +- .../core/src/query/CachedSearchIterator.cpp | 355 +++++++++ .../core/src/query/CachedSearchIterator.h | 182 +++++ internal/core/src/query/PlanProto.cpp | 13 + internal/core/src/query/SearchBruteForce.cpp | 65 +- internal/core/src/query/SearchBruteForce.h | 22 +- internal/core/src/query/SearchOnGrowing.cpp | 27 +- internal/core/src/query/SearchOnIndex.cpp | 13 +- internal/core/src/query/SearchOnSealed.cpp | 58 +- .../unittest/test_cached_search_iterator.cpp | 716 ++++++++++++++++++ internal/proto/plan.proto | 7 + internal/proxy/proxy.go | 8 + internal/proxy/search_util.go | 90 ++- internal/proxy/task.go | 5 + internal/proxy/task_search.go | 45 +- internal/proxy/task_search_test.go | 261 ++++++- 21 files changed, 1853 insertions(+), 105 deletions(-) create mode 100644 internal/core/src/query/CachedSearchIterator.cpp create mode 100644 internal/core/src/query/CachedSearchIterator.h create mode 100644 internal/core/unittest/test_cached_search_iterator.cpp diff --git a/go.mod b/go.mod index 0b6c7c89aae3a..c66dee2f6146f 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,7 @@ require ( github.com/cenkalti/backoff/v4 v4.2.1 github.com/cockroachdb/redact v1.1.3 github.com/goccy/go-json v0.10.3 + github.com/google/uuid v1.6.0 github.com/greatroar/blobloom v0.0.0-00010101000000-000000000000 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jolestar/go-commons-pool/v2 v2.1.2 @@ -142,7 +143,6 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.8+incompatible // indirect github.com/google/s2a-go v0.1.7 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/gorilla/websocket v1.4.2 // indirect diff --git a/internal/core/src/common/QueryInfo.h b/internal/core/src/common/QueryInfo.h index 760409820ee47..a08653026eb76 100644 --- a/internal/core/src/common/QueryInfo.h +++ b/internal/core/src/common/QueryInfo.h @@ -24,6 +24,12 @@ namespace milvus { +struct SearchIteratorV2Info { + std::string token = ""; + uint32_t batch_size = 0; + std::optional last_bound = std::nullopt; +}; + struct SearchInfo { int64_t topk_{0}; int64_t group_size_{1}; @@ -35,6 +41,7 @@ struct SearchInfo { std::optional group_by_field_id_; tracer::TraceContext trace_ctx_; bool materialized_view_involved = false; + std::optional iterator_v2_info_ = std::nullopt; }; using SearchInfoPtr = std::shared_ptr; diff --git a/internal/core/src/index/Utils.cpp b/internal/core/src/index/Utils.cpp index 8ed904b52f8e4..3f6c4f720ece6 100644 --- a/internal/core/src/index/Utils.cpp +++ b/internal/core/src/index/Utils.cpp @@ -362,4 +362,39 @@ ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size) { } } +bool +CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info, + const int64_t topk, + const MetricType& metric_type, + knowhere::Json& search_config) { + const auto radius = + index::GetValueFromConfig(search_info.search_params_, RADIUS); + if (!radius.has_value()) { + return false; + } + + search_config[RADIUS] = radius.value(); + // `range_search_k` is only used as one of the conditions for iterator early termination. + // not gurantee to return exactly `range_search_k` results, which may be more or less. + // set it to -1 will return all results in the range. + search_config[knowhere::meta::RANGE_SEARCH_K] = topk; + + const auto range_filter = + GetValueFromConfig(search_info.search_params_, RANGE_FILTER); + if (range_filter.has_value()) { + search_config[RANGE_FILTER] = range_filter.value(); + CheckRangeSearchParam(search_config[RADIUS], + search_config[RANGE_FILTER], + metric_type); + } + + const auto page_retain_order = + GetValueFromConfig(search_info.search_params_, PAGE_RETAIN_ORDER); + if (page_retain_order.has_value()) { + search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] = + page_retain_order.value(); + } + return true; +} + } // namespace milvus::index diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index 1c5f175e26cb5..862562687738c 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -30,6 +30,8 @@ #include "common/Types.h" #include "common/FieldData.h" +#include "common/QueryInfo.h" +#include "common/RangeSearchHelper.h" #include "index/IndexInfo.h" #include "storage/Types.h" @@ -147,4 +149,10 @@ AssembleIndexDatas(std::map& index_datas, void ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size = 0x7ffff000); +bool +CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info, + const int64_t topk, + const MetricType& metric_type, + knowhere::Json& search_config); + } // namespace milvus::index diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index e6360eb199159..e4be33267c153 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -266,32 +266,9 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; auto final = [&] { - auto radius = - GetValueFromConfig(search_info.search_params_, RADIUS); - if (radius.has_value()) { - search_config[RADIUS] = radius.value(); - // `range_search_k` is only used as one of the conditions for iterator early termination. - // not gurantee to return exactly `range_search_k` results, which may be more or less. - // set it to -1 will return all results in the range. - search_config[knowhere::meta::RANGE_SEARCH_K] = topk; - auto range_filter = GetValueFromConfig( - search_info.search_params_, RANGE_FILTER); - if (range_filter.has_value()) { - search_config[RANGE_FILTER] = range_filter.value(); - CheckRangeSearchParam(search_config[RADIUS], - search_config[RANGE_FILTER], - GetMetricType()); - } - - auto page_retain_order = GetValueFromConfig( - search_info.search_params_, PAGE_RETAIN_ORDER); - if (page_retain_order.has_value()) { - search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] = - page_retain_order.value(); - } - + if (CheckAndUpdateKnowhereRangeSearchParam( + search_info, topk, GetMetricType(), search_config)) { auto res = index_.RangeSearch(dataset, search_config, bitset); - if (!res.has_value()) { PanicInfo(ErrorCode::UnexpectedError, fmt::format("failed to range search: {}: {}", diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index c471fb02a1325..1e4ff425d6f4a 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -380,16 +380,8 @@ VectorMemIndex::Query(const DatasetPtr dataset, // TODO :: check dim of search data auto final = [&] { auto index_type = GetIndexType(); - if (CheckKeyInConfig(search_conf, RADIUS)) { - if (CheckKeyInConfig(search_conf, RANGE_FILTER)) { - CheckRangeSearchParam(search_conf[RADIUS], - search_conf[RANGE_FILTER], - GetMetricType()); - } - // `range_search_k` is only used as one of the conditions for iterator early termination. - // not gurantee to return exactly `range_search_k` results, which may be more or less. - // set it to -1 will return all results in the range. - search_conf[knowhere::meta::RANGE_SEARCH_K] = topk; + if (CheckAndUpdateKnowhereRangeSearchParam( + search_info, topk, GetMetricType(), search_conf)) { milvus::tracer::AddEvent("start_knowhere_index_range_search"); auto res = index_.RangeSearch(dataset, search_conf, bitset); milvus::tracer::AddEvent("finish_knowhere_index_range_search"); diff --git a/internal/core/src/query/CachedSearchIterator.cpp b/internal/core/src/query/CachedSearchIterator.cpp new file mode 100644 index 0000000000000..4f0e01240bc8d --- /dev/null +++ b/internal/core/src/query/CachedSearchIterator.cpp @@ -0,0 +1,355 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "query/CachedSearchIterator.h" +#include "query/SearchBruteForce.h" +#include + +namespace milvus::query { + +CachedSearchIterator::CachedSearchIterator( + const milvus::index::VectorIndex& index, + const knowhere::DataSetPtr& query_ds, + const SearchInfo& search_info, + const BitsetView& bitset) { + if (query_ds == nullptr) { + PanicInfo(ErrorCode::UnexpectedError, + "Query dataset is nullptr, cannot initialize iterator"); + } + nq_ = query_ds->GetRows(); + Init(search_info); + + auto search_json = index.PrepareSearchParams(search_info); + index::CheckAndUpdateKnowhereRangeSearchParam( + search_info, batch_size_, index.GetMetricType(), search_json); + + auto expected_iterators = + index.VectorIterators(query_ds, search_json, bitset); + if (expected_iterators.has_value()) { + iterators_ = std::move(expected_iterators.value()); + } else { + PanicInfo(ErrorCode::UnexpectedError, + "Failed to create iterators from index"); + } +} + +CachedSearchIterator::CachedSearchIterator( + const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type) { + nq_ = query_ds.num_queries; + Init(search_info); + + auto expected_iterators = GetBruteForceSearchIterators( + query_ds, raw_ds, search_info, index_info, bitset, data_type); + if (expected_iterators.has_value()) { + iterators_ = std::move(expected_iterators.value()); + } else { + PanicInfo(ErrorCode::UnexpectedError, + "Failed to create iterators from index"); + } +} + +void +CachedSearchIterator::InitializeChunkedIterators( + const dataset::SearchDataset& query_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type, + const GetChunkDataFunc& get_chunk_data) { + int64_t offset = 0; + chunked_heaps_.resize(nq_); + for (int64_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) { + auto [chunk_data, chunk_size] = get_chunk_data(chunk_id); + auto sub_data = query::dataset::RawDataset{ + offset, query_ds.dim, chunk_size, chunk_data}; + + auto expected_iterators = GetBruteForceSearchIterators( + query_ds, sub_data, search_info, index_info, bitset, data_type); + if (expected_iterators.has_value()) { + auto& chunk_iterators = expected_iterators.value(); + iterators_.insert(iterators_.end(), + std::make_move_iterator(chunk_iterators.begin()), + std::make_move_iterator(chunk_iterators.end())); + } else { + PanicInfo(ErrorCode::UnexpectedError, + "Failed to create iterators from index"); + } + offset += chunk_size; + } +} + +CachedSearchIterator::CachedSearchIterator( + const dataset::SearchDataset& query_ds, + const segcore::VectorBase* vec_data, + const int64_t row_count, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type) { + if (vec_data == nullptr) { + PanicInfo(ErrorCode::UnexpectedError, + "Vector data is nullptr, cannot initialize iterator"); + } + + if (row_count <= 0) { + PanicInfo(ErrorCode::UnexpectedError, + "Number of rows is 0, cannot initialize iterator"); + } + + const int64_t vec_size_per_chunk = vec_data->get_size_per_chunk(); + num_chunks_ = upper_div(row_count, vec_size_per_chunk); + nq_ = query_ds.num_queries; + Init(search_info); + + iterators_.reserve(nq_ * num_chunks_); + InitializeChunkedIterators( + query_ds, + search_info, + index_info, + bitset, + data_type, + [&vec_data, vec_size_per_chunk, row_count]( + int64_t chunk_id) -> std::pair { + const auto chunk_data = vec_data->get_chunk_data(chunk_id); + int64_t chunk_size = std::min( + vec_size_per_chunk, row_count - chunk_id * vec_size_per_chunk); + return {chunk_data, chunk_size}; + }); +} + +CachedSearchIterator::CachedSearchIterator( + const std::shared_ptr& column, + const dataset::SearchDataset& query_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type) { + if (column == nullptr) { + PanicInfo(ErrorCode::UnexpectedError, + "Column is nullptr, cannot initialize iterator"); + } + + num_chunks_ = column->num_chunks(); + nq_ = query_ds.num_queries; + Init(search_info); + + iterators_.reserve(nq_ * num_chunks_); + InitializeChunkedIterators( + query_ds, + search_info, + index_info, + bitset, + data_type, + [&column](int64_t chunk_id) { + const char* chunk_data = column->Data(chunk_id); + int64_t chunk_size = column->chunk_row_nums(chunk_id); + return std::make_pair(static_cast(chunk_data), + chunk_size); + }); +} + +void +CachedSearchIterator::NextBatch(const SearchInfo& search_info, + SearchResult& search_result) { + if (iterators_.empty()) { + return; + } + + if (iterators_.size() != nq_ * num_chunks_) { + PanicInfo(ErrorCode::UnexpectedError, + "Iterator size mismatch, expect %d, but got %d", + nq_ * num_chunks_, + iterators_.size()); + } + + ValidateSearchInfo(search_info); + + search_result.total_nq_ = nq_; + search_result.unity_topK_ = batch_size_; + search_result.seg_offsets_.resize(nq_ * batch_size_); + search_result.distances_.resize(nq_ * batch_size_); + + for (size_t query_idx = 0; query_idx < nq_; ++query_idx) { + auto rst = GetBatchedNextResults(query_idx, search_info); + WriteSingleQuerySearchResult( + search_result, query_idx, rst, search_info.round_decimal_); + } +} + +void +CachedSearchIterator::ValidateSearchInfo(const SearchInfo& search_info) { + if (!search_info.iterator_v2_info_.has_value()) { + PanicInfo(ErrorCode::UnexpectedError, + "Iterator v2 SearchInfo is not set"); + } + + auto iterator_v2_info = search_info.iterator_v2_info_.value(); + if (iterator_v2_info.batch_size != batch_size_) { + PanicInfo(ErrorCode::UnexpectedError, + "Batch size mismatch, expect %d, but got %d", + batch_size_, + iterator_v2_info.batch_size); + } +} + +std::optional +CachedSearchIterator::GetNextValidResult( + const size_t iterator_idx, + const std::optional& last_bound, + const std::optional& radius, + const std::optional& range_filter) { + auto& iterator = iterators_[iterator_idx]; + while (iterator->HasNext()) { + auto result = ConvertIteratorResult(iterator->Next()); + if (IsValid(result, last_bound, radius, range_filter)) { + return result; + } + } + return std::nullopt; +} + +// TODO: Optimize this method +void +CachedSearchIterator::MergeChunksResults( + size_t query_idx, + const std::optional& last_bound, + const std::optional& radius, + const std::optional& range_filter, + std::vector& rst) { + auto& heap = chunked_heaps_[query_idx]; + + if (heap.empty()) { + for (size_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) { + const size_t iterator_idx = query_idx + chunk_id * nq_; + if (auto next_result = GetNextValidResult( + iterator_idx, last_bound, radius, range_filter); + next_result.has_value()) { + heap.emplace(iterator_idx, next_result.value()); + } + } + } + + while (!heap.empty() && rst.size() < batch_size_) { + const auto [iterator_idx, cur_rst] = heap.top(); + heap.pop(); + + // last_bound may change between NextBatch calls, discard any invalid results + if (!IsValid(cur_rst, last_bound, radius, range_filter)) { + continue; + } + rst.emplace_back(cur_rst); + + if (auto next_result = GetNextValidResult( + iterator_idx, last_bound, radius, range_filter); + next_result.has_value()) { + heap.emplace(iterator_idx, next_result.value()); + } + } +} + +std::vector +CachedSearchIterator::GetBatchedNextResults(size_t query_idx, + const SearchInfo& search_info) { + auto last_bound = ConvertIncomingDistance( + search_info.iterator_v2_info_.value().last_bound); + auto radius = ConvertIncomingDistance( + index::GetValueFromConfig(search_info.search_params_, RADIUS)); + auto range_filter = + ConvertIncomingDistance(index::GetValueFromConfig( + search_info.search_params_, RANGE_FILTER)); + + std::vector rst; + rst.reserve(batch_size_); + + if (num_chunks_ == 1) { + auto& iterator = iterators_[query_idx]; + while (iterator->HasNext() && rst.size() < batch_size_) { + auto result = ConvertIteratorResult(iterator->Next()); + if (IsValid(result, last_bound, radius, range_filter)) { + rst.emplace_back(result); + } + } + } else { + MergeChunksResults(query_idx, last_bound, radius, range_filter, rst); + } + std::sort(rst.begin(), rst.end()); + if (sign_ == -1) { + std::for_each(rst.begin(), rst.end(), [this](DisIdPair& x) { + x.first = x.first * sign_; + }); + } + while (rst.size() < batch_size_) { + rst.emplace_back(1.0f / 0.0f, -1); + } + return rst; +} + +void +CachedSearchIterator::WriteSingleQuerySearchResult( + SearchResult& search_result, + const size_t idx, + std::vector& rst, + const int64_t round_decimal) { + const float multiplier = pow(10.0, round_decimal); + + std::transform(rst.begin(), + rst.end(), + search_result.distances_.begin() + idx * batch_size_, + [multiplier, round_decimal](DisIdPair& x) { + if (round_decimal != -1) { + x.first = + std::round(x.first * multiplier) / multiplier; + } + return x.first; + }); + + std::transform(rst.begin(), + rst.end(), + search_result.seg_offsets_.begin() + idx * batch_size_, + [](const DisIdPair& x) { return x.second; }); +} + +void +CachedSearchIterator::Init(const SearchInfo& search_info) { + if (!search_info.iterator_v2_info_.has_value()) { + PanicInfo(ErrorCode::UnexpectedError, + "Iterator v2 info is not set, cannot initialize iterator"); + } + + auto iterator_v2_info = search_info.iterator_v2_info_.value(); + if (iterator_v2_info.batch_size == 0) { + PanicInfo(ErrorCode::UnexpectedError, + "Batch size is 0, cannot initialize iterator"); + } + batch_size_ = iterator_v2_info.batch_size; + + if (search_info.metric_type_.empty()) { + PanicInfo(ErrorCode::UnexpectedError, + "Metric type is empty, cannot initialize iterator"); + } + if (PositivelyRelated(search_info.metric_type_)) { + sign_ = -1; + } else { + sign_ = 1; + } + + if (nq_ == 0) { + PanicInfo(ErrorCode::UnexpectedError, + "Number of queries is 0, cannot initialize iterator"); + } +} + +} // namespace milvus::query \ No newline at end of file diff --git a/internal/core/src/query/CachedSearchIterator.h b/internal/core/src/query/CachedSearchIterator.h new file mode 100644 index 0000000000000..8391a5b54bc8b --- /dev/null +++ b/internal/core/src/query/CachedSearchIterator.h @@ -0,0 +1,182 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include "common/BitsetView.h" +#include "common/QueryInfo.h" +#include "common/QueryResult.h" +#include "query/helper.h" +#include "segcore/ConcurrentVector.h" +#include "index/VectorIndex.h" + +namespace milvus::query { + +// This class is used to cache the search results from Knowhere +// search iterators and filter the results based on the last_bound, +// radius and range_filter. +// It provides a number of constructors to support different scenarios, +// including growing/sealed, chunked/non-chunked. +// +// It does not care about TopK in search_info +// The topk in SearchResult will be set to the batch_size for compatibility +// +// TODO: introduce the pool of results in the near future +// TODO: replace VectorIterator class +class CachedSearchIterator { + public: + // For sealed segment with vector index + CachedSearchIterator(const milvus::index::VectorIndex& index, + const knowhere::DataSetPtr& dataset, + const SearchInfo& search_info, + const BitsetView& bitset); + + // For sealed segment, BF + CachedSearchIterator(const dataset::SearchDataset& dataset, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type); + + // For growing segment with chunked data, BF + CachedSearchIterator(const dataset::SearchDataset& dataset, + const segcore::VectorBase* vec_data, + const int64_t row_count, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type); + + // For sealed segment with chunked data, BF + CachedSearchIterator(const std::shared_ptr& column, + const dataset::SearchDataset& dataset, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type); + + // This method fetches the next batch of search results based on the provided search information + // and updates the search_result object with the new batch of results. + void + NextBatch(const SearchInfo& search_info, SearchResult& search_result); + + // Disable copy and move + CachedSearchIterator(const CachedSearchIterator&) = delete; + CachedSearchIterator& + operator=(const CachedSearchIterator&) = delete; + CachedSearchIterator(CachedSearchIterator&&) = delete; + CachedSearchIterator& + operator=(CachedSearchIterator&&) = delete; + + private: + using DisIdPair = std::pair; + using IterIdx = size_t; + using IterIdDisIdPair = std::pair; + using GetChunkDataFunc = + std::function(int64_t)>; + + int64_t batch_size_ = 0; + std::vector iterators_; + int8_t sign_ = 1; + size_t num_chunks_ = 1; + size_t nq_ = 0; + + struct IterIdDisIdPairComparator { + bool + operator()(const IterIdDisIdPair& lhs, const IterIdDisIdPair& rhs) { + if (lhs.second.first == rhs.second.first) { + return lhs.second.second > rhs.second.second; + } + return lhs.second.first > rhs.second.first; + } + }; + std::vector, + IterIdDisIdPairComparator>> + chunked_heaps_; + + inline bool + IsValid(const DisIdPair& result, + const std::optional& last_bound, + const std::optional& radius, + const std::optional& range_filter) { + const float dist = result.first; + const bool is_valid = + !last_bound.has_value() || dist > last_bound.value(); + + if (!radius.has_value()) { + return is_valid; + } + + if (!range_filter.has_value()) { + return is_valid && dist < radius.value(); + } + + return is_valid && dist < radius.value() && + dist >= range_filter.value(); + } + + inline DisIdPair + ConvertIteratorResult(const std::pair& iter_rst) { + DisIdPair rst; + rst.first = iter_rst.second * sign_; + rst.second = iter_rst.first; + return rst; + } + + inline std::optional + ConvertIncomingDistance(std::optional dist) { + if (dist.has_value()) { + dist = dist.value() * sign_; + } + return dist; + } + + std::optional + GetNextValidResult(size_t iterator_idx, + const std::optional& last_bound, + const std::optional& radius, + const std::optional& range_filter); + + void + MergeChunksResults(size_t query_idx, + const std::optional& last_bound, + const std::optional& radius, + const std::optional& range_filter, + std::vector& rst); + + void + ValidateSearchInfo(const SearchInfo& search_info); + + std::vector + GetBatchedNextResults(size_t query_idx, const SearchInfo& search_info); + + void + WriteSingleQuerySearchResult(SearchResult& search_result, + const size_t idx, + std::vector& rst, + const int64_t round_decimal); + + void + Init(const SearchInfo& search_info); + + void + InitializeChunkedIterators( + const dataset::SearchDataset& dataset, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + const milvus::DataType& data_type, + const GetChunkDataFunc& get_chunk_data); +}; +} // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 3d7e325ce116c..8b302d5daaf26 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -69,6 +69,19 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.strict_group_size_ = query_info_proto.strict_group_size(); } + + if (query_info_proto.has_search_iterator_v2_info()) { + auto& iterator_v2_info_proto = query_info_proto.search_iterator_v2_info(); + search_info.iterator_v2_info_ = SearchIteratorV2Info{ + .token = iterator_v2_info_proto.token(), + .batch_size = iterator_v2_info_proto.batch_size(), + }; + if (iterator_v2_info_proto.has_last_bound()) { + search_info.iterator_v2_info_->last_bound = + iterator_v2_info_proto.last_bound(); + } + } + return search_info; }; diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 9df66690b8396..6a17bc2c00bd1 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -226,45 +226,66 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, return sub_result; } -SubSearchResult -BruteForceSearchIterators(const dataset::SearchDataset& query_ds, - const dataset::RawDataset& raw_ds, - const SearchInfo& search_info, - const std::map& index_info, - const BitsetView& bitset, - DataType data_type) { - auto nq = query_ds.num_queries; - auto [query_dataset, base_dataset] = - PrepareBFDataSet(query_ds, raw_ds, data_type); - auto search_cfg = PrepareBFSearchParams(search_info, index_info); - - knowhere::expected> - iterators_val; +knowhere::expected> +DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset, + const knowhere::DataSetPtr& query_dataset, + const knowhere::Json& config, + const BitsetView& bitset, + const milvus::DataType& data_type) { switch (data_type) { case DataType::VECTOR_FLOAT: - iterators_val = knowhere::BruteForce::AnnIterator( - base_dataset, query_dataset, search_cfg, bitset); + return knowhere::BruteForce::AnnIterator( + base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_FLOAT16: //todo: if knowhere support real fp16/bf16 bf, change it - iterators_val = knowhere::BruteForce::AnnIterator( - base_dataset, query_dataset, search_cfg, bitset); + return knowhere::BruteForce::AnnIterator( + base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_BFLOAT16: //todo: if knowhere support real fp16/bf16 bf, change it - iterators_val = knowhere::BruteForce::AnnIterator( - base_dataset, query_dataset, search_cfg, bitset); + return knowhere::BruteForce::AnnIterator( + base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_SPARSE_FLOAT: - iterators_val = knowhere::BruteForce::AnnIterator< + return knowhere::BruteForce::AnnIterator< knowhere::sparse::SparseRow>( - base_dataset, query_dataset, search_cfg, bitset); + base_dataset, query_dataset, config, bitset); break; default: PanicInfo(ErrorCode::Unsupported, "Unsupported dataType for chunk brute force iterator:{}", data_type); } +} + +knowhere::expected> +GetBruteForceSearchIterators( + const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + DataType data_type) { + auto nq = query_ds.num_queries; + auto [query_dataset, base_dataset] = + PrepareBFDataSet(query_ds, raw_ds, data_type); + auto search_cfg = PrepareBFSearchParams(search_info, index_info); + return DispatchBruteForceIteratorByDataType( + base_dataset, query_dataset, search_cfg, bitset, data_type); +} + +SubSearchResult +PackBruteForceSearchIteratorsIntoSubResult( + const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + DataType data_type) { + auto nq = query_ds.num_queries; + auto iterators_val = GetBruteForceSearchIterators( + query_ds, raw_ds, search_info, index_info, bitset, data_type); if (iterators_val.has_value()) { AssertInfo( iterators_val.value().size() == nq, diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index 15fb5697abfdf..51e348bfa23c8 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -31,12 +31,22 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, const BitsetView& bitset, DataType data_type); +knowhere::expected> +GetBruteForceSearchIterators( + const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + DataType data_type); + SubSearchResult -BruteForceSearchIterators(const dataset::SearchDataset& query_ds, - const dataset::RawDataset& raw_ds, - const SearchInfo& search_info, - const std::map& index_info, - const BitsetView& bitset, - DataType data_type); +PackBruteForceSearchIteratorsIntoSubResult( + const dataset::SearchDataset& query_ds, + const dataset::RawDataset& raw_ds, + const SearchInfo& search_info, + const std::map& index_info, + const BitsetView& bitset, + DataType data_type); } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 2add5f5a1fde8..4883a45e5e400 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -18,6 +18,7 @@ #include "knowhere/comp/index_param.h" #include "knowhere/config.h" #include "log/Log.h" +#include "query/CachedSearchIterator.h" #include "query/SearchBruteForce.h" #include "query/SearchOnIndex.h" @@ -124,6 +125,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, // step 3: brute force search where small indexing is unavailable auto vec_ptr = record.get_data_base(vecfield_id); + + if (info.iterator_v2_info_.has_value()) { + CachedSearchIterator cached_iter(search_dataset, + vec_ptr, + active_count, + info, + index_info, + bitset, + data_type); + cached_iter.NextBatch(info, search_result); + return; + } + auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); auto max_chunk = upper_div(active_count, vec_size_per_chunk); @@ -139,12 +153,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, auto sub_data = query::dataset::RawDataset{ element_begin, dim, size_per_chunk, chunk_data}; if (info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators(search_dataset, - sub_data, - info, - index_info, - bitset, - data_type); + auto sub_qr = + PackBruteForceSearchIteratorsIntoSubResult(search_dataset, + sub_data, + info, + index_info, + bitset, + data_type); final_qr.merge(sub_qr); } else { auto sub_qr = BruteForceSearch(search_dataset, diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp index 0204f791ce217..2cffd7ec68388 100644 --- a/internal/core/src/query/SearchOnIndex.cpp +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "SearchOnIndex.h" +#include "CachedSearchIterator.h" #include "exec/operator/groupby/SearchGroupByOperator.h" namespace milvus::query { @@ -26,14 +27,22 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset, auto dataset = knowhere::GenDataSet(num_queries, dim, search_dataset.query_data); dataset->SetIsSparse(is_sparse); - if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_conf, + if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf, num_queries, dataset, search_result, bitset, indexing)) { - indexing.Query(dataset, search_conf, bitset, search_result); + return; } + + if (search_conf.iterator_v2_info_.has_value()) { + auto iter = CachedSearchIterator(indexing, dataset, search_conf, bitset); + iter.NextBatch(search_conf, search_result); + return; + } + + indexing.Query(dataset, search_conf, bitset, search_result); } } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 59146b2447a0b..6de154c4e033f 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -18,6 +18,7 @@ #include "common/QueryInfo.h" #include "common/Types.h" #include "mmap/Column.h" +#include "query/CachedSearchIterator.h" #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" #include "query/helper.h" @@ -55,13 +56,19 @@ SearchOnSealedIndex(const Schema& schema, dataset->SetIsSparse(is_sparse); auto vec_index = dynamic_cast(field_indexing->indexing_.get()); + + if (search_info.iterator_v2_info_.has_value()) { + CachedSearchIterator cached_iter(*vec_index, dataset, search_info, bitset); + cached_iter.NextBatch(search_info, search_result); + return; + } + if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info, num_queries, dataset, search_result, bitset, *vec_index)) { - auto index_type = vec_index->GetIndexType(); vec_index->Query(dataset, search_info, bitset, search_result); float* distances = search_result.distances_.data(); auto total_num = num_queries * topK; @@ -104,6 +111,18 @@ SearchOnSealed(const Schema& schema, auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); + + if (search_info.iterator_v2_info_.has_value()) { + CachedSearchIterator cached_iter(column, + query_dataset, + search_info, + index_info, + bitview, + data_type); + cached_iter.NextBatch(search_info, result); + return; + } + auto num_chunk = column->num_chunks(); SubSearchResult final_qr(num_queries, @@ -115,17 +134,16 @@ SearchOnSealed(const Schema& schema, for (int i = 0; i < num_chunk; ++i) { auto vec_data = column->Data(i); auto chunk_size = column->chunk_row_nums(i); - const uint8_t* bitset_ptr = nullptr; - auto data_id = offset; auto raw_dataset = query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; if (search_info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators(query_dataset, - raw_dataset, - search_info, - index_info, - bitview, - data_type); + auto sub_qr = + PackBruteForceSearchIteratorsIntoSubResult(query_dataset, + raw_dataset, + search_info, + index_info, + bitview, + data_type); final_qr.merge(sub_qr); } else { auto sub_qr = BruteForceSearch(query_dataset, @@ -136,7 +154,6 @@ SearchOnSealed(const Schema& schema, data_type); final_qr.merge(sub_qr); } - offset += chunk_size; } if (search_info.group_by_field_id_.has_value()) { @@ -181,14 +198,23 @@ SearchOnSealed(const Schema& schema, CheckBruteForceSearchParam(field, search_info); auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data}; if (search_info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators(query_dataset, - raw_dataset, - search_info, - index_info, - bitset, - data_type); + auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(query_dataset, + raw_dataset, + search_info, + index_info, + bitset, + data_type); result.AssembleChunkVectorIterators( num_queries, 1, {0}, sub_qr.chunk_iterators()); + } else if (search_info.iterator_v2_info_.has_value()) { + CachedSearchIterator cached_iter(query_dataset, + raw_dataset, + search_info, + index_info, + bitset, + data_type); + cached_iter.NextBatch(search_info, result); + return; } else { auto sub_qr = BruteForceSearch(query_dataset, raw_dataset, diff --git a/internal/core/unittest/test_cached_search_iterator.cpp b/internal/core/unittest/test_cached_search_iterator.cpp new file mode 100644 index 0000000000000..97d82f45e821b --- /dev/null +++ b/internal/core/unittest/test_cached_search_iterator.cpp @@ -0,0 +1,716 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include "common/BitsetView.h" +#include "common/QueryInfo.h" +#include "common/QueryResult.h" +#include "common/Utils.h" +#include "knowhere/comp/index_param.h" +#include "query/CachedSearchIterator.h" +#include "index/VectorIndex.h" +#include "index/IndexFactory.h" +#include "knowhere/dataset.h" +#include "query/helper.h" +#include "segcore/ConcurrentVector.h" +#include "segcore/InsertRecord.h" +#include "mmap/ChunkedColumn.h" +#include "test_utils/DataGen.h" + +using namespace milvus; +using namespace milvus::query; +using namespace milvus::segcore; +using namespace milvus::index; + +namespace { +constexpr int64_t kDim = 16; +constexpr int64_t kNumVectors = 1000; +// constexpr int64_t kNumQueries = 5; +constexpr int64_t kNumQueries = 1; +constexpr int64_t kBatchSize = 100; +constexpr size_t kSizePerChunk = 128; +constexpr size_t kHnswM = 24; +constexpr size_t kHnswEfConstruction = 360; +constexpr size_t kHnswEf = 128; + +const MetricType kMetricType = knowhere::metric::L2; +} // namespace + +enum class ConstructorType { + VectorIndex = 0, + RawData, + VectorBase, + ChunkedColumn +}; +using Param = ConstructorType; + +class CachedSearchIteratorTest : public ::testing::TestWithParam { + private: + protected: + static SearchInfo + GetDefaultNormalSearchInfo() { + return SearchInfo{ + .topk_ = kBatchSize, + .metric_type_ = kMetricType, + .search_params_ = + { + {knowhere::indexparam::EF, std::to_string(kHnswEf)}, + }, + .iterator_v2_info_ = + SearchIteratorV2Info{ + .batch_size = kBatchSize, + }, + }; + } + + static DataType data_type_; + static MetricType metric_type_; + static int64_t dim_; + static int64_t nb_; + static int64_t nq_; + static FixedVector base_dataset_; + static FixedVector query_dataset_; + static IndexBasePtr index_hnsw_; + static knowhere::DataSetPtr knowhere_query_dataset_; + static dataset::SearchDataset search_dataset_; + static std::unique_ptr> vector_base_; + static std::shared_ptr column_; + static std::vector> column_data_; + + std::unique_ptr + DispatchIterator(const ConstructorType& constructor_type, + const SearchInfo& search_info, + const BitsetView& bitset) { + switch (constructor_type) { + case ConstructorType::VectorIndex: + return std::make_unique( + dynamic_cast(*index_hnsw_), + knowhere_query_dataset_, + search_info, + bitset); + + case ConstructorType::RawData: + return std::make_unique( + search_dataset_, + dataset::RawDataset{0, dim_, nb_, base_dataset_.data()}, + search_info, + std::map{}, + bitset, + data_type_); + + case ConstructorType::VectorBase: + return std::make_unique( + search_dataset_, + vector_base_.get(), + nb_, + search_info, + std::map{}, + bitset, + data_type_); + + case ConstructorType::ChunkedColumn: + return std::make_unique( + column_, + search_dataset_, + search_info, + std::map{}, + bitset, + data_type_); + default: + return nullptr; + } + } + + // use last distance of the first batch as range_filter + // use first distance of the last batch as radius + std::pair + GetRadiusAndRangeFilter() { + const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize; + SearchResult search_result; + float radius, range_filter; + bool get_radius_success = false; + bool get_range_filter_sucess = false; + SearchInfo search_info = GetDefaultNormalSearchInfo(); + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + for (size_t rnd = 0; rnd < num_rnds; ++rnd) { + iterator->NextBatch(search_info, search_result); + if (rnd == 0) { + for (size_t i = kBatchSize - 1; i >= 0; --i) { + if (search_result.seg_offsets_[i] != -1) { + range_filter = search_result.distances_[i]; + get_range_filter_sucess = true; + break; + } + } + } else { + for (size_t i = 0; i < kBatchSize; ++i) { + if (search_result.seg_offsets_[i] != -1) { + radius = search_result.distances_[i]; + get_radius_success = true; + break; + } + } + } + } + if (!get_radius_success || !get_range_filter_sucess) { + throw std::runtime_error("Failed to get radius and range filter"); + } + return {radius, range_filter}; + } + + static void + BuildIndex() { + auto dataset = knowhere::GenDataSet(nb_, dim_, base_dataset_.data()); + + // build Flat + milvus::index::CreateIndexInfo create_index_info; + create_index_info.field_type = data_type_; + create_index_info.metric_type = metric_type_; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + auto build_conf = + knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, + {knowhere::meta::DIM, std::to_string(dim_)}, + {knowhere::indexparam::M, std::to_string(kHnswM)}, + {knowhere::indexparam::EFCONSTRUCTION, + std::to_string(kHnswEfConstruction)}}; + create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW; + index_hnsw_ = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, milvus::storage::FileManagerContext()); + index_hnsw_->BuildWithDataset(dataset, build_conf); + ASSERT_EQ(index_hnsw_->Count(), nb_); + } + + static void + SetUpVectorBase() { + vector_base_ = std::make_unique>( + dim_, kSizePerChunk); + vector_base_->set_data_raw(0, base_dataset_.data(), nb_); + + ASSERT_EQ(vector_base_->num_chunk(), + (nb_ + kSizePerChunk - 1) / kSizePerChunk); + } + + static void + SetUpChunkedColumn() { + column_ = std::make_unique(); + const size_t num_chunks_ = (nb_ + kSizePerChunk - 1) / kSizePerChunk; + column_data_.resize(num_chunks_); + + size_t offset = 0; + for (size_t i = 0; i < num_chunks_; ++i) { + const size_t rows = std::min(nb_ - offset, kSizePerChunk); + const size_t chunk_bitset_size = (rows + 7) / 8; + const size_t buf_size = + chunk_bitset_size + rows * dim_ * sizeof(float); + auto& chunk_data = column_data_[i]; + chunk_data.resize(buf_size); + memcpy(chunk_data.data() + chunk_bitset_size, + base_dataset_.cbegin() + offset * dim_, + rows * dim_ * sizeof(float)); + column_->AddChunk(std::make_shared( + rows, dim_, chunk_data.data(), buf_size, sizeof(float), false)); + offset += rows; + } + } + + static void + SetUpTestSuite() { + auto schema = std::make_shared(); + auto fakevec_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, dim_, metric_type_); + + // generate base dataset + base_dataset_ = + segcore::DataGen(schema, nb_).get_col(fakevec_id); + + // generate query dataset + query_dataset_ = {base_dataset_.cbegin(), + base_dataset_.cbegin() + nq_ * dim_}; + knowhere_query_dataset_ = + knowhere::GenDataSet(nq_, dim_, query_dataset_.data()); + search_dataset_ = dataset::SearchDataset{ + .metric_type = metric_type_, + .num_queries = nq_, + .topk = kBatchSize, + .round_decimal = -1, + .dim = dim_, + .query_data = query_dataset_.data(), + }; + + BuildIndex(); + SetUpVectorBase(); + SetUpChunkedColumn(); + } + + static void + TearDownTestSuite() { + base_dataset_.clear(); + query_dataset_.clear(); + index_hnsw_.reset(); + knowhere_query_dataset_.reset(); + vector_base_.reset(); + column_.reset(); + } + + void + SetUp() override { + } + + void + TearDown() override { + } +}; + +// initialize static variables +DataType CachedSearchIteratorTest::data_type_ = DataType::VECTOR_FLOAT; +int64_t CachedSearchIteratorTest::dim_ = kDim; +int64_t CachedSearchIteratorTest::nb_ = kNumVectors; +int64_t CachedSearchIteratorTest::nq_ = kNumQueries; +MetricType CachedSearchIteratorTest::metric_type_ = kMetricType; +IndexBasePtr CachedSearchIteratorTest::index_hnsw_ = nullptr; +knowhere::DataSetPtr CachedSearchIteratorTest::knowhere_query_dataset_ = + nullptr; +dataset::SearchDataset CachedSearchIteratorTest::search_dataset_; +FixedVector CachedSearchIteratorTest::base_dataset_; +FixedVector CachedSearchIteratorTest::query_dataset_; +std::unique_ptr> + CachedSearchIteratorTest::vector_base_ = nullptr; +std::shared_ptr CachedSearchIteratorTest::column_ = nullptr; +std::vector> CachedSearchIteratorTest::column_data_; + +/********* Testcases Start **********/ + +TEST_P(CachedSearchIteratorTest, NextBatchNormal) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + const std::vector kBatchSizes = { + 1, 7, 43, 99, 100, 101, 1000, 1005}; + + for (size_t batch_size : kBatchSizes) { + std::cout << "batch_size: " << batch_size << std::endl; + search_info.iterator_v2_info_->batch_size = batch_size; + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + iterator->NextBatch(search_info, search_result); + + for (size_t i = 0; i < nq_; ++i) { + std::unordered_set seg_offsets; + size_t cnt = 0; + for (size_t j = 0; j < batch_size; ++j) { + if (search_result.seg_offsets_[i * batch_size + j] == -1) { + break; + } + ++cnt; + seg_offsets.insert( + search_result.seg_offsets_[i * batch_size + j]); + } + EXPECT_EQ(seg_offsets.size(), cnt); + EXPECT_EQ(search_result.distances_[i * batch_size], 0); + } + EXPECT_EQ(search_result.unity_topK_, batch_size); + EXPECT_EQ(search_result.total_nq_, nq_); + EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size); + EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size); + } +} + +TEST_P(CachedSearchIteratorTest, NextBatchDistBound) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + const size_t batch_size = kBatchSize; + const float dist_bound_factor = PositivelyRelated(metric_type_) ? 0.5 : 1.5; + float dist_bound = 0; + + { + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + iterator->NextBatch(search_info, search_result); + + bool found_dist_bound = false; + // use the last distance of the first query * factor as the dist bound + for (size_t j = batch_size - 1; j >= 0; --j) { + if (search_result.seg_offsets_[j] != -1) { + dist_bound = search_result.distances_[j] * dist_bound_factor; + found_dist_bound = true; + break; + } + } + ASSERT_TRUE(found_dist_bound); + + search_info.iterator_v2_info_->last_bound = dist_bound; + for (size_t rnd = 1; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) { + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + for (size_t j = 0; j < batch_size; ++j) { + if (search_result.seg_offsets_[i * batch_size + j] == -1) { + break; + } + if (PositivelyRelated(metric_type_)) { + EXPECT_LT(search_result.distances_[i * batch_size + j], + dist_bound); + } else { + EXPECT_GT(search_result.distances_[i * batch_size + j], + dist_bound); + } + } + } + } + } +} + +TEST_P(CachedSearchIteratorTest, NextBatchDistBoundEmptyResults) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + const size_t batch_size = kBatchSize; + const float dist_bound = PositivelyRelated(metric_type_) + ? std::numeric_limits::min() + : std::numeric_limits::max(); + + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + search_info.iterator_v2_info_->last_bound = dist_bound; + size_t total_cnt = 0; + for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) { + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + for (size_t j = 0; j < batch_size; ++j) { + if (search_result.seg_offsets_[i * batch_size + j] == -1) { + break; + } + ++total_cnt; + } + } + } + EXPECT_EQ(total_cnt, 0); +} + +TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadius) { + const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize; + const auto [radius, range_filter] = GetRadiusAndRangeFilter(); + SearchResult search_result; + + SearchInfo search_info = GetDefaultNormalSearchInfo(); + search_info.search_params_[knowhere::meta::RADIUS] = radius; + + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + for (size_t rnd = 0; rnd < num_rnds; ++rnd) { + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + for (size_t j = 0; j < kBatchSize; ++j) { + if (search_result.seg_offsets_[i * kBatchSize + j] == -1) { + break; + } + float dist = search_result.distances_[i * kBatchSize + j]; + if (PositivelyRelated(metric_type_)) { + ASSERT_GT(dist, radius); + } else { + ASSERT_LT(dist, radius); + } + } + } + } +} + +TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadiusAndRangeFilter) { + const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize; + const auto [radius, range_filter] = GetRadiusAndRangeFilter(); + SearchResult search_result; + + SearchInfo search_info = GetDefaultNormalSearchInfo(); + search_info.search_params_[knowhere::meta::RADIUS] = radius; + search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter; + + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + for (size_t rnd = 0; rnd < num_rnds; ++rnd) { + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + for (size_t j = 0; j < kBatchSize; ++j) { + if (search_result.seg_offsets_[i * kBatchSize + j] == -1) { + break; + } + float dist = search_result.distances_[i * kBatchSize + j]; + if (PositivelyRelated(metric_type_)) { + ASSERT_GT(dist, radius); + ASSERT_LE(dist, range_filter); + } else { + ASSERT_LT(dist, radius); + ASSERT_GE(dist, range_filter); + } + } + } + } +} + +TEST_P(CachedSearchIteratorTest, + NextBatchRangeSearchLastBoundRadiusRangeFilter) { + const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize; + const auto [radius, range_filter] = GetRadiusAndRangeFilter(); + SearchResult search_result; + const float diff = (radius + range_filter) / 2; + const std::vector last_bounds = {radius - diff, + radius, + radius + diff, + range_filter, + range_filter + diff}; + + SearchInfo search_info = GetDefaultNormalSearchInfo(); + search_info.search_params_[knowhere::meta::RADIUS] = radius; + search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter; + for (float last_bound : last_bounds) { + search_info.iterator_v2_info_->last_bound = last_bound; + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + for (size_t rnd = 0; rnd < num_rnds; ++rnd) { + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + for (size_t j = 0; j < kBatchSize; ++j) { + if (search_result.seg_offsets_[i * kBatchSize + j] == -1) { + break; + } + float dist = search_result.distances_[i * kBatchSize + j]; + if (PositivelyRelated(metric_type_)) { + ASSERT_LE(dist, last_bound); + ASSERT_GT(dist, radius); + ASSERT_LE(dist, range_filter); + } else { + ASSERT_GT(dist, last_bound); + ASSERT_LT(dist, radius); + ASSERT_GE(dist, range_filter); + } + } + } + } + } +} + +TEST_P(CachedSearchIteratorTest, NextBatchZeroBatchSize) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + search_info.iterator_v2_info_->batch_size = 0; + EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError); +} + +TEST_P(CachedSearchIteratorTest, NextBatchDiffBatchSizeComparedToInit) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + search_info.iterator_v2_info_->batch_size = kBatchSize + 1; + EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError); +} + +TEST_P(CachedSearchIteratorTest, NextBatchEmptySearchInfo) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + SearchInfo empty_search_info; + EXPECT_THROW(iterator->NextBatch(empty_search_info, search_result), + SegcoreError); +} + +TEST_P(CachedSearchIteratorTest, NextBatchEmptyIteratorV2Info) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + SearchResult search_result; + + search_info.iterator_v2_info_ = std::nullopt; + EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError); +} + +TEST_P(CachedSearchIteratorTest, NextBatchtAllBatchesNormal) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + const std::vector kBatchSizes = { + 1, 7, 43, 99, 100, 101, 1000, 1005}; + // const std::vector kBatchSizes = {1005}; + + for (size_t batch_size : kBatchSizes) { + search_info.iterator_v2_info_->batch_size = batch_size; + auto iterator = DispatchIterator(GetParam(), search_info, nullptr); + size_t total_cnt = 0; + + for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) { + SearchResult search_result; + iterator->NextBatch(search_info, search_result); + for (size_t i = 0; i < nq_; ++i) { + std::unordered_set seg_offsets; + size_t cnt = 0; + for (size_t j = 0; j < batch_size; ++j) { + if (search_result.seg_offsets_[i * batch_size + j] == -1) { + break; + } + ++cnt; + seg_offsets.insert( + search_result.seg_offsets_[i * batch_size + j]); + } + total_cnt += cnt; + // check no duplicate + EXPECT_EQ(seg_offsets.size(), cnt); + + // only check if the first distance of the first batch is 0 + if (rnd == 0) { + EXPECT_EQ(search_result.distances_[i * batch_size], 0); + } + } + EXPECT_EQ(search_result.unity_topK_, batch_size); + EXPECT_EQ(search_result.total_nq_, nq_); + EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size); + EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size); + } + if (GetParam() == ConstructorType::VectorIndex) { + EXPECT_GE(total_cnt, nb_ * nq_ * 0.9); + } else { + EXPECT_EQ(total_cnt, nb_ * nq_); + } + } +} + +TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidSearchInfo) { + EXPECT_THROW(DispatchIterator(GetParam(), SearchInfo{}, nullptr), + SegcoreError); + + EXPECT_THROW( + DispatchIterator(GetParam(), SearchInfo{.metric_type_ = ""}, nullptr), + SegcoreError); + + EXPECT_THROW( + DispatchIterator( + GetParam(), SearchInfo{.metric_type_ = kMetricType}, nullptr), + SegcoreError); + + EXPECT_THROW(DispatchIterator(GetParam(), + SearchInfo{.metric_type_ = kMetricType, + .iterator_v2_info_ = {}}, + nullptr), + SegcoreError); + + EXPECT_THROW( + DispatchIterator(GetParam(), + SearchInfo{.metric_type_ = kMetricType, + .iterator_v2_info_ = + SearchIteratorV2Info{.batch_size = 0}}, + nullptr), + SegcoreError); +} + +TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidParams) { + SearchInfo search_info = GetDefaultNormalSearchInfo(); + if (GetParam() == ConstructorType::VectorIndex) { + EXPECT_THROW(auto iterator = std::make_unique( + dynamic_cast(*index_hnsw_), + nullptr, + search_info, + nullptr), + SegcoreError); + + EXPECT_THROW(auto iterator = std::make_unique( + dynamic_cast(*index_hnsw_), + std::make_shared(), + search_info, + nullptr), + SegcoreError); + } else if (GetParam() == ConstructorType::RawData) { + EXPECT_THROW( + auto iterator = std::make_unique( + dataset::SearchDataset{}, + dataset::RawDataset{0, dim_, nb_, base_dataset_.data()}, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + } else if (GetParam() == ConstructorType::VectorBase) { + EXPECT_THROW(auto iterator = std::make_unique( + dataset::SearchDataset{}, + vector_base_.get(), + nb_, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + + EXPECT_THROW(auto iterator = std::make_unique( + search_dataset_, + nullptr, + nb_, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + + EXPECT_THROW(auto iterator = std::make_unique( + search_dataset_, + vector_base_.get(), + 0, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + } else if (GetParam() == ConstructorType::ChunkedColumn) { + EXPECT_THROW(auto iterator = std::make_unique( + nullptr, + search_dataset_, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + EXPECT_THROW(auto iterator = std::make_unique( + column_, + dataset::SearchDataset{}, + search_info, + std::map{}, + nullptr, + data_type_), + SegcoreError); + } +} + +/********* Testcases End **********/ + +static const std::vector constructor_types = { + ConstructorType::VectorIndex, + ConstructorType::RawData, + ConstructorType::VectorBase, + ConstructorType::ChunkedColumn, +}; + +INSTANTIATE_TEST_SUITE_P(CachedSearchIteratorTests, + CachedSearchIteratorTest, + ::testing::ValuesIn(constructor_types), + [](const testing::TestParamInfo& info) { + std::string constructor_type_str; + switch (info.param) { + case ConstructorType::VectorIndex: + constructor_type_str = "VectorIndex"; + break; + case ConstructorType::RawData: + constructor_type_str = "RawData"; + break; + case ConstructorType::VectorBase: + constructor_type_str = "VectorBase"; + break; + case ConstructorType::ChunkedColumn: + constructor_type_str = "ChunkedColumn"; + break; + default: + constructor_type_str = + "Unknown constructor type"; + }; + return constructor_type_str; + }); diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index 3fa2cf3b8b63c..de7af158f7bca 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -55,6 +55,12 @@ message Array { schema.DataType element_type = 3; } +message SearchIteratorV2Info { + string token = 1; + uint32 batch_size = 2; + optional float last_bound = 3; +} + message QueryInfo { int64 topk = 1; string metric_type = 3; @@ -66,6 +72,7 @@ message QueryInfo { bool strict_group_size = 9; double bm25_avgdl = 10; int64 query_field_id =11; + optional SearchIteratorV2Info search_iterator_v2_info = 12; } message ColumnInfo { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8486526dd620e..418f25937bf0f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -26,6 +26,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/google/uuid" "github.com/hashicorp/golang-lru/v2/expirable" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/atomic" @@ -303,6 +304,13 @@ func (node *Proxy) Init() error { node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool() + // Enable internal rand pool for UUIDv4 generation + // This is NOT thread-safe and should only be called before the service starts and + // there is no possibility that New or any other UUID V4 generation function will be called concurrently + // Only proxy generates UUID for now, and one Milvus process only has one proxy + uuid.EnableRandPool() + log.Debug("enable rand pool for UUIDv4 generation") + log.Info("init proxy done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address)) return nil } diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index cea3dc49c63a6..c4e4b992fb771 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/cockroachdb/errors" + "github.com/google/uuid" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -82,6 +83,75 @@ type SearchInfo struct { isIterator bool } +func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) { + isIteratorV2Str, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterV2Key, searchParamsPair) + isIteratorV2, _ := strconv.ParseBool(isIteratorV2Str) + if !isIteratorV2 { + return nil, nil + } + + // iteratorV1 and iteratorV2 should be set together for compatibility + if !isIterator { + return nil, fmt.Errorf("both %s and %s must be set in the SDK", IteratorField, SearchIterV2Key) + } + + // disable groupBy when doing iteratorV2 + // same behavior with V1 + if isIteratorV2 && groupByFieldId > 0 { + return nil, merr.WrapErrParameterInvalid("", "", + "Not allowed to groupBy when performing search iterator") + } + + // parse token, generate if not exist + token, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterTokenKey, searchParamsPair) + if token == "" { + generatedToken, err := uuid.NewRandom() + if err != nil { + return nil, err + } + token = generatedToken.String() + } else { + // Validate existing token is a valid UUID + if _, err := uuid.Parse(token); err != nil { + return nil, fmt.Errorf("invalid token format") + } + } + + // parse batch size, required non-zero value + batchSizeStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterBatchSizeKey, searchParamsPair) + if batchSizeStr == "" { + return nil, fmt.Errorf("batch size is required") + } + batchSize, err := strconv.ParseInt(batchSizeStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("batch size is invalid, %w", err) + } + // use the same validation logic as topk + if err := validateLimit(batchSize); err != nil { + return nil, fmt.Errorf("batch size is invalid, %w", err) + } + *queryTopK = batchSize // for compatibility + + // prepare plan iterator v2 info proto + planIteratorV2Info := &planpb.SearchIteratorV2Info{ + Token: token, + BatchSize: uint32(batchSize), + } + + // append optional last bound if applicable + lastBoundStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterLastBoundKey, searchParamsPair) + if lastBoundStr != "" { + lastBound, err := strconv.ParseFloat(lastBoundStr, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse input last bound, %w", err) + } + lastBoundFloat32 := float32(lastBound) + planIteratorV2Info.LastBound = &lastBoundFloat32 // escape pointer + } + + return planIteratorV2Info, nil +} + // parseSearchInfo returns QueryInfo and offset func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo { var topK int64 @@ -191,15 +261,21 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb "Not allowed to do range-search when doing search-group-by")} } + planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, &queryTopK) + if err != nil { + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("parse iterator v2 info failed: %w", err)} + } + return &SearchInfo{ planInfo: &planpb.QueryInfo{ - Topk: queryTopK, - MetricType: metricType, - SearchParams: searchParamStr, - RoundDecimal: roundDecimal, - GroupByFieldId: groupByFieldId, - GroupSize: groupSize, - StrictGroupSize: strictGroupSize, + Topk: queryTopK, + MetricType: metricType, + SearchParams: searchParamStr, + RoundDecimal: roundDecimal, + GroupByFieldId: groupByFieldId, + GroupSize: groupSize, + StrictGroupSize: strictGroupSize, + SearchIteratorV2Info: planSearchIteratorV2Info, }, offset: offset, isIterator: isIterator, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 07b6cf6f864e1..42ebc280fa3ce 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -67,6 +67,11 @@ const ( OffsetKey = "offset" LimitKey = "limit" + SearchIterV2Key = "search_iter_v2" + SearchIterBatchSizeKey = "search_iter_batch_size" + SearchIterLastBoundKey = "search_iter_last_bound" + SearchIterTokenKey = "search_iter_token" + InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" DropCollectionTaskName = "DropCollectionTask" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 3dc48cfe9503c..91ebf619b8f63 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -588,12 +589,15 @@ func (t *searchTask) Execute(ctx context.Context) error { return nil } -func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) { +func getMetricType(toReduceResults []*internalpb.SearchResults) string { metricType := "" if len(toReduceResults) >= 1 { metricType = toReduceResults[0].GetMetricType() } + return metricType +} +func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults") defer sp.End() @@ -628,6 +632,24 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter return result, nil } +// find the last bound based on reduced results and metric type +// only support nq == 1, for search iterator v2 +func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 { + len := len(result.Results.Scores) + if len > 0 && result.GetResults().GetNumQueries() == 1 { + return result.Results.Scores[len-1] + } + // if no results found and incoming last bound is not nil, return it + if incomingLastBound != nil { + return *incomingLastBound + } + // if no results found and it is the first call, return the closest bound + if metric.PositivelyRelated(metricType) { + return math.MaxFloat32 + } + return -math.MaxFloat32 +} + func (t *searchTask) PostExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute") defer sp.End() @@ -663,6 +685,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return err } + metricType := getMetricType(toReduceResults) // reduce if t.SearchRequest.GetIsAdvanced() { multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs())) @@ -689,16 +712,12 @@ func (t *searchTask) PostExecute(ctx context.Context) error { multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs())) for index, internalResults := range multipleInternalResults { subReq := t.SearchRequest.GetSubReqs()[index] - - metricType := "" - if len(internalResults) >= 1 { - metricType = internalResults[0].GetMetricType() - } - result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true) + subMetricType := getMetricType(internalResults) + result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true) if err != nil { return err } - t.reScorers[index].setMetricType(metricType) + t.reScorers[index].setMetricType(subMetricType) t.reScorers[index].reScore(result) multipleMilvusResults[index] = result } @@ -714,7 +733,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return err } } else { - t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false) + t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false) if err != nil { return err } @@ -743,6 +762,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } t.result.Results.OutputFields = t.userOutputFields t.result.CollectionName = t.request.GetCollectionName() + if t.isIterator && len(t.queryInfos) == 1 && t.queryInfos[0] != nil { + if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil { + t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{ + Token: iterInfo.GetToken(), + LastBound: getLastBound(t.result, iterInfo.LastBound, metricType), + } + } + } if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 { // first page for iteration, need to set up sessionTs for iterator t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs()) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 3d89adcae8f72..0480b45f3e457 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -18,12 +18,14 @@ package proxy import ( "context" "fmt" + "math" "strconv" "strings" "testing" "time" "github.com/cockroachdb/errors" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -103,9 +105,124 @@ func TestSearchTask_PostExecute(t *testing.T) { assert.Equal(t, qt.resultSizeInsufficient, true) assert.Equal(t, qt.isTopkReduce, false) }) + + t.Run("test search iterator v2", func(t *testing.T) { + const ( + kRows = 10 + kToken = "test-token" + ) + + collName := "test_collection_search_iterator_v2" + funcutil.GenRandomStr() + collSchema := createColl(t, collName, rc) + + createIteratorSearchTask := func(t *testing.T, metricType string, rows int) *searchTask { + ids := make([]int64, rows) + for i := range ids { + ids[i] = int64(i) + } + resultIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + } + scores := make([]float32, rows) + // proxy needs to reverse the score for negatively related metrics + for i := range scores { + if metric.PositivelyRelated(metricType) { + scores[i] = float32(len(scores) - i) + } else { + scores[i] = -float32(i + 1) + } + } + resultData := &schemapb.SearchResultData{ + Ids: resultIDs, + Scores: scores, + NumQueries: 1, + } + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + Nq: 1, + }, + schema: newSchemaInfo(collSchema), + request: &milvuspb.SearchRequest{ + CollectionName: collName, + }, + queryInfos: []*planpb.QueryInfo{{ + SearchIteratorV2Info: &planpb.SearchIteratorV2Info{ + Token: kToken, + BatchSize: 1, + }, + }}, + result: &milvuspb.SearchResults{ + Results: resultData, + }, + resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](), + tr: timerecord.NewTimeRecorder("search"), + isIterator: true, + } + bytes, err := proto.Marshal(resultData) + assert.NoError(t, err) + qt.resultBuf.Insert(&internalpb.SearchResults{ + MetricType: metricType, + SlicedBlob: bytes, + }) + return qt + } + + t.Run("test search iterator v2", func(t *testing.T) { + metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25} + for _, metricType := range metrics { + qt := createIteratorSearchTask(t, metricType, kRows) + err = qt.PostExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token) + if metric.PositivelyRelated(metricType) { + assert.Equal(t, float32(1), qt.result.Results.SearchIteratorV2Results.LastBound) + } else { + assert.Equal(t, float32(kRows), qt.result.Results.SearchIteratorV2Results.LastBound) + } + } + }) + + t.Run("test search iterator v2 with empty result", func(t *testing.T) { + metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25} + for _, metricType := range metrics { + qt := createIteratorSearchTask(t, metricType, 0) + err = qt.PostExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token) + if metric.PositivelyRelated(metricType) { + assert.Equal(t, float32(math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound) + } else { + assert.Equal(t, float32(-math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound) + } + } + }) + + t.Run("test search iterator v2 with empty result and incoming last bound", func(t *testing.T) { + metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25} + kLastBound := float32(10) + for _, metricType := range metrics { + qt := createIteratorSearchTask(t, metricType, 0) + qt.queryInfos[0].SearchIteratorV2Info.LastBound = &kLastBound + err = qt.PostExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token) + assert.Equal(t, kLastBound, qt.result.Results.SearchIteratorV2Results.LastBound) + } + }) + }) } -func createColl(t *testing.T, name string, rc types.RootCoordClient) { +func createColl(t *testing.T, name string, rc types.RootCoordClient) *schemapb.CollectionSchema { schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name) marshaledSchema, err := proto.Marshal(schema) require.NoError(t, err) @@ -126,6 +243,8 @@ func createColl(t *testing.T, name string, rc types.RootCoordClient) { require.NoError(t, createColT.PreExecute(ctx)) require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.PostExecute(ctx)) + + return schema } func getBaseSearchParams() []*commonpb.KeyValuePair { @@ -2599,6 +2718,146 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size")) } }) + + t.Run("check search iterator v2", func(t *testing.T) { + kBatchSize := uint32(10) + generateValidParamsForSearchIteratorV2 := func() []*commonpb.KeyValuePair { + param := getValidSearchParams() + return append(param, + &commonpb.KeyValuePair{ + Key: SearchIterV2Key, + Value: "True", + }, + &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }, + &commonpb.KeyValuePair{ + Key: SearchIterBatchSizeKey, + Value: fmt.Sprintf("%d", kBatchSize), + }, + ) + } + + t.Run("iteratorV2 normal", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + searchInfo := parseSearchInfo(param, nil, nil) + assert.NoError(t, searchInfo.parseError) + assert.NotNil(t, searchInfo.planInfo) + assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token) + assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize) + assert.Len(t, searchInfo.planInfo.SearchIteratorV2Info.Token, 36) + assert.Equal(t, int64(kBatchSize), searchInfo.planInfo.GetTopk()) // compatibility + }) + + t.Run("iteratorV2 without isIterator", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + resetSearchParamsValue(param, IteratorField, "False") + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "both") + }) + + t.Run("iteratorV2 with groupBy", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + param = append(param, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "string_field", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(param, schema, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "groupBy") + }) + + t.Run("iteratorV2 invalid token", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + param = append(param, &commonpb.KeyValuePair{ + Key: SearchIterTokenKey, + Value: "invalid_token", + }) + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "invalid token format") + }) + + t.Run("iteratorV2 passed token must be same", func(t *testing.T) { + token, err := uuid.NewRandom() + assert.NoError(t, err) + param := generateValidParamsForSearchIteratorV2() + param = append(param, &commonpb.KeyValuePair{ + Key: SearchIterTokenKey, + Value: token.String(), + }) + searchInfo := parseSearchInfo(param, nil, nil) + assert.NoError(t, searchInfo.parseError) + assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token) + assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token) + }) + + t.Run("iteratorV2 batch size", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123") + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid") + }) + + t.Run("iteratorV2 batch size", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + resetSearchParamsValue(param, SearchIterBatchSizeKey, "") + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "batch size is required") + }) + + t.Run("iteratorV2 batch size negative", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1") + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid") + }) + + t.Run("iteratorV2 batch size too large", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1)) + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid") + }) + + t.Run("iteratorV2 last bound", func(t *testing.T) { + kLastBound := float32(1.123) + param := generateValidParamsForSearchIteratorV2() + param = append(param, &commonpb.KeyValuePair{ + Key: SearchIterLastBoundKey, + Value: fmt.Sprintf("%f", kLastBound), + }) + searchInfo := parseSearchInfo(param, nil, nil) + assert.NoError(t, searchInfo.parseError) + assert.NotNil(t, searchInfo.planInfo) + assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound) + }) + + t.Run("iteratorV2 invalid last bound", func(t *testing.T) { + param := generateValidParamsForSearchIteratorV2() + param = append(param, &commonpb.KeyValuePair{ + Key: SearchIterLastBoundKey, + Value: "xxx", + }) + searchInfo := parseSearchInfo(param, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.ErrorContains(t, searchInfo.parseError, "failed to parse input last bound") + }) + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {